"""
Higher Order Convolution with CP decompositon
"""
# Author: Jean Kossaifi
# License: BSD 3 clause
from ._base_conv import Conv1D, BaseFactorizedConv
from .. import init
from tensorly import validate_cp_rank
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import tensorly as tl
from tensorly.decomposition import parafac
from tensorly import random
from tensorly import tenalg
tl.set_backend('pytorch')
[docs]class CPConv(BaseFactorizedConv):
    """Create a Factorized CP convolution of arbitrary order.
    Parameters
    ----------
    
    in_channels : int
    out_channels : int
    kernel_size : int or int list
        if int, order MUST be specified
        if int list, then the conv will use order = len(kernel_size)
    rank : int
        rank of the factorized kernel
    order : int, optional if kernel_size is a list
        see kernel_size
    implementation = {'factorized', 'reconstructed', 'mobilenet'}
        Strategy to use for the forward pass. Options are:
        * factorized : the CP conv is expressed as a series of 1D convolutions
        * reconstructed : full kernel is reconstructed from the decomposition. 
            the reconstruction is used to perform a regular forward pass
        * mobilenet : the equivalent formulation of CP as a MobileNet block is used
        
    stride : int, default is 1
    padding : int, default is 0
    dilation : int, default is 0
    Attributes
    ----------
    
    kernel_shape : int tuple
        shape of the kernel weight parametrizing the full convolution
    rank : int
        rank of the CP decomposition
    References
    ----------
    .. [1] Vadim Lebedev, Yaroslav Ganin, Maksim Rakhuba, Ivan V.Oseledets, and Victor S. Lempitsky.
        Speeding-up convolu-tional neural networks using fine-tuned cp-decomposition. InICLR, 2015.
    .. [2] Jean Kossaifi, Antoine Toisoul, Adrian Bulat, Yannis Panagakis, Timothy M. Hospedales, Maja Pantic; 
        Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020, pp. 6060-6069 
    See Also
    --------
    TuckerConv
    TTConvs
    """
    def __init__(self, in_channels, out_channels, kernel_size, rank, order=None,
                 implementation='reconstructed', stride=1, padding=0, dilation=1, bias=False):
        super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias)
        self.rank = validate_cp_rank(self.kernel_shape, rank=self.rank)
        self.weights = nn.Parameter(torch.Tensor(self.rank))
        self.factors = nn.ParameterList([nn.Parameter(torch.Tensor(s, self.rank)) for s in self.kernel_shape])
        self.init_from_random(decompose_full_weight=False)
[docs]    def init_from_random(self, decompose_full_weight=True):
        """Initialize the factorized convolution's parameter randomly
        
        Parameters
        ----------
        decompose_full_weight : bool 
            If True, a full weight is randomly created and decomposed to intialize the parameters (slower)
            Otherwise, the parameters are initialized directly (faster) so the reconstruction has a set variance. 
        """
        if self.bias is not None:
            self.bias.data.zero_()
        if decompose_full_weight:
            full_weight = torch.normal(0.0, 0.02, size=self.kernel_shape)
            self.init_from_tensor(full_weight)
        else:
            init.cp_init(self.weights, self.factors) 
[docs]    def init_from_decomposition(self, cp_tensor, bias=None):
        """Transpose the factors from a CP Tensor to the factorized version
    
        Parameters
        ----------
        factors : cp_tensor
        """
        shape, rank = tl.cp_tensor._validate_cp_tensor(cp_tensor)
        if shape != self.kernel_shape:
            raise ValueError(f'Expected a shape of {self.kernel_shape} but got {shape}.')
        if rank != self.rank:
            raise ValueError(f'Expected a cp_tensor of rank {self.rank} but got {rank}.')
        weights, factors = cp_tensor
    
        with torch.no_grad():
            for i, f in enumerate(factors):
                self.factors[i].data = f
            self.weights.data = weights
            if self.bias is not None and bias is not None:
                self.bias.data = bias.data 
[docs]    def init_from_tensor(self, kernel_tensor, bias=None, decomposition_kwargs=dict()):
        """Initialize the factorized convolutional layer from a full tensor
        """
        with torch.no_grad():
            cp_tensor = parafac(kernel_tensor, rank=self.rank, **decomposition_kwargs)
            self.init_from_decomposition(cp_tensor, bias=bias) 
[docs]    def get_decomposition(self, return_bias=False):
        """Returns a CP Tensor parametrizing the convolution
        
        Parameters
        ----------
        return_bias : bool, default is False
            if True also return the bias
        Returns
        -------
        weights, factors, bias
        """
        if return_bias:
            if self.bias is not None:
                bias = nn.Parameter(self.bias.data)
            else:
                bias = None
            return (self.weights, self.factors), bias
        else:
            return (self.weights, self.factors) 
