8
8
from jax import jit
9
9
import jax .util
10
10
11
+ from varipeps import varipeps_config
11
12
import varipeps .config
12
13
from varipeps .peps import PEPS_Tensor , PEPS_Unit_Cell
13
14
from varipeps .contractions import apply_contraction , Definitions
14
15
from varipeps .expectation .model import Expectation_Model
15
16
from varipeps .expectation .one_site import calc_one_site_multi_gates
16
17
from varipeps .expectation .two_sites import _two_site_workhorse
18
+ from varipeps .expectation .spiral_helpers import apply_unitary
17
19
from varipeps .typing import Tensor
18
20
from varipeps .mapping import Map_To_PEPS_Model
19
21
from varipeps .utils .random import PEPS_Random_Number_Generator
@@ -111,8 +113,18 @@ def __post_init__(self) -> None:
111
113
self ._y_tuple = tuple (self .y_gates )
112
114
self ._z_tuple = tuple (self .z_gates )
113
115
116
+ self ._result_type = (
117
+ jnp .float64
118
+ if all (jnp .allclose (g , g .T .conj ()) for g in self .x_gates )
119
+ and all (jnp .allclose (g , g .T .conj ()) for g in self .y_gates )
120
+ and all (jnp .allclose (g , g .T .conj ()) for g in self .z_gates )
121
+ else jnp .complex128
122
+ )
123
+
114
124
if self .is_spiral_peps :
115
- raise NotImplementedError
125
+ self ._spiral_D , self ._spiral_sigma = jnp .linalg .eigh (
126
+ self .spiral_unitary_operator
127
+ )
116
128
117
129
def __call__ (
118
130
self ,
@@ -124,15 +136,8 @@ def __call__(
124
136
only_unique : bool = True ,
125
137
return_single_gate_results : bool = False ,
126
138
) -> Union [jnp .ndarray , List [jnp .ndarray ]]:
127
- result_type = (
128
- jnp .float64
129
- if all (jnp .allclose (g , jnp .real (g )) for g in self .x_gates )
130
- and all (jnp .allclose (g , jnp .real (g )) for g in self .y_gates )
131
- and all (jnp .allclose (g , jnp .real (g )) for g in self .z_gates )
132
- else jnp .complex128
133
- )
134
139
result = [
135
- jnp .array (0 , dtype = result_type )
140
+ jnp .array (0 , dtype = self . _result_type )
136
141
for _ in range (
137
142
max (
138
143
len (self .x_gates ),
@@ -142,6 +147,44 @@ def __call__(
142
147
)
143
148
]
144
149
150
+ if self .is_spiral_peps :
151
+ if isinstance (spiral_vectors , jnp .ndarray ):
152
+ spiral_vectors = (spiral_vectors ,)
153
+ if len (spiral_vectors ) != 1 :
154
+ raise ValueError ("Length mismatch for spiral vectors!" )
155
+
156
+ working_h_gates = tuple (
157
+ apply_unitary (
158
+ h ,
159
+ jnp .array ((0 , 1 )),
160
+ spiral_vectors ,
161
+ self ._spiral_D ,
162
+ self ._spiral_sigma ,
163
+ self .real_d ,
164
+ 2 ,
165
+ (1 ,),
166
+ varipeps_config .spiral_wavevector_type ,
167
+ )
168
+ for h in self ._y_tuple
169
+ )
170
+ working_v_gates = tuple (
171
+ apply_unitary (
172
+ v ,
173
+ jnp .array ((1 , 0 )),
174
+ spiral_vectors ,
175
+ self ._spiral_D ,
176
+ self ._spiral_sigma ,
177
+ self .real_d ,
178
+ 2 ,
179
+ (1 ,),
180
+ varipeps_config .spiral_wavevector_type ,
181
+ )
182
+ for v in self ._z_tuple
183
+ )
184
+ else :
185
+ working_h_gates = self ._y_tuple
186
+ working_v_gates = self ._z_tuple
187
+
145
188
for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
146
189
for y , view in iter_rows :
147
190
# On site x term
@@ -196,8 +239,8 @@ def __call__(
196
239
step_result_y = _two_site_workhorse (
197
240
density_matrix_left ,
198
241
density_matrix_right ,
199
- self . _y_tuple ,
200
- result_type is jnp .float64 ,
242
+ working_h_gates ,
243
+ self . _result_type is jnp .float64 ,
201
244
)
202
245
203
246
for sr_i , sr in enumerate (step_result_y ):
@@ -241,8 +284,8 @@ def __call__(
241
284
step_result_z = _two_site_workhorse (
242
285
density_matrix_top ,
243
286
density_matrix_bottom ,
244
- self . _z_tuple ,
245
- result_type is jnp .float64 ,
287
+ working_v_gates ,
288
+ self . _result_type is jnp .float64 ,
246
289
)
247
290
248
291
for sr_i , sr in enumerate (step_result_z ):
0 commit comments