tltorch._tensor_dropout.remove_cp_dropout

tltorch._tensor_dropout.remove_cp_dropout(module)[source]

Removes the tensor dropout from a TensorModule

Parameters:
moduletltorch.TensorModule

the tensor module parametrized by the tensor decomposition to which to apply tensor dropout

Examples

>>> trl = tltorch.CPTRL((10, 10), (10, ), rank='same')
>>> trl = cp_dropout(trl, p=0.5)
>>> remove_cp_dropout(trl)