@@ -64,12 +64,15 @@ class Triangular_Expectation_Value(Expectation_Model):
64
64
\\
65
65
66
66
Args:
67
- nearest_neighbor_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
68
- Sequence with the gates that should be applied to each nearest
67
+ horizontal_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
68
+ Sequence with the gates that should be applied to each nearest horizontal
69
+ neighbor.
70
+ vertical_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
71
+ Sequence with the gates that should be applied to each nearest vertical
72
+ neighbor.
73
+ diagonal_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
74
+ Sequence with the gates that should be applied to each nearest diagonal
69
75
neighbor.
70
- downward_triangle_gates (:term:`sequence` of :obj:`jax.numpy.ndarray`):
71
- Sequence with the gates that should be applied to the downward
72
- triangles.
73
76
normalization_factor (:obj:`int`):
74
77
Factor which should be used to normalize the calculated values.
75
78
If for example three sites are mapped into one PEPS site this
@@ -81,24 +84,38 @@ class Triangular_Expectation_Value(Expectation_Model):
81
84
if spiral iPEPS ansatz is used.
82
85
"""
83
86
84
- nearest_neighbor_gates : Sequence [jnp .ndarray ]
87
+ horizontal_gates : Sequence [jnp .ndarray ]
88
+ vertical_gates : Sequence [jnp .ndarray ]
89
+ diagonal_gates : Sequence [jnp .ndarray ]
85
90
real_d : int
86
91
normalization_factor : int = 1
87
92
88
93
is_spiral_peps : bool = False
89
94
spiral_unitary_operator : Optional [jnp .ndarray ] = None
90
95
91
96
def __post_init__ (self ) -> None :
92
- if isinstance (self .nearest_neighbor_gates , jnp .ndarray ):
93
- self .nearest_neighbor_gates = (self .nearest_neighbor_gates ,)
97
+ if isinstance (self .horizontal_gates , jnp .ndarray ):
98
+ self .horizontal_gates = (self .horizontal_gates ,)
99
+ else :
100
+ self .horizontal_gates = tuple (self .horizontal_gates )
101
+
102
+ if isinstance (self .vertical_gates , jnp .ndarray ):
103
+ self .vertical_gates = (self .vertical_gates ,)
94
104
else :
95
- self .nearest_neighbor_gates = tuple (self .nearest_neighbor_gates )
105
+ self .vertical_gates = tuple (self .vertical_gates )
106
+
107
+ if isinstance (self .diagonal_gates , jnp .ndarray ):
108
+ self .diagonal_gates = (self .diagonal_gates ,)
109
+ else :
110
+ self .diagonal_gates = tuple (self .diagonal_gates )
96
111
97
112
self ._result_type = (
98
113
jnp .float64
99
114
if all (
100
115
jnp .allclose (g , g .T .conj ())
101
- for g in self .nearest_neighbor_gates
116
+ for g in self .horizontal_gates
117
+ + self .vertical_gates
118
+ + self .diagonal_gates
102
119
)
103
120
else jnp .complex128
104
121
)
@@ -120,7 +137,7 @@ def __call__(
120
137
) -> Union [jnp .ndarray , List [jnp .ndarray ]]:
121
138
result = [
122
139
jnp .array (0 , dtype = self ._result_type )
123
- for _ in range (len (self .nearest_neighbor_gates ))
140
+ for _ in range (len (self .horizontal_gates ))
124
141
]
125
142
126
143
if self .is_spiral_peps :
@@ -145,7 +162,7 @@ def __call__(
145
162
(1 ,),
146
163
varipeps_config .spiral_wavevector_type ,
147
164
)
148
- for h in self .nearest_neighbor_gates
165
+ for h in self .horizontal_gates
149
166
)
150
167
working_v_gates = tuple (
151
168
apply_unitary (
@@ -159,7 +176,7 @@ def __call__(
159
176
(1 ,),
160
177
varipeps_config .spiral_wavevector_type ,
161
178
)
162
- for v in self .nearest_neighbor_gates
179
+ for v in self .vertical_gates
163
180
)
164
181
working_d_gates = tuple (
165
182
apply_unitary (
@@ -173,12 +190,12 @@ def __call__(
173
190
(1 ,),
174
191
varipeps_config .spiral_wavevector_type ,
175
192
)
176
- for d in self .nearest_neighbor_gates
193
+ for d in self .diagonal_gates
177
194
)
178
195
else :
179
- working_h_gates = self .nearest_neighbor_gates
180
- working_v_gates = self .nearest_neighbor_gates
181
- working_d_gates = self .nearest_neighbor_gates
196
+ working_h_gates = self .horizontal_gates
197
+ working_v_gates = self .vertical_gates
198
+ working_d_gates = self .diagonal_gates
182
199
183
200
for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
184
201
for y , view in iter_rows :
0 commit comments