Skip to content

Conversation

Pperezhogin
Copy link

This PR has two main goals:

  • Create a new ANN inference method in Real(4) precision ANN_apply_array_sio_r4 (1.5-2 times faster than ANN_apply_array_sio)
  • Optimize the code of MOM_Zanna_Bolton to take the most advantage of faster ANN inference

In the global ocean model OM4, the new algorithm takes less than 5% of runtime:

Tabulating mpp_clock statistics across   1125 PEs...

                                          tmin          tmax          tavg          tstd  tfrac grain pemin pemax
Total runtime                     12386.301825  12386.303514  12386.302729      0.000324  1.000     0     0  1124
...
(Ocean Zanna-Bolton-2020)           478.779163    739.875660    555.399389     35.328823  0.045    31     0  1124
...
(ZB2020 ANN inference)              248.012372    309.466150    300.482666      3.802576  0.024    41     0  1124
(ZB2020 ANN features)               102.508038    135.846642    126.231459      4.047266  0.010    41     0  1124
...
(ZB2020 MPI exchanges)               26.768119    289.738849     88.037939     41.579598  0.007    41     0  1124
...

Regression changed compared to the initial commit to dev/m2lines for a few reasons:

  • Different order of bias addition in MOM_ANN.F90 (introduced by Alistair)
  • Real(4) instead of Real(8)
  • Getting rid of the marching halo in compute_stress_ANN_collocated in favour of grouping all exchanges into one function call. This does not affect idealized simulations but affects the prediction on tripolar fold where rotational invariance of all expressions is required to have the same regression
  • Precomputing functions like 1/sqrt() in compute_stress_ANN_collocated

MOM_ANN.F90:
- Allocate, deallocate and initialize weights and biases in real(4)
- Add activation function in real(4)
- Add add inference in real(4)
- Add unit test

tiime_MOM_ANN.F90
- Add timing for new real(4) inference function
- Make inference simultaneously for each 2D slice in real(4) precision
- Get rid of three unnecessary 3D arrays, and preallocate one 3D array (Txy_h)
- Get rid of 6 unnecessary MPI exchanges and group the only three remaining
- Replace RESHAPE with a loop
- Fuse computation of norm and assembling of vector of input features into a single loop
- More accurate measure of timing
@Pperezhogin Pperezhogin requested a review from adcroft June 22, 2025 19:32
- Memory passed to subroutine and allocated on stack was different
@Pperezhogin Pperezhogin marked this pull request as draft June 25, 2025 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant