Source code for tensorly.tenalg.core_tenalg.contraction

import numpy as np
from ... import backend as tl

# Author: Jean Kossaifi

# License: BSD 3 clause

[docs]def contract(tensor1, modes1, tensor2, modes2): """Tensor contraction between two tensors on specified modes Parameters ---------- tensor1 : tl.tensor modes1 : int list or int modes on which to contract tensor1 tensor2 : tl.tensor modes2 : int list or int modes on which to contract tensor2 Returns ------- contraction : tensor1 contracted with tensor2 on the specified modes """ if isinstance(modes1, int): modes1 = [modes1] if isinstance(modes2, int): modes2 = [modes2] modes1 = list(modes1) modes2 = list(modes2) if len(modes1) != len(modes2): raise ValueError('Can only contract two tensors along the same number of modes' '(len(modes1) == len(modes2))' 'However, got {} modes for tensor 1 and {} mode for tensor 2' '(modes1={}, and modes2={})'.format( len(modes1), len(modes2), modes1, modes2)) contraction_dims = [tl.shape(tensor1)[i] for i in modes1] if contraction_dims != [tl.shape(tensor2)[i] for i in modes2]: raise ValueError('Trying to contract tensors over modes of different sizes' '(contracting modes of sizes {} and {}'.format( contraction_dims, [tl.shape(tensor2)[i] for i in modes2])) shared_dim = int(np.prod(contraction_dims)) modes1_free = [i for i in range(tl.ndim(tensor1)) if i not in modes1] free_shape1 = [tl.shape(tensor1)[i] for i in modes1_free] tensor1 = tl.reshape(tl.transpose(tensor1, modes1_free + modes1), (int(np.prod(free_shape1)), shared_dim)) modes2_free = [i for i in range(tl.ndim(tensor2)) if i not in modes2] free_shape2 = [tl.shape(tensor2)[i] for i in modes2_free] tensor2 = tl.reshape(tl.transpose(tensor2, modes2 + modes2_free), (shared_dim, int(np.prod(free_shape2)))) res = tl.dot(tensor1, tensor2) return tl.reshape(res, tuple(free_shape1 + free_shape2))