Source code for tltorch.factorized_tensors.init

"""Module for initializing tensor decompositions
"""

# Author: Jean Kossaifi
# License: BSD 3 clause

import torch
import math
import numpy as np

import tensorly as tl
tl.set_backend('pytorch')

[docs] def tensor_init(tensor, std=0.02): """Initializes directly the parameters of a factorized tensor so the reconstruction has the specified standard deviation and 0 mean Parameters ---------- tensor : torch.Tensor or FactorizedTensor std : float, default is 0.02 the desired standard deviation of the full (reconstructed) tensor """ from .factorized_tensors import FactorizedTensor if isinstance(tensor, FactorizedTensor): tensor.normal_(0, std) elif torch.is_tensor(tensor): tensor.normal_(0, std) else: raise ValueError(f'Got tensor of class {tensor.__class__.__name__} but expected torch.Tensor or FactorizedWeight.')
[docs] def cp_init(cp_tensor, std=0.02): """Initializes directly the weights and factors of a CP decomposition so the reconstruction has the specified std and 0 mean Parameters ---------- cp_tensor : CPTensor std : float, default is 0.02 the desired standard deviation of the full (reconstructed) tensor Notes ----- We assume the given (weights, factors) form a correct CP decomposition, no checks are done here. """ rank = cp_tensor.rank # We assume we are given a valid CP order = cp_tensor.orders std_factors = (std/math.sqrt(rank))**(1/order) with torch.no_grad(): cp_tensor.weights.fill_(1) for factor in cp_tensor.factors: factor.normal_(0, std_factors) return cp_tensor
[docs] def tucker_init(tucker_tensor, std=0.02): """Initializes directly the weights and factors of a Tucker decomposition so the reconstruction has the specified std and 0 mean Parameters ---------- tucker_tensor : TuckerTensor std : float, default is 0.02 the desired standard deviation of the full (reconstructed) tensor Notes ----- We assume the given (core, factors) form a correct Tucker decomposition, no checks are done here. """ order = tucker_tensor.order rank = tucker_tensor.rank r = np.prod([math.sqrt(r) for r in rank]) std_factors = (std/r)**(1/(order+1)) with torch.no_grad(): tucker_tensor.core.normal_(0, std_factors) for factor in tucker_tensor.factors: factor.normal_(0, std_factors) return tucker_tensor
[docs] def tt_init(tt_tensor, std=0.02): """Initializes directly the weights and factors of a TT decomposition so the reconstruction has the specified std and 0 mean Parameters ---------- tt_tensor : TTTensor std : float, default is 0.02 the desired standard deviation of the full (reconstructed) tensor Notes ----- We assume the given factors form a correct TT decomposition, no checks are done here. """ order = tt_tensor.order r = np.prod(tt_tensor.rank) std_factors = (std/r)**(1/order) with torch.no_grad(): for factor in tt_tensor.factors: factor.normal_(0, std_factors) return tt_tensor
[docs] def block_tt_init(block_tt, std=0.02): """Initializes directly the weights and factors of a BlockTT decomposition so the reconstruction has the specified std and 0 mean Parameters ---------- block_tt : Matrix in the tensor-train format std : float, default is 0.02 the desired standard deviation of the full (reconstructed) tensor Notes ----- We assume the given factors form a correct Block-TT decomposition, no checks are done here. """ return tt_init(block_tt, std=std)