import tensorly as tl
# Author: Jean Kossaifi <jkossaifi@nvidia.com>
# Author: Taylor Lee Patti <taylorpatti@g.harvard.edu>
# License: BSD 3 clause
[docs]def tt_sum(t1, t2):
"""Sums two TT tensors in decomposed form
Parameters
----------
t1 : tt-tensor
t2 : tt-tensor
Returns
-------
tt-tensor sum of t1 and t2
Notes
-----
The solution can be easily seen by writing the element-wise expression.
The sum of two third order cores A and B becomes a new core::
| A(i) 0 |
| 0 B(i)|
In the code, we first form the two columns which we then concatenate::
| A(i) | | 0 |
| 0 | | B(i)|
"""
new_tt, n_cores, device = [], len(t1), t1[0].device
for i, (core1, core2) in enumerate(zip(t1, t2)):
if i == 0: # First core is (1, I_1, R_0)
core = tl.concatenate((core1, core2), axis=2)
elif i == (n_cores - 1): # Last core is (I_N, R_N, 1)
core = tl.concatenate((core1, core2), axis=0)
else: # 3rd order cores (R_k, I_k, R_{k+1})
padded_c1 = tl.concatenate(
(core1, tl.zeros((t2.rank[i], core1.shape[1], t1.rank[i+1]), device=device)),
axis=0
)
padded_c2 = tl.concatenate(
(tl.zeros((t1.rank[i], core1.shape[1], t2.rank[i+1]), device=device), core2),
axis=0
)
core = tl.concatenate((padded_c1, padded_c2), axis=2)
new_tt.append(core)
return tl.tt_tensor.TTTensor(new_tt)
[docs]def tt_matrix_sum(t1, t2):
"""Sums two TT matrices in decomposed form
Parameters
----------
t1 : tt-tensor matrix
t2 : tt-tensor matrix
Returns
-------
tt-tensor matrix sum of t1 and t2
"""
if t1 == []:
return tl.tt_matrix.TTMatrix(t2)
if t2 == []:
return tl.tt_matrix.TTMatrix(t1)
t1, t2 = tl.tt_matrix.TTMatrix(t1), tl.tt_matrix.TTMatrix(t2)
new_tt, n_cores, device = [], len(t1), t1[0].device
for i, (core1, core2) in enumerate(zip(t1, t2)):
if i == 0:
core = tl.concatenate((core1, core2), axis=3)
elif i == (n_cores - 1):
core = tl.concatenate((core1, core2), axis=0)
else:
padded_c1 = tl.concatenate(
(core1, tl.zeros((t2.rank[i], core1.shape[1], core1.shape[2], t1.rank[i+1]), device=device)),
axis=0
)
padded_c2 = tl.concatenate(
(tl.zeros((t1.rank[i], core1.shape[1], core1.shape[2], t2.rank[i+1]), device=device), core2),
axis=0
)
core = tl.concatenate((padded_c1, padded_c2), axis=3)
new_tt.append(core)
return tl.tt_matrix.TTMatrix(new_tt)