tltorch.tensor_hooks.remove_tensor_lasso

tltorch.tensor_hooks.remove_tensor_lasso(factorized_tensor)[source]

Removes the tensor lasso from a TensorModule

Parameters:
factorized_tensortltorch.FactorizedTensor

the tensor module parametrized by the tensor decomposition to which to apply tensor dropout

Examples

>>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_()
>>> tensor = tensor_lasso(tensor, p=0.5)
>>> remove_tensor_lasso(tensor)