Source code for tltorch.factorized_layers.tensor_regression_layers

"""Tensor Regression Layers
"""

# Author: Jean Kossaifi
# License: BSD 3 clause

import torch
import torch.nn as nn

import tensorly as tl
tl.set_backend('pytorch')
from ..functional.tensor_regression import trl

from ..factorized_tensors import FactorizedTensor

[docs]class TRL(nn.Module): """Tensor Regression Layers Parameters ---------- input_shape : int iterable shape of the input, excluding batch size output_shape : int iterable shape of the output, excluding batch size verbose : int, default is 0 level of verbosity References ---------- .. [1] Tensor Regression Networks, Jean Kossaifi, Zachary C. Lipton, Arinbjorn Kolbeinsson, Aran Khanna, Tommaso Furlanello, Anima Anandkumar, JMLR, 2020. """ def __init__(self, input_shape, output_shape, bias=False, verbose=0, factorization='cp', rank='same', n_layers=1, device=None, dtype=None, **kwargs): super().__init__(**kwargs) self.verbose = verbose if isinstance(input_shape, int): self.input_shape = (input_shape, ) else: self.input_shape = tuple(input_shape) if isinstance(output_shape, int): self.output_shape = (output_shape, ) else: self.output_shape = tuple(output_shape) self.n_input = len(self.input_shape) self.n_output = len(self.output_shape) self.weight_shape = self.input_shape + self.output_shape self.order = len(self.weight_shape) if bias: self.bias = nn.Parameter(torch.empty(self.output_shape, device=device, dtype=dtype)) else: self.bias = None if n_layers == 1: factorization_shape = self.weight_shape elif isinstance(n_layers, int): factorization_shape = (n_layers, ) + self.weight_shape elif isinstance(n_layers, tuple): factorization_shape = n_layers + self.weight_shape if isinstance(factorization, FactorizedTensor): self.weight = factorization.to(device).to(dtype) else: self.weight = FactorizedTensor.new(factorization_shape, rank=rank, factorization=factorization, device=device, dtype=dtype) self.init_from_random() self.factorization = self.weight.name
[docs] def forward(self, x): """Performs a forward pass""" return trl(x, self.weight, bias=self.bias)
[docs] def init_from_random(self, decompose_full_weight=False): """Initialize the module randomly Parameters ---------- decompose_full_weight : bool, default is False if True, constructs a full weight tensor and decomposes it to initialize the factors otherwise, the factors are directly initialized randomlys """ with torch.no_grad(): if decompose_full_weight: full_weight = torch.normal(0.0, 0.02, size=self.weight_shape) self.weight.init_from_tensor(full_weight) else: self.weight.normal_() if self.bias is not None: self.bias.uniform_(-1, 1)
[docs] def init_from_linear(self, linear, unsqueezed_modes=None, **kwargs): """Initialise the TRL from the weights of a fully connected layer Parameters ---------- linear : torch.nn.Linear unsqueezed_modes : int list or None For Tucker factorization, this allows to replace pooling layers and instead learn the average pooling for the specified modes ("unsqueezed_modes"). **for factorization='Tucker' only** """ if unsqueezed_modes is not None: if self.factorization != 'Tucker': raise ValueError(f'unsqueezed_modes is only supported for factorization="tucker" but factorization is {self.factorization}.') unsqueezed_modes = sorted(unsqueezed_modes) weight_shape = list(self.weight_shape) for mode in unsqueezed_modes[::-1]: if mode == 0: raise ValueError(f'Cannot learn pooling for mode-0 (channels).') if mode > self.n_input: msg = 'Can only learn pooling for the input tensor. ' msg += f'The input has only {self.n_input} modes, yet got a unsqueezed_mode for mode {mode}.' raise ValueError(msg) weight_shape.pop(mode) kwargs['unsqueezed_modes'] = unsqueezed_modes else: weight_shape = self.weight_shape with torch.no_grad(): weight = torch.t(linear.weight).contiguous().view(weight_shape) self.weight.init_from_tensor(weight, **kwargs) if self.bias is not None: self.bias.data = linear.bias.data