Source code for tltorch.factorized_tensors.complex_tensorized_matrices

import torch
import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.tensorized_matrices import TuckerTensorized, DenseTensorized, CPTensorized, BlockTT
from .complex_factorized_tensors import ComplexHandler

# Author: Jean Kossaifi
# License: BSD 3 clause


[docs]class ComplexDenseTensorized(ComplexHandler, DenseTensorized, name='ComplexDense'): """Complex DenseTensorized Factorization """ _complex_params = ['tensor']
[docs] @classmethod def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
[docs]class ComplexTuckerTensorized(ComplexHandler, TuckerTensorized, name='ComplexTucker'): """Complex TuckerTensorized Factorization """ _complex_params = ['core', 'factors']
[docs] @classmethod def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
[docs]class ComplexBlockTT(ComplexHandler, BlockTT, name='ComplexTT'): """Complex BlockTT Factorization """ _complex_params = ['factors']
[docs] @classmethod def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
[docs]class ComplexCPTensorized(ComplexHandler, CPTensorized, name='ComplexCP'): """Complex Tensorized CP Factorization """ _complex_params = ['weights', 'factors']
[docs] @classmethod def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)