Skip to content

Commit 9280c3a

Browse files
Merge pull request #231 from patrick-kidger/dev
Version 0.3.0
2 parents 05d03d8 + 813fe0f commit 9280c3a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2474
-1991
lines changed

.github/workflows/run_tests.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ jobs:
77
run-tests:
88
strategy:
99
matrix:
10-
python-version: [ 3.7, 3.8, 3.9 ]
10+
python-version: [ 3.8, 3.9 ]
1111
os: [ ubuntu-latest ]
1212
fail-fast: false
1313
runs-on: ${{ matrix.os }}

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ repos:
44
hooks:
55
- id: black
66
- repo: https://github.com/nbQA-dev/nbQA
7-
rev: 1.2.3
7+
rev: 1.6.3
88
hooks:
99
- id: nbqa-black
1010
- id: nbqa-isort
1111
- id: nbqa-flake8
1212
- repo: https://github.com/PyCQA/isort
13-
rev: 5.10.1
13+
rev: 5.12.0
1414
hooks:
1515
- id: isort
1616
- repo: https://github.com/pycqa/flake8

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
2121
pip install diffrax
2222
```
2323

24-
Requires Python >=3.7 and JAX >=0.3.4.
24+
Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+.
2525

2626
## Documentation
2727

@@ -65,4 +65,6 @@ Neural networks: [Equinox](https://github.com/patrick-kidger/equinox).
6565

6666
Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: [jaxtyping](https://github.com/google/jaxtyping).
6767

68+
Computer vision models: [Eqxvision](https://github.com/paganpasta/eqxvision).
69+
6870
SymPy<->JAX conversion; train symbolic expressions via gradient descent: [sympy2jax](https://github.com/google/sympy2jax).

benchmarks/compile_times.py

+39-9
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import diffrax as dfx
55
import equinox as eqx
6-
import fire
76
import jax
87
import jax.numpy as jnp
98
import jax.random as jr
@@ -31,12 +30,12 @@ def __call__(self, t, y, args):
3130
return jnp.stack(y)
3231

3332

34-
def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str):
35-
if adjoint == "direct":
33+
def run(inline: bool, scan_stages: bool, grad: bool, adjoint_name: str):
34+
if adjoint_name == "direct":
3635
adjoint = dfx.DirectAdjoint()
37-
elif adjoint == "recursive":
36+
elif adjoint_name == "recursive":
3837
adjoint = dfx.RecursiveCheckpointAdjoint()
39-
elif adjoint == "backsolve":
38+
elif adjoint_name == "backsolve":
4039
adjoint = dfx.BacksolveAdjoint()
4140
else:
4241
raise ValueError
@@ -72,9 +71,40 @@ def solve(y0):
7271
return jnp.sum(sol.ys)
7372

7473
solve_ = ft.partial(solve, jnp.array([1.0]))
75-
print("Compile+run time", timeit.timeit(solve_, number=1))
76-
print("Run time", timeit.timeit(solve_, number=1))
74+
compile_time = timeit.timeit(solve_, number=1)
75+
print(
76+
f"{inline=}, {scan_stages=}, {grad=}, adjoint={adjoint_name}, {compile_time=}"
77+
)
7778

7879

79-
if __name__ == "__main__":
80-
fire.Fire(main)
80+
run(inline=False, scan_stages=False, grad=False, adjoint_name="direct")
81+
run(inline=False, scan_stages=False, grad=False, adjoint_name="recursive")
82+
run(inline=False, scan_stages=False, grad=False, adjoint_name="backsolve")
83+
84+
run(inline=False, scan_stages=False, grad=True, adjoint_name="direct")
85+
run(inline=False, scan_stages=False, grad=True, adjoint_name="recursive")
86+
run(inline=False, scan_stages=False, grad=True, adjoint_name="backsolve")
87+
88+
run(inline=False, scan_stages=True, grad=False, adjoint_name="direct")
89+
run(inline=False, scan_stages=True, grad=False, adjoint_name="recursive")
90+
run(inline=False, scan_stages=True, grad=False, adjoint_name="backsolve")
91+
92+
run(inline=False, scan_stages=True, grad=True, adjoint_name="direct")
93+
run(inline=False, scan_stages=True, grad=True, adjoint_name="recursive")
94+
run(inline=False, scan_stages=True, grad=True, adjoint_name="backsolve")
95+
96+
run(inline=True, scan_stages=False, grad=False, adjoint_name="direct")
97+
run(inline=True, scan_stages=False, grad=False, adjoint_name="recursive")
98+
run(inline=True, scan_stages=False, grad=False, adjoint_name="backsolve")
99+
100+
run(inline=True, scan_stages=False, grad=True, adjoint_name="direct")
101+
run(inline=True, scan_stages=False, grad=True, adjoint_name="recursive")
102+
run(inline=True, scan_stages=False, grad=True, adjoint_name="backsolve")
103+
104+
run(inline=True, scan_stages=True, grad=False, adjoint_name="direct")
105+
run(inline=True, scan_stages=True, grad=False, adjoint_name="recursive")
106+
run(inline=True, scan_stages=True, grad=False, adjoint_name="backsolve")
107+
108+
run(inline=True, scan_stages=True, grad=True, adjoint_name="direct")
109+
run(inline=True, scan_stages=True, grad=True, adjoint_name="recursive")
110+
run(inline=True, scan_stages=True, grad=True, adjoint_name="backsolve")

benchmarks/scan_stages.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`.
22
3-
On my CPU-only machine:
3+
On my relatively beefy CPU-only machine:
44
```
5-
bash> python scan_stages.py False
6-
Compile+run time 24.38062646985054
7-
Run time 0.0018830380868166685
5+
scan_stages=True
6+
Compile+run time 1.8253102810122073
7+
Run time 0.00017526978626847267
88
9-
bash> python scan_stages.py True
10-
Compile+run time 11.418417416978627
11-
Run time 0.0014536201488226652
9+
scan_stages=False
10+
Compile+run time 10.679616351146251
11+
Run time 0.00021236995235085487
1212
```
1313
"""
1414

@@ -17,7 +17,6 @@
1717

1818
import diffrax as dfx
1919
import equinox as eqx
20-
import fire
2120
import jax.numpy as jnp
2221
import jax.random as jr
2322

@@ -44,7 +43,7 @@ def __call__(self, t, y, args):
4443
return jnp.stack(y)
4544

4645

47-
def main(scan_stages):
46+
def run(scan_stages):
4847
vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0))
4948
term = dfx.ODETerm(vf)
5049
solver = dfx.Dopri8(scan_stages=scan_stages)
@@ -60,8 +59,11 @@ def solve(y0):
6059
)
6160

6261
solve_ = ft.partial(solve, jnp.array([1.0]))
62+
print(f"scan_stages={scan_stages}")
6363
print("Compile+run time", timeit.timeit(solve_, number=1))
6464
print("Run time", timeit.timeit(solve_, number=1))
6565

6666

67-
fire.Fire(main)
67+
run(scan_stages=True)
68+
print()
69+
run(scan_stages=False)

benchmarks/scan_stages_cnf.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
import diffrax
3434
import equinox as eqx
35-
import fire
3635
import jax
3736
import jax.nn as jnn
3837
import jax.numpy as jnp
@@ -50,7 +49,7 @@ def vector_field_prob(t, input, model):
5049
return f, logp
5150

5251

53-
@eqx.filter_vmap(args=(None, 0, None, None))
52+
@eqx.filter_vmap(in_axes=(None, 0, None, None))
5453
def log_prob(model, y0, scan_stages, backsolve):
5554
term = diffrax.ODETerm(vector_field_prob)
5655
solver = diffrax.Dopri5(scan_stages=scan_stages)
@@ -80,13 +79,18 @@ def solve(model, inputs, scan_stages, backsolve):
8079
return -log_prob(model, inputs, scan_stages, backsolve).mean()
8180

8281

83-
def main(scan_stages, backsolve):
82+
def run(scan_stages, backsolve):
8483
mkey, dkey = jr.split(jr.PRNGKey(0), 2)
8584
model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey)
8685
x = jr.normal(dkey, (256, 2))
87-
solve_ = ft.partial(solve, model, x, scan_stages, backsolve)
88-
print("Compile+run time", timeit.timeit(solve_, number=1))
89-
print("Run time", timeit.timeit(solve_, number=1))
86+
solve2 = ft.partial(solve, model, x, scan_stages, backsolve)
87+
print(f"scan_stages={scan_stages}, backsolve={backsolve}")
88+
print("Compile+run time", timeit.timeit(solve2, number=1))
89+
print("Run time", timeit.timeit(solve2, number=1))
90+
print()
9091

9192

92-
fire.Fire(main)
93+
run(scan_stages=False, backsolve=False)
94+
run(scan_stages=False, backsolve=True)
95+
run(scan_stages=True, backsolve=False)
96+
run(scan_stages=True, backsolve=True)

benchmarks/small_neural_ode.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
"""Benchmarks Diffrax vs torchdiffeq vs jax.experimental.ode.odeint"""
2+
13
import gc
24
import time
35

46
import diffrax
57
import equinox as eqx
6-
import fire
78
import jax
89
import jax.experimental.ode as experimental
910
import jax.nn as jnn
@@ -166,7 +167,7 @@ def time_jax(neural_ode_jax, y0, t1, grad):
166167
_eval_jax(neural_ode_jax, y0, t1)
167168

168169

169-
def main(batch_size=64, t1=100, multiple=False, grad=False):
170+
def run(multiple, grad, batch_size=64, t1=100):
170171
neural_ode_torch = NeuralODETorch(multiple)
171172
neural_ode_diffrax = NeuralODEDiffrax(multiple)
172173
neural_ode_experimental = NeuralODEExperimental(multiple)
@@ -180,25 +181,28 @@ def main(batch_size=64, t1=100, multiple=False, grad=False):
180181
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))
181182

182183
y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4))
183-
y0_torch = torch.tensor(y0_jax.to_py())
184+
y0_torch = torch.tensor(np.asarray(y0_jax))
184185

185186
time_torch(neural_ode_torch, y0_torch, t1, grad)
186187
torch_time = time_torch(neural_ode_torch, y0_torch, t1, grad)
187188

188-
time_jax(neural_ode_diffrax, y0_jax, t1, grad)
189-
diffrax_time = time_jax(neural_ode_diffrax, y0_jax, t1, grad)
189+
time_jax(neural_ode_diffrax, jnp.copy(y0_jax), t1, grad)
190+
diffrax_time = time_jax(neural_ode_diffrax, jnp.copy(y0_jax), t1, grad)
190191

191-
time_jax(neural_ode_experimental, y0_jax, t1, grad)
192-
experimental_time = time_jax(neural_ode_experimental, y0_jax, t1, grad)
192+
time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad)
193+
experimental_time = time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad)
193194

194195
print(
195-
f"""
196-
torch_time={torch_time}
197-
diffrax_time={diffrax_time}
198-
experimetnal_time={experimental_time}
196+
f""" multiple={multiple}, grad={grad}
197+
torch_time={torch_time}
198+
diffrax_time={diffrax_time}
199+
experimental_time={experimental_time}
199200
"""
200201
)
201202

