Skip to content

Commit 528b24b

Browse files
committed
mps optimization and gate API cleanup
1 parent 33c3e3b commit 528b24b

14 files changed

Lines changed: 457 additions & 77 deletions

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .src.pepsy.optimize_mps import MpsOptimizer

src/pepsy/__init__.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,20 @@
4444
from .gate import (
4545
apply_2dtn_,
4646
apply_gates,
47-
canonize_mps,
4847
gate_1d,
4948
gen_long_range_swap_path,
49+
rx,
50+
ry,
51+
rz,
52+
rxx,
53+
ryy,
54+
rzz,
55+
u3,
56+
su4,
5057
)
5158
from .optimize_global import GlobalOptimizer
5259
from .optimize_sweep import PEPSSweepOptimizer, SweepResult
60+
from .optimize_mps import MpsOptimizer
5361

5462
__all__ = [
5563
"__version__",
@@ -81,7 +89,14 @@
8189
"apply_2dtn_",
8290
"apply_gates",
8391
"gate_1d",
84-
"canonize_mps",
92+
"rx",
93+
"ry",
94+
"rz",
95+
"rxx",
96+
"ryy",
97+
"rzz",
98+
"u3",
99+
"su4",
85100
"product_state_peps",
86101
"optimize_global",
87102
"optimize_sweep",
@@ -93,6 +108,7 @@
93108
"core",
94109
"dmrg_fit",
95110
"debug",
111+
"MpsOptimizer",
96112
]
97113

98114

@@ -105,6 +121,7 @@ def __getattr__(name):
105121
"debug",
106122
"gate",
107123
"gradient_solver",
124+
"optimize_mps",
108125
"optimize_global",
109126
"optimize_sweep",
110127
"core",
@@ -150,22 +167,43 @@ def __getattr__(name):
150167
"apply_2dtn_",
151168
"apply_gates",
152169
"gate_1d",
153-
"canonize_mps",
170+
"rx",
171+
"ry",
172+
"rz",
173+
"rxx",
174+
"ryy",
175+
"rzz",
176+
"u3",
177+
"su4",
154178
):
155179
from .gate import ( # pylint: disable=import-outside-toplevel
156180
apply_2dtn_,
157181
apply_gates,
158-
canonize_mps,
159182
gate_1d,
160183
gen_long_range_swap_path,
184+
rx,
185+
ry,
186+
rz,
187+
rxx,
188+
ryy,
189+
rzz,
190+
u3,
191+
su4,
161192
)
162193

163194
return {
164195
"gen_long_range_swap_path": gen_long_range_swap_path,
165196
"apply_2dtn_": apply_2dtn_,
166197
"apply_gates": apply_gates,
167198
"gate_1d": gate_1d,
168-
"canonize_mps": canonize_mps,
199+
"rx": rx,
200+
"ry": ry,
201+
"rz": rz,
202+
"rxx": rxx,
203+
"ryy": ryy,
204+
"rzz": rzz,
205+
"u3": u3,
206+
"su4": su4,
169207
}[name]
170208

171209
if name in ("tn_applied", "product_state_peps"):
@@ -256,4 +294,9 @@ def __getattr__(name):
256294
"SweepResult": SweepResult,
257295
}[name]
258296

297+
if name == "MpsOptimizer":
298+
from .optimize_mps import MpsOptimizer # pylint: disable=import-outside-toplevel
299+
300+
return MpsOptimizer
301+
259302
raise AttributeError(f"module 'pepsy' has no attribute {name!r}")

src/pepsy/boundary_norm.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,6 @@ class BoundaryContractResult:
2929
max_separation: int
3030

3131

