tltorch._tensor_dropout.tt_dropout

tltorch._tensor_dropout.tt_dropout(module, p)[source]

TT Dropout

Parameters:
moduletltorch.TensorModule

the tensor module parametrized by the tensor decomposition to which to apply tensor dropout

pfloat

dropout probability if 0, no dropout is applied if 1, all the components but 1 are dropped in the latent space

Returns:
TensorModule

the module to which tensor dropout has been attached

Examples

>>> trl = tltorch.TensorTrainTRL((10, 10), (10, ), rank='same')
>>> trl = tt_dropout(trl, p=0.5)
>>> remove_tt_dropout(trl)