Skip to content

Commit a1dbddf

Browse files
committed
Default PEPS gates to reduce split
1 parent 79a1653 commit a1dbddf

4 files changed

Lines changed: 83 additions & 16 deletions

File tree

src/pepsy/operators/gates.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,8 +1633,10 @@ def gate(tn, gates, where=None, which=None, **kwargs):
16331633
when ``contract`` is ``"split"`` or ``"reduce-split"``. For other contract
16341634
modes, the gate is applied directly to the requested endpoints.
16351635
By default, 2D/3D routing uses ``sequence="auto"`` to choose a shortest
1636-
route with the smallest current virtual-bond bottleneck. Pass an explicit
1637-
deterministic sequence to force a particular route.
1636+
route with the smallest current virtual-bond bottleneck, and
1637+
``contract="reduce-split"`` for quimb's reduced two-site split path. Pass
1638+
an explicit deterministic sequence or contract mode to force a particular
1639+
route or split strategy.
16381640
The efficient routed pattern is::
16391641
16401642
gate(
@@ -1730,7 +1732,7 @@ def gate(tn, gates, where=None, which=None, **kwargs):
17301732

17311733
if arity == 2:
17321734
opts_local = dict(opts)
1733-
opts_local.setdefault("contract", "split")
1735+
opts_local.setdefault("contract", "reduce-split")
17341736
if which_payload is not None:
17351737
opts_local["ind_id"] = _ind_id_from_which(which_payload, 2)
17361738
elif (which_default is not None) and ("ind_id" not in opts_local):
@@ -1748,7 +1750,7 @@ def gate(tn, gates, where=None, which=None, **kwargs):
17481750

17491751
if arity == 3:
17501752
opts_local = dict(opts)
1751-
opts_local.setdefault("contract", "split")
1753+
opts_local.setdefault("contract", "reduce-split")
17521754
if which_payload is not None:
17531755
opts_local["ind_id"] = _ind_id_from_which(which_payload, 3)
17541756
elif (which_default is not None) and ("ind_id" not in opts_local):
@@ -2203,7 +2205,7 @@ def _apply_gate_2d(
22032205
bond_dim=None,
22042206
max_bond=None,
22052207
bra=False,
2206-
contract="split",
2208+
contract="reduce-split",
22072209
tags=None,
22082210
dtype="complex128",
22092211
cutoff=1.0e-12,
@@ -2478,7 +2480,7 @@ def _apply_gate_3d(
24782480
bond_dim=None,
24792481
max_bond=None,
24802482
bra=False,
2481-
contract="split",
2483+
contract="reduce-split",
24822484
tags=None,
24832485
dtype="complex128",
24842486
cutoff=1.0e-12,
@@ -2692,7 +2694,7 @@ def build_pepo_from_gates(
26922694
dtype="complex128",
26932695
max_bond=16,
26942696
sequence="auto",
2695-
contract="split",
2697+
contract="reduce-split",
26962698
ind_id="k{},{}",
26972699
):
26982700
"""Build a PEPO from gate-style input on top of a PEPO identity.
@@ -2725,8 +2727,9 @@ def build_pepo_from_gates(
27252727
2D SWAP-path preference for long-range two-site gates. Defaults to
27262728
``"auto"`` for the same lower-bond smart routing used by
27272729
:func:`gate`.
2728-
contract : str, default="split"
2729-
Gate contraction mode.
2730+
contract : str, default="reduce-split"
2731+
Gate contraction mode. The default uses quimb's reduced two-site split
2732+
path, which is usually cheaper than ``"split"`` for PEPO/PEPS tensors.
27302733
ind_id : str, default="k{},{}"
27312734
Physical index format used for PEPO ket-family indices.
27322735
@@ -2806,7 +2809,7 @@ def build_mpo_from_gates(
28062809
mpo_=None,
28072810
dtype="complex128",
28082811
max_bond=16,
2809-
contract="split",
2812+
contract="reduce-split",
28102813
ind_id="k{}",
28112814
):
28122815
"""Build an MPO from gate-style input on top of an MPO identity.
@@ -2835,8 +2838,9 @@ def build_mpo_from_gates(
28352838
max_bond : int, default=16
28362839
Per-gate local split truncation cap and fallback construction
28372840
compression cap.
2838-
contract : str, default="split"
2839-
Gate contraction mode.
2841+
contract : str, default="reduce-split"
2842+
Gate contraction mode. The default uses quimb's reduced two-site split
2843+
path, which is usually cheaper than ``"split"`` for MPO tensors.
28402844
ind_id : str, default="k{}"
28412845
Physical index format used for MPO ket-family indices.
28422846

src/pepsy/tensors/symmetric.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def apply_gates(
689689
self,
690690
gates,
691691
*,
692-
contract="split",
692+
contract="auto",
693693
max_bond=None,
694694
cutoff=1e-10,
695695
normalize=False,
@@ -702,6 +702,7 @@ def apply_gates(
702702
"""Apply a bundled local gate stream to this state."""
703703
target = self if inplace else self.copy()
704704
method = str(method).strip().lower()
705+
contract_auto = contract is None or str(contract).strip().lower() == "auto"
705706
if max_bond is not None:
706707
compress_opts.setdefault("max_bond", max_bond)
707708
if cutoff is not None:
@@ -711,7 +712,8 @@ def apply_gates(
711712
from ..operators import gate as pepsy_gate
712713

713714
opts = dict(compress_opts)
714-
opts.setdefault("contract", contract)
715+
if not contract_auto:
716+
opts.setdefault("contract", contract)
715717
opts.update({} if gate_kwargs is None else dict(gate_kwargs))
716718
target.network = pepsy_gate(
717719
target.network,
@@ -752,7 +754,7 @@ def apply_gates(
752754
target.network,
753755
gate,
754756
inds,
755-
contract=contract,
757+
contract="split" if contract_auto else contract,
756758
tags=[],
757759
info=None,
758760
inplace=True,
@@ -775,7 +777,7 @@ def time_evolve(
775777
max_bond=None,
776778
cutoff=1e-10,
777779
normalize=None,
778-
contract="split",
780+
contract="auto",
779781
inplace=True,
780782
method="direct",
781783
gauges=None,

tests/test_gate.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,38 @@ def _fake_bulk_helper(tn_i, G_arg, where=None, **kwargs):
611611
assert has_inplace_kw is (dim == "1d")
612612

613613

614+
def test_gate_2d_3d_default_contract_is_reduce_split(monkeypatch):
615+
"""PEPS-like direct gate routing should default to quimb's reduce-split."""
616+
calls = []
617+
618+
def _fake_gate_2d(tn, G_arg, where=None, **kwargs):
619+
calls.append(("2d", where, kwargs.get("contract")))
620+
return tn
621+
622+
def _fake_gate_3d(tn, G_arg, where=None, **kwargs):
623+
calls.append(("3d", where, kwargs.get("contract")))
624+
return tn
625+
626+
monkeypatch.setattr("pepsy.operators.gates._apply_gate_2d", _fake_gate_2d)
627+
monkeypatch.setattr("pepsy.operators.gates._apply_gate_3d", _fake_gate_3d)
628+
629+
class _Dummy3DTN: # pylint: disable=too-few-public-methods
630+
Lx = 1
631+
Ly = 1
632+
Lz = 2
633+
634+
gate_op = np.eye(4, dtype=np.complex128).reshape(2, 2, 2, 2)
635+
peps = ps_to_peps(1, 2, dtype="complex128")
636+
tn3d = _Dummy3DTN()
637+
638+
assert apply_gate(peps, gate_op, ((0, 0), (0, 1))) is peps
639+
assert apply_gate(tn3d, gate_op, ((0, 0, 0), (0, 0, 1))) is tn3d
640+
assert calls == [
641+
("2d", ((0, 0), (0, 1)), "reduce-split"),
642+
("3d", ((0, 0, 0), (0, 0, 1)), "reduce-split"),
643+
]
644+
645+
614646
@pytest.mark.parametrize("dim", ("1d", "2d", "3d"))
615647
def test_gate_sequence_dispatch_inplace_false_copies(monkeypatch, dim):
616648
"""Bundled streams should copy TN first when ``inplace=False``."""
@@ -1071,6 +1103,7 @@ def _fake_gate(tn, gate, where, **kwargs):
10711103
where, kwargs = calls[0]
10721104
assert where == (0, 2)
10731105
assert kwargs["max_bond"] == 5
1106+
assert kwargs["contract"] == "reduce-split"
10741107
assert "bond_dim" not in kwargs
10751108
assert kwargs["ind_id"] == "k{}"
10761109

@@ -1123,6 +1156,7 @@ def _fake_gate(tn, gate, where, **kwargs):
11231156
where, kwargs = calls[0]
11241157
assert where == ((0, 0), (1, 2))
11251158
assert kwargs["max_bond"] == 6
1159+
assert kwargs["contract"] == "reduce-split"
11261160
assert "bond_dim" not in kwargs
11271161
assert kwargs["sequence"] == "auto"
11281162
assert kwargs["ind_id"] == "k{},{}"

tests/test_symmetric_tensors.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,33 @@ def test_sympeps_gate_stream_runs_pepsy_gate_and_gate_simple():
388388
assert np.isfinite(np.real(out_simple.norm()))
389389

390390

391+
def test_sympeps_gate_method_preserves_pepsy_gate_contract_default(monkeypatch):
392+
"""SymPEPS method='gate' should not override pepsy.gate's default."""
393+
state = SymPEPS.for_model(
394+
"heisenberg",
395+
2,
396+
2,
397+
bond_dim=2,
398+
seed=12,
399+
dtype="complex128",
400+
)
401+
calls = []
402+
403+
def _fake_gate(tn, gates, **kwargs):
404+
calls.append((gates, kwargs.copy()))
405+
return tn
406+
407+
monkeypatch.setattr("pepsy.operators.gate", _fake_gate)
408+
409+
out = state.copy().apply_gates(
410+
((np.eye(2, dtype=np.complex128), ((0, 0),)),),
411+
method="gate",
412+
)
413+
414+
assert out.tn is not None
415+
assert "contract" not in calls[0][1]
416+
417+
391418
def test_sympeps_raw_pepsy_gate_functions_accept_symmray_streams():
392419
"""The plain gate functions should accept a SymGateStream directly."""
393420
state = SymPEPS.for_model("itf", 2, 2, bond_dim=2, seed=10, dtype="complex128")

0 commit comments

Comments
 (0)