|
2 | 2 |
|
3 | 3 | Default is ``np.float32``, which matches the precision of MD trajectory |
4 | 4 | coordinates (mdtraj stores ``traj.xyz`` as float32) and is sufficient |
5 | | -for all analysis operations in this package. |
6 | | -
|
7 | | -Float64 appears in the analysis pipeline only where it is genuinely |
8 | | -necessary or where an external library forces it: |
9 | | -
|
10 | | -- **Numba JIT kernels** (``_backends/_distances.distances_numba``, |
11 | | - ``_backends/_rmsd_matrix._pairwise_rmsd``): compiled kernels output |
12 | | - float64 because Numba's ``float()`` cast maps to C ``double``. |
13 | | - Numba runs on CPU where float64 is at ~50% of float32 throughput, |
14 | | - so the cost is negligible and the extra precision is useful for the |
15 | | - QCP Newton-Raphson subtraction ``G_a + G_b - 2*lambda``. Callers |
16 | | - cast the result to the resolved user dtype afterward. |
| 5 | +for all analysis operations in this package. **Every** compute function |
| 6 | +returns float32 by default; users who want float64 must opt in either |
| 7 | +globally via :func:`set_default_dtype` or per-call via ``dtype=np.float64``. |
| 8 | +
|
| 9 | +Design rules for new compute code |
| 10 | +--------------------------------- |
| 11 | +
|
| 12 | +1. The public function's last keyword argument is |
| 13 | + ``dtype: DtypeArg = None``. |
| 14 | +2. Call ``resolved = resolve_dtype(dtype)`` at the top. |
| 15 | +3. Pass ``resolved`` through to every downstream buffer allocation and |
| 16 | + cast outputs via ``np.asarray(result, dtype=resolved)`` / |
| 17 | + ``result.astype(resolved, copy=False)`` so same-dtype returns do |
| 18 | + not duplicate memory. |
| 19 | +4. **Backend kernels** (numba/torch/jax/cupy) return their native dtype |
| 20 | + (``NDArray[np.floating]``) and should prefer float32 output unless |
| 21 | + external precision is required. The public wrapper's ``copy=False`` |
| 22 | + cast becomes a no-op when the kernel already returns the resolved |
| 23 | + dtype, which is essential at large N where each redundant N^2 copy |
| 24 | + can cost tens of GB. |
| 25 | +
|
| 26 | +Where float64 still appears (and why) |
| 27 | +------------------------------------- |
| 28 | +
|
| 29 | +These are the only places fp64 remains in the compute pipeline; each is |
| 30 | +either an O(1)-to-O(n) scalar buffer (not an OOM risk) or forced by an |
| 31 | +external library: |
| 32 | +
|
| 33 | +- **QCP Newton-Raphson scalars** in ``_backends/_rmsd_matrix._pairwise_rmsd`` |
| 34 | + and the ``traces`` buffer in ``_center_and_traces``: accumulators |
| 35 | + (``Sxx`` etc.) and the ``(G_a + G_b - 2*lambda)`` subtraction run in |
| 36 | + double precision because Numba's ``0.0`` literal maps to C |
| 37 | + ``double``. Only the final ``result[i, j] = val`` store truncates |
| 38 | + to float32 so the O(N^2) output matrix is half the memory of the |
| 39 | + old float64 output (58 GB saved at n=120k). The ``traces`` buffer |
| 40 | + is O(n_frames) so the fp64 cost is negligible. |
17 | 41 | - **GPU backends** (``_backends/_distances`` and |
18 | 42 | ``_backends/_rmsd_matrix`` ``torch``/``jax``/``cupy`` variants): |
19 | | - compute **internally in float32** because consumer and workstation |
20 | | - NVIDIA GPUs run float64 at 1/36 -- 1/64 the throughput of float32. |
21 | | - Since 2026-04-11 these backends also **return native float32** |
22 | | - (the ``RMSDMatrixBackendFn`` / ``DistanceBackendFn`` Protocols were |
23 | | - widened from ``NDArray[np.float64]`` to ``NDArray[np.floating]`` |
24 | | - so backends can report their natural dtype). The public |
25 | | - ``compute_*`` wrappers then cast with ``astype(resolved, copy=False)`` |
26 | | - so when the resolved dtype is also float32 (the package default) |
27 | | - **no additional copy is made** -- critical for large N where |
28 | | - every redundant copy of the ``(n_frames, n_frames)`` RMSD matrix |
29 | | - costs tens of GB (57 GB at n=120k). Float32 QCP agrees with the |
| 43 | + compute internally in float32 because consumer and workstation |
| 44 | + NVIDIA GPUs run float64 at 1/36 -- 1/64 the throughput of float32, |
| 45 | + and return native float32 directly. Float32 QCP agrees with the |
30 | 46 | float64 numba reference to ~1e-6 nm on realistic trajectories. |
31 | 47 | - **Deeptime TICA** (``decomposition.compute_tica``): deeptime upcasts |
32 | | - to float64 internally for covariance estimation -- no explicit cast |
33 | | - is needed from our side. |
| 48 | + to float64 internally for covariance estimation -- external to us. |
| 49 | + The output is cast back to the resolved dtype by the wrapper. |
34 | 50 | - **``np.histogram2d``** (``fes.compute_fes_2d``): returns float64 |
35 | | - probability density regardless of input dtype (edges follow the |
36 | | - input dtype); the downstream log and energy arithmetic therefore |
37 | | - runs in float64 naturally. |
| 51 | + probability density regardless of input dtype; the downstream log |
| 52 | + and energy arithmetic therefore runs in float64 naturally. Output |
| 53 | + is O(bins^2), tiny. |
38 | 54 | - **``np.mean`` on boolean arrays** (contacts, h-bonds): NumPy defaults |
39 | | - to float64 for boolean reductions. |
| 55 | + to float64 for boolean reductions. Output is O(n), tiny. |
40 | 56 | - **``jax.config.update("jax_enable_x64", True)``** in |
41 | 57 | ``_backends/_imports.require_jax``: enables float64 support in JAX |
42 | | - so ``jnp.float64`` arrays can round-trip through the JIT. The |
43 | | - actual JAX compute still runs in float32 on GPU. |
| 58 | + so ``jnp.float64`` arrays can round-trip through the JIT when the |
| 59 | + user explicitly opts in. The actual JAX compute still runs in |
| 60 | + float32 on GPU by default. |
| 61 | +
|
| 62 | +Opting into float64 |
| 63 | +------------------- |
44 | 64 |
|
45 | 65 | Use ``set_default_dtype(np.float64)`` to switch globally, or pass |
46 | | -``dtype=np.float64`` to individual functions. |
| 66 | +``dtype=np.float64`` to individual functions. Be aware that float64 |
| 67 | +doubles the memory of every O(N^2) or O(N*M) intermediate, which will |
| 68 | +OOM at trajectory sizes above ~40k frames on a 128 GB host. |
47 | 69 | """ |
48 | 70 |
|
49 | 71 | from __future__ import annotations |
|
0 commit comments