202203

203204
if __name__ == "__main__":
204-
fire.Fire(main)
205+
run(multiple=False, grad=False)
206+
run(multiple=True, grad=False)
207+
run(multiple=False, grad=True)
208+
run(multiple=True, grad=True)

diffrax/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from .adjoint import (
22
AbstractAdjoint,
33
BacksolveAdjoint,
4+
DirectAdjoint,
45
ImplicitAdjoint,
5-
NoAdjoint,
66
RecursiveCheckpointAdjoint,
77
)
8+
from .autocitation import citation, citation_rules
89
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
910
from .event import (
1011
AbstractDiscreteTerminatingEvent,
@@ -27,14 +28,14 @@
2728
LocalLinearInterpolation,
2829
ThirdOrderHermitePolynomialInterpolation,
2930
)
30-
from .misc import adjoint_rms_seminorm, sde_kl_divergence
31+
from .misc import adjoint_rms_seminorm
3132
from .nonlinear_solver import (
3233
AbstractNonlinearSolver,
3334
NewtonNonlinearSolver,
3435
NonlinearSolution,
3536
)
3637
from .path import AbstractPath
37-
from .saveat import SaveAt
38+
from .saveat import SaveAt, SubSaveAt
3839
from .solution import is_event, is_okay, is_successful, RESULTS, Solution
3940
from .solver import (
4041
AbstractAdaptiveSolver,
@@ -55,7 +56,6 @@
5556
Dopri8,
5657
Euler,
5758
EulerHeun,
58-
Fehlberg2,
5959
HalfSolver,
6060
Heun,
6161
ImplicitEuler,
@@ -87,4 +87,4 @@
8787
)
8888

8989

90-
__version__ = "0.2.2"
90+
__version__ = "0.3.0"

diffrax/misc/ad.py diffrax/ad.py

File renamed without changes.

0 commit comments

Comments
 (0)