Skip to content

Commit fdcb0fb

Browse files
committed
Implement spiral iPEPS for Floret-Pentagon lattice
1 parent 8df5271 commit fdcb0fb

File tree

4 files changed

+161
-29
lines changed

4 files changed

+161
-29
lines changed

Diff for: docs/source/images/floret_pentagon_structure.pdf

0 Bytes
Binary file not shown.

Diff for: docs/source/images/floret_pentagon_structure.svg

+4-4
Loading

Diff for: docs/source/images/floret_pentagon_structure.tex

+4-4
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@
182182
\node[ucirc, label=200:{\small 2}] at (pa 0a0b0) {};
183183
\node[ucirc, label={\small 3}] at (pa 5a0b0) {};
184184
\node[ucirc, label={\small 4}] at (pc 5a0b0) {};
185-
\node[ucirc, label={\small 5}] at (pb 5a0b0) {};
186-
\node[ucirc, label={\small 6}] at (pa 4a0b1) {};
185+
\node[ucirc, label={\small 6}] at (pb 5a0b0) {};
186+
\node[ucirc, label={\small 5}] at (pa 4a0b1) {};
187187
\node[ucirc, label={\small 7}] at (pc 4a0b0) {};
188-
\node[ucirc, label={\small 8}] at (pb 4a0b0) {};
189-
\node[ucirc, label={\small 9}] at (pb 0a-1b0) {};
188+
\node[ucirc, label={\small 9}] at (pb 4a0b0) {};
189+
\node[ucirc, label={\small 8}] at (pb 0a-1b0) {};
190190
\end{pgfonlayer}{v2}
191191
};
192192
}

Diff for: varipeps/mapping/florett_pentagon.py

+153-21
Original file line numberDiff line numberDiff line change
@@ -598,37 +598,37 @@ def _calc_onsite_gate(
598598
)
599599
blue_36 = blue_36.reshape(d**9, d**9)
600600

