tltorch.tensor_hooks
.remove_tensor_dropout¶
- tltorch.tensor_hooks.remove_tensor_dropout(factorized_tensor)[source]¶
Removes the tensor dropout from a TensorModule
- Parameters:
- factorized_tensortltorch.FactorizedTensor
the tensor module parametrized by the tensor decomposition to which to apply tensor dropout
Examples
>>> tensor = FactorizedTensor.new((3, 4, 2), rank=0.5, factorization='CP').normal_() >>> tensor = tensor_dropout(tensor, p=0.5) >>> remove_tensor_dropout(tensor)