tltorch._tensor_dropout.tt_dropout
-
tltorch._tensor_dropout.
tt_dropout
(module, p)[source] TT Dropout
Parameters: - moduletltorch.TensorModule
the tensor module parametrized by the tensor decomposition to which to apply tensor dropout
- pfloat
dropout probability if 0, no dropout is applied if 1, all the components but 1 are dropped in the latent space
Returns: - TensorModule
the module to which tensor dropout has been attached
Examples
>>> trl = tltorch.TensorTrainTRL((10, 10), (10, ), rank='same') >>> trl = tt_dropout(trl, p=0.5) >>> remove_tt_dropout(trl)