601-
green_12 = jnp.kron(g_e, Id_other_sites)
602-
green_12 = green_12.reshape(
601+
green_base = jnp.kron(g_e, Id_other_sites)
602+
green_base = green_base.reshape(
603603
d, d, d, d, d, d, d, d, d, d, d, d, d, d, d, d, d, d
604604
)
605605

606-
green_24 = green_12.transpose(
606+
green_24 = green_base.transpose(
607607
2, 0, 3, 1, 4, 5, 6, 7, 8, 11, 9, 12, 10, 13, 14, 15, 16, 17
608608
)
609609
green_24 = green_24.reshape(d**9, d**9)
610610

611-
green_45 = green_12.transpose(
611+
green_45 = green_base.transpose(
612612
2, 3, 4, 0, 1, 5, 6, 7, 8, 11, 12, 13, 9, 10, 14, 15, 16, 17
613613
)
614614
green_45 = green_45.reshape(d**9, d**9)
615615

616-
green_46 = green_12.transpose(
616+
green_46 = green_base.transpose(
617617
2, 3, 4, 0, 5, 1, 6, 7, 8, 11, 12, 13, 9, 14, 10, 15, 16, 17
618618
)
619619
green_46 = green_46.reshape(d**9, d**9)
620620

621-
green_37 = green_12.transpose(
621+
green_37 = green_base.transpose(
622622
2, 3, 0, 4, 5, 6, 1, 7, 8, 11, 12, 9, 13, 14, 15, 10, 16, 17
623623
)
624624
green_37 = green_37.reshape(d**9, d**9)
625625

626-
green_78 = green_12.transpose(
626+
green_78 = green_base.transpose(
627627
2, 3, 4, 5, 6, 7, 0, 1, 8, 11, 12, 13, 14, 15, 16, 9, 10, 17
628628
)
629629
green_78 = green_78.reshape(d**9, d**9)
630630

631-
green_79 = green_12.transpose(
631+
green_79 = green_base.transpose(
632632
2, 3, 4, 5, 6, 7, 0, 8, 1, 11, 12, 13, 14, 15, 16, 9, 17, 10
633633
)
634634
green_79 = green_79.reshape(d**9, d**9)
@@ -869,7 +869,9 @@ def __post_init__(self) -> None:
869869
)
870870

871871
if self.is_spiral_peps:
872-
raise NotImplementedError
872+
self._spiral_D, self._spiral_sigma = jnp.linalg.eigh(
873+
self.spiral_unitary_operator
874+
)
873875

874876
def __call__(
875877
self,
@@ -891,16 +893,146 @@ def __call__(
891893
working_onsite_gates = tuple(
892894
o for e in self._onsite_single_gates for o in e
893895
)
894-
working_h_single_gates = tuple(
895-
h for e in self._right_single_gates for h in e
896+
897+
if self.is_spiral_peps:
898+
if isinstance(spiral_vectors, jnp.ndarray):
899+
spiral_vectors = (
900+
spiral_vectors,
901+
spiral_vectors,
902+
spiral_vectors,
903+
)
904+
if len(spiral_vectors) == 1:
905+
spiral_vectors = (
906+
spiral_vectors[0],
907+
spiral_vectors[0],
908+
None,
909+
None,
910+
None,
911+
None,
912+
None,
913+
None,
914+
spiral_vectors[0],
915+
)
916+
if len(spiral_vectors) == 4:
917+
spiral_vectors = (
918+
spiral_vectors[0],
919+
spiral_vectors[1],
920+
None,
921+
None,
922+
None,
923+
None,
924+
None,
925+
None,
926+
spiral_vectors[2],
927+
)
928+
if len(spiral_vectors) != 9:
929+
raise ValueError("Length mismatch for spiral vectors!")
930+
931+
working_h_gates = tuple(
932+
apply_unitary(
933+
h,
934+
jnp.array((0, 1)),
935+
spiral_vectors[0:9:8],
936+
self._spiral_D,
937+
self._spiral_sigma,
938+
self.real_d,
939+
3,
940+
(1, 2),
941+
varipeps_config.spiral_wavevector_type,
942+
)
943+
for h in self._right_tuple
896944
)
897-
working_v_single_gates = tuple(
898-
v for e in self._down_single_gates for v in e
945+
working_v_gates = tuple(
946+
apply_unitary(
947+
v,
948+
jnp.array((1, 0)),
949+
spiral_vectors[:2],
950+
self._spiral_D,
951+
self._spiral_sigma,
952+
self.real_d,
953+
4,
954+
(2, 3),
955+
varipeps_config.spiral_wavevector_type,
956+
)
957+
for v in self._down_tuple
899958
)
900-
working_d_single_gates = tuple(
901-
d for e in self._diagonal_single_gates for d in e
959+
working_d_gates = tuple(
960+
apply_unitary(
961+
d,
962+
jnp.array((1, 1)),
963+
spiral_vectors[:1],
964+
self._spiral_D,
965+
self._spiral_sigma,
966+
self.real_d,
967+
3,
968+
(2,),
969+
varipeps_config.spiral_wavevector_type,
970+
)
971+
for d in self._diagonal_tuple
902972
)
903973

974+
if return_single_gate_results:
975+
working_h_single_gates = tuple(
976+
apply_unitary(
977+
h,
978+
jnp.array((0, 1)),
979+
spiral_vectors[0:9:8],
980+
self._spiral_D,
981+
self._spiral_sigma,
982+
self.real_d,
983+
3,
984+
(1, 2),
985+
varipeps_config.spiral_wavevector_type,
986+
)
987+
for e in self._right_single_gates
988+
for h in e
989+
)
990+
working_v_single_gates = tuple(
991+
apply_unitary(
992+
v,
993+
jnp.array((1, 0)),
994+
spiral_vectors[:2],
995+
self._spiral_D,
996+
self._spiral_sigma,
997+
self.real_d,
998+
4,
999+
(2, 3),
1000+
varipeps_config.spiral_wavevector_type,
1001+
)
1002+
for e in self._down_single_gates
1003+
for v in e
1004+
)
1005+
working_d_single_gates = tuple(
1006+
apply_unitary(
1007+
d,
1008+
jnp.array((1, 1)),
1009+
spiral_vectors[:1],
1010+
self._spiral_D,
1011+
self._spiral_sigma,
1012+
self.real_d,
1013+
3,
1014+
(2,),
1015+
varipeps_config.spiral_wavevector_type,
1016+
)
1017+
for e in self._diagonal_single_gates
1018+
for d in e
1019+
)
1020+
else:
1021+
working_h_gates = self._right_tuple
1022+
working_v_gates = self._down_tuple
1023+
working_d_gates = self._diagonal_tuple
1024+
1025+
if return_single_gate_results:
1026+
working_h_single_gates = tuple(
1027+
h for e in self._right_single_gates for h in e
1028+
)
1029+
working_v_single_gates = tuple(
1030+
v for e in self._down_single_gates for v in e
1031+
)
1032+
working_d_single_gates = tuple(
1033+
d for e in self._diagonal_single_gates for d in e
1034+
)
1035+
9041036
for x, iter_rows in unitcell.iter_all_rows(only_unique=only_unique):
9051037
for y, view in iter_rows:
9061038
# On site term
@@ -937,14 +1069,14 @@ def __call__(
9371069
step_result_horizontal = _two_site_workhorse(
9381070
density_matrix_left,
9391071
density_matrix_right,
940-
self._right_tuple + working_h_single_gates,
1072+
working_h_gates + working_h_single_gates,
9411073
self._result_type is jnp.float64,
9421074
)
9431075
else:
9441076
step_result_horizontal = _two_site_workhorse(
9451077
density_matrix_left,
9461078
density_matrix_right,
947-
self._right_tuple,
1079+
working_h_gates,
9481080
self._result_type is jnp.float64,
9491081
)
9501082

@@ -964,14 +1096,14 @@ def __call__(
9641096
step_result_vertical = _two_site_workhorse(
9651097
density_matrix_top,
9661098
density_matrix_bottom,
967-
self._down_tuple + working_v_single_gates,
1099+
working_v_gates + working_v_single_gates,
9681100
self._result_type is jnp.float64,
9691101
)
9701102
else:
9711103
step_result_vertical = _two_site_workhorse(
9721104
density_matrix_top,
9731105
density_matrix_bottom,
974-
self._down_tuple,
1106+
working_v_gates,
9751107
self._result_type is jnp.float64,
9761108
)
9771109

@@ -1011,7 +1143,7 @@ def __call__(
10111143
density_matrix_bottom_right,
10121144
traced_density_matrix_top_right,
10131145
traced_density_matrix_bottom_left,
1014-
self._diagonal_tuple + working_d_single_gates,
1146+
working_d_gates + working_d_single_gates,
10151147
self._result_type is jnp.float64,
10161148
)
10171149
else:
@@ -1020,7 +1152,7 @@ def __call__(
10201152
density_matrix_bottom_right,
10211153
traced_density_matrix_top_right,
10221154
traced_density_matrix_bottom_left,
1023-
self._diagonal_tuple,
1155+
working_d_gates,
10241156
self._result_type is jnp.float64,
10251157
)
10261158

0 commit comments

Comments
 (0)