Source code for tltorch.factorized_conv._tt_conv

"""
Higher Order Convolution with Tensor-Train decompositon
"""

# Author: Jean Kossaifi
# License: BSD 3 clause

from ._base_conv import Conv1D, BaseFactorizedConv
from .. import init

from tensorly import validate_tt_rank
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

import tensorly as tl
from tensorly.decomposition import tensor_train
from tensorly import random
from tensorly import tenalg
tl.set_backend('pytorch')


[docs]class TTConv(BaseFactorizedConv): """Create a convolution of arbitrary order with a Tucker kernel. Parameters ---------- in_channels : int out_channels : int kernel_size : int or int list if int, order MUST be specified if int list, then the conv will use order = len(kernel_size) rank : int rank of the factorized kernel order : int, optional if kernel_size is a list see kernel_size implementation = {'factorized', 'reconstructed'} strategy to use for the forward pass - factorized : the TT conv is expressed as a series of 1D convolutions - reconstructed : full kernel is reconstructed from the decomposition. the reconstruction is used to perform a regular forward pass stride : int, default is 1 padding : int, default is 0 dilation : int, default is 0 Attributes ---------- kernel_shape : int tuple shape of the kernel weight parametrizing the full convolution rank : int rank of the TT decomposition See Also -------- TuckerConv CPConv References ---------- .. [2] Jean Kossaifi, Antoine Toisoul, Adrian Bulat, Yannis Panagakis, Timothy M. Hospedales, Maja Pantic; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020, pp. 6060-6069 """ def __init__(self, in_channels, out_channels, kernel_size, rank, order=None, implementation=None, stride=1, padding=0, dilation=1, bias=False): super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias) tt_shape = list(self.kernel_shape) self.n_conv = len(tt_shape) # For the TT case, the decomposition has a different shape than the kernel. out_channel = tt_shape.pop(0) tt_shape += [out_channel] self.tt_shape = tuple(tt_shape) self.rank = tl.tt_tensor.validate_tt_rank(self.tt_shape, rank) self.factors = nn.ParameterList() for i, s in enumerate(self.tt_shape): self.factors.append(nn.Parameter(torch.Tensor(self.rank[i], s, self.rank[i+1]))) # Things like setting the tt_shape above are the init is not in the base class self.init_from_random(decompose_full_weight=False)
[docs] def init_from_tensor(self, kernel_tensor, bias=None, decomposition_kwargs=dict()): """Initialize the factorized convolutional layer from a regular convolutional layer """ self.rank = validate_tt_rank(kernel_tensor.shape, rank=self.rank) with torch.no_grad(): # Put output channels at the end kernel_tensor = tl.moveaxis(kernel_tensor, 0, -1) tt_tensor = tensor_train(kernel_tensor, rank=self.rank, **decomposition_kwargs) self.init_from_decomposition(tt_tensor, bias=bias)
[docs] def init_from_random(self, decompose_full_weight=True): """Initialize the factorized convolution's parameter randomly Parameters ---------- decompose_full_weight : bool If True, a full weight is randomly created and decomposed to intialize the parameters (slower) Otherwise, the parameters are initialized directly (faster) so the reconstruction has a set variance. """ if self.bias is not None: self.bias.data.zero_() if decompose_full_weight: full_weight = torch.normal(0.0, 0.02, size=self.tt_shape) self.init_from_tensor(full_weight) else: init.tt_init(self.factors)
[docs] def init_from_decomposition(self, tt_tensor, bias=None): """Transpose the factors from a full weight to the factorized version Parameters ---------- factors : tt_tensor """ shape, rank = tl.tt_tensor._validate_tt_tensor(tt_tensor) self.rank = rank if shape != self.tt_shape: raise ValueError(f'Expected a shape of {self.tt_shape} but got {shape}.') with torch.no_grad(): for i, f in enumerate(tt_tensor): self.factors[i].data = f if self.bias is not None and bias is not None: self.bias.data = bias
[docs] def get_decomposition(self, return_bias=False): """Transpose back factors from a factorized version Parameters ---------- return_bias : bool, default is False if True also return the bias Returns ------- factors, bias if return_bias: factors otherwise """ if return_bias: return self.factors, self.bias else: return self.factors
[docs] def full_weights(self): """Returns the reconstructed full convolutional kernel """ factors = self.get_decomposition(return_bias=False) kernel = tl.tt_to_tensor(factors) return tl.moveaxis(kernel, -1, 0)
[docs] def transduct(self, kernel_size, mode=0, padding=0, stride=1, dilation=1): """Transduction of the factorized convolution to add a new dimension Parameters ---------- kernel_size : int size of the additional dimension mode : where to insert the new dimension, after the channels, default is 0 by default, insert the new dimensions before the existing ones (e.g. add time before height and width) padding : int, default is 0 stride : int: default is 1 Returns ------- self """ factors, bias = self.get_decomposition(return_bias=True) # Increase the order of the convolution self.order += 1 self.n_conv += 1 self.padding = self.padding[:mode] + (padding,) + self.padding[mode:] self.stride = self.stride[:mode] + (stride,) + self.stride[mode:] self.kernel_size = self.kernel_size[:mode] + (kernel_size,) + self.kernel_size[mode:] #self.dilation = self.dilation[:mode] + (dilation,) + self.dilation[mode:] self.kernel_shape = self.kernel_shape[:mode+2] + (kernel_size,) + self.kernel_shape[mode+2:] # tt_shape is (in_channels, spacial_dims, out_channels) self.tt_shape = self.tt_shape[:mode+1] + (kernel_size, ) + self.tt_shape[mode+1:] # rank[0] = rank[-1] = 1 new_rank = self.rank[mode+1] self.rank = self.rank[:mode+1] + (new_rank, ) + self.rank[mode+2:] factors = [f for f in factors] # +1 -- account for input channels new_factor = torch.zeros(new_rank, kernel_size, new_rank) for i in range(kernel_size): new_factor[:, i, :] = torch.eye(new_rank) factors.insert(mode+1, nn.Parameter(new_factor.to(self.factors[0].device))) self.init_from_decomposition(factors, bias) return self
class TTConvFactorized(TTConv): """Create a convolution of arbitrary order Parameters ---------- in_channels : int out_channels : int kernel_size : int or int list if int, order MUST be specified if int list, then the conv will use order = len(kernel_size) rank : int rank of the factorized kernel order : int, optional if kernel_size is a list see kernel_size stride : int, default is 1 padding : int, default is 0 dilation : int, default is 0 Attributes ---------- kernel_shape : int tuple shape of the kernel weight parametrizing the full convolution rank : int rank of the tt decomposition """ def __init__(self, in_channels, out_channels, kernel_size, rank, order=None, stride=1, padding=0, dilation=1, bias=False): super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias) # tt_shape = list(self.kernel_shape) # self.n_conv = len(tt_shape) # out_channel = tt_shape.pop(0) # tt_shape += [out_channel] # self.tt_shape = tuple(tt_shape) def forward(self, x): """Perform a factorized tt convolution Parameters ---------- x : torch.tensor tensor of shape (batch_size, C, I_2, I_3, ..., I_N) Returns ------- NDConv(x) with an tt kernel """ factors = self._process_decomposition() _, rank = tl.tt_tensor._validate_tt_tensor(factors) batch_size = x.shape[0] # rank = self.rank # Change the number of channels to the rank x_shape = list(x.shape) x = x.reshape((batch_size, x_shape[1], -1)).contiguous() # First conv == tensor contraction # from (1, in_channels, rank) to (rank == out_channels, in_channels, 1) x = F.conv1d(x, tl.transpose(factors[0], [2, 1, 0])) x_shape[1] = rank[1] x = x.reshape(x_shape) # convolve over non-channels for i in range(self.n_conv-2): # From (in_rank, kernel_size, out_rank) to (out_rank, in_rank, kernel_size) kernel = tl.transpose(factors[i+1], [2, 0, 1]) x = Conv1D(x.contiguous(), kernel, i+2, stride=self.stride[i], padding=self.padding[i])#, groups=self.rank[i+1]) # Revert back number of channels from rank to output_channels x_shape = list(x.shape) x = x.reshape((batch_size, x_shape[1], -1)) # Last conv == tensor contraction # From (rank, out_channels, 1) to (out_channels, in_channels == rank, 1) x = F.conv1d(x, tl.transpose(factors[-1], [1, 0, 2])) if self.bias is not None: x += self.bias.unsqueeze(0).unsqueeze(2) x_shape[1] = self.out_channels x = x.reshape(x_shape) return x class TTConvReconstructed(TTConv): """Create a convolution of arbitrary order Parameters ---------- in_channels : int out_channels : int kernel_size : int or int list if int, order MUST be specified if int list, then the conv will use order = len(kernel_size) rank : int rank of the factorized kernel order : int, optional if kernel_size is a list see kernel_size stride : int, default is 1 padding : int, default is 0 dilation : int, default is 0 Attributes ---------- kernel_shape : int tuple shape of the kernel weight parametrizing the full convolution rank : int rank of the tt decomposition """ def __init__(self, in_channels, out_channels, kernel_size, rank, order=None, stride=1, padding=0, dilation=1, bias=False): super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias) # tt_shape = list(self.kernel_shape) # self.n_conv = len(tt_shape) # out_channel = tt_shape.pop(0) # tt_shape += [out_channel] # self.tt_shape = tuple(tt_shape) self.n_conv = len(self.kernel_shape) if self.order == 1: self.conv = F.conv1d elif self.order == 2: self.conv = F.conv2d elif self.order == 3: self.conv = F.conv3d else: raise ValueError(f'{self.__class__.__name__} currently implemented only for 1D to 3D convs, but got {self.order}') def forward(self, x): """Perform a factorized tt convolution Parameters ---------- x : torch.tensor tensor of shape (batch_size, C, I_2, I_3, ..., I_N) Returns ------- NDConv(x) with an tt kernel """ rec = tl.moveaxis(tl.tt_to_tensor(self._process_decomposition()), -1, 0) return self.conv(x, rec, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) def transduct(self, *args, **kwargs): super().transduct(*args, **kwargs) if self.order == 1: self.conv = F.conv1d elif self.order == 2: self.conv = F.conv2d elif self.order == 3: self.conv = F.conv3d