Skip to content

Conversation

@sef43
Copy link

@sef43 sef43 commented Nov 11, 2025

As discussed by @giadefa This PR adds the necessary changes to AIMNet2 so that the model can be torch.compile'd with cudagraphs enabled.

This speeds up small molecule MD significantly. The new example ase_md.py script demonstrates the speedup.
The original runs 10000 steps in 76 seconds, the new version runs in 15 seconds.

Original:

Torch version: 2.8.0+cu128
CUDA available, version 12.8, device: NVIDIA GeForce RTX 4090
running without torch_compile+cudagraphs
energy: -79772.80223186754
Time[ps]      Etot[eV]     Epot[eV]     Ekin[eV]    T[K]
0.0000       -79772.802   -79772.802        0.000     0.0
1.0000       -79765.760   -79769.961        4.201   287.6
2.0000       -79766.232   -79770.634        4.402   301.4
3.0000       -79765.783   -79770.592        4.809   329.2
4.0000       -79766.131   -79770.321        4.191   286.9
5.0000       -79764.888   -79770.354        5.466   374.2
6.0000       -79765.719   -79770.835        5.116   350.3
7.0000       -79766.672   -79771.049        4.377   299.7
8.0000       -79766.566   -79770.707        4.141   283.5
9.0000       -79766.490   -79770.862        4.371   299.3
10.0000      -79767.403   -79771.511        4.109   281.3
Completed MD in 76.3 s (7.634 ms/step)

New with torch.compile(self.model, fullgraph=True, options={'triton.cudagraphs':True}):

Torch version: 2.8.0+cu128
CUDA available, version 12.8, device: NVIDIA GeForce RTX 4090
running with torch_compile+cudagraphs
energy: -79772.8125
Time[ps]      Etot[eV]     Epot[eV]     Ekin[eV]    T[K]
0.0000       -79772.812   -79772.812        0.000     0.0
1.0000       -79766.053   -79770.633        4.580   313.5
2.0000       -79766.829   -79770.547        3.718   254.5
3.0000       -79766.958   -79771.672        4.714   322.7
4.0000       -79766.121   -79770.664        4.543   311.0
5.0000       -79766.836   -79771.133        4.297   294.2
6.0000       -79765.915   -79770.078        4.164   285.1
7.0000       -79766.040   -79771.055        5.015   343.3
8.0000       -79766.454   -79770.727        4.273   292.5
9.0000       -79766.441   -79770.531        4.090   280.0
10.0000      -79766.390   -79771.016        4.625   316.7
Completed MD in 15.1 s (1.511 ms/step)

It is currently only implemented for nb_mode=0 and a single molecule.

The key required changes are to replace data dependent control flow with compile time constant control flow. Therefore, I have added a setup_for_compile_cudagraphs method to some modules to do this.

The feature is supported by the ASE calculator interface.

@giadefa
Copy link

giadefa commented Nov 11, 2025

great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants