Skip to content

Commit 30e3d2f

Browse files
committed
Refine MPS optimizer API and clean up PEPS helpers
1 parent 528b24b commit 30e3d2f

8 files changed

Lines changed: 550 additions & 174 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ store/
1313
nohup.out
1414

1515
*.egg-info/
16+
.eggs/
17+
build/
18+
dist/

__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/pepsy/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
reset_default_backends,
3737
set_default_array_backend,
3838
set_default_grad_backend,
39-
tn_applied,
39+
tns_align,
4040
)
4141
from .linalg_registrations import reg_complex_svd_jax, reg_complex_svd_torch
4242
from .debug import plot_sweep_diagnostics, plot_inner_loss, plot_global_loss_trajectory
@@ -84,7 +84,7 @@
8484
"plot_sweep_diagnostics",
8585
"plot_inner_loss",
8686
"plot_global_loss_trajectory",
87-
"tn_applied",
87+
"tns_align",
8888
"gen_long_range_swap_path",
8989
"apply_2dtn_",
9090
"apply_gates",
@@ -206,14 +206,14 @@ def __getattr__(name):
206206
"su4": su4,
207207
}[name]
208208

209-
if name in ("tn_applied", "product_state_peps"):
209+
if name in ("tns_align", "product_state_peps"):
210210
from .core import ( # pylint: disable=import-outside-toplevel
211211
product_state_peps,
212-
tn_applied,
212+
tns_align,
213213
)
214214

215215
return {
216-
"tn_applied": tn_applied,
216+
"tns_align": tns_align,
217217
"product_state_peps": product_state_peps,
218218
}[name]
219219

src/pepsy/core.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import importlib.util
44
import math
5+
import os
6+
import tempfile
57
import warnings
68
from typing import Any
79

@@ -25,12 +27,74 @@
2527
"build_optimizer",
2628
"build_compressed_optimizer",
2729
"tn_fidelity",
30+
"tns_align",
31+
"product_state_peps",
2832
]
2933

3034
_DEFAULT_ARRAY_BACKEND = None
3135
_DEFAULT_GRAD_BACKEND = None
3236

3337