[docs]    def transduct(self, kernel_size, mode=0, padding=0, stride=1):
        """Transduction of the factorized convolution to add a new dimension
        Parameters
        ----------
        kernel_size : int
            size of the additional dimension
        mode : where to insert the new dimension, after the channels, default is 0
            by default, insert the new dimensions before the existing ones
            (e.g. add time before height and width)
        padding : int, default is 0
        stride : int: default is 1
        Returns
        -------
        self
        """
        (weights, factors), bias = self.get_decomposition(return_bias=True)
        self.order += 1
        self.padding = self.padding[:mode] + (padding,) + self.padding[mode:]
        self.stride = self.stride[:mode] + (stride,) + self.stride[mode:]
        self.kernel_size = self.kernel_size[:mode] + (kernel_size,) + self.kernel_size[mode:]
        self.kernel_shape = self.kernel_shape[:mode+2] + (kernel_size,) + self.kernel_shape[mode+2:]
        factors = [f for f in factors]
        # +2 corresponding to input and output channels
        #new_factor = torch.ones(kernel_size, self.rank)
        new_factor = torch.zeros(kernel_size, self.rank)
        new_factor[kernel_size//2, :] = 1
        factors.insert(mode+2, nn.Parameter(new_factor.to(self.factors[0].device)))
        self.init_from_decomposition((weights, factors), bias)
        return self  
    
class CPConvFactorized(CPConv):
    """Create a convolution of arbitrary order
    Parameters
    ----------
    in_channels : int
    out_channels : int
    kernel_size : int or int list
        if int, order MUST be specified
        if int list, then the conv will use order = len(kernel_size)
    rank : int
        rank of the factorized kernel
    order : int, optional if kernel_size is a list
        see kernel_size
    stride : int, default is 1
    padding : int, default is 0
    dilation : int, default is 0
    
    Attributes
    ----------
    kernel_shape : int tuple
        shape of the kernel weight parametrizing the full convolution
    rank : int
        rank of the CP decomposition
    """
    def __init__(self, in_channels, out_channels, kernel_size, rank, order=None, 
                 stride=1, padding=0, dilation=1, bias=False):
        super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias)
        self.n_conv = len(self.kernel_shape)
    
    def forward(self, x):
        """Perform a factorized CP convolution
        Parameters
        ----------
        x : torch.tensor
            tensor of shape (batch_size, C, I_2, I_3, ..., I_N)
        Returns
        -------
        NDConv(x) with an CP kernel
        """
        weights, factors = self._process_decomposition()
        _, rank = tl.cp_tensor._validate_cp_tensor((weights, factors))
        batch_size = x.shape[0]
        # rank = self.rank
        # Change the number of channels to the rank
        x_shape = list(x.shape)
        x = x.reshape((batch_size, x_shape[1], -1)).contiguous()
        # First conv == tensor contraction
        # from (in_channels, rank) to (rank == out_channels, in_channels, 1)
        x = F.conv1d(x, tl.transpose(factors[1]).unsqueeze(2))
        x_shape[1] = rank
        x = x.reshape(x_shape)
        # convolve over non-channels
        for i in range(self.order):
            # From (kernel_size, rank) to (rank, 1, kernel_size)
            kernel = tl.transpose(factors[i+2]).unsqueeze(1)             
            x = Conv1D(x.contiguous(), kernel, i+2, stride=self.stride[i], padding=self.padding[i], groups=rank)
        # Revert back number of channels from rank to output_channels
        x_shape = list(x.shape)
        x = x.reshape((batch_size, x_shape[1], -1))                
        # Last conv == tensor contraction
        # From (out_channels, rank) to (out_channels, in_channels == rank, 1)
        x = F.conv1d(x*weights.unsqueeze(1).unsqueeze(0), factors[0].unsqueeze(2))
        if self.bias is not None:
            x += self.bias.unsqueeze(0).unsqueeze(2)
        x_shape[1] = self.out_channels
        x = x.reshape(x_shape)
        return x
