Skip to content

Commit 99df5bc

Browse files
committed
Measure MPS norms with tn_norm
1 parent 2caae41 commit 99df5bc

4 files changed

Lines changed: 115 additions & 24 deletions

File tree

history/2026-06-24-dimension-aware-swap-routing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# 2026-06-24 — Dimension-aware SWAP routing
22

33
- Milestone: B0 — gate-routing audit / quimb integration prerequisite
4-
- Branch / commit: `main` working tree, uncommitted
4+
- Branch / commit: `main` / `2caae41`
55

66
## What changed
77
- Added dimension-aware routed SWAP construction in `src/pepsy/operators/gates.py`.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# 2026-06-24 — MPS normalization via `tn_norm`
2+
3+
- Milestone: M1 — non-unitary MPS diagnostics and normalization hardening
4+
- Branch / commit: `main` working tree, pending commit
5+
6+
## What changed
7+
- Switched `MpsOptimizer._canonical_span_norm(...)` from contracting a selected
8+
canonical span into a dense tensor to measuring the raw working-data norm with
9+
`pepsy.tensors.core.tn_norm(..., strip_exponent=True)`.
10+
- Temporarily clears and restores `p.exponent` while measuring the raw tensor
11+
data, so automatic normalization and `track_norm_infidelity=True` do not
12+
accidentally include the represented global scale.
13+
- Expanded non-unitary norm-infidelity smoke coverage across `dmrg`, `mpo`,
14+
`swap`, and `svd` modes, with an explicit skip for older quimb versions that
15+
do not expose `gate_with_auto_swap_`.
16+
17+
## Why
18+
- The old span contraction was correct only when the canonical span stayed
19+
small. For long-range gates or swap/sweep paths, the span can retain many
20+
physical legs; contracting it as one dense block can explode even though the
21+
norm itself is just a scalar double-layer contraction.
22+
- `tn_norm(..., strip_exponent=True)` keeps the measurement consistent with the
23+
Pepsy normalization model: normalize raw data locally, store the removed scale
24+
in `p.exponent`, and leave `p.norm()` to report the represented state norm.
25+
26+
## How it was validated
27+
- `python -m pytest -q tests/test_optimize_mps.py` -> `48 passed`.
28+
29+
## Decisions / findings
30+
- The `fallback` argument remains accepted for compatibility, but the measured
31+
path now uses `tn_norm` directly rather than the old dense-span fallback.
32+
- True overlap fidelity still goes through `tn_fidelity`; this change only
33+
affects raw norm measurements used by normalization and norm-infidelity
34+
diagnostics.
35+
36+
## Next step
37+
- Keep the Tensy `to_mps(..., contraction_opt=optimizer_cotengra)` path wired
38+
through this same contraction optimizer so large DEM/PF non-unitary streams
39+
use the intended double-layer contraction route for norm diagnostics.

src/pepsy/optimizers/mps/optimizer.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from ...fitting.local import FIT
2727
from ...operators.gates import _normalize_gate_entries, gate as apply_gate
28-
from ...tensors.core import tn_fidelity
28+
from ...tensors.core import tn_fidelity, tn_norm
2929

3030
__all__ = ["MpsOptimizer"]
3131

@@ -69,7 +69,7 @@ class MpsOptimizer: # pylint: disable=too-many-instance-attributes
6969
----------
7070
normalizations : list[dict]
7171
Automatic normalization events recorded during :meth:`run`. Each entry
72-
stores the 1-based gate step, previous local squared norm,
72+
stores the 1-based gate step, previous raw squared norm,
7373
canonicalization span, tensor site where the normalization factor was
7474
inserted, and resulting base-10 ``p.exponent``.
7575
The raw tensor data are rescaled; the represented norm remains
@@ -637,28 +637,29 @@ def _normalize_span(where):
637637
raise ValueError("where must be an int, (int,), or (int, int).")
638638

