Source code for tlquantum.tt_contraction
from opt_einsum.parser import get_symbol
from collections import Counter
# Author: Taylor Lee Patti <taylorpatti@g.harvard.edu>
# Author: Jean Kossaifi <jkossaifi@nvidia.com>
# License: BSD 3 clause
[docs]def contraction_eq(nqubits, nlayers, kept_inds=None, to_ket=False, to_operator=False):
"""Generates einsum contraciton equation.
Parameters
----------
nqubits : int, number of qubits to contract over
nlayers : int, number of layers to contract over
kept_inds : list of inds, qubit indices to keep. If not None, then ptrace equation is generated
Returns
-------
string of the contraction equation.
"""
start = 1
tt_idx = []
for i in range(nqubits):
idx = [start+2*i, start+2*i+1, start+2*i+2]
tt_idx.append(''.join(get_symbol(j) for j in idx))
start2 = start+2+2*i
max_ind = 2*start2 + (nlayers+1)*nqubits + 2
factors_idx = []
for tt in range(nlayers):
for i in range(nqubits):
if i==0:
idx = [start2+2*tt*nqubits, start2+1+2*tt*nqubits, start+1+2*tt*nqubits, start2+2+2*tt*nqubits]
if tt==0:
idx[0] = 0
elif i==nqubits-1:
idx = [start2+2*i+2*tt*nqubits, start2+2*i+1+2*tt*nqubits, start+1+2*i+2*tt*nqubits, start2+2*tt*nqubits]
if tt==0:
idx[-1] = 0
else:
idx = [start2+2*i+2*tt*nqubits, start2+2*i+1+2*tt*nqubits, start+1+2*i+2*tt*nqubits, start2+2*i+2+2*tt*nqubits]
if (kept_inds is not None) and (tt == int(nlayers/2)) and (i in kept_inds):
idx[2] += max_ind
factors_idx.append(''.join(get_symbol(j) for j in idx))
if to_ket:
out_idx = ''.join(tt_idx)+''.join(factors_idx)
counts = Counter(out_idx)
out_idx = ''.join(ind for ind, count in counts.items() if count == 1)
return ','.join(i for i in factors_idx) + ',' + ','.join(i for i in tt_idx) + '->' + out_idx
if to_operator:
out_idx = ''.join(factors_idx)
counts = Counter(out_idx)
out_idx = ''.join(ind for ind, count in counts.items() if count == 1)
if nlayers == 1:
out_idx = [out_idx[ind] for ind in range(1, len(out_idx), 2)] + [out_idx[ind] for ind in range(0, len(out_idx), 2)]
out_idx = ''.join(out_idx)
return ','.join(i for i in factors_idx) + '->' + out_idx
start_phys = start2+1+2*(nlayers-1)*nqubits
start_virt = start2+2*(nqubits-1)+1+2*(nlayers-1)*nqubits + 1
measure_idx = []
for i in range(nqubits):
idx = [start_virt+i-1, start_phys+2*i, start_virt+i]
if i==0:
idx[0] = start
if i==nqubits-1:
idx[-1] = start2
measure_idx.append(''.join(get_symbol(j) for j in idx))
if kept_inds is not None:
out_idx = ''.join(tt_idx)+''.join(factors_idx)+''.join(measure_idx)
counts = Counter(out_idx)
out_idx = ''.join(ind for ind, count in counts.items() if count == 1)
return ','.join(i for i in measure_idx) + ',' + ','.join(i for i in factors_idx) + ',' + ','.join(i for i in tt_idx) + '->' + out_idx
return ','.join(i for i in measure_idx) + ',' + ','.join(i for i in factors_idx) + ',' + ','.join(i for i in tt_idx) + '-> b'