Source code for tltorch.factorized_layers.factorized_linear

import math
import numpy as np
from torch import nn
import torch
from torch.utils import checkpoint

from ..functional import factorized_linear
from ..factorized_tensors import TensorizedTensor
from tltorch.utils import get_tensorized_shape

# Author: Jean Kossaifi
# License: BSD 3 clause


[docs]class FactorizedLinear(nn.Module): """Tensorized Fully-Connected Layers The weight matrice is tensorized to a tensor of size `(*in_tensorized_features, *out_tensorized_features)`. That tensor is expressed as a low-rank tensor. During inference, the full tensor is reconstructed, and unfolded back into a matrix, used for the forward pass in a regular linear layer. Parameters ---------- in_tensorized_features : int tuple shape to which the input_features dimension is tensorized to e.g. if in_features is 8 in_tensorized_features could be (2, 2, 2) should verify prod(in_tensorized_features) = in_features out_tensorized_features : int tuple shape to which the input_features dimension is tensorized to. factorization : str, default is 'cp' rank : int tuple or str implementation : {'factorized', 'reconstructed'}, default is 'factorized' which implementation to use for forward function: - if 'factorized', will directly contract the input with the factors of the decomposition - if 'reconstructed', the full weight matrix is reconstructed from the factorized version and used for a regular linear layer forward pass. n_layers : int, default is 1 number of linear layers to be parametrized with a single factorized tensor bias : bool, default is True checkpointing : bool whether to enable gradient checkpointing to save memory during training-mode forward, default is False device : PyTorch device to use, default is None dtype : PyTorch dtype, default is None """ def __init__(self, in_tensorized_features, out_tensorized_features, bias=True, factorization='cp', rank='same', implementation='factorized', n_layers=1, checkpointing=False, device=None, dtype=None): super().__init__() if factorization == 'TTM' and n_layers != 1: raise ValueError(f'TTM factorization only support single factorized layers but got n_layers={n_layers}.') self.in_features = np.prod(in_tensorized_features) self.out_features = np.prod(out_tensorized_features) self.in_tensorized_features = in_tensorized_features self.out_tensorized_features = out_tensorized_features self.tensorized_shape = out_tensorized_features + in_tensorized_features self.weight_shape = (self.out_features, self.in_features) self.input_rank = rank self.implementation = implementation self.checkpointing = checkpointing if bias: if n_layers == 1: self.bias = nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) self.has_bias = True else: self.bias = nn.Parameter(torch.empty((n_layers, self.out_features), device=device, dtype=dtype)) self.has_bias = np.zeros(n_layers) else: self.register_parameter('bias', None) self.rank = rank self.n_layers = n_layers if n_layers > 1: tensor_shape = (n_layers, out_tensorized_features, in_tensorized_features) else: tensor_shape = (out_tensorized_features, in_tensorized_features) if isinstance(factorization, TensorizedTensor): self.weight = factorization.to(device).to(dtype) else: self.weight = TensorizedTensor.new(tensor_shape, rank=rank, factorization=factorization, device=device, dtype=dtype) self.reset_parameters() self.rank = self.weight.rank def reset_parameters(self): with torch.no_grad(): self.weight.normal_(0, math.sqrt(5)/math.sqrt(self.in_features)) if self.bias is not None: fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) torch.nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, x, indices=0): if self.n_layers == 1: if indices == 0: weight, bias = self.weight(), self.bias else: raise ValueError(f'Only one convolution was parametrized (n_layers=1) but tried to access {indices}.') elif isinstance(self.n_layers, int): if not isinstance(indices, int): raise ValueError(f'Expected indices to be in int but got indices={indices}' f', but this conv was created with n_layers={self.n_layers}.') weight = self.weight(indices) bias = self.bias[indices] if self.bias is not None else None elif len(indices) != len(self.n_layers): raise ValueError(f'Got indices={indices}, but this conv was created with n_layers={self.n_layers}.') else: weight = self.weight(indices) bias = self.bias[indices] if self.bias is not None else None def _inner_forward(x): # move weight() out to avoid register_hooks from being executed twice during recomputation return factorized_linear(x, weight, bias=bias, in_features=self.in_features, implementation=self.implementation) if self.checkpointing and x.requires_grad: x = checkpoint.checkpoint(_inner_forward, x) else: x = _inner_forward(x) return x
def get_linear(self, indices): if self.n_layers == 1: raise ValueError('A single linear is parametrized, directly use the main class.') return SubFactorizedLinear(self, indices) def __getitem__(self, indices): return self.get_linear(indices)
[docs] @classmethod def from_linear(cls, linear, rank='same', auto_tensorize=True, n_tensorized_modes=3, in_tensorized_features=None, out_tensorized_features=None, bias=True, factorization='CP', implementation="reconstructed", checkpointing=False, decomposition_kwargs=dict(), verbose=False): """Class method to create an instance from an existing linear layer Parameters ---------- linear : torch.nn.Linear layer to tensorize auto_tensorize : bool, default is True if True, automatically find values for the tensorized_shapes n_tensorized_modes : int, default is 3 Order (number of dims) of the tensorized weights if auto_tensorize is True in_tensorized_features, out_tensorized_features : tuple shape to tensorized the factorized_weight matrix to. Must verify np.prod(tensorized_shape) == np.prod(linear.factorized_weight.shape) factorization : str, default is 'cp' implementation : str which implementation to use for forward function. support 'factorized' and 'reconstructed', default is 'factorized' checkpointing : bool whether to enable gradient checkpointing to save memory during training-mode forward, default is False rank : {rank of the decomposition, 'same', float} if float, percentage of parameters of the original factorized_weights to use if 'same' use the same number of parameters bias : bool, default is True verbose : bool, default is False """ out_features, in_features = linear.weight.shape if auto_tensorize: if out_tensorized_features is not None and in_tensorized_features is not None: raise ValueError( "Either use auto_reshape or specify out_tensorized_features and in_tensorized_features." ) in_tensorized_features, out_tensorized_features = get_tensorized_shape( in_features=in_features, out_features=out_features, order=n_tensorized_modes, min_dim=2, verbose=verbose) else: assert(out_features == np.prod(out_tensorized_features)) assert(in_features == np.prod(in_tensorized_features)) instance = cls(in_tensorized_features, out_tensorized_features, bias=bias, factorization=factorization, rank=rank, implementation=implementation, n_layers=1, checkpointing=checkpointing, device=linear.weight.device, dtype=linear.weight.dtype) instance.weight.init_from_matrix(linear.weight.data, **decomposition_kwargs) if bias and linear.bias is not None: instance.bias.data = linear.bias.data return instance
[docs] @classmethod def from_linear_list(cls, linear_list, in_tensorized_features, out_tensorized_features, rank, bias=True, factorization='CP', implementation="reconstructed", checkpointing=False, decomposition_kwargs=dict(init='random')): """Class method to create an instance from an existing linear layer Parameters ---------- linear : torch.nn.Linear layer to tensorize tensorized_shape : tuple shape to tensorized the weight matrix to. Must verify np.prod(tensorized_shape) == np.prod(linear.weight.shape) factorization : str, default is 'cp' implementation : str which implementation to use for forward function. support 'factorized' and 'reconstructed', default is 'factorized' checkpointing : bool whether to enable gradient checkpointing to save memory during training-mode forward, default is False rank : {rank of the decomposition, 'same', float} if float, percentage of parameters of the original weights to use if 'same' use the same number of parameters bias : bool, default is True """ if factorization == 'TTM' and len(linear_list) > 1: raise ValueError(f'TTM factorization only support single factorized layers but got {len(linear_list)} layers.') for linear in linear_list: out_features, in_features = linear.weight.shape assert(out_features == np.prod(out_tensorized_features)) assert(in_features == np.prod(in_tensorized_features)) instance = cls(in_tensorized_features, out_tensorized_features, bias=bias, factorization=factorization, rank=rank, implementation=implementation, n_layers=len(linear_list), checkpointing=checkpointing, device=linear.weight.device, dtype=linear.weight.dtype) weight_tensor = torch.stack([layer.weight.data for layer in linear_list]) instance.weight.init_from_matrix(weight_tensor, **decomposition_kwargs) if bias: for i, layer in enumerate(linear_list): if layer.bias is not None: instance.bias.data[i] = layer.bias.data instance.has_bias[i] = 1 return instance
def __repr__(self): msg = (f'{self.__class__.__name__}(in_features={self.in_features}, out_features={self.out_features},' f' weight of size ({self.out_features}, {self.in_features}) tensorized to ({self.out_tensorized_features}, {self.in_tensorized_features}),' f'factorization={self.weight._name}, rank={self.rank}, implementation={self.implementation}') if self.bias is None: msg += f', bias=False' if self.n_layers == 1: msg += ', with a single layer parametrized, ' return msg msg += f' with {self.n_layers} layers jointly parametrized.' return msg
class SubFactorizedLinear(nn.Module): """Class representing one of the convolutions from the mother joint factorized convolution Parameters ---------- Notes ----- This relies on the fact that nn.Parameters are not duplicated: if the same nn.Parameter is assigned to multiple modules, they all point to the same data, which is shared. """ def __init__(self, main_linear, indices): super().__init__() self.main_linear = main_linear self.indices = indices def forward(self, x): return self.main_linear(x, self.indices) def extra_repr(self): msg = f'in_features={self.main_linear.in_features}, out_features={self.main_linear.out_features}' if self.main_linear.has_bias[self.indices]: msg += ', bias=True' return msg def __repr__(self): msg = f' {self.__class__.__name__} {self.indices} from main factorized layer.' msg += f'\n{self.__class__.__name__}(' msg += self.extra_repr() msg += ')' return msg