38+
def _default_cache_root():
39+
"""Return default cache root for pepsy artifacts."""
40+
env_cache = os.environ.get("PEPSY_CACHE_DIR")
41+
if env_cache:
42+
return env_cache
43+
44+
try:
45+
from platformdirs import user_cache_dir # pylint: disable=import-outside-toplevel
46+
47+
return user_cache_dir("pepsy")
48+
except Exception: # pragma: no cover - optional dependency fallback
49+
return os.path.join(os.path.expanduser("~"), ".cache", "pepsy")
50+
51+
52+
def _resolve_cache_directory(directory, subdir):
53+
"""Resolve cache directory, honoring global disable and env override."""
54+
if directory is not None:
55+
return _ensure_cache_directory(directory, warn=True)
56+
57+
disable_cache = os.environ.get("PEPSY_DISABLE_CACHE", "").strip().lower()
58+
if disable_cache in {"1", "true", "yes", "on"}:
59+
return None
60+
61+
default_cache = _ensure_cache_directory(os.path.join(_default_cache_root(), subdir), warn=False)
62+
if default_cache is not None:
63+
return default_cache
64+
65+
# Fallback for restricted environments where user-cache is not writable.
66+
fallback_cache = _ensure_cache_directory(
67+
os.path.join(tempfile.gettempdir(), "pepsy-cache", subdir),
68+
warn=False,
69+
)
70+
if fallback_cache is not None:
71+
return fallback_cache
72+
73+
warnings.warn(
74+
"No writable cache directory available. Disabling optimizer cache.",
75+
RuntimeWarning,
76+
stacklevel=2,
77+
)
78+
return None
79+
80+
81+
def _ensure_cache_directory(path, warn=False):
82+
"""Create cache directory when possible, else return ``None``."""
83+
if path is None:
84+
return None
85+
try:
86+
os.makedirs(path, exist_ok=True)
87+
return path
88+
except OSError:
89+
if warn:
90+
warnings.warn(
91+
f"Cache directory '{path}' is not writable. Disabling optimizer cache.",
92+
RuntimeWarning,
93+
stacklevel=2,
94+
)
95+
return None
96+
97+
3498
def _validate_backend_callable(name, fn):
3599
if fn is not None and not callable(fn):
36100
raise TypeError(f"{name} must be callable or None")
@@ -183,17 +247,26 @@ def build_optimizer(
183247
max_repeats=2**6,
184248
parallel=True,
185249
optlib="cmaes",
186-
directory="cash/",
250+
directory=None,
187251
hash_method="b",
188252
):
189-
"""Build and return a reusable cotengra contraction optimizer."""
253+
"""Build and return a reusable cotengra contraction optimizer.
254+
255+
Parameters
256+
----------
257+
directory : str | None, optional
258+
Cache directory for optimizer artifacts. If ``None``, defaults to
259+
``$PEPSY_CACHE_DIR/cotengra`` (or OS user-cache dir fallback). Set
260+
environment variable ``PEPSY_DISABLE_CACHE=1`` to force ``None``.
261+
"""
190262
selected_optlib = optlib
191263
if selected_optlib == "cmaes" and importlib.util.find_spec("cmaes") is None:
192264
warnings.warn(
193265
"Package 'cmaes' not found. Falling back to optlib='random'.",
194266
RuntimeWarning,
195267
)
196268
selected_optlib = "random"
269+
cache_dir = _resolve_cache_directory(directory, "cotengra")
197270
opt = ctg.ReusableHyperOptimizer(
198271
minimize=f"combo-{int(alpha)}",
199272
slicing_opts={"target_size": 2**40},
@@ -204,7 +277,7 @@ def build_optimizer(
204277
optlib=selected_optlib,
205278
max_time=max_time,
206279
hash_method=hash_method,
207-
directory=directory,
280+
directory=cache_dir,
208281
progbar=progbar,
209282
on_trial_error="ignore",
210283
)
@@ -218,14 +291,23 @@ def build_compressed_optimizer(
218291
max_repeats=2**8,
219292
max_time="rate:1e8",
220293
):
221-
"""Build and return a reusable cotengra compressed optimizer."""
294+
"""Build and return a reusable cotengra compressed optimizer.
295+
296+
Parameters
297+
----------
298+
directory : str | None, optional
299+
Cache directory for optimizer artifacts. If ``None``, defaults to
300+
``$PEPSY_CACHE_DIR/cotengra-compressed`` (or OS user-cache dir
301+
fallback). Set ``PEPSY_DISABLE_CACHE=1`` to force ``None``.
302+
"""
303+
cache_dir = _resolve_cache_directory(directory, "cotengra-compressed")
222304
copt = ctg.ReusableHyperCompressedOptimizer(
223305
chi,
224306
max_repeats=max_repeats,
225307
minimize="combo-compressed",
226308
progbar=progbar,
227309
max_time=max_time,
228-
directory=directory,
310+
directory=cache_dir,
229311
)
230312
return copt
231313

@@ -259,7 +341,7 @@ def add_cycle(peps, bond_dim, cylinder=False):
259341
return peps
260342

261343

262-
def tn_applied(p, pepo):
344+
def tns_align(p, pepo):
263345
r"""Apply a PEPO operator to a PEPS ket: :math:`\hat{O}|\psi\rangle`.
264346
265347
The PEPO ``k``-indices contract with the PEPS ``k``-indices on join.

src/pepsy/dmrg_fit.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ def __init__(
5555
self,
5656
tn: qtn.TensorNetwork,
5757
p: Optional[qtn.TensorNetwork] = None,
58-
cutoffs: float = 1.e-12,
58+
cutoffs: float = 1e-12,
5959
backend: Optional[str] = None,
6060
site_tag_id: str = "I{}",
6161
opt: str = "auto-hq",
6262
range_int: Optional[Sequence[int]] = None,
6363
re_tag: bool = False,
6464
info: Optional[Dict[str, Any]] = None,
6565
warning: bool = False,
66-
inplace=False,
66+
inplace: bool = False,
6767
): # pylint: disable=too-many-arguments,too-many-positional-arguments
6868

6969
if p is None:
@@ -79,19 +79,19 @@ def __init__(
7979

8080
self.tn = tn.copy()
8181

82-
8382
if site_tag_id:
83+
site_ind_id = getattr(self.p, "site_ind_id", None)
8484
self.p.view_as_(
8585
qtn.MatrixProductState,
8686
L=self.L,
8787
site_tag_id=site_tag_id,
88-
site_ind_id=None,
88+
site_ind_id=site_ind_id,
8989
cyclic=False,
9090
)
9191

9292
self.site_tag_id = site_tag_id
9393

94-
# cotengra path finder
94+
# Contraction path optimizer spec.
9595
self.opt = opt
9696

9797
# cutoffs and underlying backend
@@ -101,9 +101,9 @@ def __init__(
101101
# warnings being printed or not
102102
self.warning = warning
103103

104-
# store cost function results
105-
self.loss: List[float] = []
106-
self.loss_: List[float] = []
104+
# Diagnostics collected during sweeps.
105+
self.fidelity_trace: List[float] = []
106+
self.local_norm_trace: List[float] = []
107107
self.info: Dict[str, Any] = info or {}
108108
self.range_int: List[int] = list(range_int) if range_int is not None else []
109109
if self.range_int:
@@ -114,7 +114,6 @@ def __init__(
114114
raise ValueError("range_int must satisfy start < stop.")
115115

116116

117-
# Is there a better solution?
118117
# Reindex tensor network with random UUIDs for internal indices
119118
self.tn.reindex_({idx: qtn.rand_uuid() for idx in self.tn.inner_inds()})
120119

@@ -125,7 +124,6 @@ def __init__(
125124
if re_tag:
126125
self._re_tag()
127126

128-
129127
def visual(
130128
self,
131129
figsize=(14, 14),
@@ -215,7 +213,6 @@ def _re_tag(self):
215213
)
216214
self._deep_tag()
217215

218-
219216
def run(self, n_iter=6, verbose=False):
220217
"""Run basic left-to-right local fitting sweeps.
221218
@@ -224,7 +221,7 @@ def run(self, n_iter=6, verbose=False):
224221
n_iter : int
225222
Number of complete sweeps.
226223
verbose : bool
227-
If ``True``, append per-sweep fidelity values to ``self.loss``.
224+
If ``True``, append per-sweep fidelity values to ``self.fidelity_trace``.
228225
"""
229226
if self.p is None:
230227
raise ValueError("Initial state `p` must be provided.")
@@ -250,15 +247,15 @@ def run(self, n_iter=6, verbose=False):
250247
f = f.transpose(*psi[site].inds)
251248

252249
norm_f = (f.H & f).contract(all) ** 0.5
253-
self.loss_.append(complex(norm_f).real)
250+
self.local_norm_trace.append(complex(norm_f).real)
254251

255252
# Update tensor data
256253
psi[site].modify(data=f.data)
257254

258255
# Compute fidelity if verbose mode is enabled
259256
if verbose:
260257
fidelity = tn_fidelity(self.tn, psi)
261-
self.loss.append(ar.do("real", fidelity))
258+
self.fidelity_trace.append(ar.do("real", fidelity))
262259

263260
def _build_env_right(self, psi, env_right):
264261
"""
@@ -286,9 +283,6 @@ def _build_env_right(self, psi, env_right):
286283
t |= env_right[site_tag_id.format(i + 1)]
287284
env_right[site_tag_id.format(i)] = t.contract(all, optimize=opt)
288285

289-
290-
291-
292286
def _right_range(self, psi, env_right, start, stop):
293287
"""
294288
Build right environments env_right["I{i}"] for i in 0..L-1.
@@ -355,8 +349,6 @@ def _left_range(self, psi, site, count, env_left):
355349
t |= env_left[site_tag_id.format(site - 1)]
356350
env_left[site_tag_id.format(site)] = t.contract(all, optimize=opt)
357351

358-
359-
360352
def _update_env_left(self, psi, site: int, env_left):
361353
"""Update left environment incrementally for current site."""
362354

@@ -376,8 +368,11 @@ def _update_env_left(self, psi, site: int, env_left):
376368
t |= env_left[site_tag_id.format(site - 1)]
377369
env_left[site_tag_id.format(site)] = t.contract(all, optimize=opt)
378370

379-
380-
def run_eff(self, n_iter=6, verbose=False): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
371+
def run_eff(
372+
self,
373+
n_iter=6,
374+
verbose=False,
375+
): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
381376
"""Run environment-based fitting sweeps with cached left/right blocks.
382377
383378
This method avoids rebuilding full contractions at each site by
@@ -456,7 +451,7 @@ def run_eff(self, n_iter=6, verbose=False): # pylint: disable=too-many-branches
456451
raise TypeError("Unexpected effective tensor type during run_eff.")
457452

458453
norm_f = (f.H & f).contract(all) ** 0.5
459-
self.loss_.append(complex(norm_f).real)
454+
self.local_norm_trace.append(complex(norm_f).real)
460455

461456
# Contract and normalize
462457
# Update tensor data
@@ -465,9 +460,7 @@ def run_eff(self, n_iter=6, verbose=False): # pylint: disable=too-many-branches
465460
# Compute fidelity if verbose mode is enabled
466461
if verbose:
467462
fidelity = tn_fidelity(self.tn, psi)
468-
self.loss.append(ar.do("real", fidelity))
469-
470-
463+
self.fidelity_trace.append(ar.do("real", fidelity))
471464

472465
def run_gate(
473466
self, n_iter=6, verbose=False
@@ -525,7 +518,6 @@ def run_gate(
525518
tn = env_right[site_tag_id.format(site + 1)]
526519

527520
if 0 < site < L - 1:
528-
529521
# Boundary consistency: the left and right indices must match between tn and p
530522
if count_ == 0:
531523
indx = psi.bond(start - 1, start)
@@ -579,10 +571,7 @@ def run_gate(
579571
raise TypeError("Unexpected effective tensor type during run_gate.")
580572

581573
norm_f = (f.H & f).contract(all) ** 0.5
582-
583-
# norm_f = ar.do("norm", f.data)
584-
585-
self.loss_.append(complex(norm_f).real)
574+
self.local_norm_trace.append(complex(norm_f).real)
586575

587576
# Contract and normalize
588577
# Update tensor data
@@ -591,8 +580,7 @@ def run_gate(
591580
if site < stop:
592581
psi.left_canonize_site(site, bra=None)
593582

594-
595583
# Compute fidelity if verbose mode is enabled
596584
if verbose:
597585
fidelity = tn_fidelity(self.tn, psi)
598-
self.loss.append(ar.do("real", fidelity))
586+
self.fidelity_trace.append(ar.do("real", fidelity))

0 commit comments

Comments
 (0)