@@ -598,37 +598,37 @@ def _calc_onsite_gate(
598
598
)
599
599
blue_36 = blue_36 .reshape (d ** 9 , d ** 9 )
600
600
601
- green_12 = jnp .kron (g_e , Id_other_sites )
602
- green_12 = green_12 .reshape (
601
+ green_base = jnp .kron (g_e , Id_other_sites )
602
+ green_base = green_base .reshape (
603
603
d , d , d , d , d , d , d , d , d , d , d , d , d , d , d , d , d , d
604
604
)
605
605
606
- green_24 = green_12 .transpose (
606
+ green_24 = green_base .transpose (
607
607
2 , 0 , 3 , 1 , 4 , 5 , 6 , 7 , 8 , 11 , 9 , 12 , 10 , 13 , 14 , 15 , 16 , 17
608
608
)
609
609
green_24 = green_24 .reshape (d ** 9 , d ** 9 )
610
610
611
- green_45 = green_12 .transpose (
611
+ green_45 = green_base .transpose (
612
612
2 , 3 , 4 , 0 , 1 , 5 , 6 , 7 , 8 , 11 , 12 , 13 , 9 , 10 , 14 , 15 , 16 , 17
613
613
)
614
614
green_45 = green_45 .reshape (d ** 9 , d ** 9 )
615
615
616
- green_46 = green_12 .transpose (
616
+ green_46 = green_base .transpose (
617
617
2 , 3 , 4 , 0 , 5 , 1 , 6 , 7 , 8 , 11 , 12 , 13 , 9 , 14 , 10 , 15 , 16 , 17
618
618
)
619
619
green_46 = green_46 .reshape (d ** 9 , d ** 9 )
620
620
621
- green_37 = green_12 .transpose (
621
+ green_37 = green_base .transpose (
622
622
2 , 3 , 0 , 4 , 5 , 6 , 1 , 7 , 8 , 11 , 12 , 9 , 13 , 14 , 15 , 10 , 16 , 17
623
623
)
624
624
green_37 = green_37 .reshape (d ** 9 , d ** 9 )
625
625
626
- green_78 = green_12 .transpose (
626
+ green_78 = green_base .transpose (
627
627
2 , 3 , 4 , 5 , 6 , 7 , 0 , 1 , 8 , 11 , 12 , 13 , 14 , 15 , 16 , 9 , 10 , 17
628
628
)
629
629
green_78 = green_78 .reshape (d ** 9 , d ** 9 )
630
630
631
- green_79 = green_12 .transpose (
631
+ green_79 = green_base .transpose (
632
632
2 , 3 , 4 , 5 , 6 , 7 , 0 , 8 , 1 , 11 , 12 , 13 , 14 , 15 , 16 , 9 , 17 , 10
633
633
)
634
634
green_79 = green_79 .reshape (d ** 9 , d ** 9 )
@@ -869,7 +869,9 @@ def __post_init__(self) -> None:
869
869
)
870
870
871
871
if self .is_spiral_peps :
872
- raise NotImplementedError
872
+ self ._spiral_D , self ._spiral_sigma = jnp .linalg .eigh (
873
+ self .spiral_unitary_operator
874
+ )
873
875
874
876
def __call__ (
875
877
self ,
@@ -891,16 +893,146 @@ def __call__(
891
893
working_onsite_gates = tuple (
892
894
o for e in self ._onsite_single_gates for o in e
893
895
)
894
- working_h_single_gates = tuple (
895
- h for e in self ._right_single_gates for h in e
896
+
897
+ if self .is_spiral_peps :
898
+ if isinstance (spiral_vectors , jnp .ndarray ):
899
+ spiral_vectors = (
900
+ spiral_vectors ,
901
+ spiral_vectors ,
902
+ spiral_vectors ,
903
+ )
904
+ if len (spiral_vectors ) == 1 :
905
+ spiral_vectors = (
906
+ spiral_vectors [0 ],
907
+ spiral_vectors [0 ],
908
+ None ,
909
+ None ,
910
+ None ,
911
+ None ,
912
+ None ,
913
+ None ,
914
+ spiral_vectors [0 ],
915
+ )
916
+ if len (spiral_vectors ) == 4 :
917
+ spiral_vectors = (
918
+ spiral_vectors [0 ],
919
+ spiral_vectors [1 ],
920
+ None ,
921
+ None ,
922
+ None ,
923
+ None ,
924
+ None ,
925
+ None ,
926
+ spiral_vectors [2 ],
927
+ )
928
+ if len (spiral_vectors ) != 9 :
929
+ raise ValueError ("Length mismatch for spiral vectors!" )
930
+
931
+ working_h_gates = tuple (
932
+ apply_unitary (
933
+ h ,
934
+ jnp .array ((0 , 1 )),
935
+ spiral_vectors [0 :9 :8 ],
936
+ self ._spiral_D ,
937
+ self ._spiral_sigma ,
938
+ self .real_d ,
939
+ 3 ,
940
+ (1 , 2 ),
941
+ varipeps_config .spiral_wavevector_type ,
942
+ )
943
+ for h in self ._right_tuple
896
944
)
897
- working_v_single_gates = tuple (
898
- v for e in self ._down_single_gates for v in e
945
+ working_v_gates = tuple (
946
+ apply_unitary (
947
+ v ,
948
+ jnp .array ((1 , 0 )),
949
+ spiral_vectors [:2 ],
950
+ self ._spiral_D ,
951
+ self ._spiral_sigma ,
952
+ self .real_d ,
953
+ 4 ,
954
+ (2 , 3 ),
955
+ varipeps_config .spiral_wavevector_type ,
956
+ )
957
+ for v in self ._down_tuple
899
958
)
900
- working_d_single_gates = tuple (
901
- d for e in self ._diagonal_single_gates for d in e
959
+ working_d_gates = tuple (
960
+ apply_unitary (
961
+ d ,
962
+ jnp .array ((1 , 1 )),
963
+ spiral_vectors [:1 ],
964
+ self ._spiral_D ,
965
+ self ._spiral_sigma ,
966
+ self .real_d ,
967
+ 3 ,
968
+ (2 ,),
969
+ varipeps_config .spiral_wavevector_type ,
970
+ )
971
+ for d in self ._diagonal_tuple
902
972
)
903
973
974
+ if return_single_gate_results :
975
+ working_h_single_gates = tuple (
976
+ apply_unitary (
977
+ h ,
978
+ jnp .array ((0 , 1 )),
979
+ spiral_vectors [0 :9 :8 ],
980
+ self ._spiral_D ,
981
+ self ._spiral_sigma ,
982
+ self .real_d ,
983
+ 3 ,
984
+ (1 , 2 ),
985
+ varipeps_config .spiral_wavevector_type ,
986
+ )
987
+ for e in self ._right_single_gates
988
+ for h in e
989
+ )
990
+ working_v_single_gates = tuple (
991
+ apply_unitary (
992
+ v ,
993
+ jnp .array ((1 , 0 )),
994
+ spiral_vectors [:2 ],
995
+ self ._spiral_D ,
996
+ self ._spiral_sigma ,
997
+ self .real_d ,
998
+ 4 ,
999
+ (2 , 3 ),
1000
+ varipeps_config .spiral_wavevector_type ,
1001
+ )
1002
+ for e in self ._down_single_gates
1003
+ for v in e
1004
+ )
1005
+ working_d_single_gates = tuple (
1006
+ apply_unitary (
1007
+ d ,
1008
+ jnp .array ((1 , 1 )),
1009
+ spiral_vectors [:1 ],
1010
+ self ._spiral_D ,
1011
+ self ._spiral_sigma ,
1012
+ self .real_d ,
1013
+ 3 ,
1014
+ (2 ,),
1015
+ varipeps_config .spiral_wavevector_type ,
1016
+ )
1017
+ for e in self ._diagonal_single_gates
1018
+ for d in e
1019
+ )
1020
+ else :
1021
+ working_h_gates = self ._right_tuple
1022
+ working_v_gates = self ._down_tuple
1023
+ working_d_gates = self ._diagonal_tuple
1024
+
1025
+ if return_single_gate_results :
1026
+ working_h_single_gates = tuple (
1027
+ h for e in self ._right_single_gates for h in e
1028
+ )
1029
+ working_v_single_gates = tuple (
1030
+ v for e in self ._down_single_gates for v in e
1031
+ )
1032
+ working_d_single_gates = tuple (
1033
+ d for e in self ._diagonal_single_gates for d in e
1034
+ )
1035
+
904
1036
for x , iter_rows in unitcell .iter_all_rows (only_unique = only_unique ):
905
1037
for y , view in iter_rows :
906
1038
# On site term
@@ -937,14 +1069,14 @@ def __call__(
937
1069
step_result_horizontal = _two_site_workhorse (
938
1070
density_matrix_left ,
939
1071
density_matrix_right ,
940
- self . _right_tuple + working_h_single_gates ,
1072
+ working_h_gates + working_h_single_gates ,
941
1073
self ._result_type is jnp .float64 ,
942
1074
)
943
1075
else :
944
1076
step_result_horizontal = _two_site_workhorse (
945
1077
density_matrix_left ,
946
1078
density_matrix_right ,
947
- self . _right_tuple ,
1079
+ working_h_gates ,
948
1080
self ._result_type is jnp .float64 ,
949
1081
)
950
1082
@@ -964,14 +1096,14 @@ def __call__(
964
1096
step_result_vertical = _two_site_workhorse (
965
1097
density_matrix_top ,
966
1098
density_matrix_bottom ,
967
- self . _down_tuple + working_v_single_gates ,
1099
+ working_v_gates + working_v_single_gates ,
968
1100
self ._result_type is jnp .float64 ,
969
1101
)
970
1102
else :
971
1103
step_result_vertical = _two_site_workhorse (
972
1104
density_matrix_top ,
973
1105
density_matrix_bottom ,
974
- self . _down_tuple ,
1106
+ working_v_gates ,
975
1107
self ._result_type is jnp .float64 ,
976
1108
)
977
1109
@@ -1011,7 +1143,7 @@ def __call__(
1011
1143
density_matrix_bottom_right ,
1012
1144
traced_density_matrix_top_right ,
1013
1145
traced_density_matrix_bottom_left ,
1014
- self . _diagonal_tuple + working_d_single_gates ,
1146
+ working_d_gates + working_d_single_gates ,
1015
1147
self ._result_type is jnp .float64 ,
1016
1148
)
1017
1149
else :
@@ -1020,7 +1152,7 @@ def __call__(
1020
1152
density_matrix_bottom_right ,
1021
1153
traced_density_matrix_top_right ,
1022
1154
traced_density_matrix_bottom_left ,
1023
- self . _diagonal_tuple ,
1155
+ working_d_gates ,
1024
1156
self ._result_type is jnp .float64 ,
1025
1157
)
1026
1158
0 commit comments