tltorch.tensor_hooks.remove_tensor_dropout

tltorch.tensor_hooks.remove_tensor_dropout(factorized_tensor)[source]

Removes the tensor dropout from a TensorModule

Parameters:
factorized_tensortltorch.FactorizedTensor

the tensor module parametrized by the tensor decomposition to which to apply tensor dropout

Examples

>>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_()
>>> tensor = tensor_dropout(tensor, p=0.5)
>>> remove_tensor_dropout(tensor)