Source code for tensorly.decomposition._tr_svd

import tensorly as tl
from ._base_decomposition import DecompositionMixin
from ..tr_tensor import validate_tr_rank, TRTensor
from ..tenalg.svd import svd_interface


[docs] def tensor_ring(input_tensor, rank, mode=0, svd="truncated_svd", verbose=False): """Tensor Ring decomposition via recursive SVD Decomposes `input_tensor` into a sequence of order-3 tensors (factors) [1]_. Parameters ---------- input_tensor : tensorly.tensor rank : Union[int, List[int]] maximum allowable TR rank of the factors if int, then this is the same for all the factors if int list, then rank[k] is the rank of the kth factor mode : int, default is 0 index of the first factor to compute svd : str, default is 'truncated_svd' function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS verbose : boolean, optional level of verbosity Returns ------- factors : TR factors order-3 tensors of the TR decomposition References ---------- .. [1] Qibin Zhao et al. "Tensor Ring Decomposition" arXiv preprint arXiv:1606.05535, (2016). """ rank = validate_tr_rank(tl.shape(input_tensor), rank=rank) n_dim = len(input_tensor.shape) # Change order if mode: order = tuple(range(mode, n_dim)) + tuple(range(mode)) input_tensor = tl.transpose(input_tensor, order) rank = rank[mode:] + rank[:mode] tensor_size = input_tensor.shape factors = [None] * n_dim # Getting the first factor unfolding = tl.reshape(input_tensor, (tensor_size[0], -1)) n_row, n_column = unfolding.shape if rank[0] * rank[1] > min(n_row, n_column): raise ValueError( f"rank[{mode}] * rank[{mode + 1}] = {rank[0] * rank[1]} is larger than " f"first matricization dimension {n_row}×{n_column}.\n" "Failed to compute first factor with specified rank. " "Reduce specified ranks or change first matricization `mode`." ) # SVD of unfolding matrix U, S, V = svd_interface(unfolding, n_eigenvecs=rank[0] * rank[1], method=svd) # Get first TR factor factor = tl.reshape(U, (tensor_size[0], rank[0], rank[1])) factors[0] = tl.transpose(factor, (1, 0, 2)) if verbose is True: print("TR factor " + str(mode) + " computed with shape " + str(factor.shape)) # Get new unfolding matrix for the remaining factors unfolding = tl.reshape(S, (-1, 1)) * V unfolding = tl.reshape(unfolding, (rank[0], rank[1], -1)) unfolding = tl.transpose(unfolding, (1, 2, 0)) # Getting the TR factors up to n_dim - 1 for k in range(1, n_dim - 1): # Reshape the unfolding matrix of the remaining factors n_row = int(rank[k] * tensor_size[k]) unfolding = tl.reshape(unfolding, (n_row, -1)) # SVD of unfolding matrix n_row, n_column = unfolding.shape current_rank = min(n_row, n_column, rank[k + 1]) U, S, V = svd_interface(unfolding, n_eigenvecs=current_rank, method=svd) rank[k + 1] = current_rank # Get kth TR factor factors[k] = tl.reshape(U, (rank[k], tensor_size[k], rank[k + 1])) if verbose is True: print( "TR factor " + str((mode + k) % n_dim) + " computed with shape " + str(factors[k].shape) ) # Get new unfolding matrix for the remaining factors unfolding = tl.reshape(S, (-1, 1)) * V # Getting the last factor prev_rank = unfolding.shape[0] factors[-1] = tl.reshape(unfolding, (prev_rank, -1, rank[0])) if verbose is True: print( "TR factor " + str((mode - 1) % n_dim) + " computed with shape " + str(factors[-1].shape) ) # Reorder factors to match input if mode: factors = factors[-mode:] + factors[:-mode] return TRTensor(factors)
[docs] class TensorRing(DecompositionMixin): """Tensor Ring decomposition via recursive SVD Decomposes `input_tensor` into a sequence of order-3 tensors (factors) [1]_. Parameters ---------- input_tensor : tensorly.tensor rank : Union[int, List[int]] maximum allowable TR rank of the factors if int, then this is the same for all the factors if int list, then rank[k] is the rank of the kth factor mode : int, default is 0 index of the first factor to compute svd : str, default is 'truncated_svd' function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS verbose : boolean, optional level of verbosity Returns ------- factors : TR factors order-3 tensors of the TR decomposition References ---------- .. [1] Qibin Zhao et al. "Tensor Ring Decomposition" arXiv preprint arXiv:1606.05535, (2016). """ def __init__(self, rank, mode=0, svd="truncated_svd", verbose=False): self.rank = rank self.mode = mode self.svd = svd self.verbose = verbose def fit_transform(self, tensor): self.decomposition_ = tensor_ring( tensor, rank=self.rank, mode=self.mode, svd=self.svd, verbose=self.verbose ) return self.decomposition_