import pickle
from pathlib import Path
import tensorly as tl
from .second_order_cumulant import SecondOrderCumulant
from .third_order_cumulant import ThirdOrderCumulant
[docs]
class TLDA():
"""
Class to learn topic-word distribution from a corpus of documents
"""
def __init__(self, n_topic, alpha_0, n_iter_train, n_iter_test, learning_rate,
pca_batch_size=10000, third_order_cumulant_batch=1000 , gamma_shape=1.0, smoothing=1e-6,
theta=1, ortho_loss_criterion=1000, n_eigenvec = None, random_seed=None):
"""
Parameters
----------
n_topic : int
alpha : int
n_iter_train : int
n_iter_test : int
learning_rate : float
pca_batch_size : int
third_order_cumulant_batch : int
random_seed: optional, default is None
"""
self.n_topic = n_topic
self.alpha_0 = alpha_0
self.smoothing = smoothing
self.third_order_cumulant_batch = third_order_cumulant_batch
if n_eigenvec is None:
n_eigenvec = n_topic
self.n_eigenvec = n_eigenvec
self.weights_ = tl.ones(self.n_topic)
self.vocab = 0
self.n_documents = 0
self.mean = None
self.unwhitened_factors_ = None
self.second_order = SecondOrderCumulant(n_eigenvec, alpha_0, pca_batch_size)
self.third_order = ThirdOrderCumulant(n_topic, alpha_0, n_iter_train, n_iter_test, third_order_cumulant_batch,
learning_rate, gamma_shape, theta, ortho_loss_criterion, random_seed, n_eigenvec = n_eigenvec)
[docs]
def fit(self, X, order = None):
"""
Compute the word-topic distribution for the entire dataset at once. Assumes that the whole dataset and
the tensors required to compute its word-topic distribution fit in memory.
Parameters
----------
X: tensor of size (self.n_documents , self.vocab) all documents used to fit the word-topic distribution
"""
if order is None or order == 1:
self.n_documents = X.shape[0]
self.vocab = X.shape[1]
self.mean = tl.mean(X, axis=0)
if order is None or order == 2:
self.second_order.fit(X - self.mean)
if order is None or order == 3:
X_whit = self.second_order.transform(X - self.mean)
self.third_order.fit(X_whit,verbose=False)
del X_whit
del X
def _partial_fit_first_order(self, X_batch):
if self.mean is None:
self.vocab = X_batch.shape[1]
self.mean = tl.mean(X_batch, axis=0)
else:
self.mean = ((self.mean * self.n_documents) + tl.sum(X_batch, axis=0)) / (self.n_documents + X_batch.shape[0])
self.n_documents += X_batch.shape[0]
del X_batch
def _partial_fit_second_order(self, X_batch):
for j in range(0, len(X_batch), self.second_order.batch_size):
y = X_batch[j:j+self.second_order.batch_size]
self.second_order.partial_fit(y - self.mean)
del y
del X_batch
def _partial_fit_third_order(self, X_batch):
for j in range(0, len(X_batch), self.third_order_cumulant_batch):
y = X_batch[j:j+self.third_order_cumulant_batch]
self.third_order.partial_fit(y)
del y
del X_batch
[docs]
def partial_fit(self, X_batch, batch_index, save_folder=None):
"""
Update the word-topic distribution using a batch of documents. For a given batch, the
first and second order cumulants need to be fit once, but the third order cumulant should
be fit many times.
Parameters
----------
X_batch : tensor of shape (batch_size, self.vocab)
batch_index : int
index of the current batch.
This is used to know whether to update the first and second moment or just whiten
save_folder : str, default is None
Folder in which to store the whitened batches.
If None, the whitened batches will be recomputed at each iteration
instead of being catched.
"""
if not hasattr(self, "seen_batches"):
self.seen_batches = dict()
if batch_index in self.seen_batches:
# We've seen the batch at least once
if self.seen_batches[batch_index] != 0:
# We already whitened it, just load that
if save_folder:
save_file = self.seen_batches[batch_index]
X_batch = pickle.load(open(Path(save_folder).joinpath(save_file).as_posix(),'rb'))
else:
X_batch = self.second_order.transform(X_batch - self.mean)
else:
# We only saw it once: that whitened version is not exact, recompute
X_batch = self.second_order.transform(X_batch - self.mean)
if save_folder is not None:
save_file = f'_whitened_batch_{batch_index}'
self.seen_batches[batch_index] = save_file
pickle.dump(X_batch, open(Path(save_folder).joinpath(save_file).as_posix(), 'wb'))
else:
self.seen_batches[batch_index] = 1
self._partial_fit_third_order(X_batch)
else:
# First time we see the batch: recompute the whitened version next time
self._partial_fit_first_order(X_batch)
self._partial_fit_second_order(X_batch)
self.seen_batches[batch_index] = 0
[docs]
def partial_fit_online(self, X_batch):
"""
Update the word-topic distribution using a batch of documents in a fully online version. Meant for very large datasets,
since we only do one gradient update for each batch in the third order cumulant calculation.
Parameters
----------
X_batch : tensor of shape (batch_size, self.vocab)
"""
self._partial_fit_first_order(X_batch)
self._partial_fit_second_order(X_batch)
X_whit = self.second_order.transform(X_batch - self.mean)
del X_batch
self._partial_fit_third_order(X_whit)
del X_whit
def _unwhiten_factors(self):
"""Unwhitens self.third_order.factors_, then uncenters and unnormalizes"""
factors_unwhitened = self.second_order.reverse_transform(self.third_order.factors_.T).T
# Un-centers the data
factors_unwhitened += tl.reshape(self.mean,(self.vocab,1))
factors_unwhitened [factors_unwhitened < 0.] = 0. # remove non-negative probabilities
# Save unwhitened factors before postprocessing
self.unwhitened_factors_raw_ = tl.copy(factors_unwhitened)
# Smoothing
factors_unwhitened *= (1. - self.smoothing)
factors_unwhitened += (self.smoothing / factors_unwhitened.shape[1])
# Calculate the eigenvalues from the whitened factors
eig_vals = tl.tensor([tl.norm(k)**3 for k in self.third_order.factors_ ])
alpha = eig_vals**(-2)
# Recover the topic weights
alpha_norm = (alpha / alpha.sum()) * self.alpha_0
self.weights_ = tl.tensor(alpha_norm)
# Normalize the factors
factors_unwhitened /= factors_unwhitened.sum(axis=0)
return factors_unwhitened
@property
def unwhitened_factors(self): # This doesnt work
"""Unwhitened learned factors of shape (n_topic, vocabulary_size)
On the first call, this will compute and store the unwhitened factors.
Subsequent calls will simply return the stored value.
"""
if self.unwhitened_factors_ is None:
self.unwhitened_factors_ = self._unwhiten_factors()
return self.unwhitened_factors_