Source code for tlda.second_order_cumulant
import tensorly as tl
from sklearn.decomposition import IncrementalPCA
try:
import cuml
except ImportError:
pass
[docs]
class SecondOrderCumulant():
"""
Class to compute the third order cumulant
"""
def __init__(self, n_eigenvec, alpha_0, batch_size): # n_eigenvec here corresponds to n_topic in the LDA
"""
Computes the second order cumulant from centered batches of data, returns the whitened tensor
Parameters
----------
n_eigenvec : int Corresponds to the number of topics in the Tensor LDA
alpha_0 : int Mixing parameter for the topic weights
batch_size : int Size of the batch to use for online learning
n_docs : int Running count of fitted documents. Used for normalization
"""
self.n_eigenvec = n_eigenvec
self.alpha_0 = alpha_0
self.batch_size = batch_size
self.n_docs = 0
if tl.get_backend() == "numpy":
self.pca = IncrementalPCA(n_components = self.n_eigenvec, batch_size = self.batch_size)
elif tl.get_backend() == "cupy":
self.pca = cuml.IncrementalPCA(n_components = self.n_eigenvec, batch_size = self.batch_size)
[docs]
def fit(self, X):
'''
Method to fit the entire data to get the projection weights (singular vectors) and
whitening weights (scaled explained variance) of a centered input dataset X.
Parameters
----------
X : tensor of shape (n_samples, vocabulary_size)
Tensor containing all input documents
'''
self.n_docs += X.shape[0]
self.pca.fit(X*tl.sqrt(self.alpha_0+1))
self.projection_weights_ = tl.transpose(self.pca.components_)
self.whitening_weights_ = self.pca.explained_variance_*(self.n_docs - 1)/(self.n_docs)
del X
[docs]
def partial_fit(self, X_batch):
'''Fit a batch of data and update the projection weights (singular vectors) and
whitening weights (scaled explained variance) accordingly using a centered
batch of the input dataset X.
Parameters
----------
X_batch : tensor of shape (batch_size, vocabulary_size)
Tensor containing a batch of input documents
'''
self.n_docs += X_batch.shape[0]
self.pca.partial_fit(X_batch*tl.sqrt(self.alpha_0+1))
self.projection_weights_ = tl.transpose(self.pca.components_)
self.whitening_weights_ = self.pca.explained_variance_*(self.n_docs - 1)/(self.n_docs)
del X_batch