tltorch.tensor_hooks
.tensor_lasso¶
- tltorch.tensor_hooks.tensor_lasso(factorization='CP', penalty=0.01, clamp_weights=True, threshold=1e-06, normalize_loss=True)[source]¶
Generalized Tensor Lasso from a factorized tensors
Applies a generalized Lasso (l1 regularization) on a factorized tensor.
- Parameters:
- factorizationstr
- penaltyfloat, default is 0.01
scaling factor for the loss
- clamp_weightsbool, default is True
if True, the lasso weights are clamp between -1 and 1
- thresholdfloat, default is 1e-6
if a lasso weight is lower than the set threshold, it is set to 0
- normalize_lossbool, default is True
If True, the loss will be between 0 and 1. Otherwise, the raw sum of absolute weights will be returned.
Examples
Let’s say you have a set of factorized (here, CP) tensors:
>>> tensor = FactorizedTensor.new((3, 4, 2), rank='same', factorization='CP').normal_() >>> tensor2 = FactorizedTensor.new((5, 6, 7), rank=0.5, factorization='CP').normal_()
First you need to create an instance of the regularizer:
>>> regularizer = TensorLasso(factorization='cp', penalty=penalty)
You can apply the regularizer to one or several layers:
>>> regularizer.apply(tensor) >>> regularizer.apply(tensor2)
The lasso is automatically applied:
>>> sum = torch.sum(tensor() + tensor2())
You can access the Lasso loss from your instance:
>>> l1_loss = regularizer.loss
You can optimize and backpropagate through your loss as usual.
After you finish updating the weights, don’t forget to reset the regularizer, otherwise it will keep accumulating values!
>>> regularizer.reset()
You can also remove the regularizer with regularizer.remove(tensor), or remove_tensor_lasso(tensor).