@@ -82,6 +82,7 @@ class Triangular_Expectation_Value(Expectation_Model):
82
82
"""
83
83
84
84
nearest_neighbor_gates : Sequence [jnp .ndarray ]
85
+ real_d : int
85
86
normalization_factor : int = 1
86
87
87
88
is_spiral_peps : bool = False
@@ -90,6 +91,22 @@ class Triangular_Expectation_Value(Expectation_Model):
90
91
def __post_init__ (self ) -> None :
91
92
if isinstance (self .nearest_neighbor_gates , jnp .ndarray ):
92
93
self .nearest_neighbor_gates = (self .nearest_neighbor_gates ,)
94
+ else :
95
+ self .nearest_neighbor_gates = tuple (self .nearest_neighbor_gates )
96
+
97
+ self ._result_type = (
98
+ jnp .float64
99
+ if all (
100
+ jnp .allclose (g , g .T .conj ())
101
+ for g in self .nearest_neighbor_gates
102
+ )
103
+ else jnp .complex128
104
+ )
105
+
106
+ if self .is_spiral_peps :
107
+ self ._spiral_D , self ._spiral_sigma = jnp .linalg .eigh (
108
+ self .spiral_unitary_operator
109
+ )
93
110
94
111
def __call__ (
95
112
self ,
@@ -101,32 +118,84 @@ def __call__(
101
118
only_unique : bool = True ,
102
119
return_single_gate_results : bool = False ,
103
120
) -> Union [jnp .ndarray , List [jnp .ndarray ]]:
104
- result_type = (
105
- jnp .float64
106
- if all (jnp .allclose (g , jnp .real (g )) for g in self .nearest_neighbor_gates )
107
- else jnp .complex128
108
- )
109
121
result = [
110
- jnp .array (0 , dtype = result_type )
122
+ jnp .array (0 , dtype = self . _result_type )
111
123
for _ in range (len (self .nearest_neighbor_gates ))
112
124
]
113
125
126
+ if self .is_spiral_peps :
127
+ if (
128
+ isinstance (spiral_vectors , collections .abc .Sequence )
129
+ and len (spiral_vectors ) == 1
130
+ ):
131
+ spiral_vectors = spiral_vectors [0 ]
132
+
133
+ if not isinstance (spiral_vectors , jnp .ndarray ):
134
+ raise ValueError ("Expect spiral vector as single jax.numpy array." )
135
+
136
+ working_h_gates = tuple (
137
+ apply_unitary (
138
+ h ,
139
+ jnp .array ((0 , 1 )),
140
+ (spiral_vectors ,),
141
+ self ._spiral_D ,
142
+ self ._spiral_sigma ,
143
+ self .real_d ,
144
+ 2 ,
145
+ (1 ,),
146
+ varipeps_config .spiral_wavevector_type ,
147
+ )
148
+ for h in self .nearest_neighbor_gates
149
+ )
150
+ working_v_gates = tuple (
151
+ apply_unitary (
152
+ v ,
153
+ jnp .array ((1 , 0 )),
154
+ (spiral_vectors ,),
155
+ self ._spiral_D ,
156
+ self ._spiral_sigma ,
157
+ self .real_d ,
158
+ 2 ,
159
+ (1 ,),
160
+ varipeps_config .spiral_wavevector_type ,
161
+ )
162
+ for v in self .nearest_neighbor_gates
163
+ )
164
+ working_d_gates = tuple (
165
+ apply_unitary (
166
+ d ,
167
+ jnp .array ((1 , 1 )),
168
+ (spiral_vectors ,),
169
+ self ._spiral_D ,
170
+ self ._spiral_sigma ,
171
+ self .real_d ,
172
+ 2 ,
173
+ (1 ,),
174
+ varipeps_config .spiral_wavevector_type ,
175
+ )
176
+ for d in self .nearest_neighbor_gates
177
+ )
178
+ else :
179
+ working_h_gates = self .nearest_neighbor_gates
180
+ working_v_gates = self .nearest_neighbor_gates
181
+ working_d_gates = self .nearest_neighbor_gates
182
+
114
183
for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
115
184
for y , view in iter_rows :
116
185
x_tensors_i = view .get_indices ((slice (0 , 2 , None ), 0 ))
117
186
x_tensors = [peps_tensors [i ] for j in x_tensors_i for i in j ]
118
187
x_tensor_objs = [t for tl in view [:2 , 0 ] for t in tl ]
119
188
120
189
step_result_x = calc_two_sites_vertical_multiple_gates (
121
- x_tensors , x_tensor_objs , self . nearest_neighbor_gates
190
+ x_tensors , x_tensor_objs , working_v_gates
122
191
)
123
192
124
193
y_tensors_i = view .get_indices ((0 , slice (0 , 2 , None )))
125
194
y_tensors = [peps_tensors [i ] for j in y_tensors_i for i in j ]
126
195
y_tensor_objs = [t for tl in view [0 , :2 ] for t in tl ]
127
196
128
197
step_result_y = calc_two_sites_horizontal_multiple_gates (
129
- y_tensors , y_tensor_objs , self . nearest_neighbor_gates
198
+ y_tensors , y_tensor_objs , working_h_gates
130
199
)
131
200
132
201
diagonal_tensors_i = view .get_indices (
@@ -141,7 +210,7 @@ def __call__(
141
210
calc_two_sites_diagonal_top_left_bottom_right_multiple_gates (
142
211
diagonal_tensors ,
143
212
diagonal_tensor_objs ,
144
- self . nearest_neighbor_gates ,
213
+ working_d_gates ,
145
214
)
146
215
)
147
216
0 commit comments