Skip to content

Commit 09e0ae6

Browse files
committed
Implement spiral ansatz for Honeycomb
1 parent f1af7f5 commit 09e0ae6

File tree

1 file changed

+56
-13
lines changed

1 file changed

+56
-13
lines changed

varipeps/mapping/honeycomb.py

+56-13
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
from jax import jit
99
import jax.util
1010

11+
from varipeps import varipeps_config
1112
import varipeps.config
1213
from varipeps.peps import PEPS_Tensor, PEPS_Unit_Cell
1314
from varipeps.contractions import apply_contraction, Definitions
1415
from varipeps.expectation.model import Expectation_Model
1516
from varipeps.expectation.one_site import calc_one_site_multi_gates
1617
from varipeps.expectation.two_sites import _two_site_workhorse
18+
from varipeps.expectation.spiral_helpers import apply_unitary
1719
from varipeps.typing import Tensor
1820
from varipeps.mapping import Map_To_PEPS_Model
1921
from varipeps.utils.random import PEPS_Random_Number_Generator
@@ -111,8 +113,18 @@ def __post_init__(self) -> None:
111113
self._y_tuple = tuple(self.y_gates)
112114
self._z_tuple = tuple(self.z_gates)
113115

116+
self._result_type = (
117+
jnp.float64
118+
if all(jnp.allclose(g, g.T.conj()) for g in self.x_gates)
119+
and all(jnp.allclose(g, g.T.conj()) for g in self.y_gates)
120+
and all(jnp.allclose(g, g.T.conj()) for g in self.z_gates)
121+
else jnp.complex128
122+
)
123+
114124
if self.is_spiral_peps:
115-
raise NotImplementedError
125+
self._spiral_D, self._spiral_sigma = jnp.linalg.eigh(
126+
self.spiral_unitary_operator
127+
)
116128

117129
def __call__(
118130
self,
@@ -124,15 +136,8 @@ def __call__(
124136
only_unique: bool = True,
125137
return_single_gate_results: bool = False,
126138
) -> Union[jnp.ndarray, List[jnp.ndarray]]:
127-
result_type = (
128-
jnp.float64
129-
if all(jnp.allclose(g, jnp.real(g)) for g in self.x_gates)
130-
and all(jnp.allclose(g, jnp.real(g)) for g in self.y_gates)
131-
and all(jnp.allclose(g, jnp.real(g)) for g in self.z_gates)
132-
else jnp.complex128
133-
)
134139
result = [
135-
jnp.array(0, dtype=result_type)
140+
jnp.array(0, dtype=self._result_type)
136141
for _ in range(
137142
max(
138143
len(self.x_gates),
@@ -142,6 +147,44 @@ def __call__(
142147
)
143148
]
144149

150+
if self.is_spiral_peps:
151+
if isinstance(spiral_vectors, jnp.ndarray):
152+
spiral_vectors = (spiral_vectors,)
153+
if len(spiral_vectors) != 1:
154+
raise ValueError("Length mismatch for spiral vectors!")
155+
156+
working_h_gates = tuple(
157+
apply_unitary(
158+
h,
159+
jnp.array((0, 1)),
160+
spiral_vectors,
161+
self._spiral_D,
162+
self._spiral_sigma,
163+
self.real_d,
164+
2,
165+
(1,),
166+
varipeps_config.spiral_wavevector_type,
167+
)
168+
for h in self._y_tuple
169+
)
170+
working_v_gates = tuple(
171+
apply_unitary(
172+
v,
173+
jnp.array((1, 0)),
174+
spiral_vectors,
175+
self._spiral_D,
176+
self._spiral_sigma,
177+
self.real_d,
178+
2,
179+
(1,),
180+
varipeps_config.spiral_wavevector_type,
181+
)
182+
for v in self._z_tuple
183+
)
184+
else:
185+
working_h_gates = self._y_tuple
186+
working_v_gates = self._z_tuple
187+
145188
for x, iter_rows in unitcell.iter_all_rows(only_unique=only_unique):
146189
for y, view in iter_rows:
147190
# On site x term
@@ -196,8 +239,8 @@ def __call__(
196239
step_result_y = _two_site_workhorse(
197240
density_matrix_left,
198241
density_matrix_right,
199-
self._y_tuple,
200-
result_type is jnp.float64,
242+
working_h_gates,
243+
self._result_type is jnp.float64,
201244
)
202245

203246
for sr_i, sr in enumerate(step_result_y):
@@ -241,8 +284,8 @@ def __call__(
241284
step_result_z = _two_site_workhorse(
242285
density_matrix_top,
243286
density_matrix_bottom,
244-
self._z_tuple,
245-
result_type is jnp.float64,
287+
working_v_gates,
288+
self._result_type is jnp.float64,
246289
)
247290

248291
for sr_i, sr in enumerate(step_result_z):

0 commit comments

Comments
 (0)