Skip to content

Commit 2caae41

Browse files
committed
Support dimension-aware routed swaps
1 parent d85da1a commit 2caae41

4 files changed

Lines changed: 316 additions & 33 deletions

File tree

PLAN.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# PLAN.md — Belief Propagation, quimb integration, and Symmetric Tensor Networks
22

33
Status: draft / living document
4-
Last updated: 2026-06-22
4+
Last updated: 2026-06-24
55
Owners: pepsy maintainers + coding agents
66

77
This document plans three related workstreams to extend `pepsy`:
@@ -159,6 +159,26 @@ Goal: avoid reinventing BP / gauging / contraction; wrap and adapt quimb.
159159
`contraction_opt="auto-hq"`) are reused for the loop-excitation sub-networks.
160160
- Do **not** add seed kwargs to optimizer builders (tests assert their absence).
161161

162+
### B0. Gate-routing audit: dimension-aware SWAPs
163+
164+
Status: implemented as a small prerequisite for Tensy PF PEPS replay.
165+
166+
- Finding: quimb's adjacent `tensor_network_gate_inds(..., contract="split")`,
167+
`tensor_network_gate_inds(..., contract="reduce-split")`, and
168+
`gate_simple_` accept rectangular two-site tensors whose output physical
169+
dimensions differ from their input dimensions. That is enough to represent a
170+
mixed-dimension SWAP with shape `(d_b, d_a, d_a, d_b)`.
171+
- Requirement: pepsy's long-range `gate` / `gate_simple` SWAP routing must infer
172+
the **current** physical index dimensions before each forward and reverse
173+
SWAP. A single cached `qu.swap(dim=2)` is only valid for binary sites.
174+
- Scope: keep the public API stable; make dimension-aware SWAPs the internal
175+
default for 1D/2D/3D routed `split` / `reduce-split` paths and simple-update
176+
routing. For binary sites this is behavior-preserving.
177+
- Validation target: mixed physical dimensions such as Tensy PF fused sites
178+
(`dim=4` frame, `dim=2` measurement, and possible larger selector sites)
179+
should route through spectator sites and swap back to the original layout for
180+
both direct split/reduce-split and simple-update replay.
181+
162182
Deliverable: a short `learning/quimb.md` mapping "pepsy concept → quimb API",
163183
plus adapters under the new contraction subpackage.
164184

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# 2026-06-24 — Dimension-aware SWAP routing
2+
3+
- Milestone: B0 — gate-routing audit / quimb integration prerequisite
4+
- Branch / commit: `main` working tree, uncommitted
5+
6+
## What changed
7+
- Added dimension-aware routed SWAP construction in `src/pepsy/operators/gates.py`.
8+
- Routed `gate(..., contract="split"|"reduce-split")` in 1D/2D/3D now infers
9+
the live physical dimension at each adjacent SWAP step.
10+
- Routed `gate_simple(...)` now does the same for simple-update SWAP chains.
11+
- Added mixed-dimension PEPS tests for direct `split`, direct `reduce-split`,
12+
and simple-update routing through a spectator site.
13+
- Updated `PLAN.md` with the quimb audit finding and the implementation scope.
14+
15+
## Why
16+
- Tensy PF PEPS replay has mixed physical dimensions (`dim=4` frame sites,
17+
`dim=2` measurement sites, and potentially larger selector sites). The old
18+
hard-coded `qu.swap(dim=2)` route was only valid for DEM-style binary sites.
19+
- Quimb can already apply adjacent rectangular two-site tensors with
20+
`split`, `reduce-split`, and `gate_simple_`, so Pepsy only needed to build
21+
the correct SWAP tensor per live adjacent pair.
22+
23+
## How it was validated
24+
- `python -m pyflakes src/pepsy/operators/gates.py tests/test_gate.py` -> passed.
25+
- `python -m pytest -q tests/test_gate.py` -> `92 passed`, 2 warnings.
26+
- `python -m pytest -q tests/test_public_api.py tests/test_package_layout.py` ->
27+
`404 passed`.
28+
29+
## Decisions / findings
30+
- No new public flag was added. Dimension-aware SWAPs are now the internal
31+
default for routed paths; binary sites still use an exact binary SWAP.
32+
- Generic/mocked tensor networks that cannot report physical index sizes keep
33+
the previous binary fallback for compatibility.
34+
35+
## Next step (do this first next time)
36+
- Wire Tensy PF `to_2dpeps(..., gate_method="simple_update")` or direct
37+
`gate_opts={"contract": "reduce-split"}` back to this Pepsy route and run the
38+
`sf_2dpeps_pf.ipynb` PEPS replay cell on a small layout.
39+
40+
## Open questions / blockers
41+
- None for Pepsy routing. Tensy still needs the notebook/runtime-side choice of
42+
direct `reduce-split` versus simple-update for the PF workflow.

