tltorch.factorized_layers.FactorizedEmbedding

class tltorch.factorized_layers.FactorizedEmbedding(num_embeddings, embedding_dim, auto_tensorize=True, n_tensorized_modes=3, tensorized_num_embeddings=None, tensorized_embedding_dim=None, factorization='blocktt', rank=8, n_layers=1, device=None, dtype=None)[source]

Tensorized Embedding Layers For Efficient Model Compression Tensorized drop-in replacement for torch.nn.Embedding

Parameters:
num_embeddingsint

number of entries in the lookup table

embedding_dimint

number of dimensions per entry

auto_tensorizebool

whether to use automatic reshaping for the embedding dimensions

n_tensorized_modesint or int tuple

number of reshape dimensions for both embedding table dimension

tensorized_num_embeddingsint tuple

tensorized shape of the first embedding table dimension

tensorized_embedding_dimint tuple

tensorized shape of the second embedding table dimension

factorizationstr

tensor type

rankint tuple or str

rank of the tensor factorization

Methods

forward(input[, indices])

Define the computation performed at every call.

from_embedding(embedding_layer[, rank, ...])

Create a tensorized embedding layer from a regular embedding layer

from_embedding_list(embedding_layer_list[, ...])

Create a tensorized embedding layer from a regular embedding layer

get_embedding

reset_parameters

forward(input, indices=0)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_embedding(embedding_layer, rank=8, factorization='blocktt', n_tensorized_modes=2, decompose_weights=True, auto_tensorize=True, decomposition_kwargs={}, **kwargs)[source]

Create a tensorized embedding layer from a regular embedding layer

Parameters:
embedding_layertorch.nn.Embedding
rankint tuple or str

rank of the tensor decomposition

factorizationstr

tensor type

decompose_weights: bool

whether to decompose weights and use for initialization

auto_tensorize: bool

if True, automatically reshape dimensions for TensorizedTensor

decomposition_kwargs: dict

specify kwargs for the decomposition

classmethod from_embedding_list(embedding_layer_list, rank=8, factorization='blocktt', n_tensorized_modes=2, decompose_weights=True, auto_tensorize=True, decomposition_kwargs={}, **kwargs)[source]

Create a tensorized embedding layer from a regular embedding layer

Parameters:
embedding_layertorch.nn.Embedding
rankint tuple or str

tensor rank

factorizationstr

tensor decomposition to use

decompose_weights: bool

decompose weights and use for initialization

auto_tensorize: bool

automatically reshape dimensions for TensorizedTensor

decomposition_kwargs: dict

specify kwargs for the decomposition