Skip to content

Commit 5ecd2cc

Browse files
committed
Implement next-nearest interaction for triangular lattice
1 parent 45e9506 commit 5ecd2cc

File tree

8 files changed

+718
-6
lines changed

8 files changed

+718
-6
lines changed
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
.. _varipeps_expectation_four_sites:
2+
3+
.. currentmodule:: varipeps.expectation.four_sites
4+
5+
Calculation of four sites expectation values
6+
============================================
7+
8+
.. automodule:: varipeps.expectation.four_sites
9+
:members:
10+
:undoc-members:
11+
:show-inheritance:
12+
:special-members: __call__

docs/source/api/expectation/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Calculation of expectation values (:mod:`varipeps.expectation`)
1414
one_site
1515
two_sites
1616
three_sites
17+
four_sites
1718
spiral_helpers
1819

1920
.. automodule:: varipeps.expectation

varipeps/contractions/definitions.py

+61
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,67 @@ def _prepare_defs(cls):
503503
],
504504
}
505505

506+
density_matrix_two_sites_horizontal_rectangle: Definition = {
507+
"tensors": [
508+
["tensor", "tensor_conj", "T1"], # top_middle
509+
["tensor", "tensor_conj", "T3"], # bottom_middle
510+
"top_left",
511+
"top_right",
512+
"bottom_left",
513+
"bottom_right",
514+
],
515+
"network": [
516+
[
517+
(5, 11, 9, 19, 6), # tensor
518+
(7, 13, 9, 20, 8), # tensor_conj
519+
(4, 6, 8, 18), # T1
520+
],
521+
[
522+
(10, 16, 14, 25, 11), # tensor
523+
(12, 17, 14, 26, 13), # tensor_conj
524+
(15, 24, 17, 16), # T3
525+
],
526+
(-1, -3, 1, 2, 3, 4, 5, 7), # top_left
527+
(18, 19, 20, 21, 22, 23), # top_right
528+
(
529+
15,
530+
12,
531+
10,
532+
1,
533+
2,
534+
3,
535+
), # bottom_left
536+
(-2, -4, 21, 22, 23, 24, 26, 25), # bottom_right
537+
],
538+
}
539+
540+
density_matrix_two_sites_vertical_rectangle: Definition = {
541+
"tensors": [
542+
["tensor", "tensor_conj", "T4"],
543+
["tensor", "tensor_conj", "T2"],
544+
"top_left",
545+
"top_right",
546+
"bottom_left",
547+
"bottom_right",
548+
],
549+
"network": [
550+
[
551+
(6, 19, 9, 11, 5), # tensor
552+
(8, 20, 9, 13, 7), # tensor_conj
553+
(18, 8, 6, 4), # T4
554+
],
555+
[
556+
(11, 22, 14, 16, 10), # tensor
557+
(13, 23, 14, 17, 12), # tensor_conj
558+
(16, 17, 21, 15), # T2
559+
],
560+
(-1, -3, 4, 5, 7, 1, 2, 3), # top_left
561+
(1, 2, 3, 15, 12, 10), # top_right
562+
(24, 25, 26, 18, 19, 20), # bottom_left
563+
(-2, -4, 21, 23, 22, 24, 25, 26), # bottom_right
564+
],
565+
}
566+
506567
density_matrix_four_sites_top_left: Definition = {
507568
"tensors": [["tensor", "tensor_conj", "T4", "C1", "T1"]],
508569
"network": [

varipeps/expectation/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from . import one_site
55
from . import two_sites
66
from . import three_sites
7+
from . import four_sites
78

89
from .model import Expectation_Model
910
from .one_site import One_Site_Expectation_Value

varipeps/expectation/four_sites.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from functools import partial
2+
3+
import jax.numpy as jnp
4+
from jax import jit
5+
6+
from varipeps.peps import PEPS_Tensor
7+
from varipeps.contractions import apply_contraction_jitted
8+
9+
from typing import Sequence, List, Tuple, Literal
10+
11+
Corner_Literal = Literal["top-left", "top-right", "bottom-left", "bottom-right"]
12+
13+
14+
@partial(jit, static_argnums=(5,))
15+
def _four_sites_quadrat_workhorse(
16+
top_left: jnp.ndarray,
17+
top_right: jnp.ndarray,
18+
bottom_left: jnp.ndarray,
19+
bottom_right: jnp.ndarray,
20+
gates: Tuple[jnp.ndarray, ...],
21+
real_result: bool = False,
22+
) -> List[jnp.ndarray]:
23+
density_matrix = jnp.tensordot(top_left, top_right, ((5, 6, 7), (2, 3, 4)))
24+
density_matrix = jnp.tensordot(density_matrix, bottom_left, ((2, 3, 4), (5, 6, 7)))
25+
density_matrix = jnp.tensordot(
26+
density_matrix, bottom_right, ((4, 5, 6, 9, 10, 11), (2, 3, 4, 5, 6, 7))
27+
)
28+
29+
density_matrix = density_matrix.transpose(0, 2, 4, 6, 1, 3, 5, 7)
30+
density_matrix = density_matrix.reshape(
31+
density_matrix.shape[0]
32+
* density_matrix.shape[1]
33+
* density_matrix.shape[2]
34+
* density_matrix.shape[3],
35+
density_matrix.shape[4]
36+
* density_matrix.shape[5]
37+
* density_matrix.shape[6]
38+
* density_matrix.shape[7],
39+
)
40+
41+
norm = jnp.trace(density_matrix)
42+
43+
if real_result:
44+
return [
45+
jnp.real(jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm)
46+
for g in gates
47+
]
48+
else:
49+
return [
50+
jnp.tensordot(density_matrix, g, ((0, 1), (0, 1))) / norm for g in gates
51+
]
52+
53+
54+
def calc_four_sites_quadrat_multiple_gates(
55+
peps_tensors: Sequence[jnp.ndarray],
56+
peps_tensor_objs: Sequence[PEPS_Tensor],
57+
gates: Sequence[jnp.ndarray],
58+
) -> List[jnp.ndarray]:
59+
"""
60+
Calculate the four site expectation values for three as quadrat ordered
61+
PEPS tensor and their environment.
62+
63+
The order of the PEPS sequence have to be
64+
[top-left, top-right, bottom-left, bottom-right].
65+
66+
The gate is applied in the order [top-left, top-right, bottom-left, bottom-right].
67+
68+
Args:
69+
peps_tensors (:term:`sequence` of :obj:`jax.numpy.ndarray`):
70+
The PEPS tensor arrays. Have to be the same objects as the tensor
71+
attribute of the `peps_tensor_obj` argument.
72+
peps_tensor_objs (:term:`sequence` of :obj:`~varipeps.peps.PEPS_Tensor`):
73+
PEPS tensor objects.
74+
gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
75+
Sequence with the gates which should be applied to the PEPS tensors.
76+
Gates are expected to be a matrix with first axis corresponding to
77+
the Hilbert space and the second axis corresponding to the dual room.
78+
Returns:
79+
:obj:`list` of :obj:`jax.numpy.ndarray`:
80+
List with the calculated expectation values of each gate.
81+
"""
82+
density_matrix_top_left = apply_contraction_jitted(
83+
"density_matrix_four_sites_top_left",
84+
[peps_tensors[0]],
85+
[peps_tensor_objs[0]],
86+
[],
87+
)
88+
89+
density_matrix_top_right = apply_contraction_jitted(
90+
"density_matrix_four_sites_top_right",
91+
[peps_tensors[1]],
92+
[peps_tensor_objs[1]],
93+
[],
94+
)
95+
96+
density_matrix_bottom_left = apply_contraction_jitted(
97+
"density_matrix_four_sites_bottom_left",
98+
[peps_tensors[2]],
99+
[peps_tensor_objs[2]],
100+
[],
101+
)
102+
103+
density_matrix_bottom_right = apply_contraction_jitted(
104+
"density_matrix_four_sites_bottom_right",
105+
[peps_tensors[3]],
106+
[peps_tensor_objs[3]],
107+
[],
108+
)
109+
110+
real_result = all(jnp.allclose(g, g.T.conj()) for g in gates)
111+
112+
return _four_sites_quadrat_workhorse(
113+
density_matrix_top_left,
114+
density_matrix_top_right,
115+
density_matrix_bottom_left,
116+
density_matrix_bottom_right,
117+
tuple(gates),
118+
real_result,
119+
)

varipeps/expectation/spiral_helpers.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def apply_unitary(
2929
gate (:obj:`jax.numpy.ndarray`):
3030
The gate which should be updated with the unitary operator.
3131
delta_r (:obj:`jax.numpy.ndarray`):
32-
Vector for the spatial difference.
32+
Vector for the spatial difference. Can be a sequence if the spatial
33+
difference are different for the single indices.
3334
q (:term:`sequence` of :obj:`jax.numpy.ndarray`):
3435
Sequence with the relevant wavevector for the different indices of
3536
the gate.
@@ -52,13 +53,17 @@ def apply_unitary(
5253
:obj:`jax.numpy.ndarray`:
5354
The updated gate with the unitary applied.
5455
"""
55-
if len(q) != len(apply_to_index):
56+
if isinstance(delta_r, jnp.ndarray):
57+
delta_r = (delta_r,) * len(apply_to_index)
58+
59+
if len(q) != len(apply_to_index) or len(q) != len(delta_r):
5660
raise ValueError("Length mismatch!")
5761

5862
working_gate = gate.reshape((phys_d,) * 2 * number_sites)
5963

6064
for index, i in enumerate(apply_to_index):
6165
w_q = q[index]
66+
w_r = delta_r[index]
6267

6368
if w_q.ndim == 0:
6469
w_q = jnp.array((w_q, w_q))
@@ -72,8 +77,8 @@ def apply_unitary(
7277
else:
7378
raise ValueError("Unknown wavevector type!")
7479

75-
# U = jsp.linalg.expm(1j * jnp.pi * jnp.dot(w_q, delta_r) * unitary_operator)
76-
U = jnp.exp(1j * jnp.pi * jnp.dot(w_q, delta_r) * unitary_operator_D)
80+
# U = jsp.linalg.expm(1j * jnp.pi * jnp.dot(w_q, w_r) * unitary_operator)
81+
U = jnp.exp(1j * jnp.pi * jnp.dot(w_q, w_r) * unitary_operator_D)
7782
U = jnp.dot(
7883
unitary_operator_sigma * U[jnp.newaxis, :], unitary_operator_sigma.T.conj()
7984
)

0 commit comments

Comments
 (0)