tltorch._tensor_dropout.remove_tt_dropout
-
tltorch._tensor_dropout.
remove_tt_dropout
(module)[source] Removes the tensor dropout from a TensorModule
Parameters: - moduletltorch.TensorModule
the tensor module parametrized by the tensor decomposition to which to apply tensor dropout
Examples
>>> trl = tltorch.TensorTrainTRL((10, 10), (10, ), rank='same') >>> trl = tt_dropout(trl, p=0.5) >>> remove_tt_dropout(trl)