|
| 1 | +from functools import partial |
| 2 | + |
| 3 | +import jax.numpy as jnp |
| 4 | +from jax import jit |
| 5 | + |
| 6 | +from varipeps.peps import PEPS_Tensor |
| 7 | +from varipeps.contractions import apply_contraction_jitted |
| 8 | + |
| 9 | +from typing import Sequence, List, Tuple, Literal |
| 10 | + |
| 11 | +Corner_Literal = Literal["top-left", "top-right", "bottom-left", "bottom-right"] |
| 12 | + |
| 13 | + |
| 14 | +@partial(jit, static_argnums=(5,)) |
| 15 | +def _four_sites_quadrat_workhorse( |
| 16 | + top_left: jnp.ndarray, |
| 17 | + top_right: jnp.ndarray, |
| 18 | + bottom_left: jnp.ndarray, |
| 19 | + bottom_right: jnp.ndarray, |
| 20 | + gates: Tuple[jnp.ndarray, ...], |
| 21 | + real_result: bool = False, |
| 22 | +) -> List[jnp.ndarray]: |
| 23 | + density_matrix = jnp.tensordot(top_left, top_right, ((5, 6, 7), (2, 3, 4))) |
| 24 | + density_matrix = jnp.tensordot(density_matrix, bottom_left, ((2, 3, 4), (5, 6, 7))) |
| 25 | + density_matrix = jnp.tensordot( |
| 26 | + density_matrix, bottom_right, ((4, 5, 6, 9, 10, 11), (2, 3, 4, 5, 6, 7)) |
| 27 | + ) |
| 28 | + |
| 29 | + density_matrix = density_matrix.transpose(0, 2, 4, 6, 1, 3, 5, 7) |
| 30 | + density_matrix = density_matrix.reshape( |
| 31 | + density_matrix.shape[0] |
| 32 | + * density_matrix.shape[1] |
| 33 | + * density_matrix.shape[2] |
| 34 | + * density_matrix.shape[3], |
| 35 | + density_matrix.shape[4] |
| 36 | + * density_matrix.shape[5] |
| 37 | + * density_matrix.shape[6] |
| 38 | + * density_matrix.shape[7], |
| 39 | + ) |
| 40 | + |
| 41 | + norm = jnp.trace(density_matrix) |
| 42 | + |
| 43 | + if real_result: |
| 44 | + return [ |
| 45 | + jnp.real(jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm) |
| 46 | + for g in gates |
| 47 | + ] |
| 48 | + else: |
| 49 | + return [ |
| 50 | + jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm for g in gates |
| 51 | + ] |
| 52 | + |
| 53 | + |
| 54 | +def calc_four_sites_quadrat_multiple_gates( |
| 55 | + peps_tensors: Sequence[jnp.ndarray], |
| 56 | + peps_tensor_objs: Sequence[PEPS_Tensor], |
| 57 | + gates: Sequence[jnp.ndarray], |
| 58 | +) -> List[jnp.ndarray]: |
| 59 | + """ |
| 60 | + Calculate the four site expectation values for three as quadrat ordered |
| 61 | + PEPS tensor and their environment. |
| 62 | +
|
| 63 | + The order of the PEPS sequence have to be |
| 64 | + [top-left, top-right, bottom-left, bottom-right]. |
| 65 | +
|
| 66 | + The gate is applied in the order [top-left, top-right, bottom-left, bottom-right]. |
| 67 | +
|
| 68 | + Args: |
| 69 | + peps_tensors (:term:`sequence` of :obj:`jax.numpy.ndarray`): |
| 70 | + The PEPS tensor arrays. Have to be the same objects as the tensor |
| 71 | + attribute of the `peps_tensor_obj` argument. |
| 72 | + peps_tensor_objs (:term:`sequence` of :obj:`~varipeps.peps.PEPS_Tensor`): |
| 73 | + PEPS tensor objects. |
| 74 | + gates (:term:`sequence` of :obj:`jax.numpy.ndarray`): |
| 75 | + Sequence with the gates which should be applied to the PEPS tensors. |
| 76 | + Gates are expected to be a matrix with first axis corresponding to |
| 77 | + the Hilbert space and the second axis corresponding to the dual room. |
| 78 | + Returns: |
| 79 | + :obj:`list` of :obj:`jax.numpy.ndarray`: |
| 80 | + List with the calculated expectation values of each gate. |
| 81 | + """ |
| 82 | + density_matrix_top_left = apply_contraction_jitted( |
| 83 | + "density_matrix_four_sites_top_left", |
| 84 | + [peps_tensors[0]], |
| 85 | + [peps_tensor_objs[0]], |
| 86 | + [], |
| 87 | + ) |
| 88 | + |
| 89 | + density_matrix_top_right = apply_contraction_jitted( |
| 90 | + "density_matrix_four_sites_top_right", |
| 91 | + [peps_tensors[1]], |
| 92 | + [peps_tensor_objs[1]], |
| 93 | + [], |
| 94 | + ) |
| 95 | + |
| 96 | + density_matrix_bottom_left = apply_contraction_jitted( |
| 97 | + "density_matrix_four_sites_bottom_left", |
| 98 | + [peps_tensors[2]], |
| 99 | + [peps_tensor_objs[2]], |
| 100 | + [], |
| 101 | + ) |
| 102 | + |
| 103 | + density_matrix_bottom_right = apply_contraction_jitted( |
| 104 | + "density_matrix_four_sites_bottom_right", |
| 105 | + [peps_tensors[3]], |
| 106 | + [peps_tensor_objs[3]], |
| 107 | + [], |
| 108 | + ) |
| 109 | + |
| 110 | + real_result = all(jnp.allclose(g, g.T.conj()) for g in gates) |
| 111 | + |
| 112 | + return _four_sites_quadrat_workhorse( |
| 113 | + density_matrix_top_left, |
| 114 | + density_matrix_top_right, |
| 115 | + density_matrix_bottom_left, |
| 116 | + density_matrix_bottom_right, |
| 117 | + tuple(gates), |
| 118 | + real_result, |
| 119 | + ) |
0 commit comments