@@ -99,6 +99,24 @@ def calc_ctmrg_expectation(
99
99
peps_tensors , unitcell , spiral_vectors = _map_tensors (
100
100
input_tensors , unitcell , convert_to_unitcell_func , True
101
101
)
102
+
103
+ if any (i .size == 1 for i in spiral_vectors ):
104
+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
105
+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
106
+ if spiral_vectors_x is not None :
107
+ if isinstance (spiral_vectors_x , jnp .ndarray ):
108
+ spiral_vectors_x = (spiral_vectors_x ,)
109
+ spiral_vectors = tuple (
110
+ jnp .array ((sx , sy ))
111
+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
112
+ )
113
+ elif spiral_vectors_y is not None :
114
+ if isinstance (spiral_vectors_y , jnp .ndarray ):
115
+ spiral_vectors_y = (spiral_vectors_y ,)
116
+ spiral_vectors = tuple (
117
+ jnp .array ((sx , sy ))
118
+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
119
+ )
102
120
else :
103
121
peps_tensors , unitcell = _map_tensors (
104
122
input_tensors , unitcell , convert_to_unitcell_func , False
@@ -175,6 +193,24 @@ def calc_preconverged_ctmrg_value_and_grad(
175
193
peps_tensors , unitcell , spiral_vectors = _map_tensors (
176
194
input_tensors , unitcell , convert_to_unitcell_func , True
177
195
)
196
+
197
+ if any (i .size == 1 for i in spiral_vectors ):
198
+ spiral_vectors_x = additional_input .get ("spiral_vectors_x" )
199
+ spiral_vectors_y = additional_input .get ("spiral_vectors_y" )
200
+ if spiral_vectors_x is not None :
201
+ if isinstance (spiral_vectors_x , jnp .ndarray ):
202
+ spiral_vectors_x = (spiral_vectors_x ,)
203
+ spiral_vectors = tuple (
204
+ jnp .array ((sx , sy ))
205
+ for sx , sy in safe_zip (spiral_vectors_x , spiral_vectors )
206
+ )
207
+ elif spiral_vectors_y is not None :
208
+ if isinstance (spiral_vectors_y , jnp .ndarray ):
209
+ spiral_vectors_y = (spiral_vectors_y ,)
210
+ spiral_vectors = tuple (
211
+ jnp .array ((sx , sy ))
212
+ for sx , sy in safe_zip (spiral_vectors , spiral_vectors_y )
213
+ )
178
214
else :
179
215
peps_tensors , unitcell = _map_tensors (
180
216
input_tensors , unitcell , convert_to_unitcell_func , False
0 commit comments