src/pepsy/operators/gates.py

Lines changed: 163 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from itertools import count
1212

1313
import autoray as ar
14+
import numpy as np
1415
import quimb as qu
1516
import quimb.tensor as qtn
1617

@@ -1143,6 +1144,95 @@ def _format_ind_id(ind_id, site):
11431144
) from exc
11441145

11451146

1147+
def _physical_index_for_site(tn, site, ind_id=None):
1148+
"""Return the current physical index name for a lattice site."""
1149+
if ind_id is not None:
1150+
return _format_ind_id(ind_id, site)
1151+
1152+
site_ind = getattr(tn, "site_ind", None)
1153+
if callable(site_ind):
1154+
try:
1155+
if isinstance(site, (tuple, list)):
1156+
return site_ind(*site)
1157+
return site_ind(site)
1158+
except TypeError:
1159+
return site_ind(site)
1160+
1161+
if isinstance(site, (tuple, list)):
1162+
if len(site) == 2:
1163+
return _format_ind_id("k{},{}", site)
1164+
if len(site) == 3:
1165+
return _format_ind_id("k{},{},{}", site)
1166+
elif isinstance(site, Integral):
1167+
return _format_ind_id("k{}", site)
1168+
1169+
raise ValueError(
1170+
"Cannot infer physical index for routed SWAP. Pass ind_id explicitly "
1171+
"or use a tensor network with a site_ind method."
1172+
)
1173+
1174+
1175+
def _physical_index_size_for_site(tn, site, ind_id=None):
1176+
"""Return the live physical-index dimension for one lattice site."""
1177+
ix = _physical_index_for_site(tn, site, ind_id)
1178+
1179+
ind_size = getattr(tn, "ind_size", None)
1180+
if callable(ind_size):
1181+
try:
1182+
return int(ind_size(ix))
1183+
except (KeyError, TypeError, ValueError):
1184+
pass
1185+
1186+
tensor = _site_tensor_for_coord(tn, site)
1187+
if tensor is not None and ix in getattr(tensor, "inds", ()):
1188+
return _tensor_index_size(tensor, ix)
1189+
1190+
# Keep compatibility with generic or mocked TNs that cannot report physical
1191+
# index sizes. This matches the old routed-SWAP assumption while real PEPS
1192+
# and MPS objects take the dimension-aware path above.
1193+
return 2
1194+
1195+
1196+
def _rectangular_swap_gate(dim_a, dim_b, *, dtype="complex128"):
1197+
"""Build the exact SWAP from d_a x d_b to d_b x d_a."""
1198+
dim_a = int(dim_a)
1199+
dim_b = int(dim_b)
1200+
if dim_a <= 0 or dim_b <= 0:
1201+
raise ValueError("SWAP dimensions must be positive integers.")
1202+
1203+
swap_gate = np.zeros((dim_b, dim_a, dim_a, dim_b), dtype=dtype)
1204+
for ia in range(dim_a):
1205+
for ib in range(dim_b):
1206+
swap_gate[ib, ia, ia, ib] = 1
1207+
return swap_gate
1208+
1209+
1210+
def _convert_internal_gate_to_backend(gate, inferred_converter):
1211+
"""Best-effort conversion for internally generated exact gates."""
1212+
if inferred_converter is None:
1213+
return gate
1214+
try:
1215+
return inferred_converter(gate)
1216+
except (TypeError, ValueError):
1217+
return gate
1218+
1219+
1220+
def _swap_gate_for_site_pair(
1221+
tn,
1222+
site_a,
1223+
site_b,
1224+
*,
1225+
ind_id=None,
1226+
dtype="complex128",
1227+
inferred_converter=None,
1228+
):
1229+
"""Return a SWAP tensor matching the sites' current physical dimensions."""
1230+
dim_a = _physical_index_size_for_site(tn, site_a, ind_id)
1231+
dim_b = _physical_index_size_for_site(tn, site_b, ind_id)
1232+
swap_gate = _rectangular_swap_gate(dim_a, dim_b, dtype=dtype)
1233+
return _convert_internal_gate_to_backend(swap_gate, inferred_converter)
1234+
1235+
11461236
def _normalize_gate_which(which):
11471237
"""Normalize an upper/lower layer selector."""
11481238
if which is None:
@@ -1536,8 +1626,9 @@ def gate(tn, gates, where=None, which=None, **kwargs):
15361626
gate tensors. Provide TN and gate tensors on compatible backends explicitly.
15371627
For one-site gates, ``contract`` is normalized to a boolean mode:
15381628
non-boolean values are treated as ``True``.
1539-
Internal SWAP tensors used for long-range routing are backend-aligned from
1540-
the TN sample data when available.
1629+
Internal SWAP tensors used for long-range routing infer the current
1630+
physical dimensions of each adjacent pair and are backend-aligned from the
1631+
TN sample data when available.
15411632
For nonlocal two-site gates, long-range SWAP routing is used in 1D/2D/3D
15421633
when ``contract`` is ``"split"`` or ``"reduce-split"``. For other contract
15431634
modes, the gate is applied directly to the requested endpoints.
@@ -1717,7 +1808,8 @@ def gate_simple(
17171808
(works for 1D / 2D / 3D ``where`` coordinates).
17181809
* ``which``/``ind_id`` selection for vector-like networks whose physical
17191810
site-index family is not the default ``k...`` family.
1720-
* Backend alignment of internal SWAP tensors with the TN sample data.
1811+
* Dimension-aware, backend-aligned internal SWAP tensors for long-range
1812+
routing through mixed physical dimensions.
17211813
* Optional out-of-place semantics via ``inplace=False``.
17221814
17231815
The ``gauges`` dictionary is mutated in place by ``gate_simple_`` and is
@@ -1966,19 +2058,15 @@ def _gate_simple_one_with_current_site_ind_id(
19662058
)
19672059
return tn_work
19682060

1969-
# Non-adjacent: route through a SWAP chain. Align the SWAP tensor to the
1970-
# TN sample backend so the gate_simple_ call sees consistent dtypes.
1971-
swap_gate = qu.swap(dim=2, dtype="complex128").reshape(2, 2, 2, 2)
2061+
# Non-adjacent: route through a SWAP chain. Each SWAP is built from the
2062+
# live physical dimensions because routed mixed-dimensional sites exchange
2063+
# their physical index sizes as they move along the path.
19722064
backend_sample = resolve_backend_sample_data_from_tn(tn_work)
19732065
inferred_converter = infer_backend_converter_from_sample(
19742066
backend_sample,
19752067
cast_complex_to_real=True,
19762068
)
1977-
if inferred_converter is not None:
1978-
try:
1979-
swap_gate = inferred_converter(swap_gate)
1980-
except (TypeError, ValueError):
1981-
pass
2069+
swap_ind_id = getattr(tn_work, "site_ind_id", None)
19822070

19832071
ndim = len(site_a) if isinstance(site_a, (tuple, list)) else 1
19842072
if ndim == 1:
@@ -2016,6 +2104,14 @@ def _gate_simple_one_with_current_site_ind_id(
20162104

20172105
# Forward SWAPs.
20182106
for pair in swaps:
2107+
swap_gate = _swap_gate_for_site_pair(
2108+
tn_work,
2109+
pair[0],
2110+
pair[1],
2111+
ind_id=swap_ind_id,
2112+
dtype="complex128",
2113+
inferred_converter=inferred_converter,
2114+
)
20192115
tn_work.gate_simple_(
20202116
swap_gate, where=pair, gauges=gauges,
20212117
renorm=renorm, smudge=smudge, inplace=True,
@@ -2031,6 +2127,14 @@ def _gate_simple_one_with_current_site_ind_id(
20312127

20322128
# Reverse SWAPs.
20332129
for pair in reversed(swaps):
2130+
swap_gate = _swap_gate_for_site_pair(
2131+
tn_work,
2132+
pair[0],
2133+
pair[1],
2134+
ind_id=swap_ind_id,
2135+
dtype="complex128",
2136+
inferred_converter=inferred_converter,
2137+
)
20342138
tn_work.gate_simple_(
20352139
swap_gate, where=pair, gauges=gauges,
20362140
renorm=renorm, smudge=smudge, inplace=True,
@@ -2199,13 +2303,6 @@ def _apply_gate_2d(
21992303
backend_sample,
22002304
cast_complex_to_real=True,
22012305
)
2202-
swap = qu.swap(dim=2, dtype=dtype).reshape(2, 2, 2, 2)
2203-
if inferred_converter is not None:
2204-
try:
2205-
swap = inferred_converter(swap)
2206-
except (TypeError, ValueError):
2207-
pass
2208-
22092306
lx_use = Lx
22102307
ly_use = Ly
22112308
if cyclic and (lx_use is None or ly_use is None):
@@ -2234,6 +2331,14 @@ def _apply_gate_2d(
22342331
x_, y_ = pair
22352332
i_, j_ = x_
22362333
m_, n_ = y_
2334+
swap = _swap_gate_for_site_pair(
2335+
peps,
2336+
x_,
2337+
y_,
2338+
ind_id=ind_id,
2339+
dtype=dtype,
2340+
inferred_converter=inferred_converter,
2341+
)
22372342
qtn.tensor_network_gate_inds(
22382343
peps,
22392344
swap,
@@ -2267,6 +2372,14 @@ def _apply_gate_2d(
22672372
x_, y_ = pair
22682373
i_, j_ = x_
22692374
m_, n_ = y_
2375+
swap = _swap_gate_for_site_pair(
2376+
peps,
2377+
x_,
2378+
y_,
2379+
ind_id=ind_id,
2380+
dtype=dtype,
2381+
inferred_converter=inferred_converter,
2382+
)
22702383
qtn.tensor_network_gate_inds(
22712384
peps,
22722385
swap,
@@ -2463,13 +2576,6 @@ def _apply_gate_3d(
24632576
backend_sample,
24642577
cast_complex_to_real=True,
24652578
)
2466-
swap = qu.swap(dim=2, dtype=dtype).reshape(2, 2, 2, 2)
2467-
if inferred_converter is not None:
2468-
try:
2469-
swap = inferred_converter(swap)
2470-
except (TypeError, ValueError):
2471-
pass
2472-
24732579
lx_use = Lx
24742580
ly_use = Ly
24752581
lz_use = Lz
@@ -2501,6 +2607,14 @@ def _apply_gate_3d(
25012607
x_, y_ = pair
25022608
i_, j_, k_ = x_
25032609
m_, n_, p_ = y_
2610+
swap = _swap_gate_for_site_pair(
2611+
tn,
2612+
x_,
2613+
y_,
2614+
ind_id=ind_id,
2615+
dtype=dtype,
2616+
inferred_converter=inferred_converter,
2617+
)
25042618
qtn.tensor_network_gate_inds(
25052619
tn,
25062620
swap,
@@ -2534,6 +2648,14 @@ def _apply_gate_3d(
25342648
x_, y_ = pair
25352649
i_, j_, k_ = x_
25362650
m_, n_, p_ = y_
2651+
swap = _swap_gate_for_site_pair(
2652+
tn,
2653+
x_,
2654+
y_,
2655+
ind_id=ind_id,
2656+
dtype=dtype,
2657+
inferred_converter=inferred_converter,
2658+
)
25372659
qtn.tensor_network_gate_inds(
25382660
tn,
25392661
swap,
@@ -2842,13 +2964,6 @@ def _apply_gate_1d(
28422964
backend_sample,
28432965
cast_complex_to_real=True,
28442966
)
2845-
swap = qu.swap(dim=2, dtype=dtype).reshape(2, 2, 2, 2)
2846-
if inferred_converter is not None:
2847-
try:
2848-
swap = inferred_converter(swap)
2849-
except (TypeError, ValueError):
2850-
pass
2851-
28522967
path_pairs = list(gen_long_range_swap_path_1d(x, y))
28532968
*swaps, final = path_pairs
28542969
_maybe_canonize_path(
@@ -2860,6 +2975,14 @@ def _apply_gate_1d(
28602975
)
28612976

28622977
for i_, j_ in swaps:
2978+
swap = _swap_gate_for_site_pair(
2979+
tn,
2980+
i_,
2981+
j_,
2982+
ind_id=ind_id,
2983+
dtype=dtype,
2984+
inferred_converter=inferred_converter,
2985+
)
28632986
tn = qtn.tensor_network_gate_inds(
28642987
tn,
28652988
swap,
@@ -2884,6 +3007,14 @@ def _apply_gate_1d(
28843007
)
28853008

28863009
for i_, j_ in reversed(swaps):
3010+
swap = _swap_gate_for_site_pair(
3011+
tn,
3012+
i_,
3013+
j_,
3014+
ind_id=ind_id,
3015+
dtype=dtype,
3016+
inferred_converter=inferred_converter,
3017+
)
28873018
tn = qtn.tensor_network_gate_inds(
28883019
tn,
28893020
swap,

0 commit comments

Comments
 (0)