Source code for tltorch.tensor_hooks._tensor_lasso

import tensorly as tl
tl.set_backend('pytorch')

import warnings
import torch
from torch import nn
from torch.nn import functional as F

from ..factorized_tensors import TuckerTensor, TTTensor, CPTensor
from ..utils import ParameterList

# Author: Jean Kossaifi
# License: BSD 3 clause

class TensorLasso:
    """Generalized Tensor Lasso on factorized tensors

        Applies a generalized Lasso (l1 regularization) on a factorized tensor.


    Parameters
    ----------
    penalty : float, default is 0.01
        scaling factor for the loss
    
    clamp_weights : bool, default is True
        if True, the lasso weights are clamp between -1 and 1
    
    threshold : float, default is 1e-6
        if a lasso weight is lower than the set threshold, it is set to 0

    normalize_loss : bool, default is True  
        If True, the loss will be between 0 and 1.
        Otherwise, the raw sum of absolute weights will be returned.

    Examples
    --------

    First you need to create an instance of the regularizer:

    >>> regularizer = tensor_lasso(factorization='cp')

    You can apply the regularizer to one or several layers:

    >>> trl = TRL((5, 5), (5, 5), rank='same')
    >>> trl2 = TRL((5, 5), (2, ), rank='same')
    >>> regularizer.apply(trl.weight)
    >>> regularizer.apply(trl2.weight)

    The lasso is automatically applied:

    >>> x = trl(x)
    >>> pred = trl2(x)
    >>> loss = your_loss_function(pred)

    Add the Lasso loss: 

    >>> loss = loss + regularizer.loss

    You can now backpropagate through your loss as usual:

    >>> loss.backwards()

    After you finish updating the weights, don't forget to reset the regularizer, 
    otherwise it will keep accumulating values!

    >>> loss.reset()

    You can also remove the regularizer with `regularizer.remove(trl)`.
    """
    _factorizations = dict()

    def __init_subclass__(cls, factorization, **kwargs):
        """When a subclass is created, register it in _factorizations"""
        cls._factorizations[factorization.__name__] = cls
    
    def __init__(self, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
        self.penalty = penalty
        self.clamp_weights = clamp_weights
        self.threshold = threshold
        self.normalize_loss = normalize_loss

        # Initialize the counters
        self.reset()
        
    def reset(self):
        """Reset the loss, should be called at the end of each iteration.
        """
        self._loss = 0
        self.n_element = 0

    @property
    def loss(self):
        """Returns the current Lasso (l1) loss for the layers that have been called so far.

        Returns
        -------
        float
            l1 regularization on the tensor layers the regularization has been applied to.
        """
        if self.n_element == 0:
            warnings.warn('The L1Regularization was not applied to any weights.')
            return 0
        elif self.normalize_loss:
            return self.penalty*self._loss/self.n_element
        else:
            return self.penalty*self._loss

    def __call__(self, module, input, tucker_tensor):
        raise NotImplementedError

    def apply_lasso(self, tucker_tensor, lasso_weights):
        """Applies the lasso to a decomposed tensor
        """
        raise NotImplementedError

    @classmethod
    def from_factorization(cls, factorization, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
        return cls.from_factorization_name(factorization.__class__.__name__, penalty=penalty,
                                           clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss)
    
    @classmethod
    def from_factorization_name(cls, factorization_name, penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True):
        cls = cls._factorizations[factorization_name]
        lasso = cls(penalty=penalty, clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss)
        return lasso

    def remove(self, module):
        raise NotImplementedError


class CPLasso(TensorLasso, factorization=CPTensor):
    """Decomposition Hook for Tensor Lasso on CP tensors

    Parameters
    ----------
    penalty : float, default is 0.01
        scaling factor for the loss
    
    clamp_weights : bool, default is True
        if True, the lasso weights are clamp between -1 and 1
    
    threshold : float, default is 1e-6
        if a lasso weight is lower than the set threshold, it is set to 0

    normalize_loss : bool, default is True  
        If True, the loss will be between 0 and 1.
        Otherwise, the raw sum of absolute weights will be returned.
    """
    def __call__(self, module, input, cp_tensor):
        """CP already includes weights, we'll just take their l1 norm
        """
        weights = getattr(module, 'lasso_weights')

        with torch.no_grad():
            if self.clamp_weights:
                weights.data = torch.clamp(weights.data, -1, 1)
                setattr(module, 'lasso_weights', weights)

            if self.threshold:
                weights.data = F.threshold(weights.data, threshold=self.threshold, value=0, inplace=True)
                setattr(module, 'lasso_weights', weights)

        self.n_element += weights.numel()
        self._loss = self._loss + self.penalty*torch.norm(weights, 1)
        return cp_tensor

    def apply(self, module):
        """Apply an instance of the L1Regularizer to a tensor module

        Parameters
        ----------
        module : TensorModule
            module on which to add the regularization

        Returns
        -------
        TensorModule (with Regularization hook)
        """
        context = tl.context(module.factors[0])
        lasso_weights = nn.Parameter(torch.ones(module.rank, **context))
        setattr(module, 'lasso_weights', lasso_weights)

        module.register_forward_hook(self)
        return module
    
    def remove(self, module):
        delattr(module, 'lasso_weights')

    def set_weights(self, module, value):
        with torch.no_grad():
            module.lasso_weights.data.fill_(value)


class TuckerLasso(TensorLasso, factorization=TuckerTensor):
    """Decomposition Hook for Tensor Lasso on Tucker tensors

        Applies a generalized Lasso (l1 regularization) on the tensor layers the regularization it is applied to.


    Parameters
    ----------
    penalty : float, default is 0.01
        scaling factor for the loss
    
    clamp_weights : bool, default is True
        if True, the lasso weights are clamp between -1 and 1
    
    threshold : float, default is 1e-6
        if a lasso weight is lower than the set threshold, it is set to 0

    normalize_loss : bool, default is True  
        If True, the loss will be between 0 and 1.
        Otherwise, the raw sum of absolute weights will be returned.
    """
    _log = []
    
    def __call__(self, module, input, tucker_tensor):
        lasso_weights = getattr(module, 'lasso_weights')
        order = len(lasso_weights)

        with torch.no_grad():
            for i in range(order):
                if self.clamp_weights:
                    lasso_weights[i].data = torch.clamp(lasso_weights[i].data, -1, 1)

                if self.threshold:
                    lasso_weights[i] = F.threshold(lasso_weights[i], threshold=self.threshold, value=0, inplace=True)

            setattr(module, 'lasso_weights', lasso_weights)

        for weight in lasso_weights:
            self.n_element += weight.numel()
            self._loss = self._loss + torch.sum(torch.abs(weight))

        return self.apply_lasso(tucker_tensor, lasso_weights)

    def apply_lasso(self, tucker_tensor, lasso_weights):
        """Applies the lasso to a decomposed tensor
        """
        factors = tucker_tensor.factors
        factors = [factor*w  for (factor, w) in zip(factors, lasso_weights)]
        return TuckerTensor(tucker_tensor.core, factors)

    def apply(self, module):
        """Apply an instance of the L1Regularizer to a tensor module

        Parameters
        ----------
        module : TensorModule
            module on which to add the regularization

        Returns
        -------
        TensorModule (with Regularization hook)
        """
        rank = module.rank
        context = tl.context(module.core)
        lasso_weights = ParameterList([nn.Parameter(torch.ones(r, **context)) for r in rank])
        setattr(module, 'lasso_weights', lasso_weights)
        module.register_forward_hook(self)

        return module

    def remove(self, module):
        delattr(module, 'lasso_weights')
    
    def set_weights(self, module, value):
        with torch.no_grad():
            for weight in module.lasso_weights:
                weight.data.fill_(value)


class TTLasso(TensorLasso, factorization=TTTensor):
    """Decomposition Hook for Tensor Lasso on TT tensors

    Parameters
    ----------
    penalty : float, default is 0.01
        scaling factor for the loss
    
    clamp_weights : bool, default is True
        if True, the lasso weights are clamp between -1 and 1
    
    threshold : float, default is 1e-6
        if a lasso weight is lower than the set threshold, it is set to 0

    normalize_loss : bool, default is True  
        If True, the loss will be between 0 and 1.
        Otherwise, the raw sum of absolute weights will be returned.
    """
    def __call__(self, module, input, tt_tensor):
        lasso_weights = getattr(module, 'lasso_weights')
        order = len(lasso_weights)

        with torch.no_grad():
            for i in range(order):
                if self.clamp_weights:
                    lasso_weights[i].data = torch.clamp(lasso_weights[i].data, -1, 1)

                if self.threshold:
                    lasso_weights[i] = F.threshold(lasso_weights[i], threshold=self.threshold, value=0, inplace=True)

            setattr(module, 'lasso_weights', lasso_weights)

        for weight in lasso_weights:
            self.n_element += weight.numel()
            self._loss = self._loss + torch.sum(torch.abs(weight))

        return self.apply_lasso(tt_tensor, lasso_weights)

    def apply_lasso(self, tt_tensor, lasso_weights):
        """Applies the lasso to a decomposed tensor
        """
        factors = tt_tensor.factors
        factors = [factor*w  for (factor, w) in zip(factors, lasso_weights)] + [factors[-1]]
        return TTTensor(factors)

    def apply(self, module):
        """Apply an instance of the L1Regularizer to a tensor module

        Parameters
        ----------
        module : TensorModule
            module on which to add the regularization

        Returns
        -------
        TensorModule (with Regularization hook)
        """
        rank = module.rank[1:-1]
        lasso_weights = ParameterList([nn.Parameter(torch.ones(1, 1, r)) for r in rank])
        setattr(module, 'lasso_weights', lasso_weights)
        handle = module.register_forward_hook(self)
        return module

    def remove(self, module):
        """Remove the Regularization from a module.
        """
        delattr(module, 'lasso_weights')

    def set_weights(self, module, value):
        with torch.no_grad():
            for weight in module.lasso_weights:
                weight.data.fill_(value)


[docs]def tensor_lasso(factorization='CP', penalty=0.01, clamp_weights=True, threshold=1e-6, normalize_loss=True): """Generalized Tensor Lasso from a factorized tensors Applies a generalized Lasso (l1 regularization) on a factorized tensor. Parameters ---------- factorization : str penalty : float, default is 0.01 scaling factor for the loss clamp_weights : bool, default is True if True, the lasso weights are clamp between -1 and 1 threshold : float, default is 1e-6 if a lasso weight is lower than the set threshold, it is set to 0 normalize_loss : bool, default is True If True, the loss will be between 0 and 1. Otherwise, the raw sum of absolute weights will be returned. Examples -------- Let's say you have a set of factorized (here, CP) tensors: >>> tensor = FactorizedTensor.new((3, 4, 2), rank='same', factorization='CP').normal_() >>> tensor2 = FactorizedTensor.new((5, 6, 7), rank=0.5, factorization='CP').normal_() First you need to create an instance of the regularizer: >>> regularizer = TensorLasso(factorization='cp', penalty=penalty) You can apply the regularizer to one or several layers: >>> regularizer.apply(tensor) >>> regularizer.apply(tensor2) The lasso is automatically applied: >>> sum = torch.sum(tensor() + tensor2()) You can access the Lasso loss from your instance: >>> l1_loss = regularizer.loss You can optimize and backpropagate through your loss as usual. After you finish updating the weights, don't forget to reset the regularizer, otherwise it will keep accumulating values! >>> regularizer.reset() You can also remove the regularizer with `regularizer.remove(tensor)`, or `remove_tensor_lasso(tensor)`. """ factorization = factorization.lower() mapping = dict(cp='CPTensor', tucker='TuckerTensor', tt='TTTensor') return TensorLasso.from_factorization_name(mapping[factorization], penalty=penalty, clamp_weights=clamp_weights, threshold=threshold, normalize_loss=normalize_loss)
[docs]def remove_tensor_lasso(factorized_tensor): """Removes the tensor lasso from a TensorModule Parameters ---------- factorized_tensor : tltorch.FactorizedTensor the tensor module parametrized by the tensor decomposition to which to apply tensor dropout Examples -------- >>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_() >>> tensor = tensor_lasso(tensor, p=0.5) >>> remove_tensor_lasso(tensor) """ for key, hook in factorized_tensor._forward_hooks.items(): if isinstance(hook, TensorLasso): hook.remove(factorized_tensor) del factorized_tensor._forward_hooks[key] return factorized_tensor raise ValueError(f'TensorLasso not found in factorized tensor {factorized_tensor}')