class CPConvMobileNet(CPConv):
    """Create a convolution of arbitrary order
    Parameters
    ----------
    in_channels : int
    out_channels : int
    kernel_size : int or int list
        if int, order MUST be specified
        if int list, then the conv will use order = len(kernel_size)
    rank : int
        rank of the factorized kernel
    order : int, optional if kernel_size is a list
        see kernel_size
    stride : int, default is 1
    padding : int, default is 0
    dilation : int, default is 0
    
    Attributes
    ----------
    kernel_shape : int tuple
        shape of the kernel weight parametrizing the full convolution
    rank : int
        rank of the CP decomposition
    """
    def __init__(self, in_channels, out_channels, kernel_size, rank, order=None, 
                 stride=1, padding=0, dilation=1, bias=False):
        super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias)
        self.n_conv = len(self.kernel_shape)
    def forward(self, x):
        """Perform a factorized CP convolution
        Parameters
        ----------
        x : torch.tensor
            tensor of shape (batch_size, C, I_2, I_3, ..., I_N)
        Returns
        -------
        NDConv(x) with an CP kernel
        """
        weights, factors = self._process_decomposition()
        _, rank = tl.cp_tensor._validate_cp_tensor((weights, factors))
        batch_size = x.shape[0]
        # rank = self.rank
        # Change the number of channels to the rank
        x_shape = list(x.shape)
        x = x.reshape((batch_size, x_shape[1], -1)).contiguous()
        # First conv == tensor contraction
        # from (in_channels, rank) to (rank == out_channels, in_channels, 1)
        x = F.conv1d(x, tl.transpose(factors[1]).unsqueeze(2))
        x_shape[1] = rank
        x = x.reshape(x_shape)
        # convolve over merged actual dimensions
        # Spatial convs
        # From (kernel_size, rank) to (out_rank, 1, kernel_size)
        if self.order == 1:
            weight = tl.transpose(factors[2]).unsqueeze(1)
            x = F.conv1d(x.contiguous(), weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=rank)
        elif self.order == 2:
            weight = tenalg.batched_tensor_dot(tl.transpose(factors[2]), tl.transpose(factors[3])).unsqueeze(1)
            x = F.conv2d(x.contiguous(), weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=rank)
        elif self.order == 3:
            weight = tenalg.batched_tensor_dot(tl.transpose(factors[2]), 
                        tenalg.batched_tensor_dot(tl.transpose(factors[3]), tl.transpose(factors[4]))).unsqueeze(1)
            x = F.conv3d(x.contiguous(), weight, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=rank)
        # Revert back number of channels from rank to output_channels
        x_shape = list(x.shape)
        x = x.reshape((batch_size, x_shape[1], -1))
        # Last conv == tensor contraction
        # From (out_channels, rank) to (out_channels, in_channels == rank, 1)
        x = F.conv1d(x*weights.unsqueeze(1).unsqueeze(0), factors[0].unsqueeze(2))
        if self.bias is not None:
            x += self.bias.unsqueeze(0).unsqueeze(2)
        x_shape[1] = self.out_channels
        x = x.reshape(x_shape)
        return x
class CPConvReconstructed(CPConv):
    """Create a convolution of arbitrary order
    Parameters
    ----------
    in_channels : int
    out_channels : int
    kernel_size : int or int list
        if int, order MUST be specified
        if int list, then the conv will use order = len(kernel_size)
    rank : int
        rank of the factorized kernel
    order : int, optional if kernel_size is a list
        see kernel_size
    stride : int, default is 1
    padding : int, default is 0
    dilation : int, default is 0
    
    Attributes
    ----------
    kernel_shape : int tuple
        shape of the kernel weight parametrizing the full convolution
    rank : int
        rank of the CP decomposition
    """
    def __init__(self, in_channels, out_channels, kernel_size, rank, order=None, 
                 stride=1, padding=0, dilation=1, bias=False):
        super().__init__(in_channels, out_channels, kernel_size, rank, order=order, padding=padding, stride=stride, bias=bias)
        self.n_conv = len(self.kernel_shape)
        if self.order == 1: 
            self.conv = F.conv1d
        elif self.order == 2:
            self.conv = F.conv2d
        elif self.order == 3:
            self.conv = F.conv3d
        else:
            raise ValueError(f'{self.__class__.__name__} currently implemented only for 1D to 3D convs, but got {order}')
    def forward(self, x):
        """Perform a convolution using the reconstructed full weightss
        Parameters
        ----------
        x : torch.tensor
            tensor of shape (batch_size, C, I_2, I_3, ..., I_N)
        Returns
        -------
        NDConv(x) with a CP kernel
        """
        rec = tl.cp_to_tensor(self._process_decomposition())
        return self.conv(x, rec, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation)
    def transduct(self, *args, **kwargs):
        super().transduct(*args, **kwargs)
        if self.order == 1: 
            self.conv = F.conv1d
        elif self.order == 2:
            self.conv = F.conv2d
        elif self.order == 3:
            self.conv = F.conv3d