1111from itertools import count
1212
1313import autoray as ar
14+ import numpy as np
1415import quimb as qu
1516import quimb .tensor as qtn
1617
@@ -1143,6 +1144,95 @@ def _format_ind_id(ind_id, site):
11431144 ) from exc
11441145
11451146
1147+ def _physical_index_for_site (tn , site , ind_id = None ):
1148+ """Return the current physical index name for a lattice site."""
1149+ if ind_id is not None :
1150+ return _format_ind_id (ind_id , site )
1151+
1152+ site_ind = getattr (tn , "site_ind" , None )
1153+ if callable (site_ind ):
1154+ try :
1155+ if isinstance (site , (tuple , list )):
1156+ return site_ind (* site )
1157+ return site_ind (site )
1158+ except TypeError :
1159+ return site_ind (site )
1160+
1161+ if isinstance (site , (tuple , list )):
1162+ if len (site ) == 2 :
1163+ return _format_ind_id ("k{},{}" , site )
1164+ if len (site ) == 3 :
1165+ return _format_ind_id ("k{},{},{}" , site )
1166+ elif isinstance (site , Integral ):
1167+ return _format_ind_id ("k{}" , site )
1168+
1169+ raise ValueError (
1170+ "Cannot infer physical index for routed SWAP. Pass ind_id explicitly "
1171+ "or use a tensor network with a site_ind method."
1172+ )
1173+
1174+
1175+ def _physical_index_size_for_site (tn , site , ind_id = None ):
1176+ """Return the live physical-index dimension for one lattice site."""
1177+ ix = _physical_index_for_site (tn , site , ind_id )
1178+
1179+ ind_size = getattr (tn , "ind_size" , None )
1180+ if callable (ind_size ):
1181+ try :
1182+ return int (ind_size (ix ))
1183+ except (KeyError , TypeError , ValueError ):
1184+ pass
1185+
1186+ tensor = _site_tensor_for_coord (tn , site )
1187+ if tensor is not None and ix in getattr (tensor , "inds" , ()):
1188+ return _tensor_index_size (tensor , ix )
1189+
1190+ # Keep compatibility with generic or mocked TNs that cannot report physical
1191+ # index sizes. This matches the old routed-SWAP assumption while real PEPS
1192+ # and MPS objects take the dimension-aware path above.
1193+ return 2
1194+
1195+
1196+ def _rectangular_swap_gate (dim_a , dim_b , * , dtype = "complex128" ):
1197+ """Build the exact SWAP from d_a x d_b to d_b x d_a."""
1198+ dim_a = int (dim_a )
1199+ dim_b = int (dim_b )
1200+ if dim_a <= 0 or dim_b <= 0 :
1201+ raise ValueError ("SWAP dimensions must be positive integers." )
1202+
1203+ swap_gate = np .zeros ((dim_b , dim_a , dim_a , dim_b ), dtype = dtype )
1204+ for ia in range (dim_a ):
1205+ for ib in range (dim_b ):
1206+ swap_gate [ib , ia , ia , ib ] = 1
1207+ return swap_gate
1208+
1209+
1210+ def _convert_internal_gate_to_backend (gate , inferred_converter ):
1211+ """Best-effort conversion for internally generated exact gates."""
1212+ if inferred_converter is None :
1213+ return gate
1214+ try :
1215+ return inferred_converter (gate )
1216+ except (TypeError , ValueError ):
1217+ return gate
1218+
1219+
1220+ def _swap_gate_for_site_pair (
1221+ tn ,
1222+ site_a ,
1223+ site_b ,
1224+ * ,
1225+ ind_id = None ,
1226+ dtype = "complex128" ,
1227+ inferred_converter = None ,
1228+ ):
1229+ """Return a SWAP tensor matching the sites' current physical dimensions."""
1230+ dim_a = _physical_index_size_for_site (tn , site_a , ind_id )
1231+ dim_b = _physical_index_size_for_site (tn , site_b , ind_id )
1232+ swap_gate = _rectangular_swap_gate (dim_a , dim_b , dtype = dtype )
1233+ return _convert_internal_gate_to_backend (swap_gate , inferred_converter )
1234+
1235+
11461236def _normalize_gate_which (which ):
11471237 """Normalize an upper/lower layer selector."""
11481238 if which is None :
@@ -1536,8 +1626,9 @@ def gate(tn, gates, where=None, which=None, **kwargs):
15361626 gate tensors. Provide TN and gate tensors on compatible backends explicitly.
15371627 For one-site gates, ``contract`` is normalized to a boolean mode:
15381628 non-boolean values are treated as ``True``.
1539- Internal SWAP tensors used for long-range routing are backend-aligned from
1540- the TN sample data when available.
1629+ Internal SWAP tensors used for long-range routing infer the current
1630+ physical dimensions of each adjacent pair and are backend-aligned from the
1631+ TN sample data when available.
15411632 For nonlocal two-site gates, long-range SWAP routing is used in 1D/2D/3D
15421633 when ``contract`` is ``"split"`` or ``"reduce-split"``. For other contract
15431634 modes, the gate is applied directly to the requested endpoints.
@@ -1717,7 +1808,8 @@ def gate_simple(
17171808 (works for 1D / 2D / 3D ``where`` coordinates).
17181809 * ``which``/``ind_id`` selection for vector-like networks whose physical
17191810 site-index family is not the default ``k...`` family.
1720- * Backend alignment of internal SWAP tensors with the TN sample data.
1811+ * Dimension-aware, backend-aligned internal SWAP tensors for long-range
1812+ routing through mixed physical dimensions.
17211813 * Optional out-of-place semantics via ``inplace=False``.
17221814
17231815 The ``gauges`` dictionary is mutated in place by ``gate_simple_`` and is
@@ -1966,19 +2058,15 @@ def _gate_simple_one_with_current_site_ind_id(
19662058 )
19672059 return tn_work
19682060
1969- # Non-adjacent: route through a SWAP chain. Align the SWAP tensor to the
1970- # TN sample backend so the gate_simple_ call sees consistent dtypes.
1971- swap_gate = qu . swap ( dim = 2 , dtype = "complex128" ). reshape ( 2 , 2 , 2 , 2 )
2061+ # Non-adjacent: route through a SWAP chain. Each SWAP is built from the
2062+ # live physical dimensions because routed mixed-dimensional sites exchange
2063+ # their physical index sizes as they move along the path.
19722064 backend_sample = resolve_backend_sample_data_from_tn (tn_work )
19732065 inferred_converter = infer_backend_converter_from_sample (
19742066 backend_sample ,
19752067 cast_complex_to_real = True ,
19762068 )
1977- if inferred_converter is not None :
1978- try :
1979- swap_gate = inferred_converter (swap_gate )
1980- except (TypeError , ValueError ):
1981- pass
2069+ swap_ind_id = getattr (tn_work , "site_ind_id" , None )
19822070
19832071 ndim = len (site_a ) if isinstance (site_a , (tuple , list )) else 1
19842072 if ndim == 1 :
@@ -2016,6 +2104,14 @@ def _gate_simple_one_with_current_site_ind_id(
20162104
20172105 # Forward SWAPs.
20182106 for pair in swaps :
2107+ swap_gate = _swap_gate_for_site_pair (
2108+ tn_work ,
2109+ pair [0 ],
2110+ pair [1 ],
2111+ ind_id = swap_ind_id ,
2112+ dtype = "complex128" ,
2113+ inferred_converter = inferred_converter ,
2114+ )
20192115 tn_work .gate_simple_ (
20202116 swap_gate , where = pair , gauges = gauges ,
20212117 renorm = renorm , smudge = smudge , inplace = True ,
@@ -2031,6 +2127,14 @@ def _gate_simple_one_with_current_site_ind_id(
20312127
20322128 # Reverse SWAPs.
20332129 for pair in reversed (swaps ):
2130+ swap_gate = _swap_gate_for_site_pair (
2131+ tn_work ,
2132+ pair [0 ],
2133+ pair [1 ],
2134+ ind_id = swap_ind_id ,
2135+ dtype = "complex128" ,
2136+ inferred_converter = inferred_converter ,
2137+ )
20342138 tn_work .gate_simple_ (
20352139 swap_gate , where = pair , gauges = gauges ,
20362140 renorm = renorm , smudge = smudge , inplace = True ,
@@ -2199,13 +2303,6 @@ def _apply_gate_2d(
21992303 backend_sample ,
22002304 cast_complex_to_real = True ,
22012305 )
2202- swap = qu .swap (dim = 2 , dtype = dtype ).reshape (2 , 2 , 2 , 2 )
2203- if inferred_converter is not None :
2204- try :
2205- swap = inferred_converter (swap )
2206- except (TypeError , ValueError ):
2207- pass
2208-
22092306 lx_use = Lx
22102307 ly_use = Ly
22112308 if cyclic and (lx_use is None or ly_use is None ):
@@ -2234,6 +2331,14 @@ def _apply_gate_2d(
22342331 x_ , y_ = pair
22352332 i_ , j_ = x_
22362333 m_ , n_ = y_
2334+ swap = _swap_gate_for_site_pair (
2335+ peps ,
2336+ x_ ,
2337+ y_ ,
2338+ ind_id = ind_id ,
2339+ dtype = dtype ,
2340+ inferred_converter = inferred_converter ,
2341+ )
22372342 qtn .tensor_network_gate_inds (
22382343 peps ,
22392344 swap ,
@@ -2267,6 +2372,14 @@ def _apply_gate_2d(
22672372 x_ , y_ = pair
22682373 i_ , j_ = x_
22692374 m_ , n_ = y_
2375+ swap = _swap_gate_for_site_pair (
2376+ peps ,
2377+ x_ ,
2378+ y_ ,
2379+ ind_id = ind_id ,
2380+ dtype = dtype ,
2381+ inferred_converter = inferred_converter ,
2382+ )
22702383 qtn .tensor_network_gate_inds (
22712384 peps ,
22722385 swap ,
@@ -2463,13 +2576,6 @@ def _apply_gate_3d(
24632576 backend_sample ,
24642577 cast_complex_to_real = True ,
24652578 )
2466- swap = qu .swap (dim = 2 , dtype = dtype ).reshape (2 , 2 , 2 , 2 )
2467- if inferred_converter is not None :
2468- try :
2469- swap = inferred_converter (swap )
2470- except (TypeError , ValueError ):
2471- pass
2472-
24732579 lx_use = Lx
24742580 ly_use = Ly
24752581 lz_use = Lz
@@ -2501,6 +2607,14 @@ def _apply_gate_3d(
25012607 x_ , y_ = pair
25022608 i_ , j_ , k_ = x_
25032609 m_ , n_ , p_ = y_
2610+ swap = _swap_gate_for_site_pair (
2611+ tn ,
2612+ x_ ,
2613+ y_ ,
2614+ ind_id = ind_id ,
2615+ dtype = dtype ,
2616+ inferred_converter = inferred_converter ,
2617+ )
25042618 qtn .tensor_network_gate_inds (
25052619 tn ,
25062620 swap ,
@@ -2534,6 +2648,14 @@ def _apply_gate_3d(
25342648 x_ , y_ = pair
25352649 i_ , j_ , k_ = x_
25362650 m_ , n_ , p_ = y_
2651+ swap = _swap_gate_for_site_pair (
2652+ tn ,
2653+ x_ ,
2654+ y_ ,
2655+ ind_id = ind_id ,
2656+ dtype = dtype ,
2657+ inferred_converter = inferred_converter ,
2658+ )
25372659 qtn .tensor_network_gate_inds (
25382660 tn ,
25392661 swap ,
@@ -2842,13 +2964,6 @@ def _apply_gate_1d(
28422964 backend_sample ,
28432965 cast_complex_to_real = True ,
28442966 )
2845- swap = qu .swap (dim = 2 , dtype = dtype ).reshape (2 , 2 , 2 , 2 )
2846- if inferred_converter is not None :
2847- try :
2848- swap = inferred_converter (swap )
2849- except (TypeError , ValueError ):
2850- pass
2851-
28522967 path_pairs = list (gen_long_range_swap_path_1d (x , y ))
28532968 * swaps , final = path_pairs
28542969 _maybe_canonize_path (
@@ -2860,6 +2975,14 @@ def _apply_gate_1d(
28602975 )
28612976
28622977 for i_ , j_ in swaps :
2978+ swap = _swap_gate_for_site_pair (
2979+ tn ,
2980+ i_ ,
2981+ j_ ,
2982+ ind_id = ind_id ,
2983+ dtype = dtype ,
2984+ inferred_converter = inferred_converter ,
2985+ )
28632986 tn = qtn .tensor_network_gate_inds (
28642987 tn ,
28652988 swap ,
@@ -2884,6 +3007,14 @@ def _apply_gate_1d(
28843007 )
28853008
28863009 for i_ , j_ in reversed (swaps ):
3010+ swap = _swap_gate_for_site_pair (
3011+ tn ,
3012+ i_ ,
3013+ j_ ,
3014+ ind_id = ind_id ,
3015+ dtype = dtype ,
3016+ inferred_converter = inferred_converter ,
3017+ )
28873018 tn = qtn .tensor_network_gate_inds (
28883019 tn ,
28893020 swap ,
0 commit comments