import torch
import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.tensorized_matrices import TuckerTensorized, DenseTensorized, CPTensorized, BlockTT
from .complex_factorized_tensors import ComplexHandler
# Author: Jean Kossaifi
# License: BSD 3 clause
[docs]
class ComplexDenseTensorized(ComplexHandler, DenseTensorized, name='ComplexDense'):
"""Complex DenseTensorized Factorization
"""
_complex_params = ['tensor']
[docs]
@classmethod
def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
[docs]
class ComplexTuckerTensorized(ComplexHandler, TuckerTensorized, name='ComplexTucker'):
"""Complex TuckerTensorized Factorization
"""
_complex_params = ['core', 'factors']
[docs]
@classmethod
def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
[docs]
class ComplexBlockTT(ComplexHandler, BlockTT, name='ComplexTT'):
"""Complex BlockTT Factorization
"""
_complex_params = ['factors']
[docs]
@classmethod
def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)
[docs]
class ComplexCPTensorized(ComplexHandler, CPTensorized, name='ComplexCP'):
"""Complex Tensorized CP Factorization
"""
_complex_params = ['weights', 'factors']
[docs]
@classmethod
def new(cls, tensorized_shape, rank=None, device=None, dtype=torch.cfloat, **kwargs):
return super().new(tensorized_shape, rank, device=device, dtype=dtype, **kwargs)