tltorch.factorized_layers.FactorizedEmbedding

class tltorch.factorized_layers.FactorizedEmbedding(num_embeddings, embedding_dim, auto_reshape=True, d=3, tensorized_num_embeddings=None, tensorized_embedding_dim=None, factorization='blocktt', rank=8, device=None, dtype=None)[source]

Tensorized Embedding Layers For Efficient Model Compression

Tensorized drop-in replacement for torch.nn.Embedding

Parameters
num_embeddingsint, number of entries in the lookup table
embedding_dimint, number of dimensions per entry
auto_reshapebool, whether to use automatic reshaping for the embedding dimensions
dint or int tuple, number of reshape dimensions for both embedding table dimension
tensorized_num_embeddingsint tuple, tensorized shape of the first embedding table dimension
tensorized_embedding_dimint tuple, tensorized shape of the second embedding table dimension
factorizationstr, tensor type
rankint tuple or str, tensor rank

Methods

forward(input)

Defines the computation performed at every call.

from_embedding(embedding_layer[, rank, ...])

Create a tensorized embedding layer from a regular embedding layer

reset_parameters

forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_embedding(embedding_layer, rank=8, factorization='blocktt', decompose_weights=True, auto_reshape=True, decomposition_kwargs={}, **kwargs)[source]

Create a tensorized embedding layer from a regular embedding layer

Parameters
embedding_layertorch.nn.Embedding
rankint tuple or str, tensor rank
factorizationstr, tensor type
decompose_weights: bool, decompose weights and use for initialization
auto_reshape: bool, automatically reshape dimensions for TensorizedTensor
decomposition_kwargs: dict, specify kwargs for the decomposition