639639
def _canonical_span_norm(self, p, where, *, fallback=True):
640-
"""Return ``||p||`` from a canonical center/span block.
640+
"""Return the raw working-data norm without densifying wide spans.
641641
642-
This assumes tensors outside ``where`` are already isometric. The span
643-
itself can contain multiple tensors per site, e.g. after ``split-gate``.
642+
Normalization needs the norm of the current tensor data, excluding the
643+
accumulated ``p.exponent`` scale. Contracting a canonical span into a
644+
dense block can explode for long-range gates because the span retains
645+
all physical legs, so use ``tn_norm``'s double-layer contraction instead.
644646
"""
645-
xmin, xmax = self._normalize_span(where)
647+
_ = where, fallback
648+
exponent = getattr(p, "exponent", None)
646649
try:
647-
tags = [p.site_tag(i) for i in range(xmin, xmax + 1)]
648-
block = p.select(tags, which="any")
649-
if isinstance(block, qtn.TensorNetwork):
650-
if block.num_tensors == 0:
651-
raise ValueError("canonical span selected no tensors.")
652-
block = block.contract(
653-
all,
654-
output_inds=block.outer_inds(),
655-
optimize=self.contraction_opt,
656-
)
657-
return ar.do("linalg.norm", block.data)
658-
except Exception:
659-
if not fallback:
660-
raise
661-
return p.norm(optimize=self.contraction_opt)
650+
if exponent is not None:
651+
p.exponent = 0.0
652+
mantissa, exponent_sq = tn_norm(
653+
p,
654+
contraction_opt=self.contraction_opt,
655+
strip_exponent=True,
656+
)
657+
return ar.do("sqrt", ar.do("abs", mantissa)) * 10 ** (
658+
float(exponent_sq) / 2.0
659+
)
660+
finally:
661+
if exponent is not None:
662+
p.exponent = exponent
662663

663664
@staticmethod
664665
def _norm_ratio_fidelity(approx_norm, target_norm):

tests/test_optimize_mps.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import quimb.tensor as qtn
1111

1212
import pepsy as py
13+
import pepsy.optimizers.mps.optimizer as mps_optimizer_module
1314

1415

1516
def _non_unitary_entangling_gate():
@@ -386,6 +387,54 @@ def test_mps_optimizer_canonical_span_norm_matches_full_target_norm(where):
386387
assert local_norm == pytest.approx(target.norm())
387388

388389

390+
def test_mps_optimizer_canonical_span_norm_ignores_stored_exponent():
391+
"""Internal normalization should measure raw data, not represented scale."""
392+
p0 = qtn.MPS_rand_state(4, bond_dim=2, phys_dim=2, dtype="complex128")
393+
opt = py.MpsOptimizer(p0.copy(), gates=[], chi=8, mode="svd")
394+
opt.p.exponent = 3.0
395+
396+
raw = opt.p.copy()
397+
raw.exponent = 0.0
398+
measured = opt._canonical_span_norm(opt.p, (0, 3)) # pylint: disable=protected-access
399+
400+
assert measured == pytest.approx(raw.norm())
401+
assert opt.p.exponent == pytest.approx(3.0)
402+
403+
404+
def test_mps_optimizer_norm_infidelity_uses_tn_norm_strip_exponent(monkeypatch):
405+
"""Norm-infidelity diagnostics should measure raw norms through ``tn_norm``."""
406+
calls = []
407+
original_tn_norm = mps_optimizer_module.tn_norm
408+
409+
def _spy_tn_norm(*args, **kwargs):
410+
calls.append(kwargs.copy())
411+
return original_tn_norm(*args, **kwargs)
412+
413+
monkeypatch.setattr(mps_optimizer_module, "tn_norm", _spy_tn_norm)
414+
415+
p0 = qtn.MPS_computational_state("0000", dtype="complex128")
416+
gates = [
417+
(qu.hadamard(), (0,)),
418+
(qu.hadamard(), (1,)),
419+
(_non_unitary_entangling_gate(), (0, 1)),
420+
]
421+
422+
opt = py.MpsOptimizer(p0.copy(), gates=gates, chi=1, mode="mpo")
423+
opt.run(
424+
progbar=False,
425+
cutoff=1e-12,
426+
non_unitary=True,
427+
normalize_final=True,
428+
track_norm_infidelity=True,
429+
)
430+
431+
samples = opt.get_norm_infidelity_samples()
432+
assert len(samples) == 1
433+
assert len(calls) >= 2
434+
assert all(call["strip_exponent"] is True for call in calls)
435+
assert all(call["contraction_opt"] == opt.contraction_opt for call in calls)
436+
437+
389438
def test_mps_optimizer_non_unitary_norm_infidelity_matches_svd_target():
390439
"""SVD non-unitary proxy should match quimb's target infidelity."""
391440
p0 = qtn.MPS_computational_state("0000", dtype="complex128")
@@ -664,9 +713,9 @@ def test_mps_optimizer_dmrg_non_unitary_matches_mpo_accuracy():
664713
)
665714

666715

667-
@pytest.mark.parametrize("mode", ["dmrg", "mpo"])
716+
@pytest.mark.parametrize("mode", ["dmrg", "mpo", "swap", "svd"])
668717
def test_mps_optimizer_non_unitary_norm_infidelity_smoke_other_modes(mode):
669-
"""Other compressed modes should expose a bounded non-unitary proxy."""
718+
"""All compressed modes should expose a bounded non-unitary proxy."""
670719
p0 = qtn.MPS_computational_state("0000", dtype="complex128")
671720
gates = [
672721
(qu.hadamard(), (0,)),
@@ -675,6 +724,8 @@ def test_mps_optimizer_non_unitary_norm_infidelity_smoke_other_modes(mode):
675724
]
676725

677726
opt = py.MpsOptimizer(p0.copy(), gates=gates, chi=1, mode=mode)
727+
if mode == "swap" and not hasattr(opt.p, "gate_with_auto_swap_"):
728+
pytest.skip("swap mode requires gate_with_auto_swap_ in this quimb version.")
678729
opt.run(
679730
progbar=False,
680731
cutoff=1e-12,

0 commit comments

Comments
 (0)