Basic tensor operations

Example on how to use tensorly.base to perform basic tensor operations.

import matplotlib.pyplot as plt
from tensorly.base import unfold, fold
import numpy as np
import tensorly.backend as T

A tensor is simply a numpy array

tensor = T.tensor(np.arange(24).reshape((3, 4, 2)))
print('* original tensor:\n{}'.format(tensor))

Out:

* original tensor:
[[[ 0.  1.]
  [ 2.  3.]
  [ 4.  5.]
  [ 6.  7.]]

 [[ 8.  9.]
  [10. 11.]
  [12. 13.]
  [14. 15.]]

 [[16. 17.]
  [18. 19.]
  [20. 21.]
  [22. 23.]]]

Unfolding a tensor is easy

for mode in range(tensor.ndim):
    print('* mode-{} unfolding:\n{}'.format(mode, unfold(tensor, mode)))

Out:

* mode-0 unfolding:
[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]]
* mode-1 unfolding:
[[ 0.  1.  8.  9. 16. 17.]
 [ 2.  3. 10. 11. 18. 19.]
 [ 4.  5. 12. 13. 20. 21.]
 [ 6.  7. 14. 15. 22. 23.]]
* mode-2 unfolding:
[[ 0.  2.  4.  6.  8. 10. 12. 14. 16. 18. 20. 22.]
 [ 1.  3.  5.  7.  9. 11. 13. 15. 17. 19. 21. 23.]]

Re-folding the tensor is as easy:

for mode in range(tensor.ndim):
    unfolding = unfold(tensor, mode)
    folded = fold(unfolding, mode, tensor.shape)
    T.assert_array_equal(folded, tensor)

Total running time of the script: ( 0 minutes 0.004 seconds)

Generated by Sphinx-Gallery