|
3 | 3 |
|
4 | 4 | import diffrax as dfx
|
5 | 5 | import equinox as eqx
|
6 |
| -import fire |
7 | 6 | import jax
|
8 | 7 | import jax.numpy as jnp
|
9 | 8 | import jax.random as jr
|
@@ -31,12 +30,12 @@ def __call__(self, t, y, args):
|
31 | 30 | return jnp.stack(y)
|
32 | 31 |
|
33 | 32 |
|
34 |
| -def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str): |
35 |
| - if adjoint == "direct": |
| 33 | +def run(inline: bool, scan_stages: bool, grad: bool, adjoint_name: str): |
| 34 | + if adjoint_name == "direct": |
36 | 35 | adjoint = dfx.DirectAdjoint()
|
37 |
| - elif adjoint == "recursive": |
| 36 | + elif adjoint_name == "recursive": |
38 | 37 | adjoint = dfx.RecursiveCheckpointAdjoint()
|
39 |
| - elif adjoint == "backsolve": |
| 38 | + elif adjoint_name == "backsolve": |
40 | 39 | adjoint = dfx.BacksolveAdjoint()
|
41 | 40 | else:
|
42 | 41 | raise ValueError
|
@@ -72,9 +71,40 @@ def solve(y0):
|
72 | 71 | return jnp.sum(sol.ys)
|
73 | 72 |
|
74 | 73 | solve_ = ft.partial(solve, jnp.array([1.0]))
|
75 |
| - print("Compile+run time", timeit.timeit(solve_, number=1)) |
76 |
| - print("Run time", timeit.timeit(solve_, number=1)) |
| 74 | + compile_time = timeit.timeit(solve_, number=1) |
| 75 | + print( |
| 76 | + f"{inline=}, {scan_stages=}, {grad=}, adjoint={adjoint_name}, {compile_time=}" |
| 77 | + ) |
77 | 78 |
|
78 | 79 |
|
79 |
| -if __name__ == "__main__": |
80 |
| - fire.Fire(main) |
| 80 | +run(inline=False, scan_stages=False, grad=False, adjoint_name="direct") |
| 81 | +run(inline=False, scan_stages=False, grad=False, adjoint_name="recursive") |
| 82 | +run(inline=False, scan_stages=False, grad=False, adjoint_name="backsolve") |
| 83 | + |
| 84 | +run(inline=False, scan_stages=False, grad=True, adjoint_name="direct") |
| 85 | +run(inline=False, scan_stages=False, grad=True, adjoint_name="recursive") |
| 86 | +run(inline=False, scan_stages=False, grad=True, adjoint_name="backsolve") |
| 87 | + |
| 88 | +run(inline=False, scan_stages=True, grad=False, adjoint_name="direct") |
| 89 | +run(inline=False, scan_stages=True, grad=False, adjoint_name="recursive") |
| 90 | +run(inline=False, scan_stages=True, grad=False, adjoint_name="backsolve") |
| 91 | + |
| 92 | +run(inline=False, scan_stages=True, grad=True, adjoint_name="direct") |
| 93 | +run(inline=False, scan_stages=True, grad=True, adjoint_name="recursive") |
| 94 | +run(inline=False, scan_stages=True, grad=True, adjoint_name="backsolve") |
| 95 | + |
| 96 | +run(inline=True, scan_stages=False, grad=False, adjoint_name="direct") |
| 97 | +run(inline=True, scan_stages=False, grad=False, adjoint_name="recursive") |
| 98 | +run(inline=True, scan_stages=False, grad=False, adjoint_name="backsolve") |
| 99 | + |
| 100 | +run(inline=True, scan_stages=False, grad=True, adjoint_name="direct") |
| 101 | +run(inline=True, scan_stages=False, grad=True, adjoint_name="recursive") |
| 102 | +run(inline=True, scan_stages=False, grad=True, adjoint_name="backsolve") |
| 103 | + |
| 104 | +run(inline=True, scan_stages=True, grad=False, adjoint_name="direct") |
| 105 | +run(inline=True, scan_stages=True, grad=False, adjoint_name="recursive") |
| 106 | +run(inline=True, scan_stages=True, grad=False, adjoint_name="backsolve") |
| 107 | + |
| 108 | +run(inline=True, scan_stages=True, grad=True, adjoint_name="direct") |
| 109 | +run(inline=True, scan_stages=True, grad=True, adjoint_name="recursive") |
| 110 | +run(inline=True, scan_stages=True, grad=True, adjoint_name="backsolve") |
0 commit comments