Source code for tltorch.factorized_tensors.complex_factorized_tensors


import torch
from torch import nn

import tensorly as tl
tl.set_backend('pytorch')
from tltorch.factorized_tensors.factorized_tensors import TuckerTensor, CPTensor, TTTensor, DenseTensor
from tltorch.utils.parameter_list import FactorList, ComplexFactorList


# Author: Jean Kossaifi
# License: BSD 3 clause

class ComplexHandler():
    def __setattr__(self, key, value):
        if isinstance(value, (FactorList)):
            value = ComplexFactorList(value)
            super().__setattr__(key, value)
            
        elif isinstance(value, nn.Parameter):
            self.register_parameter(key, value)
        elif torch.is_tensor(value):
            self.register_buffer(key, value)
        else:
            super().__setattr__(key, value)

    def __getattr__(self, key):
        value = super().__getattr__(key)
        if torch.is_tensor(value):
            value = torch.view_as_complex(value)
        return value

    def register_parameter(self, key, value):
        value = nn.Parameter(torch.view_as_real(value))
        super().register_parameter(key, value)

    def register_buffer(self, key, value):
        value = torch.view_as_real(value)
        super().register_buffer(key, value)


[docs]class ComplexDenseTensor(ComplexHandler, DenseTensor, name='ComplexDense'): """Complex Dense Factorization """
[docs] @classmethod def new(cls, shape, rank=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(shape, rank, device=device, dtype=dtype, **kwargs)
[docs]class ComplexTuckerTensor(ComplexHandler, TuckerTensor, name='ComplexTucker'): """Complex Tucker Factorization """
[docs] @classmethod def new(cls, shape, rank='same', fixed_rank_modes=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(shape, rank, fixed_rank_modes=fixed_rank_modes, device=device, dtype=dtype, **kwargs)
[docs]class ComplexTTTensor(ComplexHandler, TTTensor, name='ComplexTT'): """Complex TT Factorization """
[docs] @classmethod def new(cls, shape, rank='same', fixed_rank_modes=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(shape, rank, device=device, dtype=dtype, **kwargs)
[docs]class ComplexCPTensor(ComplexHandler, CPTensor, name='ComplexCP'): """Complex CP Factorization """
[docs] @classmethod def new(cls, shape, rank='same', fixed_rank_modes=None, device=None, dtype=torch.cfloat, **kwargs): return super().new(shape, rank, device=device, dtype=dtype, **kwargs)