Source code for tltorch.init
"""Module for initializing tensor decompositions
"""
# Author: Jean Kossaifi
# License: BSD 3 clause
import torch
import math
import numpy as np
import tensorly as tl
tl.set_backend('pytorch')
[docs]def cp_init(weights, factors, std=0.02):
"""Initializes directly the weights and factors of a CP decomposition so the reconstruction has the specified std and 0 mean
Parameters
----------
weights : 1D tensor
factors : list of 2D factors of size (dim_i, rank)
std : float, default is 0.02
the desired standard deviation of the full (reconstructed) tensor
Notes
-----
We assume the given (weights, factors) form a correct CP decomposition, no checks are done here.
"""
rank = factors[0].shape[1] # We assume we are given a valid CP
order = len(factors)
std_factors = (std/math.sqrt(rank))**(1/order)
with torch.no_grad():
weights.fill_(1)
for i in range(len(factors)):
factors[i].normal_(0, std_factors)
return weights, factors
[docs]def tucker_init(core, factors, std=0.02):
"""Initializes directly the weights and factors of a Tucker decomposition so the reconstruction has the specified std and 0 mean
Parameters
----------
weights : 1D tensor
factors : list of 2D factors of size (dim_i, rank)
std : float, default is 0.02
the desired standard deviation of the full (reconstructed) tensor
Notes
-----
We assume the given (core, factors) form a correct Tucker decomposition, no checks are done here.
"""
order = len(factors)
rank = tl.shape(core)
r = np.prod([math.sqrt(r) for r in rank])
std_factors = (std/r)**(1/(order+1))
with torch.no_grad():
core.normal_(0, std_factors)
for i in range(len(factors)):
factors[i].normal_(0, std_factors)
return core, factors
[docs]def tt_init(factors, std=0.02):
"""Initializes directly the weights and factors of a TT decomposition so the reconstruction has the specified std and 0 mean
Parameters
----------
weights : 1D tensor
factors : list of 2D factors of size (dim_i, rank)
std : float, default is 0.02
the desired standard deviation of the full (reconstructed) tensor
Notes
-----
We assume the given factors form a correct TT decomposition, no checks are done here.
"""
order = len(factors)
r = np.prod([math.sqrt(f.shape[2]) for f in factors[:-1]])
std_factors = (std/r)**(1/order)
with torch.no_grad():
for i in range(len(factors)):
factors[i].normal_(0, std_factors)
return factors
def tt_matrix_init(factors, std=0.02):
"""Initializes directly the weights and factors of a TT-Matrix decomposition so the reconstruction has the specified std and 0 mean
Parameters
----------
weights : 1D tensor
factors : list of 2D factors of size (dim_i, rank)
std : float, default is 0.02
the desired standard deviation of the full (reconstructed) tensor
Notes
-----
We assume the given factors form a correct TT-Matrix decomposition, no checks are done here.
"""
return tt_init(factors, std=std)