Skip to content

Commit 7ee3f4f

Browse files
committed
Apply pending pepsy updates
1 parent 07ad4e4 commit 7ee3f4f

14 files changed

Lines changed: 197 additions & 546 deletions

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ jobs:
1717
run: |
1818
python -m pip install --upgrade pip
1919
pip install -e .[dev]
20+
- name: Validate source syntax
21+
run: python -m compileall -q -f src/pepsy
2022
- name: Run tests
21-
run: pytest -q
23+
run: |
24+
NUMBA_CACHE_DIR=/tmp PYTHONPYCACHEPREFIX=/tmp \
25+
pytest -q
2226
2327
docs:
2428
runs-on: ubuntu-latest

README.md

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ Current package version: `0.1.1` (from `pyproject.toml` / `pepsy.__version__`).
1010
- `src/pepsy/`: installable library code
1111
- `boundary_states.py`: boundary state initialization (`BdyMPS`)
1212
- `boundary_sweeps.py`: sweep/contraction runner (`CompBdy`)
13-
- `boundary_metrics.py`: input preparation + contraction (`prepare_boundary_inputs`, `ContractBoundary`)
14-
- `optimize_sweep.py`, `optimize_global.py`, `gate.py`, `gradient_solver.py`, `debug.py`
15-
- `fit.py`, `core.py`, `linalg_registrations.py`
16-
- `example/`: example notebooks
13+
- `boundary_metrics.py`: input preparation + contraction (`build_bra_ket`, `contract_boundary`, `BoundaryContractResult`)
14+
- `optimize_sweep.py`, `optimize_global.py`, `optimize_energy.py`, `gate.py`, `gradient_solver.py`
15+
- `fit.py`, `core.py`, `_backend_utils.py`, `_backend_linalg.py`
1716
- `docs/`: Sphinx documentation source
1817
- `tests/`: package tests
1918

@@ -25,23 +24,32 @@ pip install -e .
2524
# Optional backends:
2625
# pip install -e .[torch]
2726
# pip install -e .[solvers]
27+
# jax backend (manual, platform-specific wheels):
28+
# pip install jax jaxlib
2829
# Optional plotting helpers:
2930
# pip install -e .[viz]
3031
```
3132

3233
## Quick Usage
3334
```python
3435
import pepsy
35-
from pepsy import BdyMPS, CompBdy, ContractBoundary, prepare_boundary_inputs
36+
import quimb.tensor as qtn
3637