32-
def _validate_tensor_network_tags(p):
33-
"""Ensure PEPS lattice/site tags are present for shape inference."""
34-
validate_tensor_network_tags(p)
35-
36-
37-
def _normalize_retag_for_direction(direction, re_tag):
38-
"""Normalize ``re_tag`` flag for direction-specific calls."""
39-
_ = direction
40-
return bool(re_tag)
41-
42-
4332
def _warn_nonstandard_physical_outer_inds(tn, role):
4433
"""Warn when outer physical indices don't match ``k<int>[,<int>...]`` or ``b<int>[,<int>...]``."""
4534
bad = [
@@ -103,7 +92,7 @@ def prepare_boundary_inputs(
10392
if ket is None:
10493
raise ValueError("Provide ket.")
10594

106-
_validate_tensor_network_tags(ket)
95+
validate_tensor_network_tags(ket)
10796

10897
ket_tagged = ket
10998
auto_bra = bra is None
@@ -196,7 +185,7 @@ def ContractBoundary(
196185
if not isinstance(mps_boundaries, dict):
197186
raise TypeError("mps_boundaries must be a dictionary of boundary states.")
198187

199-
re_tag = _normalize_retag_for_direction(direction, re_tag)
188+
re_tag = bool(re_tag)
200189
norm_tagged = norm.copy()
201190

202191
comp_bdy = CompBdy(

src/pepsy/boundary_states.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ def __init__(
113113
default_array_backend = get_default_array_backend()
114114
if default_array_backend is not None:
115115
self.array_backend = default_array_backend
116-
# Backward-compatible alias used by older tests/callers.
117-
self.to_backend = self.array_backend
118116

119117
self.numpy_backend = make_numpy_array_caster(dtype=dtype_name)
120118

src/pepsy/boundary_sweeps.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
from .core import build_optimizer, tn_fidelity
1111
from .dmrg_fit import FIT
1212

13-
__all__ = ["BdyMPS", "CompBdy", "tn_fidelity", "build_optimizer", "opt_"]
14-
15-
# Backward-compatible alias.
16-
opt_ = build_optimizer
13+
__all__ = ["BdyMPS", "CompBdy", "tn_fidelity", "build_optimizer"]
1714

1815

1916
@dataclass(frozen=True)
@@ -97,11 +94,6 @@ def __init__(
9794
self.Lx = 1 + max_x # pylint: disable=invalid-name
9895
self._update_separation()
9996

100-
@property
101-
def fidelity(self):
102-
"""Alias for ``self.fidel``."""
103-
return self.fidel
104-
10597
def _reset_fidelity_history(self):
10698
"""Reset stored fidelity values for a fresh public call."""
10799
self.fidel = []

src/pepsy/dmrg_fit.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,10 @@
1818
__all__ = [
1919
"FIT",
2020
"build_optimizer",
21-
"opt_",
2221
"tn_fidelity",
2322
"internal_inds",
2423
]
2524

26-
# Backward-compatible alias.
27-
opt_ = build_optimizer
28-
2925

3026
def internal_inds(psi):
3127
"""Return all internal (non-open) indices of ``psi``."""

src/pepsy/gate.py

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515
__all__ = [
1616
"apply_gates",
1717
"gate_1d",
18-
"canonize_mps",
18+
"rx",
19+
"ry",
20+
"rz",
21+
"rxx",
22+
"ryy",
23+
"rzz",
24+
"u3",
25+
"su4",
1926
]
2027

2128

@@ -63,6 +70,98 @@ def _to_torch(x, dtype=target_dtype, device=target_device):
6370
return None
6471

6572

73+
def rx(theta):
74+
"""Return a one-qubit RX gate for angle ``theta``.
75+
76+
Parameters
77+
----------
78+
theta : float
79+
Rotation angle.
80+
"""
81+
return qtn.circuit.rx_gate_param_gen([theta])
82+
83+
84+
def ry(theta):
85+
"""Return a one-qubit RY gate for angle ``theta``.
86+
87+
Parameters
88+
----------
89+
theta : float
90+
Rotation angle.
91+
"""
92+
return qtn.circuit.ry_gate_param_gen([theta])
93+
94+
95+
def rz(theta):
96+
"""Return a one-qubit RZ gate for angle ``theta``.
97+
98+
Parameters
99+
----------
100+
theta : float
101+
Rotation angle.
102+
"""
103+
return qtn.circuit.rz_gate_param_gen([theta])
104+
105+
106+
def rzz(theta):
107+
"""Return a two-qubit RZZ gate for angle ``theta``.
108+
109+
Parameters
110+
----------
111+
theta : float
112+
Rotation angle.
113+
"""
114+
return qtn.circuit.rzz_param_gen([theta])
115+
116+
117+
def rxx(theta):
118+
"""Return a two-qubit RXX gate for angle ``theta``.
119+
120+
Parameters
121+
----------
122+
theta : float
123+
Rotation angle.
124+
"""
125+
return qtn.circuit.rxx_param_gen([theta])
126+
127+
128+
def ryy(theta):
129+
"""Return a two-qubit RYY gate for angle ``theta``.
130+
131+
Parameters
132+
----------
133+
theta : float
134+
Rotation angle.
135+
"""
136+
return qtn.circuit.ryy_param_gen([theta])
137+
138+
139+
def su4(params):
140+
"""Return a two-qubit SU(4) gate from 15 parameters.
141+
142+
Parameters
143+
----------
144+
params : sequence
145+
Sequence of exactly 15 parameters.
146+
"""
147+
if len(params) != 15:
148+
raise ValueError("su4 expects exactly 15 parameters.")
149+
return qtn.circuit.su4_gate_param_gen(params)
150+
151+
152+
def u3(params):
153+
"""Return a one-qubit U3 gate from 3 parameters.
154+
155+
Parameters
156+
----------
157+
params : sequence
158+
Sequence of exactly 3 parameters.
159+
"""
160+
if len(params) != 3:
161+
raise ValueError("u3 expects exactly 3 parameters.")
162+
return qtn.circuit.u3_gate_param_gen(params)
163+
164+
66165
def gen_long_range_swap_path( # pylint: disable=too-many-branches,too-many-locals,too-many-statements
67166
ij_a, ij_b, sequence=None
68167
):
@@ -427,15 +526,6 @@ def apply_gates( # pylint: disable=too-many-arguments,too-many-positional-argum
427526
return peps
428527

429528

430-
def canonize_mps(p, where, cur_orthog):
431-
xmin, xmax = sorted(where)
432-
p.canonize([xmin, xmax], cur_orthog=cur_orthog,
433-
#info=info_c
434-
)
435-
# update cur_orthog in place (preserving reference)
436-
cur_orthog[:] = [xmin, xmax]
437-
438-
439529
def gate_1d(tn, where, G, ind_id="k{}", site_tags="I{}",
440530
cutoff=1.e-12, contract='split-gate',
441531
inplace=False):

src/pepsy/optimize_global.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def _merge_opts(base, extra):
6161
merged.update(dict(extra))
6262
return merged
6363

64+
@staticmethod
65+
def _normalize_optimizer_name(optimizer):
66+
"""Normalize common optimizer aliases for ``qtn.TNOptimizer``."""
67+
if not isinstance(optimizer, str):
68+
return optimizer
69+
70+
key = optimizer.strip().lower().replace("_", "-")
71+
if key in {"lbfgs", "l-bfgs", "lbfgsb", "l-bfgs-b"}:
72+
return "L-BFGS-B"
73+
return optimizer
74+
6475
@classmethod
6576
def _pick_known_keys(cls, options, allowed_keys, *, warn_unknown=True):
6677
incoming = dict(options or {})
@@ -123,7 +134,7 @@ def _norm_peps( # pylint: disable=too-many-arguments,too-many-positional-argume
123134
sequence = ["xmin", "xmax", "ymin", "ymax"]
124135
if (mode == "hyper") and (copt is None):
125136
warnings.warn(
126-
"mode='hyper' requested but copt is None; provide copt_() for stable behavior.",
137+
"mode='hyper' requested but copt is None; provide a compressed optimizer via `copt`.",
127138
RuntimeWarning,
128139
stacklevel=2,
129140
)
@@ -203,7 +214,7 @@ def _loss_peps( # pylint: disable=too-many-arguments,too-many-positional-argume
203214
"""Compute overlap-based loss between trainable and target PEPS."""
204215
if (mode == "hyper") and (copt is None):
205216
warnings.warn(
206-
"mode='hyper' requested but copt is None; provide copt_() for stable behavior.",
217+
"mode='hyper' requested but copt is None; provide a compressed optimizer via `copt`.",
207218
RuntimeWarning,
208219
stacklevel=2,
209220
)
@@ -395,6 +406,7 @@ def make_tn_optimizer(
395406
):
396407
"""Construct a configured :class:`quimb.tensor.TNOptimizer`."""
397408
merged_loss_kwargs = self._merge_opts(self.loss_kwargs, loss_kwargs)
409+
optimizer = self._normalize_optimizer_name(optimizer)
398410

399411
constants = {}
400412
if self.peps_target is not None:

0 commit comments

Comments
 (0)