Skip to content

Commit 0698bab

Browse files
committed
Add dedicated FD solver module and public FDSolver API
1 parent 7ee3f4f commit 0698bab

8 files changed

Lines changed: 1327 additions & 18 deletions

File tree

src/pepsy/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
boundary_sweeps,
1717
core,
1818
fit,
19+
ft_solver,
1920
gate,
2021
ham,
2122
gradient_solver,
@@ -103,6 +104,7 @@
103104
from .optimize_energy import EnergyOptimizer
104105
from .optimize_mps import MpsOptimizer
105106
from .optimize_mpo import MpoOptimizer
107+
from .gradient_solver import FDSolver
106108

107109
__all__ = [
108110
"__version__",
@@ -123,6 +125,7 @@
123125
"register_torch_linalg",
124126
"reset_default_backends",
125127
"SweepOptimizer",
128+
"FDSolver",
126129
"EnergyOptimizer",
127130
"tns_align",
128131
"measure_obs",
@@ -182,6 +185,7 @@
182185
"optimize_mps",
183186
"optimize_mpo",
184187
"gradient_solver",
188+
"ft_solver",
185189
"ham",
186190
"boundary_metrics",
187191
"boundary_states",
@@ -201,6 +205,7 @@ def __getattr__(name):
201205
"boundary_sweeps",
202206
"ham",
203207
"gradient_solver",
208+
"ft_solver",
204209
"optimize_mps",
205210
"optimize_global",
206211
"optimize_sweep",
@@ -476,6 +481,11 @@ def __getattr__(name):
476481

477482
return SweepOptimizer
478483

484+
if name == "FDSolver":
485+
from .gradient_solver import FDSolver # pylint: disable=import-outside-toplevel
486+
487+
return FDSolver
488+
479489
if name == "EnergyOptimizer":
480490
from .optimize_energy import EnergyOptimizer # pylint: disable=import-outside-toplevel
481491

src/pepsy/core.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,9 @@ def backend_cupy(device=None, dtype=None):
773773
----------
774774
device : int | cupy.cuda.Device | None, optional
775775
Target CUDA device. If ``None``, use CuPy's current device.
776-
dtype : dtype-like | None, optional
777-
Target CuPy dtype. If ``None``, infer from input.
776+
dtype : dtype-like | torch.dtype | None, optional
777+
Target CuPy dtype. If ``None``, infer from input. Torch dtypes are
778+
accepted and internally mapped to CuPy-compatible dtypes.
778779
"""
779780
try:
780781
import cupy as cp # pylint: disable=import-outside-toplevel
@@ -790,6 +791,26 @@ def backend_cupy(device=None, dtype=None):
790791
if isinstance(target_device, int):
791792
target_device = cp.cuda.Device(target_device)
792793

794+
if torch is not None and isinstance(dtype, torch.dtype):
795+
torch_to_cupy = {
796+
torch.complex128: cp.complex128,
797+
torch.complex64: cp.complex64,
798+
torch.float64: cp.float64,
799+
torch.float32: cp.float32,
800+
torch.float16: cp.float16,
801+
torch.int64: cp.int64,
802+
torch.int32: cp.int32,
803+
torch.int16: cp.int16,
804+
torch.int8: cp.int8,
805+
torch.uint8: cp.uint8,
806+
torch.bool: cp.bool_,
807+
}
808+
if dtype not in torch_to_cupy:
809+
raise ValueError(
810+
f"backend_cupy does not support torch dtype {dtype!r}."
811+
)
812+
dtype = torch_to_cupy[dtype]
813+
793814
def cast_array(x, device=target_device, dtype=dtype):
794815
if device is None:
795816
return cp.asarray(x, dtype=dtype)

0 commit comments

Comments
 (0)