37-
print(pepsy.__version__)
38+
ket = qtn.PEPS.rand(Lx=3, Ly=3, bond_dim=2, seed=1, dtype="complex128")
39+
ket_tagged, norm = pepsy.build_bra_ket(ket=ket)
40+
41+
bdy = pepsy.BdyMPS(tn_flat=ket_tagged, tn_double=norm, chi=32, single_layer=False)
42+
res = pepsy.contract_boundary(norm=norm, bdy=bdy, direction="y", n_iter=2)
43+
44+
print(pepsy.__version__, res.cost)
3845
```
3946

4047
## Documentation
4148
Build docs locally:
4249

4350
```bash
4451
pip install -e .[docs]
52+
NUMBA_CACHE_DIR=/tmp PYTHONPYCACHEPREFIX=/tmp \
4553
sphinx-build -W -b html docs docs/_build/html
4654
```
4755

@@ -52,10 +60,6 @@ Main docs sections:
5260
- `howto/`
5361
- `api/`
5462

55-
Guided notebook example:
56-
57-
- `example/norm.ipynb`
58-
5963
## Notes
6064
- `.gitattributes` marks notebooks as binary to avoid noisy diffs.
6165
- `.gitignore` excludes checkpoints, caches, `cash/`, and `nohup.out`.

docs/examples.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Examples
22

3-
## Primary notebook
3+
## Primary walkthrough
44

5-
- `example/norm.ipynb`
5+
- [Tutorial: Contract a PEPS Norm](tutorials/contract_norm.md)
66

7-
This is the recommended first notebook for interactive usage.
7+
This is the recommended first end-to-end walkthrough for interactive usage.
88

99
It demonstrates:
1010

@@ -13,9 +13,10 @@ It demonstrates:
1313
- running `contract_boundary(...)`
1414
- inspecting `BoundaryContractResult.cost` and `.fidel`
1515

16-
## Additional notebooks
16+
## Additional walkthroughs
1717

18-
- `example/peps_norm_.ipynb`: historical exploratory notebook
19-
- `example/peps_boundary_states.ipynb`: boundary-state focused experiments
18+
- [Tutorial: Fidelity Diagnostics](tutorials/fidelity_diagnostics.md)
19+
- [How-To: Choose Parameters](howto/choose_parameters.md)
20+
- [How-To: Tune Sweep Solvers](howto/solver_tuning.md)
2021

2122
For a cleaner, docs-first narrative, start from [tutorials](tutorials/index.md).

docs/howto/troubleshooting.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,14 @@ Cause:
4343
Fix:
4444

4545
- Set `track_boundary_fidelity=True`.
46+
47+
## `RuntimeError: cannot cache function ... quimb/core.py`
48+
49+
Cause:
50+
51+
- Some environments do not allow Numba cache writes in default locations.
52+
53+
Fix:
54+
55+
- Run commands with explicit cache env vars, for example:
56+
`NUMBA_CACHE_DIR=/tmp PYTHONPYCACHEPREFIX=/tmp pytest -q`

src/pepsy/__init__.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
set_default_grad_backend,
5454
tns_align,
5555
)
56-
from .linalg_registrations import reg_complex_svd_jax, reg_complex_svd_torch
5756
from .fit import FIT
5857
from .gate import (
5958
gate,
@@ -122,8 +121,6 @@
122121
"set_default_grad_backend",
123122
"get_default_grad_backend",
124123
"register_torch_linalg",
125-
"reg_complex_svd_torch",
126-
"reg_complex_svd_jax",
127124
"reset_default_backends",
128125
"SweepOptimizer",
129126
"EnergyOptimizer",
@@ -469,17 +466,6 @@ def __getattr__(name):
469466
"reset_default_backends": reset_default_backends,
470467
}[name]
471468

472-
if name in ("reg_complex_svd_torch", "reg_complex_svd_jax"):
473-
from .linalg_registrations import ( # pylint: disable=import-outside-toplevel
474-
reg_complex_svd_jax,
475-
reg_complex_svd_torch,
476-
)
477-
478-
return {
479-
"reg_complex_svd_torch": reg_complex_svd_torch,
480-
"reg_complex_svd_jax": reg_complex_svd_jax,
481-
}[name]
482-
483469
if name == "CompBdy":
484470
from .boundary_sweeps import CompBdy # pylint: disable=import-outside-toplevel
485471

src/pepsy/_backend_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,56 @@ def _to_cupy(
175175
return _to_cupy
176176

177177

178+
def _build_to_jax(sample_data, dtype_name, *, cast_complex_to_real=False):
179+
import jax # pylint: disable=import-outside-toplevel
180+
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
181+
182+
dtype_map = {
183+
"complex128": jnp.complex128,
184+
"complex64": jnp.complex64,
185+
"float64": jnp.float64,
186+
"float32": jnp.float32,
187+
"float16": jnp.float16,
188+
"int64": jnp.int64,
189+
"int32": jnp.int32,
190+
}
191+
if dtype_name not in dtype_map:
192+
raise ValueError(f"Unsupported dtype '{dtype_name}' for jax backend.")
193+
194+
dtype = getattr(sample_data, "dtype", None) or dtype_map[dtype_name]
195+
device = getattr(sample_data, "device", None)
196+
197+
def _to_jax(
198+
x,
199+
dtype=dtype,
200+
device=device,
201+
cast_complex_to_real=cast_complex_to_real,
202+
):
203+
# Torch tensors need explicit host conversion before jnp.asarray.
204+
try:
205+
import torch # pylint: disable=import-outside-toplevel
206+
except ImportError: # pragma: no cover - optional dependency
207+
torch = None
208+
if torch is not None and isinstance(x, torch.Tensor):
209+
x = x.detach().cpu().numpy()
210+
211+
arr = jnp.asarray(x)
212+
213+
if cast_complex_to_real and jnp.issubdtype(dtype, jnp.floating) and jnp.iscomplexobj(arr):
214+
arr = arr.real
215+
216+
target_dtype = dtype
217+
if (not cast_complex_to_real) and jnp.issubdtype(target_dtype, jnp.floating) and jnp.iscomplexobj(arr):
218+
target_dtype = jnp.result_type(target_dtype, jnp.complex64)
219+
220+
out = jnp.asarray(arr, dtype=target_dtype)
221+
if device is not None:
222+
out = jax.device_put(out, device)
223+
return out
224+
225+
return _to_jax
226+
227+
178228
def dispatch_backend_converter(
179229
*,
180230
backend,
@@ -204,6 +254,12 @@ def dispatch_backend_converter(
204254
dtype_name,
205255
cast_complex_to_real=cast_complex_to_real,
206256
)
257+
if backend == "jax":
258+
return _build_to_jax(
259+
sample_data,
260+
dtype_name,
261+
cast_complex_to_real=cast_complex_to_real,
262+
)
207263

208264
raise ValueError(f"Unsupported backend: {backend}")
209265

src/pepsy/boundary_states.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ def _build_to_cupy(sample_data, dtype_name):
151151
cast_complex_to_real=True,
152152
)
153153

154+
@staticmethod
155+
def _build_to_jax(sample_data, dtype_name):
156+
return dispatch_backend_converter(
157+
backend="jax",
158+
dtype_name=dtype_name,
159+
sample_data=sample_data,
160+
cast_complex_to_real=True,
161+
)
162+
154163
def _dispatch_backend_converter(self, backend, dtype_name, sample_data):
155164
"""Return a conversion callable for the detected backend."""
156165
if sample_data is None:
@@ -164,6 +173,8 @@ def _dispatch_backend_converter(self, backend, dtype_name, sample_data):
164173
return self._build_to_torch(sample_data, dtype_name)
165174
if backend == "cupy":
166175
return self._build_to_cupy(sample_data, dtype_name)
176+
if backend == "jax":
177+
return self._build_to_jax(sample_data, dtype_name)
167178

168179
raise ValueError(f"Unsupported backend: {backend}")
169180

src/pepsy/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def run_gate(
584584
raise TypeError("Unexpected effective tensor type during run_gate.")
585585

586586
norm_f = (f.H & f).contract(all) ** 0.5
587-
self.local_norm_trace.append(complex(norm_f).real)
587+
self.local_norm_trace.append(ar.do("real", norm_f))
588588

589589
# Update tensor data
590590
psi[site].modify(data=f.data)

src/pepsy/gate.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,13 +1105,10 @@ def gate(tn, gates, where=None, **kwargs):
11051105
tn : qtn.TensorNetwork
11061106
Input tensor network.
11071107
gates : array_like | tuple | sequence
1108-
Gate payload in one of these forms:
1109-
1) single gate with explicit ``where`` argument:
1110-
``gate(tn, G, where=...)``
1111-
2) canonical bundled stream:
1112-
``gate(tn, ((G1, where1), (G2, where2), ...))``
1113-
The single bundled alias ``(gate, where)`` is intentionally rejected
1114-
to avoid ambiguity with a plain rank-2 gate tensor.
1108+
Gate payload. Use ``gate(tn, G, where=...)`` for a single gate, or
1109+
``gate(tn, ((G1, where1), (G2, where2), ...))`` for bundled gates.
1110+
The single bundled alias ``(gate, where)`` is rejected to avoid
1111+
ambiguity with a plain rank-2 gate tensor.
11151112
where : object, optional
11161113
Target location for single-gate form.
11171114
- 1D: ``1``, ``(1,)``, ``(1, 2)``
@@ -1623,13 +1620,12 @@ def build_pepo_from_gates(
16231620
Parameters
16241621
----------
16251622
gates : array_like | sequence | tuple
1626-
Gate payload in one of these forms:
1627-
1) single gate with explicit ``where`` argument:
1628-
``build_pepo_from_gates(G, where=...)``
1629-
2) canonical bundled stream:
1630-
``build_pepo_from_gates(((G1, where1), (G2, where2), ...))``
1631-
3) legacy parallel form with ``wheres``:
1632-
``build_pepo_from_gates([G1, G2], [where1, where2])``
1623+
Gate payload. Accepted forms are:
1624+
``build_pepo_from_gates(G, where=...)`` for a single gate,
1625+
``build_pepo_from_gates(((G1, where1), (G2, where2), ...))`` for the
1626+
canonical bundled stream, and
1627+
``build_pepo_from_gates([G1, G2], [where1, where2])`` for the legacy
1628+
parallel ``wheres`` form.
16331629
wheres : sequence[tuple] | None, optional
16341630
Legacy parallel where stream aligned with ``gates``.
16351631
where : object, optional
@@ -1735,13 +1731,12 @@ def build_mpo_from_gates(
17351731
Parameters
17361732
----------
17371733
gates : array_like | sequence | tuple
1738-
Gate payload in one of these forms:
1739-
1) single gate with explicit ``where`` argument:
1740-
``build_mpo_from_gates(G, where=...)``
1741-
2) canonical bundled stream:
1742-
``build_mpo_from_gates(((G1, where1), (G2, where2), ...))``
1743-
3) legacy parallel form with ``wheres``:
1744-
``build_mpo_from_gates([G1, G2], [where1, where2])``
1734+
Gate payload. Accepted forms are:
1735+
``build_mpo_from_gates(G, where=...)`` for a single gate,
1736+
``build_mpo_from_gates(((G1, where1), (G2, where2), ...))`` for the
1737+
canonical bundled stream, and
1738+
``build_mpo_from_gates([G1, G2], [where1, where2])`` for the legacy
1739+
parallel ``wheres`` form.
17451740
wheres : sequence[tuple[int, ...]] | None, optional
17461741
Legacy parallel where stream aligned with ``gates``.
17471742
where : object, optional

src/pepsy/gradient_solver.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,32 +90,9 @@ def _require_optax():
9090
class GradSolverResult:
9191
"""Structured optimization result returned by :class:`GradientOptimizer`.
9292
93-
Attributes
94-
----------
95-
params : dict[str, Any]
96-
Optimised parameters on their original devices, detached from
97-
autograd. Concrete element type is backend-dependent: a
98-
``torch.Tensor`` for the torch/scipy/nlopt solvers shipped here,
99-
a ``jax.Array`` / ``jnp.ndarray`` for a JAX-based solver, or any
100-
other array-like a future backend produces.
101-
history : list[float]
102-
Per-step loss trace. Length matches ``n_steps``.
103-
solver : str
104-
Normalised solver name actually used (e.g. ``"scipy"``).
105-
n_steps : int
106-
Number of entries recorded in ``history`` (iterations for scipy,
107-
function evaluations for nlopt/torch).
108-
best_loss : float
109-
Lowest loss value observed during the run.
110-
final_loss : float
111-
Loss after the final restored state (== ``best_loss`` when
112-
``restore_best=True``).
113-
convergence_reason : str
114-
``"maxiter"``, ``"patience"``, ``"bad_max"``, ``"empty_params"``,
115-
``"nan_x"``, or ``"nlopt_error:<ExcType>"``.
116-
n_evals : int
117-
Number of objective evaluations (``-1`` if the backend does not
118-
report it separately).
93+
This dataclass stores backend-agnostic outputs (parameters, loss history,
94+
convergence metadata, and evaluation counts). Field docs are derived from
95+
the dataclass members listed below.
11996
"""
12097

12198
# Backend-agnostic: holds torch.Tensor, jax.Array, np.ndarray, etc.

0 commit comments

Comments
 (0)