From 6e4729697d5e0b3343b8f19fd244aedfeadd4152 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Tue, 13 May 2025 11:23:11 -0600 Subject: [PATCH 01/30] Consolidating OpenACC device-host memory transfers This PR consolidates much of the OpenACC host and device data transfers during the course of the dynamical execution to two subroutines mpas_atm_pre_dynamics _h2d and mpas_atm_post_dynamics_d2h that are called before and after the call to atm_srk3 subroutine. Due to atm_compute_solve_diagnostics also being called once before the start of model run, we also have a pair of subroutines mpas_atm _pre_computesolvediag_h2d and mpas_atm_post_computesolvediag_d2h to handle data movements around the first call to atm_compute_solve_diagnostics. Any fields copied onto the device in these subroutines are removed from explicit data movement statements in the dynamical core. The mesh/time-invariant fields are still copied onto the device in mpas_atm_ dynamics_init and removed from the device in mpas_atm_dynamics_finalize, with the exception of select fields moved in mpas_atm_pre_computesolvediag_h2d and mpas_atm_post_computesolvediag_d2h. This is a special case due to atm_compute_ solve_diagnostics being called for the first time before the call to mpas_atm_ dynamics_init This PR also includes explicit host-device data transfers in the mpas_atm_iau, mpas_atmphys_interface and mpas_atmphys_todynamics modules to ensure that the physics and IAU regions, which run on CPU, use the latest values from the dynamical core running on GPUs, and vice versa. In addition, this PR also includes explicit data transfers around halo exchanges in the atm_srk3 subroutine. These subroutines for data routines, and the acc update statements are an interim solution until we have a book-keeping method in place. This PR also introduces a couple of new timers to keep track of the cost of data transfers. --- .../dynamics/mpas_atm_boundaries.F | 32 - src/core_atmosphere/dynamics/mpas_atm_iau.F | 14 + .../dynamics/mpas_atm_time_integration.F | 1359 ++++++++++++----- src/core_atmosphere/mpas_atm_core.F | 5 +- .../physics/mpas_atmphys_interface.F | 20 + .../physics/mpas_atmphys_todynamics.F | 17 + 6 files changed, 1071 insertions(+), 376 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_boundaries.F b/src/core_atmosphere/dynamics/mpas_atm_boundaries.F index 787e7719a1..6c19ed7931 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_boundaries.F +++ b/src/core_atmosphere/dynamics/mpas_atm_boundaries.F @@ -395,18 +395,14 @@ subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t nullify(tend) call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1) - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_tend [ACC_data_xfer]') if (associated(tend)) then - !$acc enter data copyin(tend) else call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1) - !$acc enter data copyin(tend_scalars) ! Ensure the integer pointed to by idx_ptr is copied to the gpu device call mpas_pool_get_dimension(lbc, 'index_'//trim(field), idx_ptr) idx = idx_ptr end if - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_tend [ACC_data_xfer]') !$acc parallel default(present) if (associated(tend)) then @@ -426,13 +422,6 @@ subroutine mpas_atm_get_bdy_tend(clock, block, vertDim, horizDim, field, delta_t end if !$acc end parallel - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_tend [ACC_data_xfer]') - if (associated(tend)) then - !$acc exit data delete(tend) - else - !$acc exit data delete(tend_scalars) - end if - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_tend [ACC_data_xfer]') end subroutine mpas_atm_get_bdy_tend @@ -533,9 +522,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del ! query the field as a scalar constituent ! if (associated(tend) .and. associated(state)) then - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') - !$acc enter data copyin(tend, state) - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang vector collapse(2) @@ -546,9 +532,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del end do !$acc end parallel - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') - !$acc exit data delete(tend, state) - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') else call mpas_pool_get_array(lbc, 'lbc_scalars', tend_scalars, 1) call mpas_pool_get_array(lbc, 'lbc_scalars', state_scalars, 2) @@ -556,10 +539,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del idx=idx_ptr ! Avoid non-array pointer for OpenACC - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') - !$acc enter data copyin(tend_scalars, state_scalars) - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') - !$acc parallel default(present) !$acc loop gang vector collapse(2) do i=1, horizDim+1 @@ -569,9 +548,6 @@ subroutine mpas_atm_get_bdy_state_2d(clock, block, vertDim, horizDim, field, del end do !$acc end parallel - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') - !$acc exit data delete(tend_scalars, state_scalars) - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_2d [ACC_data_xfer]') end if end subroutine mpas_atm_get_bdy_state_2d @@ -652,10 +628,6 @@ subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, call mpas_pool_get_array(lbc, 'lbc_'//trim(field), tend, 1) call mpas_pool_get_array(lbc, 'lbc_'//trim(field), state, 2) - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]') - !$acc enter data copyin(tend, state) - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]') - !$acc parallel default(present) !$acc loop gang vector collapse(3) do i=1, horizDim+1 @@ -667,10 +639,6 @@ subroutine mpas_atm_get_bdy_state_3d(clock, block, innerDim, vertDim, horizDim, end do !$acc end parallel - MPAS_ACC_TIMER_START('mpas_atm_get_bdy_state_3d [ACC_data_xfer]') - !$acc exit data delete(tend, state) - MPAS_ACC_TIMER_STOP('mpas_atm_get_bdy_state_3d [ACC_data_xfer]') - end subroutine mpas_atm_get_bdy_state_3d diff --git a/src/core_atmosphere/dynamics/mpas_atm_iau.F b/src/core_atmosphere/dynamics/mpas_atm_iau.F index 654fd3ae82..b380e3c0e8 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_iau.F +++ b/src/core_atmosphere/dynamics/mpas_atm_iau.F @@ -13,9 +13,20 @@ module mpas_atm_iau use mpas_dmpar use mpas_constants use mpas_log, only : mpas_log_write + use mpas_timer !public :: atm_compute_iau_coef, atm_add_tend_anal_incr + + #ifdef MPAS_OPENACC + #define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) + #define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) + #else + #define MPAS_ACC_TIMER_START(X) + #define MPAS_ACC_TIMER_STOP(X) + #endif + + contains !================================================================================================== @@ -137,6 +148,7 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten call mpas_pool_get_array(state, 'scalars', scalars, 1) call mpas_pool_get_array(state, 'rho_zz', rho_zz, 2) call mpas_pool_get_array(diag , 'rho_edge', rho_edge) + !$acc update self(theta_m, scalars, rho_zz, rho_edge) call mpas_pool_get_dimension(state, 'moist_start', moist_start) call mpas_pool_get_dimension(state, 'moist_end', moist_end) @@ -149,6 +161,8 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten ! call mpas_pool_get_array(tend, 'rho_zz', tend_rho) ! call mpas_pool_get_array(tend, 'theta_m', tend_theta) call mpas_pool_get_array(tend, 'scalars_tend', tend_scalars) + !$acc update self(tend_scalars) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(tend_iau, 'theta', theta_amb) call mpas_pool_get_array(tend_iau, 'rho', rho_amb) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 4fe2faefc4..de3565637b 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -272,6 +272,8 @@ subroutine mpas_atm_dynamics_init(domain) real (kind=RKIND), dimension(:), pointer :: angleEdge real (kind=RKIND), dimension(:), pointer :: meshScalingDel2 real (kind=RKIND), dimension(:), pointer :: meshScalingDel4 + real (kind=RKIND), dimension(:), pointer :: u_init, v_init, qv_init + real (kind=RKIND), dimension(:,:), pointer :: t_init #endif #ifdef MPAS_CAM_DYCORE @@ -292,6 +294,7 @@ subroutine mpas_atm_dynamics_init(domain) nullify(mesh) call mpas_pool_get_subpool(domain % blocklist % structs, 'mesh', mesh) + MPAS_ACC_TIMER_START('mpas_dynamics_init [ACC_data_xfer]') call mpas_pool_get_array(mesh, 'dvEdge', dvEdge) !$acc enter data copyin(dvEdge) @@ -456,9 +459,904 @@ subroutine mpas_atm_dynamics_init(domain) call mpas_pool_get_array(mesh, 'meshScalingDel4', meshScalingDel4) !$acc enter data copyin(meshScalingDel4) + + call mpas_pool_get_array(mesh, 'u_init', u_init) + !$acc enter data copyin(u_init) + call mpas_pool_get_array(mesh, 'v_init', v_init) + !$acc enter data copyin(v_init) + call mpas_pool_get_array(mesh, 't_init', t_init) + !$acc enter data copyin(t_init) + call mpas_pool_get_array(mesh, 'qv_init', qv_init) + !$acc enter data copyin(qv_init) + + MPAS_ACC_TIMER_STOP('mpas_dynamics_init [ACC_data_xfer]') +#endif + + end subroutine mpas_atm_dynamics_init + + subroutine mpas_atm_pre_computesolvediag_h2d(block) + + implicit none + + type (block_type), intent(inout) :: block + + +#ifdef MPAS_OPENACC + type (mpas_pool_type), pointer :: mesh + type (mpas_pool_type), pointer :: diag + type (mpas_pool_type), pointer :: state + type (mpas_pool_type), pointer :: tend_physics + real (kind=RKIND), dimension(:,:), pointer :: rthdynten + + real (kind=RKIND), dimension(:,:), pointer :: h_edge, v, vorticity, ke, pv_edge, & + pv_vertex, pv_cell, gradPVn, gradPVt, divergence + real (kind=RKIND), dimension(:,:), pointer :: u, h + + real (kind=RKIND), dimension(:,:), pointer :: zz + real (kind=RKIND), dimension(:,:,:), pointer :: zb_cell + real (kind=RKIND), dimension(:,:,:), pointer :: zb3_cell + real (kind=RKIND), dimension(:), pointer :: fzm + real (kind=RKIND), dimension(:), pointer :: fzp + real (kind=RKIND), dimension(:,:,:), pointer :: zb + real (kind=RKIND), dimension(:,:,:), pointer :: zb3 + + + real (kind=RKIND), dimension(:), pointer :: dvEdge + integer, dimension(:,:), pointer :: cellsOnCell + integer, dimension(:,:), pointer :: cellsOnEdge + integer, dimension(:,:), pointer :: advCellsForEdge + integer, dimension(:,:), pointer :: edgesOnCell + integer, dimension(:), pointer :: nAdvCellsForEdge + integer, dimension(:), pointer :: nEdgesOnCell + real (kind=RKIND), dimension(:,:), pointer :: adv_coefs + real (kind=RKIND), dimension(:,:), pointer :: adv_coefs_3rd + real (kind=RKIND), dimension(:,:), pointer :: edgesOnCell_sign + real (kind=RKIND), dimension(:), pointer :: invAreaCell + integer, dimension(:), pointer :: bdyMaskCell + integer, dimension(:), pointer :: bdyMaskEdge + real (kind=RKIND), dimension(:), pointer :: specZoneMaskEdge + real (kind=RKIND), dimension(:), pointer :: invDvEdge + real (kind=RKIND), dimension(:), pointer :: dcEdge + real (kind=RKIND), dimension(:), pointer :: invDcEdge + integer, dimension(:,:), pointer :: edgesOnEdge + integer, dimension(:,:), pointer :: edgesOnVertex + real (kind=RKIND), dimension(:,:), pointer :: edgesOnVertex_sign + integer, dimension(:), pointer :: nEdgesOnEdge + real (kind=RKIND), dimension(:,:), pointer :: weightsOnEdge + integer, dimension(:,:), pointer :: cellsOnVertex + integer, dimension(:,:), pointer :: verticesOnCell + integer, dimension(:,:), pointer :: verticesOnEdge + real (kind=RKIND), dimension(:), pointer :: invAreaTriangle + integer, dimension(:,:), pointer :: kiteForCell + real (kind=RKIND), dimension(:,:), pointer :: kiteAreasOnVertex + real (kind=RKIND), dimension(:), pointer :: fEdge + real (kind=RKIND), dimension(:), pointer :: fVertex + + nullify(mesh) + call mpas_pool_get_subpool(block % structs, 'mesh', mesh) + nullify(state) + call mpas_pool_get_subpool(block % structs, 'state', state) + nullify(diag) + call mpas_pool_get_subpool(block % structs, 'diag', diag) + + MPAS_ACC_TIMER_START('first_compute_solve_diagnostics [ACC_data_xfer]') + call mpas_pool_get_array(state, 'rho_zz', h, 1) + !$acc enter data create(h) + call mpas_pool_get_array(state, 'u', u, 1) + !$acc enter data copyin(u) + + call mpas_pool_get_array(diag, 'v', v) + !$acc enter data copyin(v) + call mpas_pool_get_array(diag, 'rho_edge', h_edge) + !$acc enter data copyin(h_edge) + call mpas_pool_get_array(diag, 'vorticity', vorticity) + !$acc enter data copyin(vorticity) + call mpas_pool_get_array(diag, 'divergence', divergence) + !$acc enter data copyin(divergence) + call mpas_pool_get_array(diag, 'ke', ke) + !$acc enter data copyin(ke) + call mpas_pool_get_array(diag, 'pv_edge', pv_edge) + !$acc enter data copyin(pv_edge) + call mpas_pool_get_array(diag, 'pv_vertex', pv_vertex) + !$acc enter data copyin(pv_vertex) + call mpas_pool_get_array(diag, 'pv_cell', pv_cell) + !$acc enter data copyin(pv_cell) + call mpas_pool_get_array(diag, 'gradPVn', gradPVn) + !$acc enter data copyin(gradPVn) + call mpas_pool_get_array(diag, 'gradPVt', gradPVt) + !$acc enter data copyin(gradPVt) + + ! Required by atm_init_coupled_diagnostics + call mpas_pool_get_array(mesh, 'zz', zz) + !$acc enter data copyin(zz) + + call mpas_pool_get_array(mesh, 'zb_cell', zb_cell) + !$acc enter data copyin(zb_cell) + + call mpas_pool_get_array(mesh, 'zb3_cell', zb3_cell) + !$acc enter data copyin(zb3_cell) + + call mpas_pool_get_array(mesh, 'fzm', fzm) + !$acc enter data copyin(fzm) + + call mpas_pool_get_array(mesh, 'fzp', fzp) + !$acc enter data copyin(fzp) + + call mpas_pool_get_array(mesh, 'zb', zb) + !$acc enter data copyin(zb) + + call mpas_pool_get_array(mesh, 'zb3', zb3) + !$acc enter data copyin(zb3) + + ! Required by atm_compute_solve_diagnostics + call mpas_pool_get_array(mesh, 'dvEdge', dvEdge) + !$acc enter data copyin(dvEdge) + + call mpas_pool_get_array(mesh, 'cellsOnEdge', cellsOnEdge) + !$acc enter data copyin(cellsOnEdge) + + call mpas_pool_get_array(mesh, 'edgesOnCell', edgesOnCell) + !$acc enter data copyin(edgesOnCell) + + call mpas_pool_get_array(mesh, 'nEdgesOnCell', nEdgesOnCell) + !$acc enter data copyin(nEdgesOnCell) + + call mpas_pool_get_array(mesh, 'edgesOnCell_sign', edgesOnCell_sign) + !$acc enter data copyin(edgesOnCell_sign) + + call mpas_pool_get_array(mesh, 'invAreaCell', invAreaCell) + !$acc enter data copyin(invAreaCell) + + call mpas_pool_get_array(mesh, 'invDvEdge', invDvEdge) + !$acc enter data copyin(invDvEdge) + + call mpas_pool_get_array(mesh, 'dcEdge', dcEdge) + !$acc enter data copyin(dcEdge) + + call mpas_pool_get_array(mesh, 'invDcEdge', invDcEdge) + !$acc enter data copyin(invDcEdge) + + call mpas_pool_get_array(mesh, 'edgesOnEdge', edgesOnEdge) + !$acc enter data copyin(edgesOnEdge) + + call mpas_pool_get_array(mesh, 'edgesOnVertex', edgesOnVertex) + !$acc enter data copyin(edgesOnVertex) + + call mpas_pool_get_array(mesh, 'edgesOnVertex_sign', edgesOnVertex_sign) + !$acc enter data copyin(edgesOnVertex_sign) + + call mpas_pool_get_array(mesh, 'nEdgesOnEdge', nEdgesOnEdge) + !$acc enter data copyin(nEdgesOnEdge) + + call mpas_pool_get_array(mesh, 'weightsOnEdge', weightsOnEdge) + !$acc enter data copyin(weightsOnEdge) + + call mpas_pool_get_array(mesh, 'verticesOnCell', verticesOnCell) + !$acc enter data copyin(verticesOnCell) + + call mpas_pool_get_array(mesh, 'verticesOnEdge', verticesOnEdge) + !$acc enter data copyin(verticesOnEdge) + + call mpas_pool_get_array(mesh, 'invAreaTriangle', invAreaTriangle) + !$acc enter data copyin(invAreaTriangle) + + call mpas_pool_get_array(mesh, 'kiteForCell', kiteForCell) + !$acc enter data copyin(kiteForCell) + + call mpas_pool_get_array(mesh, 'kiteAreasOnVertex', kiteAreasOnVertex) + !$acc enter data copyin(kiteAreasOnVertex) + + call mpas_pool_get_array(mesh, 'fVertex', fVertex) + !$acc enter data copyin(fVertex) + + MPAS_ACC_TIMER_STOP('first_compute_solve_diagnostics [ACC_data_xfer]') +#endif + + end subroutine mpas_atm_pre_computesolvediag_h2d + + + subroutine mpas_atm_post_computesolvediag_d2h(block) + + implicit none + + type (block_type), intent(inout) :: block + + +#ifdef MPAS_OPENACC + type (mpas_pool_type), pointer :: mesh + type (mpas_pool_type), pointer :: diag + type (mpas_pool_type), pointer :: state + type (mpas_pool_type), pointer :: tend_physics + real (kind=RKIND), dimension(:,:), pointer :: rthdynten + + real (kind=RKIND), dimension(:,:), pointer :: h_edge, v, vorticity, ke, pv_edge, & + pv_vertex, pv_cell, gradPVn, gradPVt, divergence + real (kind=RKIND), dimension(:,:), pointer :: u, h + + real (kind=RKIND), dimension(:,:), pointer :: zz + real (kind=RKIND), dimension(:,:,:), pointer :: zb_cell + real (kind=RKIND), dimension(:,:,:), pointer :: zb3_cell + real (kind=RKIND), dimension(:), pointer :: fzm + real (kind=RKIND), dimension(:), pointer :: fzp + real (kind=RKIND), dimension(:,:,:), pointer :: zb + real (kind=RKIND), dimension(:,:,:), pointer :: zb3 + + + real (kind=RKIND), dimension(:), pointer :: dvEdge + integer, dimension(:,:), pointer :: cellsOnCell + integer, dimension(:,:), pointer :: cellsOnEdge + integer, dimension(:,:), pointer :: advCellsForEdge + integer, dimension(:,:), pointer :: edgesOnCell + integer, dimension(:), pointer :: nAdvCellsForEdge + integer, dimension(:), pointer :: nEdgesOnCell + real (kind=RKIND), dimension(:,:), pointer :: adv_coefs + real (kind=RKIND), dimension(:,:), pointer :: adv_coefs_3rd + real (kind=RKIND), dimension(:,:), pointer :: edgesOnCell_sign + real (kind=RKIND), dimension(:), pointer :: invAreaCell + integer, dimension(:), pointer :: bdyMaskCell + integer, dimension(:), pointer :: bdyMaskEdge + real (kind=RKIND), dimension(:), pointer :: specZoneMaskEdge + real (kind=RKIND), dimension(:), pointer :: invDvEdge + real (kind=RKIND), dimension(:), pointer :: dcEdge + real (kind=RKIND), dimension(:), pointer :: invDcEdge + integer, dimension(:,:), pointer :: edgesOnEdge + integer, dimension(:,:), pointer :: edgesOnVertex + real (kind=RKIND), dimension(:,:), pointer :: edgesOnVertex_sign + integer, dimension(:), pointer :: nEdgesOnEdge + real (kind=RKIND), dimension(:,:), pointer :: weightsOnEdge + integer, dimension(:,:), pointer :: cellsOnVertex + integer, dimension(:,:), pointer :: verticesOnCell + integer, dimension(:,:), pointer :: verticesOnEdge + real (kind=RKIND), dimension(:), pointer :: invAreaTriangle + integer, dimension(:,:), pointer :: kiteForCell + real (kind=RKIND), dimension(:,:), pointer :: kiteAreasOnVertex + real (kind=RKIND), dimension(:), pointer :: fEdge + real (kind=RKIND), dimension(:), pointer :: fVertex + + nullify(mesh) + call mpas_pool_get_subpool(block % structs, 'mesh', mesh) + nullify(state) + call mpas_pool_get_subpool(block % structs, 'state', state) + nullify(diag) + call mpas_pool_get_subpool(block % structs, 'diag', diag) + + MPAS_ACC_TIMER_START('first_compute_solve_diagnostics [ACC_data_xfer]') + + call mpas_pool_get_array(state, 'rho_zz', h, 1) + !$acc exit data copyout(h) + call mpas_pool_get_array(state, 'u', u, 1) + !$acc exit data copyout(u) + + call mpas_pool_get_array(diag, 'v', v) + !$acc exit data copyout(v) + call mpas_pool_get_array(diag, 'rho_edge', h_edge) + !$acc exit data copyout(h_edge) + call mpas_pool_get_array(diag, 'vorticity', vorticity) + !$acc exit data copyout(vorticity) + call mpas_pool_get_array(diag, 'divergence', divergence) + !$acc exit data copyout(divergence) + call mpas_pool_get_array(diag, 'ke', ke) + !$acc exit data copyout(ke) + call mpas_pool_get_array(diag, 'pv_edge', pv_edge) + !$acc exit data copyout(pv_edge) + call mpas_pool_get_array(diag, 'pv_vertex', pv_vertex) + !$acc exit data copyout(pv_vertex) + call mpas_pool_get_array(diag, 'pv_cell', pv_cell) + !$acc exit data copyout(pv_cell) + call mpas_pool_get_array(diag, 'gradPVn', gradPVn) + !$acc exit data copyout(gradPVn) + call mpas_pool_get_array(diag, 'gradPVt', gradPVt) + !$acc exit data copyout(gradPVt) + + ! Required by atm_init_coupled_diagnostics + call mpas_pool_get_array(mesh, 'zz', zz) + !$acc exit data delete(zz) + + call mpas_pool_get_array(mesh, 'zb_cell', zb_cell) + !$acc exit data delete(zb_cell) + + call mpas_pool_get_array(mesh, 'zb3_cell', zb3_cell) + !$acc exit data delete(zb3_cell) + + call mpas_pool_get_array(mesh, 'fzm', fzm) + !$acc exit data delete(fzm) + + call mpas_pool_get_array(mesh, 'fzp', fzp) + !$acc exit data delete(fzp) + + call mpas_pool_get_array(mesh, 'zb', zb) + !$acc exit data delete(zb) + + call mpas_pool_get_array(mesh, 'zb3', zb3) + !$acc exit data delete(zb3) + + + call mpas_pool_get_array(mesh, 'dvEdge', dvEdge) + !$acc exit data delete(dvEdge) + + call mpas_pool_get_array(mesh, 'cellsOnEdge', cellsOnEdge) + !$acc exit data delete(cellsOnEdge) + + call mpas_pool_get_array(mesh, 'edgesOnCell', edgesOnCell) + !$acc exit data delete(edgesOnCell) + + call mpas_pool_get_array(mesh, 'nEdgesOnCell', nEdgesOnCell) + !$acc exit data delete(nEdgesOnCell) + + call mpas_pool_get_array(mesh, 'edgesOnCell_sign', edgesOnCell_sign) + !$acc exit data delete(edgesOnCell_sign) + + call mpas_pool_get_array(mesh, 'invAreaCell', invAreaCell) + !$acc exit data delete(invAreaCell) + + call mpas_pool_get_array(mesh, 'invDvEdge', invDvEdge) + !$acc exit data delete(invDvEdge) + + call mpas_pool_get_array(mesh, 'dcEdge', dcEdge) + !$acc exit data delete(dcEdge) + + call mpas_pool_get_array(mesh, 'invDcEdge', invDcEdge) + !$acc exit data delete(invDcEdge) + + call mpas_pool_get_array(mesh, 'edgesOnEdge', edgesOnEdge) + !$acc exit data delete(edgesOnEdge) + + call mpas_pool_get_array(mesh, 'edgesOnVertex', edgesOnVertex) + !$acc exit data delete(edgesOnVertex) + + call mpas_pool_get_array(mesh, 'edgesOnVertex_sign', edgesOnVertex_sign) + !$acc exit data delete(edgesOnVertex_sign) + + call mpas_pool_get_array(mesh, 'nEdgesOnEdge', nEdgesOnEdge) + !$acc exit data delete(nEdgesOnEdge) + + call mpas_pool_get_array(mesh, 'weightsOnEdge', weightsOnEdge) + !$acc exit data delete(weightsOnEdge) + + call mpas_pool_get_array(mesh, 'verticesOnCell', verticesOnCell) + !$acc exit data delete(verticesOnCell) + + call mpas_pool_get_array(mesh, 'verticesOnEdge', verticesOnEdge) + !$acc exit data delete(verticesOnEdge) + + call mpas_pool_get_array(mesh, 'invAreaTriangle', invAreaTriangle) + !$acc exit data delete(invAreaTriangle) + + call mpas_pool_get_array(mesh, 'kiteForCell', kiteForCell) + !$acc exit data delete(kiteForCell) + + call mpas_pool_get_array(mesh, 'kiteAreasOnVertex', kiteAreasOnVertex) + !$acc exit data delete(kiteAreasOnVertex) + + call mpas_pool_get_array(mesh, 'fVertex', fVertex) + !$acc exit data delete(fVertex) + + MPAS_ACC_TIMER_STOP('first_compute_solve_diagnostics [ACC_data_xfer]') +#endif + + end subroutine mpas_atm_post_computesolvediag_d2h + + subroutine mpas_atm_pre_dynamics_h2d(domain) + + implicit none + + type (domain_type), intent(inout) :: domain + + +#ifdef MPAS_OPENACC + type (mpas_pool_type), pointer :: state + type (mpas_pool_type), pointer :: diag + type (mpas_pool_type), pointer :: tend + type (mpas_pool_type), pointer :: tend_physics + type (mpas_pool_type), pointer :: lbc + + + real (kind=RKIND), dimension(:,:), pointer :: ru, ru_p + real (kind=RKIND), dimension(:,:), pointer :: ru_save + real (kind=RKIND), dimension(:,:), pointer :: rw, rw_p + real (kind=RKIND), dimension(:,:), pointer :: rw_save + real (kind=RKIND), dimension(:,:), pointer :: rtheta_p + real (kind=RKIND), dimension(:,:), pointer :: exner, exner_base + real (kind=RKIND), dimension(:,:), pointer :: rtheta_base, rho_base + real (kind=RKIND), dimension(:,:), pointer :: rtheta_p_save + real (kind=RKIND), dimension(:,:), pointer :: rho_p, rho_pp, rho, theta, theta_base + real (kind=RKIND), dimension(:,:), pointer :: rho_p_save + real (kind=RKIND), dimension(:,:), pointer :: rho_zz_old_split + real (kind=RKIND), dimension(:,:), pointer :: cqw, rtheta_pp_old, rtheta_pp + real (kind=RKIND), dimension(:,:), pointer :: cqu, pressure_base, pressure_p, pressure, v + real (kind=RKIND), dimension(:,:), pointer :: kdiff, pv_edge, pv_vertex, pv_cell, rho_edge, h_divergence, ke + real (kind=RKIND), dimension(:,:), pointer :: cofwr, cofwz, coftz, cofwt, a_tri, alpha_tri, gamma_tri + real (kind=RKIND), dimension(:), pointer :: cofrz + real (kind=RKIND), dimension(:,:), pointer :: gradPVn, gradPVt + + + real (kind=RKIND), dimension(:,:), pointer :: u_1, u_2 + real (kind=RKIND), dimension(:,:), pointer :: w_1, w_2 + real (kind=RKIND), dimension(:,:), pointer :: theta_m_1, theta_m_2 + real (kind=RKIND), dimension(:,:), pointer :: rho_zz_1, rho_zz_2 + real (kind=RKIND), dimension(:,:,:), pointer :: scalars_1, scalars_2 + real (kind=RKIND), dimension(:,:), pointer :: ruAvg, wwAvg, ruAvg_split, wwAvg_split + + real (kind=RKIND), dimension(:,:), pointer :: tend_ru, tend_rt, tend_rho, tend_rw, rt_diabatic_tend + real (kind=RKIND), dimension(:,:), pointer :: tend_u_euler, tend_w_euler, tend_theta_euler + real(kind=RKIND), dimension(:,:), pointer :: tend_w_pgf, tend_w_buoy + real(kind=RKIND), dimension(:,:,:), pointer :: scalar_tend_save + + real (kind=RKIND), dimension(:,:), pointer :: rthdynten, divergence, vorticity + + real (kind=RKIND), dimension(:,:), pointer :: lbc_u, lbc_w, lbc_ru, lbc_rho_edge, lbc_rho, lbc_rtheta_m, lbc_rho_zz, lbc_theta + real (kind=RKIND), dimension(:,:), pointer :: lbc_tend_u, lbc_tend_w, lbc_tend_ru, lbc_tend_rho_edge, lbc_tend_rho + real (kind=RKIND), dimension(:,:), pointer :: lbc_tend_rtheta_m, lbc_tend_rho_zz, lbc_tend_theta + + real (kind=RKIND), dimension(:,:,:), pointer :: lbc_scalars, lbc_tend_scalars + + nullify(state) + nullify(diag) + nullify(tend) + nullify(tend_physics) + nullify(lbc) + call mpas_pool_get_subpool(domain % blocklist % structs, 'state', state) + call mpas_pool_get_subpool(domain % blocklist % structs, 'diag', diag) + call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tend) + call mpas_pool_get_subpool(domain % blocklist % structs, 'tend_physics', tend_physics) + call mpas_pool_get_subpool(domain % blocklist % structs, 'lbc', lbc) + + MPAS_ACC_TIMER_START('atm_srk3 [ACC_data_xfer]') + call mpas_pool_get_array(diag, 'ru', ru) + !$acc enter data copyin(ru) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'ru_p', ru_p) + !$acc enter data copyin(ru_p) + call mpas_pool_get_array(diag, 'ru_save', ru_save) + !$acc enter data copyin(ru_save) + call mpas_pool_get_array(diag, 'rw', rw) + !$acc enter data copyin(rw) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rw_p', rw_p) + !$acc enter data copyin(rw_p) + call mpas_pool_get_array(diag, 'rw_save', rw_save) + !$acc enter data copyin(rw_save) + call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) + !$acc enter data copyin(rtheta_p) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rtheta_p_save', rtheta_p_save) + !$acc enter data copyin(rtheta_p_save) + call mpas_pool_get_array(diag, 'exner', exner) + !$acc enter data copyin(exner) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'exner_base', exner_base) + !$acc enter data copyin(exner_base) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rtheta_base', rtheta_base) + !$acc enter data copyin(rtheta_base) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rho_base', rho_base) + !$acc enter data copyin(rho_base) + call mpas_pool_get_array(diag, 'rho', rho) + !$acc enter data copyin(rho) + call mpas_pool_get_array(diag, 'theta', theta) + !$acc enter data copyin(theta) + call mpas_pool_get_array(diag, 'theta_base', theta_base) + !$acc enter data copyin(theta_base) + call mpas_pool_get_array(diag, 'rho_p', rho_p) + !$acc enter data copyin(rho_p) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rho_p_save', rho_p_save) + !$acc enter data copyin(rho_p_save) + call mpas_pool_get_array(diag, 'rho_pp', rho_pp) + !$acc enter data copyin(rho_pp) + call mpas_pool_get_array(diag, 'rho_zz_old_split', rho_zz_old_split) + !$acc enter data copyin(rho_zz_old_split) + call mpas_pool_get_array(diag, 'cqw', cqw) + !$acc enter data copyin(cqw) + call mpas_pool_get_array(diag, 'cqu', cqu) + !$acc enter data copyin(cqu) + call mpas_pool_get_array(diag, 'pressure_p', pressure_p) + !$acc enter data copyin(pressure_p) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'pressure_base', pressure_base) + !$acc enter data copyin(pressure_base) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'pressure', pressure) + !$acc enter data copyin(pressure) + call mpas_pool_get_array(diag, 'v', v) + !$acc enter data copyin(v) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) + !$acc enter data copyin(rtheta_pp) + call mpas_pool_get_array(diag, 'rtheta_pp_old', rtheta_pp_old) + !$acc enter data copyin(rtheta_pp_old) + call mpas_pool_get_array(diag, 'kdiff', kdiff) + !$acc enter data copyin(kdiff) + call mpas_pool_get_array(diag, 'pv_edge', pv_edge) + !$acc enter data copyin(pv_edge) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'pv_vertex', pv_vertex) + !$acc enter data copyin(pv_vertex) + call mpas_pool_get_array(diag, 'pv_cell', pv_cell) + !$acc enter data copyin(pv_cell) + call mpas_pool_get_array(diag, 'rho_edge', rho_edge) + !$acc enter data copyin(rho_edge) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'h_divergence', h_divergence) + !$acc enter data copyin(h_divergence) + call mpas_pool_get_array(diag, 'ke', ke) + !$acc enter data copyin(ke) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'gradPVn', gradPVn) + !$acc enter data copyin(gradPVn) + call mpas_pool_get_array(diag, 'gradPVt', gradPVt) + !$acc enter data copyin(gradPVt) + + call mpas_pool_get_array(diag, 'alpha_tri', alpha_tri) + !$acc enter data copyin(alpha_tri) + call mpas_pool_get_array(diag, 'gamma_tri', gamma_tri) + !$acc enter data copyin(gamma_tri) + call mpas_pool_get_array(diag, 'a_tri', a_tri) + !$acc enter data copyin(a_tri) + call mpas_pool_get_array(diag, 'cofwr', cofwr) + !$acc enter data copyin(cofwr) + call mpas_pool_get_array(diag, 'cofwz', cofwz) + !$acc enter data copyin(cofwz) + call mpas_pool_get_array(diag, 'coftz', coftz) + !$acc enter data copyin(coftz) + call mpas_pool_get_array(diag, 'cofwt', cofwt) + !$acc enter data copyin(cofwt) + call mpas_pool_get_array(diag, 'cofrz', cofrz) + !$acc enter data copyin(cofrz) + call mpas_pool_get_array(diag, 'vorticity', vorticity) + !$acc enter data copyin(vorticity) + call mpas_pool_get_array(diag, 'divergence', divergence) + !$acc enter data copyin(divergence) + call mpas_pool_get_array(diag, 'ruAvg', ruAvg) + !$acc enter data copyin(ruAvg) + call mpas_pool_get_array(diag, 'ruAvg_split', ruAvg_split) + !$acc enter data copyin(ruAvg_split) + call mpas_pool_get_array(diag, 'wwAvg', wwAvg) + !$acc enter data copyin(wwAvg) + call mpas_pool_get_array(diag, 'wwAvg_split', wwAvg_split) + !$acc enter data copyin(wwAvg_split) + + call mpas_pool_get_array(state, 'u', u_1, 1) + !$acc enter data copyin(u_1) + call mpas_pool_get_array(state, 'u', u_2, 2) + !$acc enter data copyin(u_2) + call mpas_pool_get_array(state, 'w', w_1, 1) + !$acc enter data copyin(w_1) + call mpas_pool_get_array(state, 'w', w_2, 2) + !$acc enter data copyin(w_2) + call mpas_pool_get_array(state, 'theta_m', theta_m_1, 1) + !$acc enter data copyin(theta_m_1) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(state, 'theta_m', theta_m_2, 2) + !$acc enter data copyin(theta_m_2) + call mpas_pool_get_array(state, 'rho_zz', rho_zz_1, 1) + !$acc enter data copyin(rho_zz_1) + call mpas_pool_get_array(state, 'rho_zz', rho_zz_2, 2) + !$acc enter data copyin(rho_zz_2) + call mpas_pool_get_array(state, 'scalars', scalars_1, 1) + !$acc enter data copyin(scalars_1) + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc enter data copyin(scalars_2) + + + call mpas_pool_get_array(tend, 'u', tend_ru) + !$acc enter data copyin(tend_ru) + call mpas_pool_get_array(tend, 'rho_zz', tend_rho) + !$acc enter data copyin(tend_rho) + call mpas_pool_get_array(tend, 'theta_m', tend_rt) + !$acc enter data copyin(tend_rt) + call mpas_pool_get_array(tend, 'w', tend_rw) + !$acc enter data copyin(tend_rw) + call mpas_pool_get_array(tend, 'rt_diabatic_tend', rt_diabatic_tend) + !$acc enter data copyin(rt_diabatic_tend) + call mpas_pool_get_array(tend, 'u_euler', tend_u_euler) + !$acc enter data copyin(tend_u_euler) + call mpas_pool_get_array(tend, 'theta_euler', tend_theta_euler) + !$acc enter data copyin(tend_theta_euler) + call mpas_pool_get_array(tend, 'w_euler', tend_w_euler) + !$acc enter data copyin(tend_w_euler) + call mpas_pool_get_array(tend, 'w_pgf', tend_w_pgf) + !$acc enter data copyin(tend_w_pgf) + call mpas_pool_get_array(tend, 'w_buoy', tend_w_buoy) + !$acc enter data copyin(tend_w_buoy) + call mpas_pool_get_array(tend, 'scalars_tend', scalar_tend_save) + !$acc enter data copyin(scalar_tend_save) + + + call mpas_pool_get_array(lbc, 'lbc_u', lbc_u, 2) + !$acc enter data copyin(lbc_u) + call mpas_pool_get_array(lbc, 'lbc_w', lbc_w, 2) + !$acc enter data copyin(lbc_w) + call mpas_pool_get_array(lbc, 'lbc_ru', lbc_ru, 2) + !$acc enter data copyin(lbc_ru) + call mpas_pool_get_array(lbc, 'lbc_rho_edge', lbc_rho_edge, 2) + !$acc enter data copyin(lbc_rho_edge) + call mpas_pool_get_array(lbc, 'lbc_theta', lbc_theta, 2) + !$acc enter data copyin(lbc_theta) + call mpas_pool_get_array(lbc, 'lbc_rtheta_m', lbc_rtheta_m, 2) + !$acc enter data copyin(lbc_rtheta_m) + call mpas_pool_get_array(lbc, 'lbc_rho_zz', lbc_rho_zz, 2) + !$acc enter data copyin(lbc_rho_zz) + call mpas_pool_get_array(lbc, 'lbc_rho', lbc_rho, 2) + !$acc enter data copyin(lbc_rho) + call mpas_pool_get_array(lbc, 'lbc_scalars', lbc_scalars, 2) + !$acc enter data copyin(lbc_scalars) + + + call mpas_pool_get_array(lbc, 'lbc_u', lbc_tend_u, 1) + !$acc enter data copyin(lbc_tend_u) + call mpas_pool_get_array(lbc, 'lbc_ru', lbc_tend_ru, 1) + !$acc enter data copyin(lbc_tend_ru) + call mpas_pool_get_array(lbc, 'lbc_rho_edge', lbc_tend_rho_edge, 1) + !$acc enter data copyin(lbc_tend_rho_edge) + call mpas_pool_get_array(lbc, 'lbc_w', lbc_tend_w, 1) + !$acc enter data copyin(lbc_tend_w) + call mpas_pool_get_array(lbc, 'lbc_theta', lbc_tend_theta, 1) + !$acc enter data copyin(lbc_tend_theta) + call mpas_pool_get_array(lbc, 'lbc_rtheta_m', lbc_tend_rtheta_m, 1) + !$acc enter data copyin(lbc_tend_rtheta_m) + call mpas_pool_get_array(lbc, 'lbc_rho_zz', lbc_tend_rho_zz, 1) + !$acc enter data copyin(lbc_tend_rho_zz) + call mpas_pool_get_array(lbc, 'lbc_rho', lbc_tend_rho, 1) + !$acc enter data copyin(lbc_tend_rho) + call mpas_pool_get_array(lbc, 'lbc_scalars', lbc_tend_scalars, 1) + !$acc enter data copyin(lbc_tend_scalars) + + call mpas_pool_get_array(tend_physics, 'rthdynten', rthdynten) + !$acc enter data copyin(rthdynten) + + MPAS_ACC_TIMER_STOP('atm_srk3 [ACC_data_xfer]') #endif - end subroutine mpas_atm_dynamics_init + end subroutine mpas_atm_pre_dynamics_h2d + + + subroutine mpas_atm_post_dynamics_d2h(domain) + + implicit none + + type (domain_type), intent(inout) :: domain + + +#ifdef MPAS_OPENACC + type (mpas_pool_type), pointer :: state + type (mpas_pool_type), pointer :: diag + type (mpas_pool_type), pointer :: tend + type (mpas_pool_type), pointer :: tend_physics + type (mpas_pool_type), pointer :: lbc + + + real (kind=RKIND), dimension(:,:), pointer :: ru, ru_p + real (kind=RKIND), dimension(:,:), pointer :: ru_save + real (kind=RKIND), dimension(:,:), pointer :: rw, rw_p + real (kind=RKIND), dimension(:,:), pointer :: rw_save + real (kind=RKIND), dimension(:,:), pointer :: rtheta_p + real (kind=RKIND), dimension(:,:), pointer :: exner, exner_base + real (kind=RKIND), dimension(:,:), pointer :: rtheta_base, rho_base + real (kind=RKIND), dimension(:,:), pointer :: rtheta_p_save + real (kind=RKIND), dimension(:,:), pointer :: rho_p, rho_pp, rho, theta, theta_base + real (kind=RKIND), dimension(:,:), pointer :: rho_p_save + real (kind=RKIND), dimension(:,:), pointer :: rho_zz_old_split + real (kind=RKIND), dimension(:,:), pointer :: cqw, rtheta_pp_old, rtheta_pp + real (kind=RKIND), dimension(:,:), pointer :: cqu, pressure_base, pressure_p, pressure, v + real (kind=RKIND), dimension(:,:), pointer :: kdiff, pv_edge, pv_vertex, pv_cell, rho_edge, h_divergence, ke + real (kind=RKIND), dimension(:,:), pointer :: cofwr, cofwz, coftz, cofwt, a_tri, alpha_tri, gamma_tri + real (kind=RKIND), dimension(:), pointer :: cofrz + real (kind=RKIND), dimension(:,:), pointer :: gradPVn, gradPVt + + + real (kind=RKIND), dimension(:,:), pointer :: u_1, u_2 + real (kind=RKIND), dimension(:,:), pointer :: w_1, w_2 + real (kind=RKIND), dimension(:,:), pointer :: theta_m_1, theta_m_2 + real (kind=RKIND), dimension(:,:), pointer :: rho_zz_1, rho_zz_2 + real (kind=RKIND), dimension(:,:,:), pointer :: scalars_1, scalars_2 + real (kind=RKIND), dimension(:,:), pointer :: ruAvg, wwAvg, ruAvg_split, wwAvg_split + + real (kind=RKIND), dimension(:,:), pointer :: tend_ru, tend_rt, tend_rho, tend_rw, rt_diabatic_tend + real (kind=RKIND), dimension(:,:), pointer :: tend_u_euler, tend_w_euler, tend_theta_euler + real(kind=RKIND), dimension(:,:), pointer :: tend_w_pgf, tend_w_buoy + real(kind=RKIND), dimension(:,:,:), pointer :: scalar_tend_save + + real (kind=RKIND), dimension(:,:), pointer :: rthdynten, divergence, vorticity + + real (kind=RKIND), dimension(:,:), pointer :: lbc_u, lbc_w, lbc_ru, lbc_rho_edge, lbc_rho, lbc_rtheta_m, lbc_rho_zz, lbc_theta + real (kind=RKIND), dimension(:,:), pointer :: lbc_tend_u, lbc_tend_w, lbc_tend_ru, lbc_tend_rho_edge, lbc_tend_rho + real (kind=RKIND), dimension(:,:), pointer :: lbc_tend_rtheta_m, lbc_tend_rho_zz, lbc_tend_theta + + real (kind=RKIND), dimension(:,:,:), pointer :: lbc_scalars, lbc_tend_scalars + + nullify(state) + nullify(diag) + nullify(tend) + nullify(tend_physics) + nullify(lbc) + call mpas_pool_get_subpool(domain % blocklist % structs, 'state', state) + call mpas_pool_get_subpool(domain % blocklist % structs, 'diag', diag) + call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tend) + call mpas_pool_get_subpool(domain % blocklist % structs, 'tend_physics', tend_physics) + call mpas_pool_get_subpool(domain % blocklist % structs, 'lbc', lbc) + + MPAS_ACC_TIMER_START('atm_srk3 [ACC_data_xfer]') + call mpas_pool_get_array(diag, 'ru', ru) + !$acc exit data copyout(ru) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'ru_p', ru_p) + !$acc exit data copyout(ru_p) + call mpas_pool_get_array(diag, 'ru_save', ru_save) + !$acc exit data delete(ru_save) + call mpas_pool_get_array(diag, 'rw', rw) + !$acc exit data copyout(rw) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rw_p', rw_p) + !$acc exit data copyout(rw_p) + call mpas_pool_get_array(diag, 'rw_save', rw_save) + !$acc exit data delete(rw_save) + call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) + !$acc exit data copyout(rtheta_p) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rtheta_p_save', rtheta_p_save) + !$acc exit data delete(rtheta_p_save) + call mpas_pool_get_array(diag, 'exner', exner) + !$acc exit data copyout(exner) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'exner_base', exner_base) + !$acc exit data copyout(exner_base) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rtheta_base', rtheta_base) + !$acc exit data copyout(rtheta_base) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rho_base', rho_base) + !$acc exit data copyout(rho_base) + call mpas_pool_get_array(diag, 'rho', rho) + !$acc exit data copyout(rho) + call mpas_pool_get_array(diag, 'theta', theta) + !$acc exit data copyout(theta) + call mpas_pool_get_array(diag, 'theta_base', theta_base) + !$acc exit data copyout(theta_base) + call mpas_pool_get_array(diag, 'rho_p', rho_p) + !$acc exit data copyout(rho_p) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'rho_p_save', rho_p_save) + !$acc exit data delete(rho_p_save) + call mpas_pool_get_array(diag, 'rho_pp', rho_pp) + !$acc exit data copyout(rho_pp) + call mpas_pool_get_array(diag, 'rho_zz_old_split', rho_zz_old_split) + !$acc exit data delete(rho_zz_old_split) + call mpas_pool_get_array(diag, 'cqw', cqw) + !$acc exit data delete(cqw) + call mpas_pool_get_array(diag, 'cqu', cqu) + !$acc exit data copyout(cqu) + call mpas_pool_get_array(diag, 'pressure_p', pressure_p) + !$acc exit data copyout(pressure_p) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'pressure_base', pressure_base) + !$acc exit data copyout(pressure_base) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(diag, 'pressure', pressure) + !$acc exit data copyout(pressure) + call mpas_pool_get_array(diag, 'v', v) + !$acc exit data copyout(v) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) + !$acc exit data copyout(rtheta_pp) + call mpas_pool_get_array(diag, 'rtheta_pp_old', rtheta_pp_old) + !$acc exit data copyout(rtheta_pp_old) + call mpas_pool_get_array(diag, 'kdiff', kdiff) + !$acc exit data copyout(kdiff) + call mpas_pool_get_array(diag, 'pv_edge', pv_edge) + !$acc exit data copyout(pv_edge) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'pv_vertex', pv_vertex) + !$acc exit data copyout(pv_vertex) + call mpas_pool_get_array(diag, 'pv_cell', pv_cell) + !$acc exit data delete(pv_cell) + call mpas_pool_get_array(diag, 'rho_edge', rho_edge) + !$acc exit data copyout(rho_edge) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'h_divergence', h_divergence) + !$acc exit data copyout(h_divergence) + call mpas_pool_get_array(diag, 'ke', ke) + !$acc exit data copyout(ke) ! use values from atm_compute_solve_diagnostics + call mpas_pool_get_array(diag, 'gradPVn', gradPVn) + !$acc exit data delete(gradPVn) + call mpas_pool_get_array(diag, 'gradPVt', gradPVt) + !$acc exit data delete(gradPVt) + + call mpas_pool_get_array(diag, 'alpha_tri', alpha_tri) + !$acc exit data delete(alpha_tri) + call mpas_pool_get_array(diag, 'gamma_tri', gamma_tri) + !$acc exit data delete(gamma_tri) + call mpas_pool_get_array(diag, 'a_tri', a_tri) + !$acc exit data delete(a_tri) + call mpas_pool_get_array(diag, 'cofwr', cofwr) + !$acc exit data delete(cofwr) + call mpas_pool_get_array(diag, 'cofwz', cofwz) + !$acc exit data delete(cofwz) + call mpas_pool_get_array(diag, 'coftz', coftz) + !$acc exit data delete(coftz) + call mpas_pool_get_array(diag, 'cofwt', cofwt) + !$acc exit data delete(cofwt) + call mpas_pool_get_array(diag, 'cofrz', cofrz) + !$acc exit data delete(cofrz) + call mpas_pool_get_array(diag, 'vorticity', vorticity) + !$acc exit data copyout(vorticity) + call mpas_pool_get_array(diag, 'divergence', divergence) + !$acc exit data copyout(divergence) + call mpas_pool_get_array(diag, 'ruAvg', ruAvg) + !$acc exit data copyout(ruAvg) + call mpas_pool_get_array(diag, 'ruAvg_split', ruAvg_split) + !$acc exit data copyout(ruAvg_split) + call mpas_pool_get_array(diag, 'wwAvg', wwAvg) + !$acc exit data copyout(wwAvg) + call mpas_pool_get_array(diag, 'wwAvg_split', wwAvg_split) + !$acc exit data copyout(wwAvg_split) + + call mpas_pool_get_array(state, 'u', u_1, 1) + !$acc exit data copyout(u_1) + call mpas_pool_get_array(state, 'u', u_2, 2) + !$acc exit data delete(u_2) + call mpas_pool_get_array(state, 'w', w_1, 1) + !$acc exit data copyout(w_1) + call mpas_pool_get_array(state, 'w', w_2, 2) + !$acc exit data delete(w_2) + call mpas_pool_get_array(state, 'theta_m', theta_m_1, 1) + !$acc exit data copyout(theta_m_1) ! use values from atm_init_coupled_diagnostics + call mpas_pool_get_array(state, 'theta_m', theta_m_2, 2) + !$acc exit data copyout(theta_m_2) ! Delete gives incorrect results + call mpas_pool_get_array(state, 'rho_zz', rho_zz_1, 1) + !$acc exit data copyout(rho_zz_1) + call mpas_pool_get_array(state, 'rho_zz', rho_zz_2, 2) + !$acc exit data delete(rho_zz_2) + call mpas_pool_get_array(state, 'scalars', scalars_1, 1) + !$acc exit data copyout(scalars_1) + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc exit data copyout(scalars_2) ! Delete gives incorrect results + + + call mpas_pool_get_array(tend, 'u', tend_ru) + !$acc exit data copyout(tend_ru) + call mpas_pool_get_array(tend, 'rho_zz', tend_rho) + !$acc exit data copyout(tend_rho) + call mpas_pool_get_array(tend, 'theta_m', tend_rt) + !$acc exit data copyout(tend_rt) + call mpas_pool_get_array(tend, 'w', tend_rw) + !$acc exit data copyout(tend_rw) + call mpas_pool_get_array(tend, 'rt_diabatic_tend', rt_diabatic_tend) + !$acc exit data copyout(rt_diabatic_tend) + call mpas_pool_get_array(tend, 'u_euler', tend_u_euler) + !$acc exit data copyout(tend_u_euler) + call mpas_pool_get_array(tend, 'theta_euler', tend_theta_euler) + !$acc exit data copyout(tend_theta_euler) + call mpas_pool_get_array(tend, 'w_euler', tend_w_euler) + !$acc exit data copyout(tend_w_euler) + call mpas_pool_get_array(tend, 'w_pgf', tend_w_pgf) + !$acc exit data copyout(tend_w_pgf) + call mpas_pool_get_array(tend, 'w_buoy', tend_w_buoy) + !$acc exit data copyout(tend_w_buoy) + call mpas_pool_get_array(tend, 'scalars_tend', scalar_tend_save) + !$acc exit data copyout(scalar_tend_save) + + + call mpas_pool_get_array(lbc, 'lbc_u', lbc_u, 2) + !$acc exit data delete(lbc_u) + call mpas_pool_get_array(lbc, 'lbc_w', lbc_w, 2) + !$acc exit data delete(lbc_w) + call mpas_pool_get_array(lbc, 'lbc_ru', lbc_ru, 2) + !$acc exit data delete(lbc_ru) + call mpas_pool_get_array(lbc, 'lbc_rho_edge', lbc_rho_edge, 2) + !$acc exit data delete(lbc_rho_edge) + call mpas_pool_get_array(lbc, 'lbc_theta', lbc_theta, 2) + !$acc exit data delete(lbc_theta) + call mpas_pool_get_array(lbc, 'lbc_rtheta_m', lbc_rtheta_m, 2) + !$acc exit data delete(lbc_rtheta_m) + call mpas_pool_get_array(lbc, 'lbc_rho_zz', lbc_rho_zz, 2) + !$acc exit data delete(lbc_rho_zz) + call mpas_pool_get_array(lbc, 'lbc_rho', lbc_rho, 2) + !$acc exit data delete(lbc_rho) + call mpas_pool_get_array(lbc, 'lbc_scalars', lbc_scalars, 2) + !$acc exit data delete(lbc_scalars) + + + call mpas_pool_get_array(lbc, 'lbc_u', lbc_tend_u, 1) + !$acc exit data delete(lbc_tend_u) + call mpas_pool_get_array(lbc, 'lbc_ru', lbc_tend_ru, 1) + !$acc exit data delete(lbc_tend_ru) + call mpas_pool_get_array(lbc, 'lbc_rho_edge', lbc_tend_rho_edge, 1) + !$acc exit data delete(lbc_tend_rho_edge) + call mpas_pool_get_array(lbc, 'lbc_w', lbc_tend_w, 1) + !$acc exit data delete(lbc_tend_w) + call mpas_pool_get_array(lbc, 'lbc_theta', lbc_tend_theta, 1) + !$acc exit data delete(lbc_tend_theta) + call mpas_pool_get_array(lbc, 'lbc_rtheta_m', lbc_tend_rtheta_m, 1) + !$acc exit data delete(lbc_tend_rtheta_m) + call mpas_pool_get_array(lbc, 'lbc_rho_zz', lbc_tend_rho_zz, 1) + !$acc exit data delete(lbc_tend_rho_zz) + call mpas_pool_get_array(lbc, 'lbc_rho', lbc_tend_rho, 1) + !$acc exit data delete(lbc_tend_rho) + call mpas_pool_get_array(lbc, 'lbc_scalars', lbc_tend_scalars, 1) + !$acc exit data delete(lbc_tend_scalars) + + call mpas_pool_get_array(tend_physics, 'rthdynten', rthdynten) + !$acc exit data copyout(rthdynten) + MPAS_ACC_TIMER_STOP('atm_srk3 [ACC_data_xfer]') +#endif + + end subroutine mpas_atm_post_dynamics_d2h !---------------------------------------------------------------------------- @@ -774,12 +1672,14 @@ subroutine atm_timestep(domain, dt, nowTime, itimestep, exchange_halo_group) config_apply_lbcs = config_apply_lbcs_ptr + call mpas_atm_pre_dynamics_h2d(domain) if (trim(config_time_integration) == 'SRK3') then call atm_srk3(domain, dt, itimestep, exchange_halo_group) else call mpas_log_write('Unknown time integration option '//trim(config_time_integration), messageType=MPAS_LOG_ERR) call mpas_log_write('Currently, only ''SRK3'' is supported.', messageType=MPAS_LOG_CRIT) end if + call mpas_atm_post_dynamics_d2h(domain) call mpas_set_timeInterval(dtInterval, dt=dt) currTime = nowTime + dtInterval @@ -873,6 +1773,8 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) real (kind=RKIND), dimension(:,:,:), pointer :: scalars, scalars_1, scalars_2 real (kind=RKIND), dimension(:,:), pointer :: rqvdynten, rthdynten, theta_m + real (kind=RKIND), dimension(:,:), pointer :: pressure_p, rtheta_p, exner, tend_u + real (kind=RKIND), dimension(:,:), pointer :: rho_pp, rtheta_pp, ru_p, rw_p, pv_edge, rho_edge real (kind=RKIND) :: theta_local, fac_m #ifndef MPAS_CAM_DYCORE @@ -1040,7 +1942,15 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! ! Communicate halos for theta_m, scalars, pressure_p, and rtheta_p ! + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'theta_m', theta_m, 1) + call mpas_pool_get_array(state, 'scalars', scalars_1, 1) + call mpas_pool_get_array(diag, 'pressure_p', pressure_p) + call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) + !$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p') + !$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_rk_integration_setup') @@ -1102,6 +2012,8 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) tend_ru_physics, tend_rtheta_physics, tend_rho_physics) end if + !$acc enter data copyin(tend_rtheta_physics,tend_rho_physics,tend_ru_physics) + DYNAMICS_SUBSTEPS : do dynamics_substep = 1, dynamics_split @@ -1121,8 +2033,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) !$OMP END PARALLEL DO call mpas_timer_stop('atm_compute_vert_imp_coefs') + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(diag, 'exner', exner) + !$acc update self(exner) call exchange_halo_group(domain, 'dynamics:exner') - + !$acc update device(exner) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ! BEGIN Runge-Kutta loop @@ -1200,7 +2116,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) !*********************************** ! tend_u + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(tend, 'u', tend_u) + !$acc update self(tend_u) call exchange_halo_group(domain, 'dynamics:tend_u') + !$acc update device(tend_u) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('small_step_prep') @@ -1276,7 +2197,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) do small_step = 1, number_sub_steps(rk_step) + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(diag, 'rho_pp', rho_pp) + !$acc update self(rho_pp) call exchange_halo_group(domain, 'dynamics:rho_pp') + !$acc update device(rho_pp) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_advance_acoustic_step') @@ -1298,8 +2224,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! rtheta_pp ! This is the only communications needed during the acoustic steps because we solve for u on all edges of owned cells - + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) + !$acc update self(rtheta_pp) call exchange_halo_group(domain, 'dynamics:rtheta_pp') + !$acc update device(rtheta_pp) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! complete update of horizontal momentum by including 3d divergence damping at the end of the acoustic step @@ -1319,7 +2249,15 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! ! Communicate halos for rw_p[1,2], ru_p[1,2], rho_pp[1,2], rtheta_pp[2] ! + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(diag, 'ru_p', ru_p) + call mpas_pool_get_array(diag, 'rw_p', rw_p) + call mpas_pool_get_array(diag, 'rho_pp', rho_pp) + call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) + !$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp') + !$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_recover_large_step_variables') @@ -1354,7 +2292,6 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_atm_get_bdy_state(clock, block, nVertLevels, nEdges, 'u', time_dyn_step, ru_driving_values) ! do this inline at present - it is simple enough - !$acc enter data copyin(u) !$acc parallel default(present) !$acc loop gang worker do iEdge = 1, nEdgesSolve @@ -1366,12 +2303,10 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) end if end do !$acc end parallel - !$acc exit data copyout(u) call mpas_atm_get_bdy_state(clock, block, nVertLevels, nEdges, 'ru', time_dyn_step, ru_driving_values) call mpas_pool_get_array(diag, 'ru', u) ! do this inline at present - it is simple enough - !$acc enter data copyin(u) !$acc parallel default(present) !$acc loop gang worker do iEdge = 1, nEdges @@ -1383,7 +2318,6 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) end if end do !$acc end parallel - !$acc exit data copyout(u) deallocate(ru_driving_values) @@ -1391,12 +2325,17 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) !------------------------------------------------------------------- + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'u', u, 2) + !$acc update self(u) ! u if (config_apply_lbcs) then call exchange_halo_group(domain, 'dynamics:u_123') else call exchange_halo_group(domain, 'dynamics:u_3') end if + !$acc update device(u) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! scalar advection: RK3 scheme of Skamarock and Gassmann (2011). ! PD or monotonicity constraints applied only on the final Runge-Kutta substep. @@ -1408,7 +2347,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary tendencies for regional_MPAS scalar transport + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update device(scalars_2) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -1460,17 +2404,27 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_timer_stop('atm_compute_solve_diagnostics') + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'w', w, 2) + call mpas_pool_get_array(diag, 'pv_edge', pv_edge) + call mpas_pool_get_array(diag, 'rho_edge', rho_edge) + !$acc update self(w,pv_edge,rho_edge) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2], scalars[1,2] ! + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars') + !$acc update device(scalars_2) else ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2] ! call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge') end if + !$acc update device(w,pv_edge,rho_edge) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! set the zero-gradient condition on w for regional_MPAS @@ -1483,8 +2437,13 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) end do !$OMP END PARALLEL DO + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! w halo values needs resetting after regional boundary update + call mpas_pool_get_array(state, 'w', w, 2) + !$acc update self(w) call exchange_halo_group(domain, 'dynamics:w') + !$acc update device(w) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if ! end of regional_MPAS addition @@ -1495,7 +2454,14 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! ! Communicate halos for theta_m[1,2], pressure_p[1,2], and rtheta_p[1,2] ! + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'theta_m', theta_m, 2) + call mpas_pool_get_array(diag, 'pressure_p', pressure_p) + call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) + !$acc update self(theta_m,pressure_p,rtheta_p) call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p') + !$acc update device(theta_m,pressure_p,rtheta_p) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! ! Note: A halo exchange for 'exner' here as well as after the call @@ -1532,6 +2498,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) deallocate(qtot) ! we are finished with these now + !$acc exit data delete(tend_rtheta_physics,tend_rho_physics,tend_ru_physics) #ifndef MPAS_CAM_DYCORE call mpas_deallocate_scratch_field(tend_rtheta_physicsField) call mpas_deallocate_scratch_field(tend_rho_physicsField) @@ -1559,8 +2526,13 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary tendencies for regional_MPAS scalar transport + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! need to fill halo for horizontal filter + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update device(scalars_2) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -1586,7 +2558,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) !------------------------------------------------------------------------------------------------------------------------ if (rk_step < 3) then + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update device(scalars_2) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if end do RK3_SPLIT_TRANSPORT @@ -1618,16 +2595,25 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! #ifdef DO_PHYSICS + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_1, 1) + !$acc update self(scalars_1) call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc update self(scalars_2) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') if(config_convection_scheme == 'cu_grell_freitas' .or. & config_convection_scheme == 'cu_ntiedtke') then + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(tend_physics, 'rqvdynten', rqvdynten) call mpas_pool_get_array(state, 'theta_m', theta_m, 2) + !$acc update self(theta_m) call mpas_pool_get_array(tend_physics, 'rthdynten', rthdynten) + !$acc update self(rthdynten) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + !NOTE: The calculation of the tendency due to horizontal and vertical advection for the water vapor mixing ratio !requires that the subroutine atm_advance_scalars_mono was called on the third Runge Kutta step, so that a halo @@ -1652,6 +2638,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) where ( scalars_2(:,:,:) < 0.0) & scalars_2(:,:,:) = 0.0 + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') + !$acc update device(scalars_2, rthdynten) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') !call microphysics schemes: if (trim(config_microp_scheme) /= 'off') then call mpas_timer_start('microphysics') @@ -1699,7 +2688,12 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary values for regional_MPAS scalar transport + MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') + call mpas_pool_get_array(state, 'scalars', scalars_2, 2) + !$acc update self(scalars_2) call exchange_halo_group(domain, 'dynamics:scalars') + !$acc update device(scalars_2) + MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -1976,12 +2970,6 @@ subroutine atm_rk_integration_setup( state, diag, nVertLevels, num_scalars, & call mpas_pool_get_array(state, 'scalars', scalars_1, 1) call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - MPAS_ACC_TIMER_START('atm_rk_integration_setup [ACC_data_xfer]') - !$acc enter data create(ru_save, u_2, rw_save, rtheta_p_save, rho_p_save, & - !$acc w_2, theta_m_2, rho_zz_2, rho_zz_old_split, scalars_2) & - !$acc copyin(ru, rw, rtheta_p, rho_p, u_1, w_1, theta_m_1, & - !$acc rho_zz_1, scalars_1) - MPAS_ACC_TIMER_STOP('atm_rk_integration_setup [ACC_data_xfer]') !$acc kernels theta_m_2(:,cellEnd+1) = 0.0_RKIND @@ -2029,12 +3017,6 @@ subroutine atm_rk_integration_setup( state, diag, nVertLevels, num_scalars, & end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_rk_integration_setup [ACC_data_xfer]') - !$acc exit data copyout(ru_save, rw_save, rtheta_p_save, rho_p_save, u_2, & - !$acc w_2, theta_m_2, rho_zz_2, rho_zz_old_split, scalars_2) & - !$acc delete(ru, rw, rtheta_p, rho_p, u_1, w_1, theta_m_1, & - !$acc rho_zz_1, scalars_1) - MPAS_ACC_TIMER_STOP('atm_rk_integration_setup [ACC_data_xfer]') end subroutine atm_rk_integration_setup @@ -2085,11 +3067,6 @@ subroutine atm_compute_moist_coefficients( dims, state, diag, mesh, & moist_start = moist_start_ptr moist_end = moist_end_ptr - MPAS_ACC_TIMER_START('atm_compute_moist_coefficients [ACC_data_xfer]') - !$acc enter data create(cqw, cqu) & - !$acc copyin(scalars) - MPAS_ACC_TIMER_STOP('atm_compute_moist_coefficients [ACC_data_xfer]') - !$acc parallel default(present) !$acc loop gang worker ! do iCell = cellSolveStart,cellSolveEnd @@ -2138,10 +3115,6 @@ subroutine atm_compute_moist_coefficients( dims, state, diag, mesh, & end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_compute_moist_coefficients [ACC_data_xfer]') - !$acc exit data copyout(cqw, cqu) & - !$acc delete(scalars) - MPAS_ACC_TIMER_STOP('atm_compute_moist_coefficients [ACC_data_xfer]') end subroutine atm_compute_moist_coefficients @@ -2274,9 +3247,7 @@ subroutine atm_compute_vert_imp_coefs_work(nCells, moist_start, moist_end, dts, real (kind=RKIND), dimension( nVertLevels ) :: b_tri, c_tri MPAS_ACC_TIMER_START('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') - !$acc enter data copyin(cqw, p, t, rb, rtb, rt, pb) - !$acc enter data create(cofrz, cofwr, cofwz, coftz, cofwt, a_tri, b_tri, & - !$acc c_tri, alpha_tri, gamma_tri) + !$acc enter data create(b_tri, c_tri) MPAS_ACC_TIMER_STOP('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') ! set coefficients @@ -2358,9 +3329,7 @@ subroutine atm_compute_vert_imp_coefs_work(nCells, moist_start, moist_end, dts, !$acc end parallel MPAS_ACC_TIMER_START('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') - !$acc exit data copyout(cofrz, cofwr, cofwz, coftz, cofwt, a_tri, b_tri, & - !$acc c_tri, alpha_tri, gamma_tri) - !$acc exit data delete(cqw, p, t, rb, rtb, rt, pb) + !$acc exit data copyout(b_tri, c_tri) MPAS_ACC_TIMER_STOP('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') end subroutine atm_compute_vert_imp_coefs_work @@ -2465,9 +3434,6 @@ subroutine atm_set_smlstep_pert_variables_work(nCells, nEdges, & integer :: iCell, iEdge, i, k real (kind=RKIND) :: flux - MPAS_ACC_TIMER_START('atm_set_smlstep_pert_variables [ACC_data_xfer]') - !$acc enter data copyin(u_tend, w_tend) - MPAS_ACC_TIMER_STOP('atm_set_smlstep_pert_variables [ACC_data_xfer]') ! we solve for omega instead of w (see Klemp et al MWR 2007), ! so here we change the w_p tendency to an omega_p tendency @@ -2500,10 +3466,6 @@ subroutine atm_set_smlstep_pert_variables_work(nCells, nEdges, & end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_set_smlstep_pert_variables [ACC_data_xfer]') - !$acc exit data delete(u_tend) - !$acc exit data copyout(w_tend) - MPAS_ACC_TIMER_STOP('atm_set_smlstep_pert_variables [ACC_data_xfer]') end subroutine atm_set_smlstep_pert_variables_work @@ -2736,17 +3698,6 @@ subroutine atm_advance_acoustic_step_work(nCells, nEdges, nCellsSolve, cellStart resm = (1.0 - epssm) / (1.0 + epssm) rdts = 1./dts - MPAS_ACC_TIMER_START('atm_advance_acoustic_step [ACC_data_xfer]') - !$acc enter data copyin(exner,cqu,cofwt,coftz,cofrz,cofwr,cofwz, & - !$acc a_tri,alpha_tri,gamma_tri,rho_zz,theta_m,w, & - !$acc tend_ru,tend_rho,tend_rt,tend_rw,rw,rw_save) - !$acc enter data create(rtheta_pp_old) - if(small_step == 1) then - !$acc enter data create(ru_p,ruAvg,rho_pp,rtheta_pp,wwAvg,rw_p) - else - !$acc enter data copyin(ru_p,ruAvg,rho_pp,rtheta_pp,wwAvg,rw_p) - end if - MPAS_ACC_TIMER_STOP('atm_advance_acoustic_step [ACC_data_xfer]') if(small_step /= 1) then ! not needed on first small step @@ -2973,13 +3924,6 @@ subroutine atm_advance_acoustic_step_work(nCells, nEdges, nCellsSolve, cellStart end do ! end of loop over cells !$acc end parallel - MPAS_ACC_TIMER_START('atm_advance_acoustic_step [ACC_data_xfer]') - !$acc exit data delete(exner,cqu,cofwt,coftz,cofrz,cofwr,cofwz, & - !$acc a_tri,alpha_tri,gamma_tri,rho_zz,theta_m,w, & - !$acc tend_ru,tend_rho,tend_rt,tend_rw,rw,rw_save) - !$acc exit data copyout(rtheta_pp_old,ru_p,ruAvg,rho_pp, & - !$acc rtheta_pp,wwAvg,rw_p) - MPAS_ACC_TIMER_STOP('atm_advance_acoustic_step [ACC_data_xfer]') end subroutine atm_advance_acoustic_step_work @@ -3031,9 +3975,6 @@ subroutine atm_divergence_damping_3d( state, diag, mesh, configs, dts, edgeStart nCellsSolve = nCellsSolve_ptr nVertLevels = nVertLevels_ptr - MPAS_ACC_TIMER_START('atm_divergence_damping_3d [ACC_data_xfer]') - !$acc enter data copyin(ru_p, rtheta_pp, rtheta_pp_old, theta_m) - MPAS_ACC_TIMER_STOP('atm_divergence_damping_3d [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang worker @@ -3066,10 +4007,6 @@ subroutine atm_divergence_damping_3d( state, diag, mesh, configs, dts, edgeStart end do ! end loop over edges !$acc end parallel - MPAS_ACC_TIMER_START('atm_divergence_damping_3d [ACC_data_xfer]') - !$acc exit data copyout(ru_p) & - !$acc delete(rtheta_pp, rtheta_pp_old, theta_m) - MPAS_ACC_TIMER_STOP('atm_divergence_damping_3d [ACC_data_xfer]') end subroutine atm_divergence_damping_3d @@ -3260,17 +4197,6 @@ subroutine atm_recover_large_step_variables_work(nCells, nEdges, nCellsSolve, nE integer :: i, iCell, iEdge, k, cell1, cell2 real (kind=RKIND) :: invNs, rcv, p0, flux - MPAS_ACC_TIMER_START('atm_recover_large_step_variables [ACC_data_xfer]') - !$acc enter data copyin(rho_p_save,rho_pp,rho_base,rw_save,rw_p, & - !$acc rtheta_p_save,rtheta_pp,rtheta_base, & - !$acc ru_save,ru_p,wwAvg,ruAvg) & - !$acc create(rho_zz,rho_p,rw,w,rtheta_p,theta_m, & - !$acc ru,u) - if (rk_step == 3) then - !$acc enter data copyin(rt_diabatic_tend,exner_base) & - !$acc create(exner,pressure_p) - end if - MPAS_ACC_TIMER_STOP('atm_recover_large_step_variables [ACC_data_xfer]') rcv = rgas/(cp-rgas) p0 = 1.0e+05 ! this should come from somewhere else... @@ -3416,17 +4342,6 @@ subroutine atm_recover_large_step_variables_work(nCells, nEdges, nCellsSolve, nE end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_recover_large_step_variables [ACC_data_xfer]') - !$acc exit data delete(rho_p_save,rho_pp,rho_base,rw_save,rw_p, & - !$acc rtheta_p_save,rtheta_pp,rtheta_base, & - !$acc ru_save,ru_p) & - !$acc copyout(rho_zz,rho_p,rw,w,rtheta_p,theta_m, & - !$acc ru,u,wwAvg,ruAvg) - if (rk_step == 3) then - !$acc exit data delete(rt_diabatic_tend,exner_base) & - !$acc copyout(exner,pressure_p) - end if - MPAS_ACC_TIMER_STOP('atm_recover_large_step_variables [ACC_data_xfer]') end subroutine atm_recover_large_step_variables_work @@ -3661,10 +4576,6 @@ subroutine atm_advance_scalars_work(nCells, num_scalars, dt, & weight_time_old = 1. - weight_time_new - MPAS_ACC_TIMER_START('atm_advance_scalars [ACC_data_xfer]') - !$acc enter data copyin(uhAvg, scalar_new) - MPAS_ACC_TIMER_STOP('atm_advance_scalars [ACC_data_xfer]') - !$acc parallel async !$acc loop gang worker private(scalar_weight2, ica) do iEdge=edgeStart,edgeEnd @@ -3759,12 +4670,6 @@ subroutine atm_advance_scalars_work(nCells, num_scalars, dt, & ! MPAS_ACC_TIMER_START('atm_advance_scalars [ACC_data_xfer]') -#ifndef DO_PHYSICS - !$acc enter data create(scalar_tend_save) -#else - !$acc enter data copyin(scalar_tend_save) -#endif - !$acc enter data copyin(scalar_old, fnm, fnp, rdnw, wwAvg, rho_zz_old, rho_zz_new) !$acc enter data create(scalar_tend_column) MPAS_ACC_TIMER_STOP('atm_advance_scalars [ACC_data_xfer]') @@ -3847,9 +4752,7 @@ subroutine atm_advance_scalars_work(nCells, num_scalars, dt, & !$acc end parallel MPAS_ACC_TIMER_START('atm_advance_scalars [ACC_data_xfer]') - !$acc exit data copyout(scalar_new) - !$acc exit data delete(scalar_tend_column, uhAvg, wwAvg, scalar_old, fnm, fnp, & - !$acc rdnw, rho_zz_old, rho_zz_new, scalar_tend_save) + !$acc exit data delete(scalar_tend_column) MPAS_ACC_TIMER_STOP('atm_advance_scalars [ACC_data_xfer]') end subroutine atm_advance_scalars_work @@ -4108,19 +5011,9 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc data present(nEdgesOnCell, edgesOnCell, edgesOnCell_sign, & - !$acc invAreaCell, cellsOnCell, cellsOnEdge, nAdvCellsForEdge, & - !$acc advCellsForEdge, adv_coefs, adv_coefs_3rd, dvEdge, bdyMaskCell) - -#ifdef DO_PHYSICS - !$acc enter data copyin(scalar_tend) -#else - !$acc enter data create(scalar_tend) -#endif if (local_advance_density) then !$acc enter data copyin(rho_zz_int) end if - !$acc enter data copyin(scalars_old, rho_zz_old, rdnw, uhAvg, wwAvg) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$acc parallel @@ -4145,8 +5038,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$acc end parallel MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc exit data copyout(scalar_tend) - !$acc update self(scalars_old) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') @@ -4210,10 +5101,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge end if MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - if (.not. local_advance_density) then - !$acc enter data copyin(rho_zz_new) - end if - !$acc enter data copyin(scalars_new, fnm, fnp) !$acc enter data create(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') @@ -4721,14 +5608,8 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') if (local_advance_density) then !$acc exit data copyout(rho_zz_int) - else - !$acc exit data delete(rho_zz_new) end if - !$acc exit data copyout(scalars_new) - !$acc exit data delete(scalars_old, scale_arr, rho_zz_old, wwAvg, & - !$acc uhAvg, fnm, fnp, rdnw) - - !$acc end data + !$acc exit data delete(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') end subroutine atm_advance_scalars_mono_work @@ -5162,43 +6043,10 @@ subroutine atm_compute_dyn_tend_work(nCells, nEdges, nVertices, nVertLevels_dumm MPAS_ACC_TIMER_START('atm_compute_dyn_tend_work [ACC_data_xfer]') - if (rk_step == 1) then - !$acc enter data create(tend_w_euler) - !$acc enter data create(tend_u_euler) - !$acc enter data create(tend_theta_euler) - !$acc enter data create(tend_rho) - - !$acc enter data create(kdiff) - !$acc enter data copyin(tend_rho_physics) - !$acc enter data copyin(rb, rr_save) - !$acc enter data copyin(divergence, vorticity) - !$acc enter data copyin(v) - !$acc enter data copyin(u_init, v_init) - else - !$acc enter data copyin(tend_w_euler) - !$acc enter data copyin(tend_u_euler) - !$acc enter data copyin(tend_theta_euler) - !$acc enter data copyin(tend_rho) - end if - !$acc enter data create(tend_u) - !$acc enter data copyin(cqu, pp, u, w, pv_edge, rho_edge, ke) - !$acc enter data create(h_divergence) - !$acc enter data copyin(ru, rw) !$acc enter data create(rayleigh_damp_coef) - !$acc enter data copyin(tend_ru_physics) - !$acc enter data create(tend_w) - !$acc enter data copyin(rho_zz) - !$acc enter data create(tend_theta) - !$acc enter data copyin(theta_m) - !$acc enter data copyin(ru_save, theta_m_save) - !$acc enter data copyin(cqw) - !$acc enter data copyin(tend_rtheta_physics) - !$acc enter data copyin(rw_save, rt_diabatic_tend) - !$acc enter data create(rthdynten) - !$acc enter data copyin(t_init) -#ifdef CURVATURE + #ifdef CURVATURE !$acc enter data copyin(ur_cell, vr_cell) -#endif + #endif MPAS_ACC_TIMER_STOP('atm_compute_dyn_tend_work [ACC_data_xfer]') prandtl_inv = 1.0_RKIND / prandtl @@ -6198,43 +7046,10 @@ subroutine atm_compute_dyn_tend_work(nCells, nEdges, nVertices, nVertLevels_dumm !$acc end parallel MPAS_ACC_TIMER_START('atm_compute_dyn_tend_work [ACC_data_xfer]') - if (rk_step == 1) then - !$acc exit data copyout(tend_w_euler) - !$acc exit data copyout(tend_u_euler) - !$acc exit data copyout(tend_theta_euler) - !$acc exit data copyout(tend_rho) - - !$acc exit data delete(kdiff) - !$acc exit data delete(tend_rho_physics) - !$acc exit data delete(rb, rr_save) - !$acc exit data delete(divergence, vorticity) - !$acc exit data delete(v) - !$acc exit data delete(u_init, v_init) - else - !$acc exit data delete(tend_w_euler) - !$acc exit data delete(tend_u_euler) - !$acc exit data delete(tend_theta_euler) - !$acc exit data delete(tend_rho) - end if - !$acc exit data copyout(tend_u) - !$acc exit data delete(cqu, pp, u, w, pv_edge, rho_edge, ke) - !$acc exit data copyout(h_divergence) - !$acc exit data delete(ru, rw) !$acc exit data delete(rayleigh_damp_coef) - !$acc exit data delete(tend_ru_physics) - !$acc exit data copyout(tend_w) - !$acc exit data delete(rho_zz) - !$acc exit data copyout(tend_theta) - !$acc exit data delete(theta_m) - !$acc exit data delete(ru_save, theta_m_save) - !$acc exit data delete(cqw) - !$acc exit data delete(tend_rtheta_physics) - !$acc exit data delete(rw_save, rt_diabatic_tend) - !$acc exit data copyout(rthdynten) - !$acc exit data delete(t_init) -#ifdef CURVATURE + #ifdef CURVATURE !$acc exit data delete(ur_cell, vr_cell) -#endif + #endif MPAS_ACC_TIMER_STOP('atm_compute_dyn_tend_work [ACC_data_xfer]') end subroutine atm_compute_dyn_tend_work @@ -6403,26 +7218,10 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & logical :: reconstruct_v - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data copyin(cellsOnEdge,dcEdge,dvEdge, & - !$acc edgesOnVertex,edgesOnVertex_sign,invAreaTriangle, & - !$acc nEdgesOnCell,edgesOnCell, & - !$acc edgesOnCell_sign,invAreaCell, & - !$acc invAreaTriangle,edgesOnVertex, & - !$acc verticesOnCell,kiteForCell,kiteAreasOnVertex, & - !$acc nEdgesOnEdge,edgesOnEdge,weightsOnEdge, & - !$acc fVertex, & - !$acc verticesOnEdge, & - !$acc invDvEdge,invDcEdge) - !$acc enter data copyin(u,h) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') ! ! Compute height on cell edges at velocity locations ! - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data create(h_edge,vorticity,divergence) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang do iEdge=edgeStart,edgeEnd @@ -6507,9 +7306,6 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & ! ! Replace 2.0 with 2 in exponentiation to avoid outside chance that ! compiler will actually allow "float raised to float" operation - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data create(ke) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang do iCell=cellStart,cellEnd @@ -6604,14 +7400,6 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & if(rk_step /= 3) reconstruct_v = .false. end if - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - if (reconstruct_v) then - !$acc enter data create(v) - else - !$acc enter data copyin(v) - end if - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') - if (reconstruct_v) then !$acc parallel default(present) !$acc loop gang @@ -6639,9 +7427,6 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & ! ! Avoid dividing h_vertex by areaTriangle and move areaTriangle into ! numerator for the pv_vertex calculation - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data create(pv_vertex) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') !$acc parallel default(present) !$acc loop collapse(2) do iVertex = vertexStart,vertexEnd @@ -6665,9 +7450,6 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & ! Compute pv at the edges ! ( this computes pv_edge at all edges bounding real cells ) ! - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data create(pv_edge) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') !$acc parallel default(present) !$acc loop collapse(2) do iEdge = edgeStart,edgeEnd @@ -6685,9 +7467,6 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & ! ( this computes pv_cell for all real cells ) ! only needed for APVM upwinding ! - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data create(pv_cell) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang do iCell=cellStart,cellEnd @@ -6726,9 +7505,6 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & ! Merged loops for calculating gradPVt, gradPVn and pv_edge ! Also precomputed inverses of dvEdge and dcEdge to avoid repeated divisions ! - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc enter data create(gradPVt,gradPVn) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') r = config_apvm_upwinding * dt !$acc parallel default(present) !$acc loop gang @@ -6745,31 +7521,10 @@ subroutine atm_compute_solve_diagnostics_work(nCells, nEdges, nVertices, & end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc exit data delete(pv_cell,gradPVt,gradPVn) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') end if ! apvm upwinding - MPAS_ACC_TIMER_START('atm_compute_solve_diagnostics [ACC_data_xfer]') - !$acc exit data delete(cellsOnEdge,dcEdge,dvEdge, & - !$acc edgesOnVertex,edgesOnVertex_sign,invAreaTriangle, & - !$acc nEdgesOnCell,edgesOnCell, & - !$acc edgesOnCell_sign,invAreaCell, & - !$acc invAreaTriangle,edgesOnVertex, & - !$acc verticesOnCell,kiteForCell,kiteAreasOnVertex, & - !$acc nEdgesOnEdge,edgesOnEdge,weightsOnEdge, & - !$acc verticesOnEdge, & - !$acc fVertex,invDvEdge,invDcEdge) - !$acc exit data delete(u,h) - !$acc exit data copyout(h_edge,vorticity,divergence, & - !$acc ke, & - !$acc v, & - !$acc pv_vertex, & - !$acc pv_edge) - MPAS_ACC_TIMER_STOP('atm_compute_solve_diagnostics [ACC_data_xfer]') - end subroutine atm_compute_solve_diagnostics_work @@ -6858,17 +7613,13 @@ subroutine atm_init_coupled_diagnostics(state, time_lev, diag, mesh, configs, & call mpas_pool_get_array(mesh, 'zb3_cell', zb3_cell) MPAS_ACC_TIMER_START('atm_init_coupled_diagnostics [ACC_data_xfer]') - ! copyin invariant fields - !$acc enter data copyin(cellsOnEdge,nEdgesOnCell,edgesOnCell, & - !$acc edgesOnCell_sign,zz,fzm,fzp,zb,zb3, & - !$acc zb_cell,zb3_cell) ! copyin the data that is only on the right-hand side - !$acc enter data copyin(scalars(index_qv,:,:),u,w,rho,theta, & + !$acc enter data copyin(scalars(index_qv,:,:),w,rho,theta, & !$acc rho_base,theta_base) ! copyin the data that will be modified in this routine - !$acc enter data create(theta_m,rho_zz,ru,rw,rho_p,rtheta_base, & + !$acc enter data create(theta_m,ru,rw,rho_p,rtheta_base, & !$acc rtheta_p,exner,exner_base,pressure_p, & !$acc pressure_base) MPAS_ACC_TIMER_STOP('atm_init_coupled_diagnostics [ACC_data_xfer]') @@ -6992,17 +7743,12 @@ subroutine atm_init_coupled_diagnostics(state, time_lev, diag, mesh, configs, & !$acc end parallel MPAS_ACC_TIMER_START('atm_init_coupled_diagnostics [ACC_data_xfer]') - ! delete invariant fields - !$acc exit data delete(cellsOnEdge,nEdgesOnCell,edgesOnCell, & - !$acc edgesOnCell_sign,zz,fzm,fzp,zb,zb3, & - !$acc zb_cell,zb3_cell) - ! delete the data that is only on the right-hand side - !$acc exit data delete(scalars(index_qv,:,:),u,w,rho,theta, & + !$acc exit data delete(scalars(index_qv,:,:),w,rho,theta, & !$acc rho_base,theta_base) ! copyout the data that will be modified in this routine - !$acc exit data copyout(theta_m,rho_zz,ru,rw,rho_p,rtheta_base, & + !$acc exit data copyout(theta_m,ru,rw,rho_p,rtheta_base, & !$acc rtheta_p,exner,exner_base,pressure_p, & !$acc pressure_base) MPAS_ACC_TIMER_STOP('atm_init_coupled_diagnostics [ACC_data_xfer]') @@ -7069,13 +7815,6 @@ subroutine atm_rk_dynamics_substep_finish( state, diag, nVertLevels, dynamics_su call mpas_pool_get_array(state, 'rho_zz', rho_zz_1, 1) call mpas_pool_get_array(state, 'rho_zz', rho_zz_2, 2) - MPAS_ACC_TIMER_START('atm_rk_dynamics_substep_finish [ACC_data_xfer]') - !$acc enter data create(ru_save, u_1, rtheta_p_save, theta_m_1, rho_p_save, rw_save, & - !$acc w_1, rho_zz_1) & - !$acc copyin(ru, u_2, rtheta_p, rho_p, theta_m_2, rho_zz_2, rw, & - !$acc w_2, ruAvg, wwAvg, ruAvg_split, wwAvg_split, rho_zz_old_split) - MPAS_ACC_TIMER_STOP('atm_rk_dynamics_substep_finish [ACC_data_xfer]') - ! Interim fix for the atm_compute_dyn_tend_work subroutine accessing uninitialized values ! in garbage cells of theta_m !$acc kernels @@ -7180,13 +7919,6 @@ subroutine atm_rk_dynamics_substep_finish( state, diag, nVertLevels, dynamics_su !$acc end parallel end if - MPAS_ACC_TIMER_START('atm_rk_dynamics_substep_finish [ACC_data_xfer]') - !$acc exit data copyout(ru_save, u_1, rtheta_p_save, rho_p_save, rw_save, & - !$acc w_1, theta_m_1, rho_zz_1, ruAvg, wwAvg, ruAvg_split, & - !$acc wwAvg_split) & - !$acc delete(ru, u_2, rtheta_p, rho_p, theta_m_2, rho_zz_2, rw, & - !$acc w_2, rho_zz_old_split) - MPAS_ACC_TIMER_STOP('atm_rk_dynamics_substep_finish [ACC_data_xfer]') end subroutine atm_rk_dynamics_substep_finish @@ -7241,9 +7973,6 @@ subroutine atm_zero_gradient_w_bdy_work( w, bdyMaskCell, nearestRelaxationCell, integer :: iCell, k - MPAS_ACC_TIMER_START('atm_zero_gradient_w_bdy_work [ACC_data_xfer]') - !$acc enter data copyin(w) - MPAS_ACC_TIMER_STOP('atm_zero_gradient_w_bdy_work [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang worker @@ -7259,9 +7988,6 @@ subroutine atm_zero_gradient_w_bdy_work( w, bdyMaskCell, nearestRelaxationCell, end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_zero_gradient_w_bdy_work [ACC_data_xfer]') - !$acc exit data copyout(w) - MPAS_ACC_TIMER_STOP('atm_zero_gradient_w_bdy_work [ACC_data_xfer]') end subroutine atm_zero_gradient_w_bdy_work @@ -7302,11 +8028,6 @@ subroutine atm_bdy_adjust_dynamics_speczone_tend( tend, mesh, config, nVertLevel call mpas_pool_get_array(mesh, 'bdyMaskEdge', bdyMaskEdge) call mpas_pool_get_array(tend, 'rt_diabatic_tend', rt_diabatic_tend) - MPAS_ACC_TIMER_START('atm_bdy_adjust_dynamics_speczone_tend [ACC_data_xfer]') - !$acc enter data copyin(tend_ru,tend_rho,tend_rt,tend_rw, & - !$acc rt_diabatic_tend) - MPAS_ACC_TIMER_STOP('atm_bdy_adjust_dynamics_speczone_tend [ACC_data_xfer]') - !$acc parallel default(present) !$acc loop gang worker do iCell = cellSolveStart, cellSolveEnd @@ -7333,11 +8054,6 @@ subroutine atm_bdy_adjust_dynamics_speczone_tend( tend, mesh, config, nVertLevel end if end do !$acc end parallel - - MPAS_ACC_TIMER_START('atm_bdy_adjust_dynamics_speczone_tend [ACC_data_xfer]') - !$acc exit data copyout(tend_ru,tend_rho,tend_rt, & - !$acc tend_rw,rt_diabatic_tend) - MPAS_ACC_TIMER_STOP('atm_bdy_adjust_dynamics_speczone_tend [ACC_data_xfer]') end subroutine atm_bdy_adjust_dynamics_speczone_tend @@ -7424,7 +8140,6 @@ subroutine atm_bdy_adjust_dynamics_relaxzone_tend( config, tend, state, diag, me vertexDegree = vertexDegree_ptr MPAS_ACC_TIMER_START('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') - !$acc enter data copyin(tend_rho, tend_rt, rho_zz, theta_m, tend_ru, ru) !$acc enter data create(divergence1, divergence2, vorticity1, vorticity2) MPAS_ACC_TIMER_STOP('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') @@ -7572,9 +8287,7 @@ subroutine atm_bdy_adjust_dynamics_relaxzone_tend( config, tend, state, diag, me !$acc end parallel MPAS_ACC_TIMER_START('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') - !$acc exit data copyout(tend_rho, tend_rt, tend_ru) - !$acc exit data delete(rho_zz, theta_m, ru, & - !$acc divergence1, divergence2, vorticity1, vorticity2) + !$acc exit data delete(divergence1, divergence2, vorticity1, vorticity2) MPAS_ACC_TIMER_STOP('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') end subroutine atm_bdy_adjust_dynamics_relaxzone_tend @@ -7609,10 +8322,6 @@ subroutine atm_bdy_reset_speczone_values( state, diag, mesh, nVertLevels, & call mpas_pool_get_array(state, 'theta_m', theta_m, 2) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) call mpas_pool_get_array(diag, 'rtheta_base', rtheta_base) - - MPAS_ACC_TIMER_START('atm_bdy_reset_speczone_values [ACC_data_xfer]') - !$acc enter data copyin(rtheta_base, theta_m, rtheta_p) - MPAS_ACC_TIMER_STOP('atm_bdy_reset_speczone_values [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang worker @@ -7627,11 +8336,6 @@ subroutine atm_bdy_reset_speczone_values( state, diag, mesh, nVertLevels, & end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_bdy_reset_speczone_values [ACC_data_xfer]') - !$acc exit data copyout(theta_m, rtheta_p) & - !$acc delete(rtheta_base) - MPAS_ACC_TIMER_STOP('atm_bdy_reset_speczone_values [ACC_data_xfer]') - end subroutine atm_bdy_reset_speczone_values !------------------------------------------------------------------------- @@ -7721,8 +8425,7 @@ subroutine atm_bdy_adjust_scalars_work( scalars_new, scalars_driving, dt, dt_rk, !--- MPAS_ACC_TIMER_START('atm_bdy_adjust_scalars [ACC_data_xfer]') - !$acc enter data create(scalars_tmp) & - !$acc copyin(scalars_new) + !$acc enter data create(scalars_tmp) MPAS_ACC_TIMER_STOP('atm_bdy_adjust_scalars [ACC_data_xfer]') !$acc parallel default(present) @@ -7806,8 +8509,7 @@ subroutine atm_bdy_adjust_scalars_work( scalars_new, scalars_driving, dt, dt_rk, !$acc end parallel MPAS_ACC_TIMER_START('atm_bdy_adjust_scalars [ACC_data_xfer]') - !$acc exit data delete(scalars_tmp) & - !$acc copyout(scalars_new) + !$acc exit data delete(scalars_tmp) MPAS_ACC_TIMER_STOP('atm_bdy_adjust_scalars [ACC_data_xfer]') end subroutine atm_bdy_adjust_scalars_work @@ -7878,10 +8580,6 @@ subroutine atm_bdy_set_scalars_work( scalars_driving, scalars_new, & !--- - MPAS_ACC_TIMER_START('atm_bdy_set_scalars_work [ACC_data_xfer]') - !$acc enter data copyin(scalars_new) - MPAS_ACC_TIMER_STOP('atm_bdy_set_scalars_work [ACC_data_xfer]') - !$acc parallel default(present) !$acc loop gang worker do iCell = cellSolveStart, cellSolveEnd ! threaded over cells @@ -7902,10 +8600,6 @@ subroutine atm_bdy_set_scalars_work( scalars_driving, scalars_new, & end do ! updates now in temp storage !$acc end parallel - - MPAS_ACC_TIMER_START('atm_bdy_set_scalars_work [ACC_data_xfer]') - !$acc exit data copyout(scalars_new) - MPAS_ACC_TIMER_STOP('atm_bdy_set_scalars_work [ACC_data_xfer]') end subroutine atm_bdy_set_scalars_work @@ -7975,16 +8669,6 @@ subroutine summarize_timestep(domain) nVertLevels = nVertLevels_ptr num_scalars = num_scalars_ptr - MPAS_ACC_TIMER_START('summarize_timestep [ACC_data_xfer]') - if (config_print_detailed_minmax_vel) then - !$acc enter data copyin(w,u,v) - else if (config_print_global_minmax_vel) then - !$acc enter data copyin(w,u) - end if - if (config_print_global_minmax_sca) then - !$acc enter data copyin(scalars) - end if - MPAS_ACC_TIMER_STOP('summarize_timestep [ACC_data_xfer]') if (config_print_detailed_minmax_vel) then call mpas_log_write('') @@ -8343,17 +9027,6 @@ subroutine summarize_timestep(domain) end if - MPAS_ACC_TIMER_START('summarize_timestep [ACC_data_xfer]') - if (config_print_detailed_minmax_vel) then - !$acc exit data delete(w,u,v) - else if (config_print_global_minmax_vel) then - !$acc exit data delete(w,u) - end if - if (config_print_global_minmax_sca) then - !$acc exit data delete(scalars) - end if - MPAS_ACC_TIMER_STOP('summarize_timestep [ACC_data_xfer]') - end subroutine summarize_timestep end module atm_time_integration diff --git a/src/core_atmosphere/mpas_atm_core.F b/src/core_atmosphere/mpas_atm_core.F index f7d04a1f0c..d1b9931c6c 100644 --- a/src/core_atmosphere/mpas_atm_core.F +++ b/src/core_atmosphere/mpas_atm_core.F @@ -43,7 +43,8 @@ function atm_core_init(domain, startTimeStamp) result(ierr) use mpas_atm_dimensions, only : mpas_atm_set_dims use mpas_atm_diagnostics_manager, only : mpas_atm_diag_setup use mpas_atm_threading, only : mpas_atm_threading_init - use atm_time_integration, only : mpas_atm_dynamics_init + use atm_time_integration, only : mpas_atm_dynamics_init, & + mpas_atm_pre_dynamics_h2d, mpas_atm_post_dynamics_d2h use mpas_timer, only : mpas_timer_start, mpas_timer_stop use mpas_attlist, only : mpas_modify_att use mpas_string_utils, only : mpas_string_replace @@ -509,6 +510,7 @@ subroutine atm_mpas_init_block(dminfo, stream_manager, block, mesh, dt) call mpas_pool_get_dimension(block % dimensions, 'edgeSolveThreadStart', edgeSolveThreadStart) call mpas_pool_get_dimension(block % dimensions, 'edgeSolveThreadEnd', edgeSolveThreadEnd) + call mpas_atm_pre_computesolvediag_h2d(block) !$OMP PARALLEL DO do thread=1,nThreads if (.not. config_do_restart .or. (config_do_restart .and. config_do_DAcycling)) then @@ -527,6 +529,7 @@ subroutine atm_mpas_init_block(dminfo, stream_manager, block, mesh, dt) edgeThreadStart(thread), edgeThreadEnd(thread)) end do !$OMP END PARALLEL DO + call mpas_atm_post_computesolvediag_d2h(block) deallocate(ke_vertex) deallocate(ke_edge) diff --git a/src/core_atmosphere/physics/mpas_atmphys_interface.F b/src/core_atmosphere/physics/mpas_atmphys_interface.F index 71e46dfcd2..afd8ed2810 100644 --- a/src/core_atmosphere/physics/mpas_atmphys_interface.F +++ b/src/core_atmosphere/physics/mpas_atmphys_interface.F @@ -12,6 +12,16 @@ module mpas_atmphys_interface use mpas_atmphys_constants use mpas_atmphys_vars + use mpas_timer + +#ifdef MPAS_OPENACC +#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) +#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) +#else +#define MPAS_ACC_TIMER_START(X) +#define MPAS_ACC_TIMER_STOP(X) +#endif + implicit none private @@ -588,6 +598,7 @@ subroutine microphysics_from_MPAS(configs,mesh,state,time_lev,diag,diag_physics, call mpas_pool_get_array(mesh,'zgrid',zgrid) call mpas_pool_get_array(mesh,'zz' ,zz ) + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(diag,'exner' ,exner ) call mpas_pool_get_array(diag,'pressure_base',pressure_b) call mpas_pool_get_array(diag,'pressure_p' ,pressure_p) @@ -595,11 +606,14 @@ subroutine microphysics_from_MPAS(configs,mesh,state,time_lev,diag,diag_physics, call mpas_pool_get_array(state,'rho_zz' ,rho_zz ,time_lev) call mpas_pool_get_array(state,'theta_m',theta_m,time_lev) call mpas_pool_get_array(state,'w' ,w ,time_lev) + !$acc update host(exner, pressure_b, pressure_p, rho_zz, theta_m, w) call mpas_pool_get_dimension(state,'index_qv',index_qv) call mpas_pool_get_dimension(state,'index_qc',index_qc) call mpas_pool_get_dimension(state,'index_qr',index_qr) call mpas_pool_get_array(state,'scalars',scalars,time_lev) + !$acc update host(scalars) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') qv => scalars(index_qv,:,:) qc => scalars(index_qc,:,:) qr => scalars(index_qr,:,:) @@ -1040,6 +1054,12 @@ subroutine microphysics_to_MPAS(configs,mesh,state,time_lev,diag,diag_physics,te case default end select mp_tend_select + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') + !$acc update device(exner, exner_b, pressure_b, pressure_p, rtheta_b) + !$acc update device(rtheta_p, rho_zz, theta_m, scalars) + !$acc update device(rt_diabatic_tend) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + end subroutine microphysics_to_MPAS !================================================================================================================= diff --git a/src/core_atmosphere/physics/mpas_atmphys_todynamics.F b/src/core_atmosphere/physics/mpas_atmphys_todynamics.F index 284b072851..290cc56330 100644 --- a/src/core_atmosphere/physics/mpas_atmphys_todynamics.F +++ b/src/core_atmosphere/physics/mpas_atmphys_todynamics.F @@ -13,11 +13,20 @@ module mpas_atmphys_todynamics use mpas_atm_dimensions use mpas_atmphys_constants, only: R_d,R_v,degrad + use mpas_timer implicit none private public:: physics_get_tend +#ifdef MPAS_OPENACC +#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) +#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) +#else +#define MPAS_ACC_TIMER_START(X) +#define MPAS_ACC_TIMER_STOP(X) +#endif + !Interface between the physics parameterizations and the non-hydrostatic dynamical core. !Laura D. Fowler (send comments to laura@ucar.edu). @@ -127,12 +136,14 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s call mpas_pool_get_config(configs,'config_radt_lw_scheme' ,radt_lw_scheme ) call mpas_pool_get_config(configs,'config_radt_sw_scheme' ,radt_sw_scheme ) + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(state,'theta_m' ,theta_m,1) call mpas_pool_get_array(state,'scalars' ,scalars,1) call mpas_pool_get_array(state,'rho_zz' ,mass,2 ) call mpas_pool_get_array(diag ,'rho_edge',mass_edge) call mpas_pool_get_array(diag ,'tend_u_phys',tend_u_phys) + !$acc update self(theta_m, scalars, mass, mass_edge) call mpas_pool_get_dimension(state,'index_qv',index_qv) call mpas_pool_get_dimension(state,'index_qc',index_qc) call mpas_pool_get_dimension(state,'index_qr',index_qr) @@ -170,6 +181,8 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s call mpas_pool_get_array(tend_physics,'rthratensw',rthratensw) call mpas_pool_get_array(tend,'scalars_tend',tend_scalars) +!$acc update self(tend_scalars) ! Probably not needed +MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') !initialize the tendency for the potential temperature and all scalars due to PBL, convection, @@ -219,6 +232,10 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s tend_th,tend_rtheta_physics,tend_scalars,tend_ru_physics,tend_u_phys, & exchange_halo_group) +MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') +!$acc update device(tend_scalars) +MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + !clean up any pointers that were allocated with zero size before the call to physics_get_tend_work: if(size(rucuten) == 0 ) deallocate(rucuten ) if(size(rvcuten) == 0 ) deallocate(rvcuten ) From ac59b66b8d2b13e9603f6bd1870486651ffcf0c1 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Thu, 22 May 2025 13:56:08 -0600 Subject: [PATCH 02/30] Fixing bug associated with rho_zz_2 not being copied out at the end of dynamics --- src/core_atmosphere/dynamics/mpas_atm_time_integration.F | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index de3565637b..18dbe434ea 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -1281,7 +1281,7 @@ subroutine mpas_atm_post_dynamics_d2h(domain) call mpas_pool_get_array(state, 'rho_zz', rho_zz_1, 1) !$acc exit data copyout(rho_zz_1) call mpas_pool_get_array(state, 'rho_zz', rho_zz_2, 2) - !$acc exit data delete(rho_zz_2) + !$acc exit data copyout(rho_zz_2) call mpas_pool_get_array(state, 'scalars', scalars_1, 1) !$acc exit data copyout(scalars_1) call mpas_pool_get_array(state, 'scalars', scalars_2, 2) From 2ccf89d5137f5b3a869f34f4cc49ccc024332267 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Fri, 13 Jun 2025 18:19:15 -0600 Subject: [PATCH 03/30] Moving some OpenACC data movements to subroutines --- src/core_atmosphere/dynamics/mpas_atm_iau.F | 57 ++++++++--- .../dynamics/mpas_atm_time_integration.F | 7 +- .../physics/mpas_atmphys_interface.F | 96 ++++++++++++++++--- .../physics/mpas_atmphys_todynamics.F | 81 ++++++++++++---- 4 files changed, 194 insertions(+), 47 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_iau.F b/src/core_atmosphere/dynamics/mpas_atm_iau.F index b380e3c0e8..d5999e18c7 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_iau.F +++ b/src/core_atmosphere/dynamics/mpas_atm_iau.F @@ -5,6 +5,15 @@ ! Additional copyright and license information can be found in the LICENSE file ! distributed with this code, or at http://mpas-dev.github.com/license.html ! + + #ifdef MPAS_OPENACC + #define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) + #define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) + #else + #define MPAS_ACC_TIMER_START(X) + #define MPAS_ACC_TIMER_STOP(X) + #endif + module mpas_atm_iau use mpas_derived_types @@ -15,17 +24,7 @@ module mpas_atm_iau use mpas_log, only : mpas_log_write use mpas_timer - !public :: atm_compute_iau_coef, atm_add_tend_anal_incr - - - #ifdef MPAS_OPENACC - #define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) - #define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) - #else - #define MPAS_ACC_TIMER_START(X) - #define MPAS_ACC_TIMER_STOP(X) - #endif - + !public :: atm_compute_iau_coef, atm_add_tend_anal_incr contains @@ -87,6 +86,39 @@ real (kind=RKIND) function atm_iau_coef(configs, itimestep, dt) result(wgt_iau) end if end function atm_iau_coef + +!================================================================================================== + subroutine update_d2h_pre_add_tend_anal_incr(configs,structs) +!================================================================================================== + + implicit none + + type (mpas_pool_type), intent(in) :: configs + type (mpas_pool_type), intent(inout) :: structs + + type (mpas_pool_type), pointer :: tend + type (mpas_pool_type), pointer :: state + type (mpas_pool_type), pointer :: diag + + real (kind=RKIND), dimension(:,:), pointer :: rho_edge, rho_zz, theta_m + real(kind=RKIND),dimension(:,:,:), pointer :: scalars, tend_scalars + + call mpas_pool_get_subpool(structs, 'tend', tend) + call mpas_pool_get_subpool(structs, 'state', state) + call mpas_pool_get_subpool(structs, 'diag', diag) + + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') + call mpas_pool_get_array(state, 'theta_m', theta_m, 1) + call mpas_pool_get_array(state, 'scalars', scalars, 1) + call mpas_pool_get_array(state, 'rho_zz', rho_zz, 2) + call mpas_pool_get_array(diag , 'rho_edge', rho_edge) + !$acc update self(theta_m, scalars, rho_zz, rho_edge) + + call mpas_pool_get_array(tend, 'scalars_tend', tend_scalars) + !$acc update self(tend_scalars) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + + end subroutine update_d2h_pre_add_tend_anal_incr !================================================================================================== subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, tend_rtheta, tend_rho) @@ -148,7 +180,6 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten call mpas_pool_get_array(state, 'scalars', scalars, 1) call mpas_pool_get_array(state, 'rho_zz', rho_zz, 2) call mpas_pool_get_array(diag , 'rho_edge', rho_edge) - !$acc update self(theta_m, scalars, rho_zz, rho_edge) call mpas_pool_get_dimension(state, 'moist_start', moist_start) call mpas_pool_get_dimension(state, 'moist_end', moist_end) @@ -161,8 +192,6 @@ subroutine atm_add_tend_anal_incr (configs, structs, itimestep, dt, tend_ru, ten ! call mpas_pool_get_array(tend, 'rho_zz', tend_rho) ! call mpas_pool_get_array(tend, 'theta_m', tend_theta) call mpas_pool_get_array(tend, 'scalars_tend', tend_scalars) - !$acc update self(tend_scalars) - MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(tend_iau, 'theta', theta_amb) call mpas_pool_get_array(tend_iau, 'rho', rho_amb) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 18dbe434ea..4bcc59cb66 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -29,6 +29,7 @@ module atm_time_integration #ifdef DO_PHYSICS use mpas_atmphys_driver_microphysics + use mpas_atmphys_interface, only: update_d2h_pre_microphysics, update_h2d_post_microphysics use mpas_atmphys_todynamics use mpas_atmphys_utilities #endif @@ -1985,6 +1986,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_timer_stop('atm_compute_moist_coefficients') #ifdef DO_PHYSICS + call update_d2h_pre_physics_get_tend(block % configs, state, diag, tend) call mpas_timer_start('physics_get_tend') rk_step = 1 dynamics_substep = 1 @@ -1993,6 +1995,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) tend_ru_physics, tend_rtheta_physics, tend_rho_physics, & exchange_halo_group ) call mpas_timer_stop('physics_get_tend') + call update_h2d_post_physics_get_tend(block % configs, state, diag, tend) #else #ifndef MPAS_CAM_DYCORE ! @@ -2008,6 +2011,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! IAU - Incremental Analysis Update ! if (trim(config_IAU_option) /= 'off') then + call update_d2h_pre_add_tend_anal_incr(block % configs, block % structs) call atm_add_tend_anal_incr(block % configs, block % structs, itimestep, dt, & tend_ru_physics, tend_rtheta_physics, tend_rho_physics) end if @@ -2614,7 +2618,6 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) !$acc update self(rthdynten) MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') - !NOTE: The calculation of the tendency due to horizontal and vertical advection for the water vapor mixing ratio !requires that the subroutine atm_advance_scalars_mono was called on the third Runge Kutta step, so that a halo !update for the scalars at time_levs(1) is applied. A halo update for the scalars at time_levs(2) is done above. @@ -2643,6 +2646,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') !call microphysics schemes: if (trim(config_microp_scheme) /= 'off') then + call update_d2h_pre_microphysics( block % configs, state, diag, 2) call mpas_timer_start('microphysics') !$OMP PARALLEL DO do thread=1,nThreads @@ -2651,6 +2655,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) end do !$OMP END PARALLEL DO call mpas_timer_stop('microphysics') + call update_h2d_post_microphysics( block % configs, state, diag, tend, 2) end if ! diff --git a/src/core_atmosphere/physics/mpas_atmphys_interface.F b/src/core_atmosphere/physics/mpas_atmphys_interface.F index afd8ed2810..4e65cffd7a 100644 --- a/src/core_atmosphere/physics/mpas_atmphys_interface.F +++ b/src/core_atmosphere/physics/mpas_atmphys_interface.F @@ -6,13 +6,6 @@ ! distributed with this code, or at http://mpas-dev.github.com/license.html ! !================================================================================================================= - module mpas_atmphys_interface - use mpas_kind_types - use mpas_pool_routines - - use mpas_atmphys_constants - use mpas_atmphys_vars - use mpas_timer #ifdef MPAS_OPENACC #define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) @@ -22,6 +15,13 @@ module mpas_atmphys_interface #define MPAS_ACC_TIMER_STOP(X) #endif + module mpas_atmphys_interface + use mpas_kind_types + use mpas_pool_routines + + use mpas_atmphys_constants + use mpas_atmphys_vars + use mpas_timer implicit none private @@ -555,6 +555,40 @@ subroutine MPAS_to_physics(configs,mesh,state,time_lev,diag,diag_physics,its,ite end subroutine MPAS_to_physics +!================================================================================================================= + subroutine update_d2h_pre_microphysics(configs,state,diag,time_lev) +!================================================================================================================= + +!input variables: + type(mpas_pool_type),intent(in):: configs + type(mpas_pool_type),intent(in):: state + type(mpas_pool_type),intent(in):: diag + + integer:: time_lev + +!local pointers: + real(kind=RKIND),dimension(:,:),pointer :: exner,pressure_b,w + real(kind=RKIND),dimension(:,:),pointer :: rho_zz,theta_m,pressure_p + real(kind=RKIND),dimension(:,:,:),pointer:: scalars + + + MPAS_ACC_TIMER_START('update_d2h_pre_microphysics [ACC_data_xfer]') + call mpas_pool_get_array(diag,'exner' ,exner ) + call mpas_pool_get_array(diag,'pressure_base',pressure_b) + call mpas_pool_get_array(diag,'pressure_p' ,pressure_p) + + call mpas_pool_get_array(state,'rho_zz' ,rho_zz ,time_lev) + call mpas_pool_get_array(state,'theta_m',theta_m,time_lev) + call mpas_pool_get_array(state,'w' ,w ,time_lev) + !$acc update host(exner, pressure_b, pressure_p, rho_zz, theta_m, w) + + call mpas_pool_get_array(state,'scalars',scalars,time_lev) + !$acc update host(scalars) + + MPAS_ACC_TIMER_STOP('update_d2h_pre_microphysics [ACC_data_xfer]') + +end subroutine update_d2h_pre_microphysics + !================================================================================================================= subroutine microphysics_from_MPAS(configs,mesh,state,time_lev,diag,diag_physics,tend_physics,its,ite) !================================================================================================================= @@ -598,7 +632,6 @@ subroutine microphysics_from_MPAS(configs,mesh,state,time_lev,diag,diag_physics, call mpas_pool_get_array(mesh,'zgrid',zgrid) call mpas_pool_get_array(mesh,'zz' ,zz ) - MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(diag,'exner' ,exner ) call mpas_pool_get_array(diag,'pressure_base',pressure_b) call mpas_pool_get_array(diag,'pressure_p' ,pressure_p) @@ -606,14 +639,11 @@ subroutine microphysics_from_MPAS(configs,mesh,state,time_lev,diag,diag_physics, call mpas_pool_get_array(state,'rho_zz' ,rho_zz ,time_lev) call mpas_pool_get_array(state,'theta_m',theta_m,time_lev) call mpas_pool_get_array(state,'w' ,w ,time_lev) - !$acc update host(exner, pressure_b, pressure_p, rho_zz, theta_m, w) call mpas_pool_get_dimension(state,'index_qv',index_qv) call mpas_pool_get_dimension(state,'index_qc',index_qc) call mpas_pool_get_dimension(state,'index_qr',index_qr) - call mpas_pool_get_array(state,'scalars',scalars,time_lev) - !$acc update host(scalars) - MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + call mpas_pool_get_array(state,'scalars',scalars,time_lev) qv => scalars(index_qv,:,:) qc => scalars(index_qc,:,:) qr => scalars(index_qr,:,:) @@ -1054,13 +1084,49 @@ subroutine microphysics_to_MPAS(configs,mesh,state,time_lev,diag,diag_physics,te case default end select mp_tend_select - MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') + end subroutine microphysics_to_MPAS + + !================================================================================================================= + subroutine update_h2d_post_microphysics(configs,state,diag,tend,time_lev) +!================================================================================================================= + +!input variables: + type(mpas_pool_type),intent(in):: configs + type(mpas_pool_type),intent(in):: state + type(mpas_pool_type),intent(in):: diag + type(mpas_pool_type),intent(inout):: tend + + + integer:: time_lev + +!local pointers: + real(kind=RKIND),dimension(:,:),pointer :: exner,exner_b,pressure_b,rtheta_p,rtheta_b + real(kind=RKIND),dimension(:,:),pointer :: rho_zz,theta_m,pressure_p + real(kind=RKIND),dimension(:,:,:),pointer:: scalars + real(kind=RKIND),dimension(:,:),pointer :: rt_diabatic_tend + + call mpas_pool_get_array(diag,'exner' ,exner ) + call mpas_pool_get_array(diag,'exner_base' ,exner_b ) + call mpas_pool_get_array(diag,'pressure_base',pressure_b) + call mpas_pool_get_array(diag,'pressure_p' ,pressure_p) + call mpas_pool_get_array(diag,'rtheta_base' ,rtheta_b ) + call mpas_pool_get_array(diag,'rtheta_p' ,rtheta_p ) + + call mpas_pool_get_array(state,'rho_zz' ,rho_zz ,time_lev) + call mpas_pool_get_array(state,'theta_m',theta_m,time_lev) + + call mpas_pool_get_array(state,'scalars',scalars,time_lev) + + call mpas_pool_get_array(tend,'rt_diabatic_tend',rt_diabatic_tend) + + + MPAS_ACC_TIMER_START('update_h2d_post_microphysics [ACC_data_xfer]') !$acc update device(exner, exner_b, pressure_b, pressure_p, rtheta_b) !$acc update device(rtheta_p, rho_zz, theta_m, scalars) !$acc update device(rt_diabatic_tend) - MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + MPAS_ACC_TIMER_STOP('update_h2d_post_microphysics [ACC_data_xfer]') - end subroutine microphysics_to_MPAS +end subroutine update_h2d_post_microphysics !================================================================================================================= end module mpas_atmphys_interface diff --git a/src/core_atmosphere/physics/mpas_atmphys_todynamics.F b/src/core_atmosphere/physics/mpas_atmphys_todynamics.F index 290cc56330..2cb94a7ba5 100644 --- a/src/core_atmosphere/physics/mpas_atmphys_todynamics.F +++ b/src/core_atmosphere/physics/mpas_atmphys_todynamics.F @@ -6,6 +6,15 @@ ! distributed with this code, or at http://mpas-dev.github.com/license.html ! !================================================================================================================= + +#ifdef MPAS_OPENACC +#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) +#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) +#else +#define MPAS_ACC_TIMER_START(X) +#define MPAS_ACC_TIMER_STOP(X) +#endif + module mpas_atmphys_todynamics use mpas_kind_types use mpas_pool_routines @@ -17,15 +26,7 @@ module mpas_atmphys_todynamics implicit none private - public:: physics_get_tend - -#ifdef MPAS_OPENACC -#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) -#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) -#else -#define MPAS_ACC_TIMER_START(X) -#define MPAS_ACC_TIMER_STOP(X) -#endif + public:: physics_get_tend, update_d2h_pre_physics_get_tend, update_h2d_post_physics_get_tend !Interface between the physics parameterizations and the non-hydrostatic dynamical core. @@ -69,6 +70,40 @@ end subroutine halo_exchange_routine contains + +!================================================================================================================= + subroutine update_d2h_pre_physics_get_tend(configs,state,diag,tend) +!================================================================================================================= + +!input variables: + type(mpas_pool_type),intent(in):: configs + type(mpas_pool_type),intent(in):: state + type(mpas_pool_type),intent(in):: diag + type(mpas_pool_type),intent(in):: tend + +!local variables: + real(kind=RKIND),dimension(:,:),pointer:: mass ! time level 2 rho_zz + real(kind=RKIND),dimension(:,:),pointer:: mass_edge ! diag rho_edge + real(kind=RKIND),dimension(:,:),pointer:: theta_m ! time level 1 + real(kind=RKIND),dimension(:,:,:),pointer:: scalars + + real(kind=RKIND),dimension(:,:),pointer:: tend_u_phys + real(kind=RKIND),dimension(:,:,:),pointer:: tend_scalars + + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') + call mpas_pool_get_array(state,'theta_m' ,theta_m,1) + call mpas_pool_get_array(state,'scalars' ,scalars,1) + call mpas_pool_get_array(state,'rho_zz' ,mass,2 ) + call mpas_pool_get_array(diag ,'rho_edge',mass_edge) + call mpas_pool_get_array(diag ,'tend_u_phys',tend_u_phys) + + !$acc update self(theta_m, scalars, mass, mass_edge) + + call mpas_pool_get_array(tend,'scalars_tend',tend_scalars) + !$acc update self(tend_scalars) ! Probably not needed + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + + end subroutine update_d2h_pre_physics_get_tend !================================================================================================================= subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_step,dynamics_substep, & @@ -136,14 +171,12 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s call mpas_pool_get_config(configs,'config_radt_lw_scheme' ,radt_lw_scheme ) call mpas_pool_get_config(configs,'config_radt_sw_scheme' ,radt_sw_scheme ) - MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') call mpas_pool_get_array(state,'theta_m' ,theta_m,1) call mpas_pool_get_array(state,'scalars' ,scalars,1) call mpas_pool_get_array(state,'rho_zz' ,mass,2 ) call mpas_pool_get_array(diag ,'rho_edge',mass_edge) call mpas_pool_get_array(diag ,'tend_u_phys',tend_u_phys) - !$acc update self(theta_m, scalars, mass, mass_edge) call mpas_pool_get_dimension(state,'index_qv',index_qv) call mpas_pool_get_dimension(state,'index_qc',index_qc) call mpas_pool_get_dimension(state,'index_qr',index_qr) @@ -181,8 +214,6 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s call mpas_pool_get_array(tend_physics,'rthratensw',rthratensw) call mpas_pool_get_array(tend,'scalars_tend',tend_scalars) -!$acc update self(tend_scalars) ! Probably not needed -MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') !initialize the tendency for the potential temperature and all scalars due to PBL, convection, @@ -232,10 +263,6 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s tend_th,tend_rtheta_physics,tend_scalars,tend_ru_physics,tend_u_phys, & exchange_halo_group) -MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') -!$acc update device(tend_scalars) -MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') - !clean up any pointers that were allocated with zero size before the call to physics_get_tend_work: if(size(rucuten) == 0 ) deallocate(rucuten ) if(size(rvcuten) == 0 ) deallocate(rvcuten ) @@ -262,6 +289,26 @@ subroutine physics_get_tend(block,mesh,state,diag,tend,tend_physics,configs,rk_s end subroutine physics_get_tend + !================================================================================================================= + subroutine update_h2d_post_physics_get_tend(configs,state,diag,tend) +!================================================================================================================= + +!input variables: + type(mpas_pool_type),intent(in):: configs + type(mpas_pool_type),intent(in):: state + type(mpas_pool_type),intent(in):: diag + type(mpas_pool_type),intent(in):: tend + +!local variables: + real(kind=RKIND),dimension(:,:,:),pointer:: tend_scalars + + MPAS_ACC_TIMER_START('atm_srk3: physics ACC_data_xfer') + call mpas_pool_get_array(tend,'scalars_tend',tend_scalars) + !$acc update device(tend_scalars) + MPAS_ACC_TIMER_STOP('atm_srk3: physics ACC_data_xfer') + + end subroutine update_h2d_post_physics_get_tend + !================================================================================================================= subroutine physics_get_tend_work( & block,mesh,nCells,nEdges,nCellsSolve,nEdgesSolve,rk_step,dynamics_substep, & From f55d6509a8f30f4c7325a366460dbf7510e893e4 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Thu, 3 Jul 2025 17:10:06 -0600 Subject: [PATCH 04/30] Removing acc data xfer timers for device variables using create/delete --- .../dynamics/mpas_atm_time_integration.F | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 4bcc59cb66..1a221b22ef 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -3251,9 +3251,7 @@ subroutine atm_compute_vert_imp_coefs_work(nCells, moist_start, moist_end, dts, real (kind=RKIND) :: dtseps, c2, qtotal, rcv real (kind=RKIND), dimension( nVertLevels ) :: b_tri, c_tri - MPAS_ACC_TIMER_START('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') !$acc enter data create(b_tri, c_tri) - MPAS_ACC_TIMER_STOP('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') ! set coefficients dtseps = .5*dts*(1.+epssm) @@ -3333,9 +3331,7 @@ subroutine atm_compute_vert_imp_coefs_work(nCells, moist_start, moist_end, dts, end do ! loop over cells !$acc end parallel - MPAS_ACC_TIMER_START('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') - !$acc exit data copyout(b_tri, c_tri) - MPAS_ACC_TIMER_STOP('atm_compute_vert_imp_coefs_work [ACC_data_xfer]') + !$acc exit data delete(b_tri, c_tri) end subroutine atm_compute_vert_imp_coefs_work @@ -8144,9 +8140,7 @@ subroutine atm_bdy_adjust_dynamics_relaxzone_tend( config, tend, state, diag, me divdamp_coef = divdamp_coef_ptr vertexDegree = vertexDegree_ptr - MPAS_ACC_TIMER_START('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') !$acc enter data create(divergence1, divergence2, vorticity1, vorticity2) - MPAS_ACC_TIMER_STOP('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') ! First, Rayleigh damping terms for ru, rtheta_m and rho_zz !$acc parallel default(present) @@ -8291,9 +8285,7 @@ subroutine atm_bdy_adjust_dynamics_relaxzone_tend( config, tend, state, diag, me end do ! end of loop over edges !$acc end parallel - MPAS_ACC_TIMER_START('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') !$acc exit data delete(divergence1, divergence2, vorticity1, vorticity2) - MPAS_ACC_TIMER_STOP('atm_bdy_adjust_dynamics_relaxzone_tend [ACC_data_xfer]') end subroutine atm_bdy_adjust_dynamics_relaxzone_tend @@ -8429,9 +8421,7 @@ subroutine atm_bdy_adjust_scalars_work( scalars_new, scalars_driving, dt, dt_rk, integer :: iCell, iEdge, iScalar, i, k, cell1, cell2 !--- - MPAS_ACC_TIMER_START('atm_bdy_adjust_scalars [ACC_data_xfer]') !$acc enter data create(scalars_tmp) - MPAS_ACC_TIMER_STOP('atm_bdy_adjust_scalars [ACC_data_xfer]') !$acc parallel default(present) !$acc loop gang worker @@ -8513,9 +8503,7 @@ subroutine atm_bdy_adjust_scalars_work( scalars_new, scalars_driving, dt, dt_rk, end do !$acc end parallel - MPAS_ACC_TIMER_START('atm_bdy_adjust_scalars [ACC_data_xfer]') !$acc exit data delete(scalars_tmp) - MPAS_ACC_TIMER_STOP('atm_bdy_adjust_scalars [ACC_data_xfer]') end subroutine atm_bdy_adjust_scalars_work From cf373f52fdb286304500ae54ce06e3e01583e394 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Mon, 7 Jul 2025 19:10:23 -0600 Subject: [PATCH 05/30] Using acc declare create for rho_zz_int and corresponding cleanup --- .../dynamics/mpas_atm_time_integration.F | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 1a221b22ef..b26c4344db 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -86,6 +86,7 @@ end subroutine halo_exchange_routine !$acc declare create(s_max_arr, s_min_arr) !$acc declare create(flux_array, flux_upwind_tmp_arr) !$acc declare create(flux_tmp_arr, wdtn_arr) + !$acc declare create(rho_zz_int) real (kind=RKIND), dimension(:,:), allocatable :: ru_driving_tend ! regional_MPAS addition real (kind=RKIND), dimension(:,:), allocatable :: rt_driving_tend ! regional_MPAS addition @@ -5011,12 +5012,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge ! The transport will maintain this positive definite solution and optionally, shape preservation (monotonicity). - MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - if (local_advance_density) then - !$acc enter data copyin(rho_zz_int) - end if - MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc parallel !$acc loop gang worker @@ -5607,9 +5602,6 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge end do ! loop over scalars MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - if (local_advance_density) then - !$acc exit data copyout(rho_zz_int) - end if !$acc exit data delete(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') From 34d4c8c389a4987e9eaf89e105d29e749e37f0a4 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Mon, 7 Jul 2025 19:12:52 -0600 Subject: [PATCH 06/30] Removing atm_advance_scalars_mono ACC_data_xfer timers around create/delete --- src/core_atmosphere/dynamics/mpas_atm_time_integration.F | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index b26c4344db..6684f46a09 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -5096,9 +5096,7 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge end if - MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') !$acc enter data create(scale_arr) - MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') do iScalar = 1, num_scalars @@ -5601,9 +5599,7 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge end do ! loop over scalars - MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') !$acc exit data delete(scale_arr) - MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') end subroutine atm_advance_scalars_mono_work From 9d3c3fcd6398b1f523a5d502dd859c01e87467fc Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Tue, 8 Jul 2025 16:31:09 -0600 Subject: [PATCH 07/30] Simplifying OpenACC data transfers around the call to mpas_reconstruct_2d This commit introduces two OpenACC data transfer routines, mpas_reconstruct_2d_h2d and mpas_reconstruct_2d_d2h in order to remove the data transfers from the mpas_reconstruct_2d routine itself. This also allows us to remove extraneous data movements within the atm_srk3 routine. mpas_reconstruct_2d_h2d and mpas_reconstruct_2d_d2h are called before and after the call to mpas_reconstruct in atm_mpas_init_block. And the reconstructed vector fields are also copied to and from the device before and after every dynamics call in mpas_atm_pre_dynamics_h2d and mpas_atm_post_dynamics_d2h. --- .../dynamics/mpas_atm_time_integration.F | 40 ++++++ src/core_atmosphere/mpas_atm_core.F | 4 + src/operators/mpas_vector_reconstruction.F | 115 ++++++++++++++++-- 3 files changed, 148 insertions(+), 11 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 6684f46a09..b24d4cebfd 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -846,6 +846,7 @@ subroutine mpas_atm_pre_dynamics_h2d(domain) #ifdef MPAS_OPENACC + type (mpas_pool_type), pointer :: mesh type (mpas_pool_type), pointer :: state type (mpas_pool_type), pointer :: diag type (mpas_pool_type), pointer :: tend @@ -879,6 +880,10 @@ subroutine mpas_atm_pre_dynamics_h2d(domain) real (kind=RKIND), dimension(:,:,:), pointer :: scalars_1, scalars_2 real (kind=RKIND), dimension(:,:), pointer :: ruAvg, wwAvg, ruAvg_split, wwAvg_split + integer, pointer :: nCells_ptr + integer :: nCells + real (kind=RKIND), dimension(:,:), pointer :: uReconstructZonal, uReconstructMeridional, uReconstructX, uReconstructY, uReconstructZ + real (kind=RKIND), dimension(:,:), pointer :: tend_ru, tend_rt, tend_rho, tend_rw, rt_diabatic_tend real (kind=RKIND), dimension(:,:), pointer :: tend_u_euler, tend_w_euler, tend_theta_euler real(kind=RKIND), dimension(:,:), pointer :: tend_w_pgf, tend_w_buoy @@ -892,11 +897,13 @@ subroutine mpas_atm_pre_dynamics_h2d(domain) real (kind=RKIND), dimension(:,:,:), pointer :: lbc_scalars, lbc_tend_scalars + nullify(mesh) nullify(state) nullify(diag) nullify(tend) nullify(tend_physics) nullify(lbc) + call mpas_pool_get_subpool(domain % blocklist % structs, 'mesh', mesh) call mpas_pool_get_subpool(domain % blocklist % structs, 'state', state) call mpas_pool_get_subpool(domain % blocklist % structs, 'diag', diag) call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tend) @@ -1006,6 +1013,19 @@ subroutine mpas_atm_pre_dynamics_h2d(domain) call mpas_pool_get_array(diag, 'wwAvg_split', wwAvg_split) !$acc enter data copyin(wwAvg_split) + call mpas_pool_get_dimension(mesh, 'nCellsSolve', nCells_ptr) + nCells = nCells_ptr + call mpas_pool_get_array(diag, 'uReconstructX', uReconstructX) + !$acc enter data create(uReconstructX(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructY', uReconstructY) + !$acc enter data create(uReconstructY(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructZ', uReconstructZ) + !$acc enter data create(uReconstructZ(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructZonal', uReconstructZonal) + !$acc enter data create(uReconstructZonal(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructMeridional', uReconstructMeridional) + !$acc enter data create(uReconstructMeridional(:,1:nCells)) + call mpas_pool_get_array(state, 'u', u_1, 1) !$acc enter data copyin(u_1) call mpas_pool_get_array(state, 'u', u_2, 2) @@ -1108,6 +1128,7 @@ subroutine mpas_atm_post_dynamics_d2h(domain) #ifdef MPAS_OPENACC + type (mpas_pool_type), pointer :: mesh type (mpas_pool_type), pointer :: state type (mpas_pool_type), pointer :: diag type (mpas_pool_type), pointer :: tend @@ -1141,6 +1162,10 @@ subroutine mpas_atm_post_dynamics_d2h(domain) real (kind=RKIND), dimension(:,:,:), pointer :: scalars_1, scalars_2 real (kind=RKIND), dimension(:,:), pointer :: ruAvg, wwAvg, ruAvg_split, wwAvg_split + integer, pointer :: nCells_ptr + integer :: nCells + real (kind=RKIND), dimension(:,:), pointer :: uReconstructZonal, uReconstructMeridional, uReconstructX, uReconstructY, uReconstructZ + real (kind=RKIND), dimension(:,:), pointer :: tend_ru, tend_rt, tend_rho, tend_rw, rt_diabatic_tend real (kind=RKIND), dimension(:,:), pointer :: tend_u_euler, tend_w_euler, tend_theta_euler real(kind=RKIND), dimension(:,:), pointer :: tend_w_pgf, tend_w_buoy @@ -1154,11 +1179,13 @@ subroutine mpas_atm_post_dynamics_d2h(domain) real (kind=RKIND), dimension(:,:,:), pointer :: lbc_scalars, lbc_tend_scalars + nullify(mesh) nullify(state) nullify(diag) nullify(tend) nullify(tend_physics) nullify(lbc) + call mpas_pool_get_subpool(domain % blocklist % structs, 'mesh', mesh) call mpas_pool_get_subpool(domain % blocklist % structs, 'state', state) call mpas_pool_get_subpool(domain % blocklist % structs, 'diag', diag) call mpas_pool_get_subpool(domain % blocklist % structs, 'tend', tend) @@ -1268,6 +1295,19 @@ subroutine mpas_atm_post_dynamics_d2h(domain) call mpas_pool_get_array(diag, 'wwAvg_split', wwAvg_split) !$acc exit data copyout(wwAvg_split) + call mpas_pool_get_dimension(mesh, 'nCellsSolve', nCells_ptr) + nCells = nCells_ptr + call mpas_pool_get_array(diag, 'uReconstructX', uReconstructX) + !$acc exit data copyout(uReconstructX(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructY', uReconstructY) + !$acc exit data copyout(uReconstructY(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructZ', uReconstructZ) + !$acc exit data copyout(uReconstructZ(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructZonal', uReconstructZonal) + !$acc exit data copyout(uReconstructZonal(:,1:nCells)) + call mpas_pool_get_array(diag, 'uReconstructMeridional', uReconstructMeridional) + !$acc exit data copyout(uReconstructMeridional(:,1:nCells)) + call mpas_pool_get_array(state, 'u', u_1, 1) !$acc exit data copyout(u_1) call mpas_pool_get_array(state, 'u', u_2, 2) diff --git a/src/core_atmosphere/mpas_atm_core.F b/src/core_atmosphere/mpas_atm_core.F index d1b9931c6c..087cfc2f2c 100644 --- a/src/core_atmosphere/mpas_atm_core.F +++ b/src/core_atmosphere/mpas_atm_core.F @@ -543,6 +543,8 @@ subroutine atm_mpas_init_block(dminfo, stream_manager, block, mesh, dt) call mpas_pool_get_array(diag, 'uReconstructZ', uReconstructZ) call mpas_pool_get_array(diag, 'uReconstructZonal', uReconstructZonal) call mpas_pool_get_array(diag, 'uReconstructMeridional', uReconstructMeridional) + call mpas_reconstruct_2d_h2d(mesh, u, uReconstructX, uReconstructY, uReconstructZ, & + uReconstructZonal, uReconstructMeridional) call mpas_reconstruct(mesh, u, & uReconstructX, & uReconstructY, & @@ -550,6 +552,8 @@ subroutine atm_mpas_init_block(dminfo, stream_manager, block, mesh, dt) uReconstructZonal, & uReconstructMeridional & ) + call mpas_reconstruct_2d_d2h(mesh, u, uReconstructX, uReconstructY, uReconstructZ, & + uReconstructZonal, uReconstructMeridional) #ifdef DO_PHYSICS !proceed with initialization of physics parameterization if moist_physics is set to true: diff --git a/src/operators/mpas_vector_reconstruction.F b/src/operators/mpas_vector_reconstruction.F index 605da9cd6d..2aa4ca2aee 100644 --- a/src/operators/mpas_vector_reconstruction.F +++ b/src/operators/mpas_vector_reconstruction.F @@ -258,16 +258,6 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon call mpas_pool_get_config(meshPool, 'on_a_sphere', on_a_sphere) - MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]') - ! Only use sections needed, nCells may be all cells or only non-halo cells - !$acc enter data copyin(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), & - !$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells)) - !$acc enter data copyin(u(:,:)) - !$acc enter data create(uReconstructX(:,1:nCells),uReconstructY(:,1:nCells), & - !$acc uReconstructZ(:,1:nCells),uReconstructZonal(:,1:nCells), & - !$acc uReconstructMeridional(:,1:nCells)) - MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]') - ! loop over cell centers !$omp do schedule(runtime) !$acc parallel default(present) @@ -337,6 +327,109 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon !$omp end do end if + end subroutine mpas_reconstruct_2d!}}} + + + subroutine mpas_reconstruct_2d_h2d(meshPool, u, uReconstructX, uReconstructY, uReconstructZ, uReconstructZonal, uReconstructMeridional, includeHalos)!{{{ + + implicit none + + type (mpas_pool_type), intent(in) :: meshPool !< Input: Mesh information + real (kind=RKIND), dimension(:,:), intent(in) :: u !< Input: Velocity field on edges + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructX !< Output: X Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructY !< Output: Y Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZ !< Output: Z Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZonal !< Output: Zonal Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructMeridional !< Output: Meridional Component of velocity reconstructed to cell centers + logical, optional, intent(in) :: includeHalos !< Input: Optional logical that allows reconstruction over halo regions + + logical :: includeHalosLocal + integer, dimension(:,:), pointer :: edgesOnCell + integer, dimension(:), pointer :: nEdgesOnCell + integer :: nCells + integer, pointer :: nCells_ptr + real(kind=RKIND), dimension(:), pointer :: latCell, lonCell + real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct + + if ( present(includeHalos) ) then + includeHalosLocal = includeHalos + else + includeHalosLocal = .false. + end if + + ! stored arrays used during compute procedure + call mpas_pool_get_array(meshPool, 'coeffs_reconstruct', coeffs_reconstruct) + + ! temporary variables + call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell) + call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell) + call mpas_pool_get_array(meshPool, 'latCell', latCell) + call mpas_pool_get_array(meshPool, 'lonCell', lonCell) + + if ( includeHalosLocal ) then + call mpas_pool_get_dimension(meshPool, 'nCells', nCells_ptr) + else + call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCells_ptr) + end if + nCells = nCells_ptr + + MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]') + ! Only use sections needed, nCells may be all cells or only non-halo cells + !$acc enter data copyin(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), & + !$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells)) + !$acc enter data copyin(u(:,:)) + !$acc enter data create(uReconstructX(:,1:nCells),uReconstructY(:,1:nCells), & + !$acc uReconstructZ(:,1:nCells),uReconstructZonal(:,1:nCells), & + !$acc uReconstructMeridional(:,1:nCells)) + MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]') + + end subroutine mpas_reconstruct_2d_h2d + + + + subroutine mpas_reconstruct_2d_d2h(meshPool, u, uReconstructX, uReconstructY, uReconstructZ, uReconstructZonal, uReconstructMeridional, includeHalos)!{{{ + + implicit none + + type (mpas_pool_type), intent(in) :: meshPool !< Input: Mesh information + real (kind=RKIND), dimension(:,:), intent(in) :: u !< Input: Velocity field on edges + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructX !< Output: X Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructY !< Output: Y Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZ !< Output: Z Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructZonal !< Output: Zonal Component of velocity reconstructed to cell centers + real (kind=RKIND), dimension(:,:), intent(out) :: uReconstructMeridional !< Output: Meridional Component of velocity reconstructed to cell centers + logical, optional, intent(in) :: includeHalos !< Input: Optional logical that allows reconstruction over halo regions + + logical :: includeHalosLocal + integer, dimension(:,:), pointer :: edgesOnCell + integer, dimension(:), pointer :: nEdgesOnCell + integer :: nCells + integer, pointer :: nCells_ptr + real(kind=RKIND), dimension(:), pointer :: latCell, lonCell + real (kind=RKIND), dimension(:,:,:), pointer :: coeffs_reconstruct + + if ( present(includeHalos) ) then + includeHalosLocal = includeHalos + else + includeHalosLocal = .false. + end if + + ! stored arrays used during compute procedure + call mpas_pool_get_array(meshPool, 'coeffs_reconstruct', coeffs_reconstruct) + + ! temporary variables + call mpas_pool_get_array(meshPool, 'nEdgesOnCell', nEdgesOnCell) + call mpas_pool_get_array(meshPool, 'edgesOnCell', edgesOnCell) + call mpas_pool_get_array(meshPool, 'latCell', latCell) + call mpas_pool_get_array(meshPool, 'lonCell', lonCell) + + if ( includeHalosLocal ) then + call mpas_pool_get_dimension(meshPool, 'nCells', nCells_ptr) + else + call mpas_pool_get_dimension(meshPool, 'nCellsSolve', nCells_ptr) + end if + nCells = nCells_ptr + MPAS_ACC_TIMER_START('mpas_reconstruct_2d [ACC_data_xfer]') !$acc exit data delete(coeffs_reconstruct(:,:,1:nCells),nEdgesOnCell(1:nCells), & !$acc edgesOnCell(:,1:nCells),latCell(1:nCells),lonCell(1:nCells)) @@ -346,7 +439,7 @@ subroutine mpas_reconstruct_2d(meshPool, u, uReconstructX, uReconstructY, uRecon !$acc uReconstructMeridional(:,1:nCells)) MPAS_ACC_TIMER_STOP('mpas_reconstruct_2d [ACC_data_xfer]') - end subroutine mpas_reconstruct_2d!}}} + end subroutine mpas_reconstruct_2d_d2h !*********************************************************************** From 4b7137d9c29719afb006961060f259230449e993 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Thu, 14 Aug 2025 10:31:53 -0600 Subject: [PATCH 08/30] Need to copyout u_2 and w_2 at the end of dynamics --- src/core_atmosphere/dynamics/mpas_atm_time_integration.F | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index b24d4cebfd..5cb15624f2 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -1311,11 +1311,11 @@ subroutine mpas_atm_post_dynamics_d2h(domain) call mpas_pool_get_array(state, 'u', u_1, 1) !$acc exit data copyout(u_1) call mpas_pool_get_array(state, 'u', u_2, 2) - !$acc exit data delete(u_2) + !$acc exit data copyout(u_2) call mpas_pool_get_array(state, 'w', w_1, 1) !$acc exit data copyout(w_1) call mpas_pool_get_array(state, 'w', w_2, 2) - !$acc exit data delete(w_2) + !$acc exit data copyout(w_2) call mpas_pool_get_array(state, 'theta_m', theta_m_1, 1) !$acc exit data copyout(theta_m_1) ! use values from atm_init_coupled_diagnostics call mpas_pool_get_array(state, 'theta_m', theta_m_2, 2) From 7be031569748d5046ff9b980439a740ca6db425c Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 11:50:32 -0600 Subject: [PATCH 09/30] Add data movement for some fields under the mpas_halo_groups --- src/framework/mpas_halo.F | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 4ab8817c23..a2c75327a3 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -280,6 +280,13 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) call refactor_lists(domain, groupName, iErr) + !$acc enter data copyin(newGroup) + !$acc enter data copyin(newGroup % fields(:), newGroup % sendBuf(:)) + do i = 1, newGroup % nFields + !$acc enter data copyin(newGroup % fields(i)) + !$acc enter data copyin(newGroup % fields(i) % sendListSrc(:,:,:)) + end do + end subroutine mpas_halo_exch_group_complete @@ -350,6 +357,7 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % fields(i) % compactSendLists) deallocate(cursor % fields(i) % compactRecvLists) deallocate(cursor % fields(i) % nSendLists) + !$acc exit data delete(cursor % fields(i) % sendListSrc(:,:,:)) deallocate(cursor % fields(i) % sendListSrc) deallocate(cursor % fields(i) % sendListDst) deallocate(cursor % fields(i) % packOffsets) @@ -357,7 +365,10 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % fields(i) % recvListSrc) deallocate(cursor % fields(i) % recvListDst) deallocate(cursor % fields(i) % unpackOffsets) + !$acc exit data delete(cursor % fields(i)) end do + ! Use finalize here in-case the copyins in ..._complete increment the reference counter + !$acc exit data finalize delete(cursor % fields(:)) deallocate(cursor % fields) deallocate(cursor % groupPackOffsets) deallocate(cursor % groupSendNeighbors) @@ -368,10 +379,12 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % groupToFieldRecvIdx) deallocate(cursor % groupRecvOffsets) deallocate(cursor % groupRecvCounts) + !$acc exit data delete(cursor % sendBuf(:)) deallocate(cursor % sendBuf) deallocate(cursor % recvBuf) deallocate(cursor % sendRequests) deallocate(cursor % recvRequests) + !$acc exit data delete(cursor) deallocate(cursor) end subroutine mpas_halo_exch_group_destroy From 912612837f3129f373755738fa747c600708d7aa Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 12:30:01 -0600 Subject: [PATCH 10/30] Add a data region and acc kernels to the 2D packing code --- src/framework/mpas_halo.F | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index a2c75327a3..aef3759d88 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -676,6 +676,15 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! + + ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of + ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' + !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) + !$acc data copyin(group % fields(i) % r2arr(:,:)) & + !$acc copyin(group % fields(i) % sendListSrc(:,:,:), group % fields(i) % sendListDst(:,:,:), group % fields(i) % nSendLists(:,:), group % fields(i) % packOffsets(:)) + + ! Kernels is good enough, use default present to force a run-time error if programmer forgot something + !$acc kernels default(present) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList @@ -688,6 +697,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end kernels + !$acc end data + !$acc end data ! ! Packing code for 3-d real-valued fields From 38ee36f52411c18aafc33cd634862ec9b9b15cb2 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 13:31:41 -0600 Subject: [PATCH 11/30] Add the update directives that should have been part of the last commit This commit does work and matches the previous results! --- src/framework/mpas_halo.F | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index aef3759d88..778e27649c 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -680,6 +680,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) + !$acc update device(group % sendBuf(:)) !$acc data copyin(group % fields(i) % r2arr(:,:)) & !$acc copyin(group % fields(i) % sendListSrc(:,:,:), group % fields(i) % sendListDst(:,:,:), group % fields(i) % nSendLists(:,:), group % fields(i) % packOffsets(:)) @@ -699,6 +700,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do !$acc end kernels !$acc end data + !$acc update host(group % sendBuf(:)) !$acc end data ! From fa984fe04622d693309c46c24f0908c649907763 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 13:34:26 -0600 Subject: [PATCH 12/30] Comment out data present region, see if this causes an error --- src/framework/mpas_halo.F | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 778e27649c..3b77b8fbb1 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -679,7 +679,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' - !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) + ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) !$acc update device(group % sendBuf(:)) !$acc data copyin(group % fields(i) % r2arr(:,:)) & !$acc copyin(group % fields(i) % sendListSrc(:,:,:), group % fields(i) % sendListDst(:,:,:), group % fields(i) % nSendLists(:,:), group % fields(i) % packOffsets(:)) @@ -701,7 +701,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) !$acc end kernels !$acc end data !$acc update host(group % sendBuf(:)) - !$acc end data + ! !$acc end data ! ! Packing code for 3-d real-valued fields From ec7991150f459231231a03d9890dc8229954baef Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 14:24:48 -0600 Subject: [PATCH 13/30] Expand the data managed on the GPU for the halo exchange NOTE: The last commit was successful! --- src/framework/mpas_halo.F | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 3b77b8fbb1..44736740af 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -280,11 +280,21 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) call refactor_lists(domain, groupName, iErr) + ! Always copy in the main data member first !$acc enter data copyin(newGroup) - !$acc enter data copyin(newGroup % fields(:), newGroup % sendBuf(:)) + ! Then the data in the members of the type + !$acc enter data copyin(newGroup % recvBuf(:), newGroup % sendBuf(:)) + !$acc enter data copyin(newGroup % fields(:)) do i = 1, newGroup % nFields !$acc enter data copyin(newGroup % fields(i)) + !$acc enter data copyin(newGroup % fields(i) % nSendLists(:,:)) + !$acc enter data copyin(newGroup % fields(i) % packOffsets(:)) !$acc enter data copyin(newGroup % fields(i) % sendListSrc(:,:,:)) + !$acc enter data copyin(newGroup % fields(i) % sendListDst(:,:,:)) + !$acc enter data copyin(newGroup % fields(i) % nRecvLists(:,:)) + !$acc enter data copyin(newGroup % fields(i) % unpackOffsets(:)) + !$acc enter data copyin(newGroup % fields(i) % recvListSrc(:,:,:)) + !$acc enter data copyin(newGroup % fields(i) % recvListDst(:,:,:)) end do end subroutine mpas_halo_exch_group_complete @@ -356,14 +366,21 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % fields(i) % compactHaloInfo) deallocate(cursor % fields(i) % compactSendLists) deallocate(cursor % fields(i) % compactRecvLists) + !$acc exit data delete(cursor % fields(i) % nSendLists(:,:)) deallocate(cursor % fields(i) % nSendLists) !$acc exit data delete(cursor % fields(i) % sendListSrc(:,:,:)) deallocate(cursor % fields(i) % sendListSrc) + !$acc exit data delete(cursor % fields(i) % sendListDst(:,:,:)) deallocate(cursor % fields(i) % sendListDst) + !$acc exit data delete(cursor % fields(i) % packOffsets(:)) deallocate(cursor % fields(i) % packOffsets) + !$acc exit data delete(cursor % fields(i) % nRecvLists(:,:)) deallocate(cursor % fields(i) % nRecvLists) + !$acc exit data delete(cursor % fields(i) % recvListSrc(:,:,:)) deallocate(cursor % fields(i) % recvListSrc) + !$acc exit data delete(cursor % fields(i) % recvListDst(:,:,:)) deallocate(cursor % fields(i) % recvListDst) + !$acc exit data delete(cursor % fields(i) % unpackOffsets(:)) deallocate(cursor % fields(i) % unpackOffsets) !$acc exit data delete(cursor % fields(i)) end do @@ -381,10 +398,12 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % groupRecvCounts) !$acc exit data delete(cursor % sendBuf(:)) deallocate(cursor % sendBuf) + !$acc exit data delete(cursor % recvBuf(:)) deallocate(cursor % recvBuf) deallocate(cursor % sendRequests) deallocate(cursor % recvRequests) - !$acc exit data delete(cursor) + ! Finalize here as well, just in-case + !$acc exit data finalize delete(cursor) deallocate(cursor) end subroutine mpas_halo_exch_group_destroy @@ -680,9 +699,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) - !$acc update device(group % sendBuf(:)) - !$acc data copyin(group % fields(i) % r2arr(:,:)) & - !$acc copyin(group % fields(i) % sendListSrc(:,:,:), group % fields(i) % sendListDst(:,:,:), group % fields(i) % nSendLists(:,:), group % fields(i) % packOffsets(:)) + !$acc data copyin(group % fields(i) % r2arr(:,:)) ! Kernels is good enough, use default present to force a run-time error if programmer forgot something !$acc kernels default(present) From e98a57e226824224e3bd5da8ba8441bdaeb96bb7 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 15:01:58 -0600 Subject: [PATCH 14/30] Remove the OpenACC management of recvBuf Last commit had differences from the baseline. It's either this, or the change dropping 'update device(group % sendBuf(:)' in the last commit --- src/framework/mpas_halo.F | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 44736740af..d6d7b503db 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -283,7 +283,8 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) ! Always copy in the main data member first !$acc enter data copyin(newGroup) ! Then the data in the members of the type - !$acc enter data copyin(newGroup % recvBuf(:), newGroup % sendBuf(:)) + ! !$acc enter data copyin(newGroup % recvBuf(:), newGroup % sendBuf(:)) + !$acc enter data copyin(newGroup % sendBuf(:)) !$acc enter data copyin(newGroup % fields(:)) do i = 1, newGroup % nFields !$acc enter data copyin(newGroup % fields(i)) @@ -398,7 +399,7 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % groupRecvCounts) !$acc exit data delete(cursor % sendBuf(:)) deallocate(cursor % sendBuf) - !$acc exit data delete(cursor % recvBuf(:)) + ! !$acc exit data delete(cursor % recvBuf(:)) deallocate(cursor % recvBuf) deallocate(cursor % sendRequests) deallocate(cursor % recvRequests) From 4ee67d55a3051f90871158b99535f5f5d81835ca Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 15:25:45 -0600 Subject: [PATCH 15/30] Add update host(sendBuf) back, address answer diff Last commit still had answer differences --- src/framework/mpas_halo.F | 1 + 1 file changed, 1 insertion(+) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index d6d7b503db..bf9003c87a 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -700,6 +700,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) + !$acc update device(group % sendBuf(:)) !$acc data copyin(group % fields(i) % r2arr(:,:)) ! Kernels is good enough, use default present to force a run-time error if programmer forgot something From 40aecd68c7fbb8ecf89a473722dfae52a460d8bd Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 15:56:47 -0600 Subject: [PATCH 16/30] Expand to other packing kernels, only update sendBuf after packing finishes --- src/framework/mpas_halo.F | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index bf9003c87a..f84dfeb1dc 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -672,9 +672,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) + !$acc data copyin(group % fields(i) % r1arr(:)) ! ! Pack send buffer for all neighbors for current field ! + !$acc kernels default(present) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList @@ -685,6 +687,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end kernels + !$acc end data ! ! Packing code for 2-d real-valued fields @@ -700,7 +704,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) - !$acc update device(group % sendBuf(:)) !$acc data copyin(group % fields(i) % r2arr(:,:)) ! Kernels is good enough, use default present to force a run-time error if programmer forgot something @@ -719,7 +722,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do !$acc end kernels !$acc end data - !$acc update host(group % sendBuf(:)) ! !$acc end data ! @@ -728,10 +730,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r3arr, group % fields(i) % timeLevel) + !$acc data copyin(group % fields(i) % r3arr(:,:,:)) ! ! Pack send buffer for all neighbors for current field ! + !$acc kernels default(present) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList @@ -747,10 +751,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do + !$acc end kernels + !$acc end data end select end if end do + !$acc update host(group % sendBuf(:)) ! ! Initiate non-blocking sends to all neighbors From dc6eae73560928f540668fb59e1a448195612f34 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 16:44:49 -0600 Subject: [PATCH 17/30] Change to simple integers to access the buffers and the field arrays This should make the dependency analysis easier on the compiler. NOTE: The last commit succeeded and had no diffs after 1 timestep compared to a reference run! --- src/framework/mpas_halo.F | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index f84dfeb1dc..2244f20fb8 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -564,6 +564,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Local variables integer :: i, bufstart, bufend + integer :: idxBuf, idxArr integer :: dim1, dim2 integer :: i1, i2, j, iNeighbor, iReq integer :: iHalo, iEndp @@ -681,8 +682,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do iHalo = 1, nHalos do j = 1, maxNSendList if (j <= nSendLists(iHalo,iEndp)) then - group % sendBuf(packOffsets(iEndp) + sendListDst(j,iHalo,iEndp)) = & - group % fields(i) % r1arr(sendListSrc(j,iHalo,iEndp)) + bufIdx = packOffsets(iEndp) + sendListDst(j,iHalo,iEndp) + arrIdx = sendListSrc(j,iHalo,iEndp) + group % sendBuf(bufIdx) = group % fields(i) % r1arr(arrIdx) end if end do end do @@ -713,8 +715,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do j = 1, maxNSendList do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then - group % sendBuf(packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1) = & - group % fields(i) % r2arr(i1, sendListSrc(j,iHalo,iEndp)) + bufIdx = packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1 + arrIdx = sendListSrc(j,iHalo,iEndp) + group % sendBuf(bufIdx) = group % fields(i) % r2arr(i1,arrIdx) end if end do end do @@ -742,9 +745,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then - group % sendBuf(packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & - + dim1*(i2-1) + i1) = & - group % fields(i) % r3arr(i1, i2, sendListSrc(j,iHalo,iEndp)) + bufIdx = packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & + + dim1*(i2-1) + i1 + arrIdx = sendListSrc(j,iHalo,iEndp) + group % sendBuf(bufIdx) = group % fields(i) % r3arr(i1,i2,arrIdx) end if end do end do @@ -827,8 +831,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do iHalo = 1, nHalos do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then - group % fields(i) % r1arr(recvListDst(j,iHalo,iEndp)) = & - group % recvBuf(unpackOffsets(iEndp) + recvListSrc(j,iHalo,iEndp)) + arrIdx = recvListDst(j,iHalo,iEndp) + bufIdx = unpackOffsets(iEndp) + recvListSrc(j,iHalo,iEndp) + group % fields(i) % r1arr(arrIdx) = group % recvBuf(bufIdx) end if end do end do @@ -844,8 +849,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do j = 1, maxNRecvList do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then - group % fields(i) % r2arr(i1, recvListDst(j,iHalo,iEndp)) = & - group % recvBuf(unpackOffsets(iEndp) + dim1 * (recvListSrc(j,iHalo,iEndp) - 1) + i1) + arrIdx = recvListDst(j,iHalo,iEndp) + bufIdx = unpackOffsets(iEndp) + dim1 * (recvListSrc(j,iHalo,iEndp) - 1) + i1 + group % fields(i) % r2arr(i1, arrIdx) = group % recvBuf(bufIdx) end if end do end do @@ -863,9 +869,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then - group % fields(i) % r3arr(i1, i2, recvListDst(j,iHalo,iEndp)) = & - group % recvBuf(unpackOffsets(iEndp) + dim1*dim2*(recvListSrc(j,iHalo,iEndp) - 1) & - + dim1*(i2-1) + i1) + arrIdx = recvListDst(j,iHalo,iEndp) + bufIdx = unpackOffsets(iEndp) + dim1*dim2*(recvListSrc(j,iHalo,iEndp) - 1) & + + dim1*(i2-1) + i1 + group % fields(i) % r3arr(i1, i2, arrIdx) = group % recvBuf(bufIdx) end if end do end do From 130d75e09886710190023c5916103f8c39769f60 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 18:29:38 -0600 Subject: [PATCH 18/30] Add kernels to unpacking loops and use a data present region to try to force GPUDirect MPI NOTE: The last commit ran successfully and matched previous 1 step results --- src/framework/mpas_halo.F | 57 ++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 2244f20fb8..4fc6ca0e76 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -283,8 +283,8 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) ! Always copy in the main data member first !$acc enter data copyin(newGroup) ! Then the data in the members of the type - ! !$acc enter data copyin(newGroup % recvBuf(:), newGroup % sendBuf(:)) - !$acc enter data copyin(newGroup % sendBuf(:)) + !$acc enter data copyin(newGroup % recvBuf(:), newGroup % sendBuf(:)) + ! !$acc enter data copyin(newGroup % sendBuf(:)) !$acc enter data copyin(newGroup % fields(:)) do i = 1, newGroup % nFields !$acc enter data copyin(newGroup % fields(i)) @@ -399,7 +399,7 @@ subroutine mpas_halo_exch_group_destroy(domain, groupName, iErr) deallocate(cursor % groupRecvCounts) !$acc exit data delete(cursor % sendBuf(:)) deallocate(cursor % sendBuf) - ! !$acc exit data delete(cursor % recvBuf(:)) + !$acc exit data delete(cursor % recvBuf(:)) deallocate(cursor % recvBuf) deallocate(cursor % sendRequests) deallocate(cursor % recvRequests) @@ -623,6 +623,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) #endif rank = group % fields(1) % compactHaloInfo(8) + !$acc data present(group % recvBuf(:), group % sendBuf(:)) ! ! Initiate non-blocking MPI receives for all neighbors @@ -682,9 +683,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do iHalo = 1, nHalos do j = 1, maxNSendList if (j <= nSendLists(iHalo,iEndp)) then - bufIdx = packOffsets(iEndp) + sendListDst(j,iHalo,iEndp) - arrIdx = sendListSrc(j,iHalo,iEndp) - group % sendBuf(bufIdx) = group % fields(i) % r1arr(arrIdx) + idxBuf = packOffsets(iEndp) + sendListDst(j,iHalo,iEndp) + idxArr = sendListSrc(j,iHalo,iEndp) + group % sendBuf(idxBuf) = group % fields(i) % r1arr(idxArr) end if end do end do @@ -715,9 +716,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do j = 1, maxNSendList do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then - bufIdx = packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1 - arrIdx = sendListSrc(j,iHalo,iEndp) - group % sendBuf(bufIdx) = group % fields(i) % r2arr(i1,arrIdx) + idxBuf = packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1 + idxArr = sendListSrc(j,iHalo,iEndp) + group % sendBuf(idxBuf) = group % fields(i) % r2arr(i1,idxArr) end if end do end do @@ -745,10 +746,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then - bufIdx = packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & + idxBuf = packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & + dim1*(i2-1) + i1 - arrIdx = sendListSrc(j,iHalo,iEndp) - group % sendBuf(bufIdx) = group % fields(i) % r3arr(i1,i2,arrIdx) + idxArr = sendListSrc(j,iHalo,iEndp) + group % sendBuf(idxBuf) = group % fields(i) % r3arr(i1,i2,idxArr) end if end do end do @@ -761,7 +762,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end select end if end do - !$acc update host(group % sendBuf(:)) + ! !$acc update host(group % sendBuf(:)) ! ! Initiate non-blocking sends to all neighbors @@ -828,15 +829,18 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! + !$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then - arrIdx = recvListDst(j,iHalo,iEndp) - bufIdx = unpackOffsets(iEndp) + recvListSrc(j,iHalo,iEndp) - group % fields(i) % r1arr(arrIdx) = group % recvBuf(bufIdx) + idxArr = recvListDst(j,iHalo,iEndp) + idxBuf = unpackOffsets(iEndp) + recvListSrc(j,iHalo,iEndp) + group % fields(i) % r1arr(idxArr) = group % recvBuf(idxBuf) end if end do end do + !$acc end kernels + !$acc exit data copyout(group % fields(i) % r1arr(:)) ! ! Unpacking code for 2-d real-valued fields @@ -845,17 +849,20 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! + !$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then - arrIdx = recvListDst(j,iHalo,iEndp) - bufIdx = unpackOffsets(iEndp) + dim1 * (recvListSrc(j,iHalo,iEndp) - 1) + i1 - group % fields(i) % r2arr(i1, arrIdx) = group % recvBuf(bufIdx) + idxArr = recvListDst(j,iHalo,iEndp) + idxBuf = unpackOffsets(iEndp) + dim1 * (recvListSrc(j,iHalo,iEndp) - 1) + i1 + group % fields(i) % r2arr(i1, idxArr) = group % recvBuf(idxBuf) end if end do end do end do + !$acc end kernels + !$acc exit data copyout(group % fields(i) % r2arr(:,:)) ! ! Unpacking code for 3-d real-valued fields @@ -864,26 +871,32 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! + !$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then - arrIdx = recvListDst(j,iHalo,iEndp) - bufIdx = unpackOffsets(iEndp) + dim1*dim2*(recvListSrc(j,iHalo,iEndp) - 1) & + idxArr = recvListDst(j,iHalo,iEndp) + idxBuf = unpackOffsets(iEndp) + dim1*dim2*(recvListSrc(j,iHalo,iEndp) - 1) & + dim1*(i2-1) + i1 - group % fields(i) % r3arr(i1, i2, arrIdx) = group % recvBuf(bufIdx) + group % fields(i) % r3arr(i1, i2, idxArr) = group % recvBuf(idxBuf) end if end do end do end do end do + !$acc end kernels + !$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) end select end if end do end do + ! For the present(group % recvBuf(:), group % sendBuf(:)) + !$acc end data + ! ! Nullify array pointers - not necessary for correctness, but helpful when debugging ! to not leave pointers to what might later be incorrect targets From e27a75f3bec6be4eacd2566046af4e78536d3ab1 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 19:05:10 -0600 Subject: [PATCH 19/30] Change from data copyin regions to enter/exit directives for the r?arr variables Last run failed with CUDA_ERROR_ILLEGAL_ADDRESS, I think keeping these on the GPU would help! --- src/framework/mpas_halo.F | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 4fc6ca0e76..fb51f3a9db 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -674,7 +674,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) - !$acc data copyin(group % fields(i) % r1arr(:)) + ! !$acc data copyin(group % fields(i) % r1arr(:)) + !$acc enter data copyin(group % fields(i) % r1arr(:)) ! ! Pack send buffer for all neighbors for current field ! @@ -691,7 +692,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do !$acc end kernels - !$acc end data + ! !$acc end data ! ! Packing code for 2-d real-valued fields @@ -707,7 +708,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) - !$acc data copyin(group % fields(i) % r2arr(:,:)) + ! !$acc data copyin(group % fields(i) % r2arr(:,:)) + !$acc enter data copyin(group % fields(i) % r2arr(:,:)) ! Kernels is good enough, use default present to force a run-time error if programmer forgot something !$acc kernels default(present) @@ -725,7 +727,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do !$acc end kernels - !$acc end data + ! !$acc end data ! !$acc end data ! @@ -734,7 +736,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r3arr, group % fields(i) % timeLevel) - !$acc data copyin(group % fields(i) % r3arr(:,:,:)) + ! !$acc data copyin(group % fields(i) % r3arr(:,:,:)) + !$acc enter data copyin(group % fields(i) % r3arr(:,:,:)) ! ! Pack send buffer for all neighbors for current field @@ -757,7 +760,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do !$acc end kernels - !$acc end data + ! !$acc end data end select end if From 52b34507b04c246572ed71a4d703d05cc7bc2d21 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 19:50:34 -0600 Subject: [PATCH 20/30] Re-enable update host for sendBuf, add update device recvBuf Last commit gave me some big differences, let's see if this helps. If this helps, then that means I wasn't using GPU-aware MPI routines like I thought... --- src/framework/mpas_halo.F | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index fb51f3a9db..e576e6c13f 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -765,7 +765,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end select end if end do - ! !$acc update host(group % sendBuf(:)) + !$acc update host(group % sendBuf(:)) ! ! Initiate non-blocking sends to all neighbors From 4285b921c3f5d0fa243039f793f74df7ff5a5614 Mon Sep 17 00:00:00 2001 From: "G. Dylan Dickerson" Date: Wed, 7 May 2025 20:02:30 -0600 Subject: [PATCH 21/30] Remove update directives, use acc host_data use_device(...) near MPI calls instead Last commit still had answer differences. NOTE: This commit does too --- src/framework/mpas_halo.F | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index e576e6c13f..d3b77c780b 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -633,9 +633,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 !TO DO: how do we determine appropriate type here? + !$acc host_data use_device(group % recvBuf) call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & group % recvRequests(i), mpi_ierr) + !$acc end host_data else group % recvRequests(i) = MPI_REQUEST_NULL end if @@ -765,7 +767,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end select end if end do - !$acc update host(group % sendBuf(:)) ! ! Initiate non-blocking sends to all neighbors @@ -775,9 +776,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 !TO DO: how do we determine appropriate type here? + !$acc host_data use_device(group % sendBuf) call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & group % groupSendNeighbors(i), rank, comm, & group % sendRequests(i), mpi_ierr) + !$acc end host_data else group % sendRequests(i) = MPI_REQUEST_NULL end if From fda45c428b74831273b301e25731aaf136c2c549 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Wed, 6 Aug 2025 16:33:26 -0600 Subject: [PATCH 22/30] checkpoints: acc pack + cuda aware mpi working --- src/framework/mpas_halo.F | 91 +++++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index d3b77c780b..cb432b05ef 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -280,6 +280,12 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) call refactor_lists(domain, groupName, iErr) + if ( newGroup% nGroupSendNeighbors <=0 ) then + !call mpas_log_write('No send neighbors for halo exchange group '//trim(groupName)) + return + end if + + ! Always copy in the main data member first !$acc enter data copyin(newGroup) ! Then the data in the members of the type @@ -541,6 +547,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) use mpas_derived_types, only : domain_type, mpas_halo_group, MPAS_HALO_REAL, MPAS_LOG_CRIT use mpas_pool_routines, only : mpas_pool_get_array use mpas_log, only : mpas_log_write + use mpas_kind_types, only : RKIND ! Parameters #ifdef MPAS_USE_MPI_F08 @@ -588,7 +595,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) integer :: maxNRecvList integer, dimension(:,:,:), CONTIGUOUS pointer :: recvListSrc, recvListDst integer, dimension(:), CONTIGUOUS pointer :: unpackOffsets - + real (kind=RKIND), dimension(:), pointer :: sendBufptr, recvBufptr if (present(iErr)) then iErr = 0 @@ -611,6 +618,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) messageType=MPAS_LOG_CRIT) end if + if ( group% nGroupSendNeighbors <=0 ) then + !call mpas_log_write('group has no halo exchanges: '//trim(groupName)) + return + end if ! ! Get the rank of this task and the MPI communicator to use from the first field in ! the group; all fields should be using the same communicator, so this should not @@ -623,7 +634,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) #endif rank = group % fields(1) % compactHaloInfo(8) - !$acc data present(group % recvBuf(:), group % sendBuf(:)) + sendBufptr => group % sendBuf + recvBufptr => group % recvBuf + + !!!$acc data present(group % recvBuf(:), group % sendBuf(:)) + !$acc data present(sendBufptr,recvBufptr) ! ! Initiate non-blocking MPI receives for all neighbors @@ -633,8 +648,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 !TO DO: how do we determine appropriate type here? - !$acc host_data use_device(group % recvBuf) - call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & + ! !$acc host_data use_device(group % recvBuf) + ! call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & + ! group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & + ! group % recvRequests(i), mpi_ierr) + !$acc host_data use_device(recvBufptr) + call MPI_Irecv(recvBufptr(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & group % recvRequests(i), mpi_ierr) !$acc end host_data @@ -695,7 +714,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do !$acc end kernels ! !$acc end data - + !!$acc update device(group % sendBuf(:)) ! ! Packing code for 2-d real-valued fields ! @@ -731,7 +750,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) !$acc end kernels ! !$acc end data ! !$acc end data - + !!$acc update device(group % sendBuf(:)) ! ! Packing code for 3-d real-valued fields ! @@ -763,11 +782,25 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do !$acc end kernels ! !$acc end data + !!$acc update device(group % sendBuf(:)) end select end if end do + do i = 1, group % nFields + if (group % fields(i) % fieldType == MPAS_HALO_REAL) then + select case (group % fields(i) % nDims) + case (1) + !$acc exit data delete(group % fields(i) % r1arr(:)) + case (2) + !$acc exit data delete(group % fields(i) % r2arr(:,:)) + case (3) + !$acc exit data delete(group % fields(i) % r3arr(:,:,:)) + end select + end if + end do + ! ! Initiate non-blocking sends to all neighbors ! @@ -776,8 +809,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 !TO DO: how do we determine appropriate type here? - !$acc host_data use_device(group % sendBuf) - call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & + ! !$acc host_data use_device(group % sendBuf) + ! call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & + ! group % groupSendNeighbors(i), rank, comm, & + ! group % sendRequests(i), mpi_ierr) + !$acc host_data use_device(sendBufptr) + call MPI_Isend(sendBufptr(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & group % groupSendNeighbors(i), rank, comm, & group % sendRequests(i), mpi_ierr) !$acc end host_data @@ -835,7 +872,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc kernels default(present) + !$acc update host(group % recvBuf(:)) + !$acc wait + !!$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then @@ -845,8 +884,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end if end do end do - !$acc end kernels - !$acc exit data copyout(group % fields(i) % r1arr(:)) + !!$acc end kernels + !!$acc exit data copyout(group % fields(i) % r1arr(:)) ! ! Unpacking code for 2-d real-valued fields @@ -855,7 +894,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc kernels default(present) + !$acc update host(group % recvBuf(:)) + !$acc wait + !!$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList do i1 = 1, dim1 @@ -867,8 +908,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - !$acc exit data copyout(group % fields(i) % r2arr(:,:)) + !!$acc end kernels + !!$acc exit data copyout(group % fields(i) % r2arr(:,:)) ! ! Unpacking code for 3-d real-valued fields @@ -877,7 +918,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc kernels default(present) + !$acc update host(group % recvBuf(:)) + !$acc wait + !!$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList do i2 = 1, dim2 @@ -892,8 +935,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - !$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) + !!$acc end kernels + !!$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) end select end if @@ -903,6 +946,20 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! For the present(group % recvBuf(:), group % sendBuf(:)) !$acc end data + ! do i = 1, group % nFields + ! if (group % fields(i) % fieldType == MPAS_HALO_REAL) then + ! select case (group % fields(i) % nDims) + ! case (1) + ! !$acc exit data copyout(group % fields(i) % r1arr(:)) + ! case (2) + ! !$acc exit data copyout(group % fields(i) % r2arr(:,:)) + ! case (3) + ! !$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) + ! end select + ! end if + ! end do + + ! ! Nullify array pointers - not necessary for correctness, but helpful when debugging ! to not leave pointers to what might later be incorrect targets From 250c4d91290a4ecf9aa7e4c3c98ea90ff08ad07e Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Thu, 7 Aug 2025 10:10:55 -0600 Subject: [PATCH 23/30] seems to be working --- src/framework/mpas_halo.F | 67 +++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 27 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index cb432b05ef..9062bc4537 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -788,19 +788,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end if end do - do i = 1, group % nFields - if (group % fields(i) % fieldType == MPAS_HALO_REAL) then - select case (group % fields(i) % nDims) - case (1) - !$acc exit data delete(group % fields(i) % r1arr(:)) - case (2) - !$acc exit data delete(group % fields(i) % r2arr(:,:)) - case (3) - !$acc exit data delete(group % fields(i) % r3arr(:,:,:)) - end select - end if - end do - ! ! Initiate non-blocking sends to all neighbors ! @@ -872,9 +859,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc update host(group % recvBuf(:)) - !$acc wait - !!$acc kernels default(present) + !!$acc update host(group % recvBuf(:)) + !!$acc wait + !$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then @@ -884,7 +871,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end if end do end do - !!$acc end kernels + !$acc end kernels !!$acc exit data copyout(group % fields(i) % r1arr(:)) ! @@ -894,9 +881,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc update host(group % recvBuf(:)) - !$acc wait - !!$acc kernels default(present) + !!$acc update host(group % recvBuf(:)) + !!$acc wait + !$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList do i1 = 1, dim1 @@ -908,7 +895,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !!$acc end kernels + !$acc end kernels !!$acc exit data copyout(group % fields(i) % r2arr(:,:)) ! @@ -918,9 +905,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !$acc update host(group % recvBuf(:)) - !$acc wait - !!$acc kernels default(present) + !!$acc update host(group % recvBuf(:)) + !!$acc wait + !$acc kernels default(present) do iHalo = 1, nHalos do j = 1, maxNRecvList do i2 = 1, dim2 @@ -935,7 +922,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !!$acc end kernels + !$acc end kernels !!$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) end select @@ -943,9 +930,35 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do + do i = 1, group % nFields + if (group % fields(i) % fieldType == MPAS_HALO_REAL) then + select case (group % fields(i) % nDims) + case (1) + !$acc update self(group % fields(i) % r1arr(:)) + case (2) + !$acc update self(group % fields(i) % r2arr(:,:)) + case (3) + !$acc update self(group % fields(i) % r3arr(:,:,:)) + end select + end if + end do + + do i = 1, group % nFields + if (group % fields(i) % fieldType == MPAS_HALO_REAL) then + select case (group % fields(i) % nDims) + case (1) + !$acc exit data delete(group % fields(i) % r1arr(:)) + case (2) + !$acc exit data delete(group % fields(i) % r2arr(:,:)) + case (3) + !$acc exit data delete(group % fields(i) % r3arr(:,:,:)) + end select + end if + end do + ! For the present(group % recvBuf(:), group % sendBuf(:)) !$acc end data - + ! !$acc wait ! do i = 1, group % nFields ! if (group % fields(i) % fieldType == MPAS_HALO_REAL) then ! select case (group % fields(i) % nDims) @@ -958,7 +971,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! end select ! end if ! end do - + ! !$acc wait ! ! Nullify array pointers - not necessary for correctness, but helpful when debugging From 0d3709ee4016c94330e1a60b181fbadb8f7a7cce Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Thu, 7 Aug 2025 17:17:45 -0600 Subject: [PATCH 24/30] Optimized packing and unpacking loops. Adding timers and other cleanup --- src/framework/mpas_halo.F | 142 ++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 69 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 9062bc4537..7f96e0d397 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -17,6 +17,15 @@ !> communicating the halos of all fields in a group. ! !----------------------------------------------------------------------- + +#ifdef MPAS_OPENACC +#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) +#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) +#else +#define MPAS_ACC_TIMER_START(X) +#define MPAS_ACC_TIMER_STOP(X) +#endif + module mpas_halo implicit none @@ -281,9 +290,8 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) call refactor_lists(domain, groupName, iErr) if ( newGroup% nGroupSendNeighbors <=0 ) then - !call mpas_log_write('No send neighbors for halo exchange group '//trim(groupName)) return - end if + end if ! Always copy in the main data member first @@ -547,7 +555,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) use mpas_derived_types, only : domain_type, mpas_halo_group, MPAS_HALO_REAL, MPAS_LOG_CRIT use mpas_pool_routines, only : mpas_pool_get_array use mpas_log, only : mpas_log_write - use mpas_kind_types, only : RKIND + use mpas_timer, only : mpas_timer_start, mpas_timer_stop + ! Parameters #ifdef MPAS_USE_MPI_F08 @@ -595,12 +604,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) integer :: maxNRecvList integer, dimension(:,:,:), CONTIGUOUS pointer :: recvListSrc, recvListDst integer, dimension(:), CONTIGUOUS pointer :: unpackOffsets - real (kind=RKIND), dimension(:), pointer :: sendBufptr, recvBufptr + if (present(iErr)) then iErr = 0 end if + ! ! Find this halo exhange group in the list of groups ! @@ -618,10 +628,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) messageType=MPAS_LOG_CRIT) end if - if ( group% nGroupSendNeighbors <=0 ) then - !call mpas_log_write('group has no halo exchanges: '//trim(groupName)) + if ( group% nGroupSendNeighbors <= 0 ) then return - end if + end if + + call mpas_timer_start('full_halo_exch') ! ! Get the rank of this task and the MPI communicator to use from the first field in ! the group; all fields should be using the same communicator, so this should not @@ -634,11 +645,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) #endif rank = group % fields(1) % compactHaloInfo(8) - sendBufptr => group % sendBuf - recvBufptr => group % recvBuf - - !!!$acc data present(group % recvBuf(:), group % sendBuf(:)) - !$acc data present(sendBufptr,recvBufptr) + !$acc data present(group % recvBuf(:), group % sendBuf(:)) ! ! Initiate non-blocking MPI receives for all neighbors @@ -648,12 +655,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 !TO DO: how do we determine appropriate type here? - ! !$acc host_data use_device(group % recvBuf) - ! call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & - ! group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & - ! group % recvRequests(i), mpi_ierr) - !$acc host_data use_device(recvBufptr) - call MPI_Irecv(recvBufptr(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & + !$acc host_data use_device(group % recvBuf) + call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & group % recvRequests(i), mpi_ierr) !$acc end host_data @@ -695,14 +698,18 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) - ! !$acc data copyin(group % fields(i) % r1arr(:)) + MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') !$acc enter data copyin(group % fields(i) % r1arr(:)) + MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') ! ! Pack send buffer for all neighbors for current field ! - !$acc kernels default(present) + call mpas_timer_start('packing_halo_exch') + !$acc parallel default(present) + !$acc loop gang collapse(2) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos + !$acc loop vector do j = 1, maxNSendList if (j <= nSendLists(iHalo,iEndp)) then idxBuf = packOffsets(iEndp) + sendListDst(j,iHalo,iEndp) @@ -712,9 +719,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - ! !$acc end data - !!$acc update device(group % sendBuf(:)) + !$acc end parallel + call mpas_timer_stop('packing_halo_exch') + ! ! Packing code for 2-d real-valued fields ! @@ -725,18 +732,23 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Pack send buffer for all neighbors for current field ! - + ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) ! !$acc data copyin(group % fields(i) % r2arr(:,:)) + MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') !$acc enter data copyin(group % fields(i) % r2arr(:,:)) + MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + call mpas_timer_start('packing_halo_exch') ! Kernels is good enough, use default present to force a run-time error if programmer forgot something - !$acc kernels default(present) + !$acc parallel default(present) + !$acc loop gang collapse(3) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList + !$acc loop vector do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then idxBuf = packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1 @@ -747,27 +759,30 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - ! !$acc end data - ! !$acc end data - !!$acc update device(group % sendBuf(:)) + !$acc end parallel + call mpas_timer_stop('packing_halo_exch') + ! ! Packing code for 3-d real-valued fields ! case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r3arr, group % fields(i) % timeLevel) - ! !$acc data copyin(group % fields(i) % r3arr(:,:,:)) + MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') !$acc enter data copyin(group % fields(i) % r3arr(:,:,:)) + MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') ! ! Pack send buffer for all neighbors for current field ! - !$acc kernels default(present) + call mpas_timer_start('packing_halo_exch') + !$acc parallel default(present) + !$acc loop gang collapse(4) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos do j = 1, maxNSendList do i2 = 1, dim2 + !$acc loop vector do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then idxBuf = packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & @@ -780,9 +795,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - ! !$acc end data - !!$acc update device(group % sendBuf(:)) + !$acc end parallel + call mpas_timer_stop('packing_halo_exch') end select end if @@ -796,12 +810,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 !TO DO: how do we determine appropriate type here? - ! !$acc host_data use_device(group % sendBuf) - ! call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & - ! group % groupSendNeighbors(i), rank, comm, & - ! group % sendRequests(i), mpi_ierr) - !$acc host_data use_device(sendBufptr) - call MPI_Isend(sendBufptr(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & + !$acc host_data use_device(group % sendBuf) + call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & group % groupSendNeighbors(i), rank, comm, & group % sendRequests(i), mpi_ierr) !$acc end host_data @@ -859,10 +869,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !!$acc update host(group % recvBuf(:)) - !!$acc wait - !$acc kernels default(present) + call mpas_timer_start('unpacking_halo_exch') + !$acc parallel default(present) + !$acc loop gang do iHalo = 1, nHalos + !$acc loop vector do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then idxArr = recvListDst(j,iHalo,iEndp) @@ -871,8 +882,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end if end do end do - !$acc end kernels - !!$acc exit data copyout(group % fields(i) % r1arr(:)) + !$acc end parallel + call mpas_timer_stop('unpacking_halo_exch') ! ! Unpacking code for 2-d real-valued fields @@ -881,11 +892,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !!$acc update host(group % recvBuf(:)) - !!$acc wait - !$acc kernels default(present) + call mpas_timer_start('unpacking_halo_exch') + !$acc parallel default(present) + !$acc loop gang do iHalo = 1, nHalos + !$acc loop worker do j = 1, maxNRecvList + !$acc loop vector do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then idxArr = recvListDst(j,iHalo,iEndp) @@ -895,8 +908,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - !!$acc exit data copyout(group % fields(i) % r2arr(:,:)) + !$acc end parallel + call mpas_timer_stop('unpacking_halo_exch') ! ! Unpacking code for 3-d real-valued fields @@ -905,11 +918,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! ! Unpack recv buffer from all neighbors for current field ! - !!$acc update host(group % recvBuf(:)) - !!$acc wait - !$acc kernels default(present) + call mpas_timer_start('unpacking_halo_exch') + !$acc parallel default(present) + !$acc loop gang collapse(2) do iHalo = 1, nHalos do j = 1, maxNRecvList + !$acc loop vector collapse(2) do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then @@ -922,14 +936,15 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) end do end do end do - !$acc end kernels - !!$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) + !$acc end parallel + call mpas_timer_stop('unpacking_halo_exch') end select end if end do end do + MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') do i = 1, group % nFields if (group % fields(i) % fieldType == MPAS_HALO_REAL) then select case (group % fields(i) % nDims) @@ -958,20 +973,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! For the present(group % recvBuf(:), group % sendBuf(:)) !$acc end data - ! !$acc wait - ! do i = 1, group % nFields - ! if (group % fields(i) % fieldType == MPAS_HALO_REAL) then - ! select case (group % fields(i) % nDims) - ! case (1) - ! !$acc exit data copyout(group % fields(i) % r1arr(:)) - ! case (2) - ! !$acc exit data copyout(group % fields(i) % r2arr(:,:)) - ! case (3) - ! !$acc exit data copyout(group % fields(i) % r3arr(:,:,:)) - ! end select - ! end if - ! end do - ! !$acc wait + MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') ! ! Nullify array pointers - not necessary for correctness, but helpful when debugging @@ -992,6 +994,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! call MPI_Waitall(group % nGroupSendNeighbors, group % sendRequests, MPI_STATUSES_IGNORE, mpi_ierr) + call mpas_timer_stop('full_halo_exch') + end subroutine mpas_halo_exch_group_full_halo_exch From f51f57accc45c3b46120c772d22cd17b7d2df4fe Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Tue, 12 Aug 2025 17:13:52 -0600 Subject: [PATCH 25/30] Working savepoint --- .../dynamics/mpas_atm_time_integration.F | 107 +++++++++--------- src/core_atmosphere/mpas_atm_core.F | 3 +- src/core_atmosphere/mpas_atm_halos.F | 3 +- .../physics/mpas_atmphys_todynamics.F | 3 +- src/framework/mpas_dmpar.F | 11 +- src/framework/mpas_halo.F | 67 +++++++---- 6 files changed, 112 insertions(+), 82 deletions(-) diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index 5cb15624f2..b656d6b4dd 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -43,12 +43,13 @@ module atm_time_integration ! in a named group ! abstract interface - subroutine halo_exchange_routine(domain, halo_group, ierr) + subroutine halo_exchange_routine(domain, halo_group, withGPUAwareMPI, ierr) use mpas_derived_types, only : domain_type type (domain_type), intent(inout) :: domain character(len=*), intent(in) :: halo_group + logical, intent(in), optional :: withGPUAwareMPI integer, intent(out), optional :: ierr end subroutine halo_exchange_routine @@ -1989,9 +1990,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'scalars', scalars_1, 1) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - !$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) - call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p') - !$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) + !!$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) + call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p', .true.) + !!$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_rk_integration_setup') @@ -2080,9 +2081,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'exner', exner) - !$acc update self(exner) - call exchange_halo_group(domain, 'dynamics:exner') - !$acc update device(exner) + !!$acc update self(exner) + call exchange_halo_group(domain, 'dynamics:exner', .true.) + !!$acc update device(exner) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -2163,9 +2164,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! tend_u MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(tend, 'u', tend_u) - !$acc update self(tend_u) - call exchange_halo_group(domain, 'dynamics:tend_u') - !$acc update device(tend_u) + !!$acc update self(tend_u) + call exchange_halo_group(domain, 'dynamics:tend_u', .true.) + !!$acc update device(tend_u) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('small_step_prep') @@ -2244,9 +2245,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rho_pp', rho_pp) - !$acc update self(rho_pp) - call exchange_halo_group(domain, 'dynamics:rho_pp') - !$acc update device(rho_pp) + !!$acc update self(rho_pp) + call exchange_halo_group(domain, 'dynamics:rho_pp', .true.) + !!$acc update device(rho_pp) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_advance_acoustic_step') @@ -2271,9 +2272,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! This is the only communications needed during the acoustic steps because we solve for u on all edges of owned cells MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - !$acc update self(rtheta_pp) - call exchange_halo_group(domain, 'dynamics:rtheta_pp') - !$acc update device(rtheta_pp) + !!$acc update self(rtheta_pp) + call exchange_halo_group(domain, 'dynamics:rtheta_pp', .true.) + !!$acc update device(rtheta_pp) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! complete update of horizontal momentum by including 3d divergence damping at the end of the acoustic step @@ -2299,9 +2300,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(diag, 'rw_p', rw_p) call mpas_pool_get_array(diag, 'rho_pp', rho_pp) call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - !$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) - call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp') - !$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) + !!$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) + call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp', .true.) + !!$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_recover_large_step_variables') @@ -2372,14 +2373,14 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'u', u, 2) - !$acc update self(u) + !!$acc update self(u) ! u if (config_apply_lbcs) then - call exchange_halo_group(domain, 'dynamics:u_123') + call exchange_halo_group(domain, 'dynamics:u_123', .true.) else - call exchange_halo_group(domain, 'dynamics:u_3') + call exchange_halo_group(domain, 'dynamics:u_3', .true.) end if - !$acc update device(u) + !!$acc update device(u) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! scalar advection: RK3 scheme of Skamarock and Gassmann (2011). @@ -2394,9 +2395,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) + !!$acc update self(scalars_2) + call exchange_halo_group(domain, 'dynamics:scalars', .true.) + !!$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2453,22 +2454,22 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'w', w, 2) call mpas_pool_get_array(diag, 'pv_edge', pv_edge) call mpas_pool_get_array(diag, 'rho_edge', rho_edge) - !$acc update self(w,pv_edge,rho_edge) + !!$acc update self(w,pv_edge,rho_edge) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2], scalars[1,2] ! call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars') - !$acc update device(scalars_2) + !!$acc update self(scalars_2) + call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars', .true.) + !!$acc update device(scalars_2) else ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2] ! - call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge') + call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge', .true.) end if - !$acc update device(w,pv_edge,rho_edge) + !!$acc update device(w,pv_edge,rho_edge) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! set the zero-gradient condition on w for regional_MPAS @@ -2485,9 +2486,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! w halo values needs resetting after regional boundary update call mpas_pool_get_array(state, 'w', w, 2) - !$acc update self(w) - call exchange_halo_group(domain, 'dynamics:w') - !$acc update device(w) + !!$acc update self(w) + call exchange_halo_group(domain, 'dynamics:w', .true.) + !!$acc update device(w) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if ! end of regional_MPAS addition @@ -2503,9 +2504,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'theta_m', theta_m, 2) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - !$acc update self(theta_m,pressure_p,rtheta_p) - call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p') - !$acc update device(theta_m,pressure_p,rtheta_p) + !!$acc update self(theta_m,pressure_p,rtheta_p) + call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p', .true.) + !!$acc update device(theta_m,pressure_p,rtheta_p) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! @@ -2574,9 +2575,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! need to fill halo for horizontal filter call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) + !!$acc update self(scalars_2) + call exchange_halo_group(domain, 'dynamics:scalars', .true.) + !!$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2605,9 +2606,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (rk_step < 3) then MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) + !!$acc update self(scalars_2) + call exchange_halo_group(domain, 'dynamics:scalars', .true.) + !!$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if @@ -2736,9 +2737,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars') - !$acc update device(scalars_2) + !!$acc update self(scalars_2) + call exchange_halo_group(domain, 'dynamics:scalars', .true.) + !!$acc update device(scalars_2) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -5074,17 +5075,17 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$acc end parallel MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update self(scalars_old) + !!$acc update self(scalars_old) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER !$OMP MASTER - call exchange_halo_group(block % domain, 'dynamics:'//trim(field_name)//'_old') + call exchange_halo_group(block % domain, 'dynamics:'//trim(field_name)//'_old', .true.) !$OMP END MASTER !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update device(scalars_old) + !!$acc update device(scalars_old) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') ! @@ -5481,17 +5482,17 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge ! MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update self(scale_arr) + !!$acc update self(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER !$OMP MASTER - call exchange_halo_group(block % domain, 'dynamics:scale') + call exchange_halo_group(block % domain, 'dynamics:scale', .true.) !$OMP END MASTER !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !$acc update device(scale_arr) + !!$acc update device(scale_arr) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$acc parallel diff --git a/src/core_atmosphere/mpas_atm_core.F b/src/core_atmosphere/mpas_atm_core.F index 087cfc2f2c..8ab4b60928 100644 --- a/src/core_atmosphere/mpas_atm_core.F +++ b/src/core_atmosphere/mpas_atm_core.F @@ -18,12 +18,13 @@ module atm_core ! in a named group ! abstract interface - subroutine halo_exchange_routine(domain, halo_group, ierr) + subroutine halo_exchange_routine(domain, halo_group, withGPUAwareMPI, ierr) use mpas_derived_types, only : domain_type type (domain_type), intent(inout) :: domain character(len=*), intent(in) :: halo_group + logical, intent(in), optional :: withGPUAwareMPI integer, intent(out), optional :: ierr end subroutine halo_exchange_routine diff --git a/src/core_atmosphere/mpas_atm_halos.F b/src/core_atmosphere/mpas_atm_halos.F index df02ee30a2..955f9b5ea0 100644 --- a/src/core_atmosphere/mpas_atm_halos.F +++ b/src/core_atmosphere/mpas_atm_halos.F @@ -15,12 +15,13 @@ module mpas_atm_halos ! in a named group ! abstract interface - subroutine halo_exchange_routine(domain, halo_group, ierr) + subroutine halo_exchange_routine(domain, halo_group, withGPUAwareMPI, ierr) use mpas_derived_types, only : domain_type type (domain_type), intent(inout) :: domain character(len=*), intent(in) :: halo_group + logical, intent(in), optional :: withGPUAwareMPI integer, intent(out), optional :: ierr end subroutine halo_exchange_routine diff --git a/src/core_atmosphere/physics/mpas_atmphys_todynamics.F b/src/core_atmosphere/physics/mpas_atmphys_todynamics.F index 2cb94a7ba5..71f37eb550 100644 --- a/src/core_atmosphere/physics/mpas_atmphys_todynamics.F +++ b/src/core_atmosphere/physics/mpas_atmphys_todynamics.F @@ -56,12 +56,13 @@ module mpas_atmphys_todynamics ! in a named group ! abstract interface - subroutine halo_exchange_routine(domain, halo_group, ierr) + subroutine halo_exchange_routine(domain, halo_group, withGPUAwareMPI, ierr) use mpas_derived_types, only : domain_type type (domain_type), intent(inout) :: domain character(len=*), intent(in) :: halo_group + logical, intent(in), optional :: withGPUAwareMPI integer, intent(out), optional :: ierr end subroutine halo_exchange_routine diff --git a/src/framework/mpas_dmpar.F b/src/framework/mpas_dmpar.F index 6d68c0c656..5d9b48d53b 100644 --- a/src/framework/mpas_dmpar.F +++ b/src/framework/mpas_dmpar.F @@ -7448,19 +7448,28 @@ end subroutine mpas_dmpar_exch_group_end_halo_exch!}}} !> exchange is complete. ! !----------------------------------------------------------------------- - subroutine mpas_dmpar_exch_group_full_halo_exch(domain, groupName, iErr)!{{{ + subroutine mpas_dmpar_exch_group_full_halo_exch(domain, groupName, withGPUAwareMPI, iErr)!{{{ type (domain_type), intent(inout) :: domain character (len=*), intent(in) :: groupName + logical, optional, intent(in) :: withGPUAwareMPI integer, optional, intent(out) :: iErr type (mpas_exchange_group), pointer :: exchGroupPtr integer :: nLen + logical :: useGPUAwareMPI if ( present(iErr) ) then iErr = MPAS_DMPAR_NOERR end if + useGPUAwareMPI = .false. + if (present(withGPUAwareMPI)) then + if (withGPUAwareMPI) then + call mpas_log_write(' GPU-aware MPI not implemented in this module', MPAS_LOG_CRIT) + end if + end if + nLen = len_trim(groupName) DMPAR_DEBUG_WRITE(' -- Trying to perform a full exchange for group ' // trim(groupName)) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 7f96e0d397..ead5e42b35 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -542,7 +542,7 @@ end subroutine mpas_halo_exch_group_add_field !> exchange group. ! !----------------------------------------------------------------------- - subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) + subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMPI, iErr) #ifdef MPAS_USE_MPI_F08 use mpi_f08, only : MPI_Datatype, MPI_Comm @@ -576,6 +576,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Arguments type (domain_type), intent(inout) :: domain character (len=*), intent(in) :: groupName + logical, optional, intent(in) :: withGPUAwareMPI integer, optional, intent(out) :: iErr ! Local variables @@ -592,6 +593,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) integer :: comm #endif integer :: mpi_ierr + logical:: useGPUAwareMPI type (mpas_halo_group), pointer :: group integer, dimension(:), pointer :: compactHaloInfo integer, dimension(:), pointer :: compactSendLists @@ -605,11 +607,17 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) integer, dimension(:,:,:), CONTIGUOUS pointer :: recvListSrc, recvListDst integer, dimension(:), CONTIGUOUS pointer :: unpackOffsets + if (present(iErr)) then iErr = 0 end if + useGPUAwareMPI = .false. + if (present(withGPUAwareMPI)) then + useGPUAwareMPI = withGPUAwareMPI + end if + ! ! Find this halo exhange group in the list of groups @@ -645,7 +653,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) #endif rank = group % fields(1) % compactHaloInfo(8) - !$acc data present(group % recvBuf(:), group % sendBuf(:)) + !$acc data present(group % recvBuf(:), group % sendBuf(:)) if(useGPUAwareMPI) ! ! Initiate non-blocking MPI receives for all neighbors @@ -654,8 +662,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) if (group % groupRecvCounts(i) > 0) then bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 + + !!$acc update self(group % recvBuf(bufstart:bufend)) if(useGPUAwareMPI) + !TO DO: how do we determine appropriate type here? - !$acc host_data use_device(group % recvBuf) + !$acc host_data use_device(group % recvBuf) if(useGPUAwareMPI) call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, & group % recvRequests(i), mpi_ierr) @@ -698,14 +709,14 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) - MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - !$acc enter data copyin(group % fields(i) % r1arr(:)) - MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + !MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') + !$acc enter data copyin(group % fields(i) % r1arr(:)) if(useGPUAwareMPI) + !MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') ! ! Pack send buffer for all neighbors for current field ! call mpas_timer_start('packing_halo_exch') - !$acc parallel default(present) + !$acc parallel default(present) if(useGPUAwareMPI) !$acc loop gang collapse(2) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -737,13 +748,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) ! !$acc data copyin(group % fields(i) % r2arr(:,:)) - MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - !$acc enter data copyin(group % fields(i) % r2arr(:,:)) - MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + ! MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') + !$acc enter data copyin(group % fields(i) % r2arr(:,:)) if (useGPUAwareMPI) + ! MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') call mpas_timer_start('packing_halo_exch') ! Kernels is good enough, use default present to force a run-time error if programmer forgot something - !$acc parallel default(present) + !$acc parallel default(present) if(useGPUAwareMPI) !$acc loop gang collapse(3) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -768,15 +779,15 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r3arr, group % fields(i) % timeLevel) - MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - !$acc enter data copyin(group % fields(i) % r3arr(:,:,:)) - MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + ! MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') + !$acc enter data copyin(group % fields(i) % r3arr(:,:,:)) if (useGPUAwareMPI) + ! MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') ! ! Pack send buffer for all neighbors for current field ! call mpas_timer_start('packing_halo_exch') - !$acc parallel default(present) + !$acc parallel default(present) if(useGPUAwareMPI) !$acc loop gang collapse(4) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -809,8 +820,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) if (group % groupSendCounts(i) > 0) then bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 + + !!$acc update self(group % sendBuf(bufstart:bufend)) if(.not. useGPUAwareMPI) !TO DO: how do we determine appropriate type here? - !$acc host_data use_device(group % sendBuf) + !$acc host_data use_device(group % sendBuf) if(useGPUAwareMPI) call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & group % groupSendNeighbors(i), rank, comm, & group % sendRequests(i), mpi_ierr) @@ -860,6 +873,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) dim1 = compactHaloInfo(2) dim2 = compactHaloInfo(3) + bufstart = group % groupSendOffsets(iNeighbor) + bufend = group % groupSendOffsets(iNeighbor) + group % groupSendCounts(iNeighbor) - 1 + !!!$acc update device(group % recvBuf(bufstart:bufend)) if(.not. useGPUAwareMPI) + select case (group % fields(i) % nDims) ! @@ -870,7 +887,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Unpack recv buffer from all neighbors for current field ! call mpas_timer_start('unpacking_halo_exch') - !$acc parallel default(present) + !$acc parallel default(present) if(useGPUAwareMPI) !$acc loop gang do iHalo = 1, nHalos !$acc loop vector @@ -893,7 +910,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Unpack recv buffer from all neighbors for current field ! call mpas_timer_start('unpacking_halo_exch') - !$acc parallel default(present) + !$acc parallel default(present) if(useGPUAwareMPI) !$acc loop gang do iHalo = 1, nHalos !$acc loop worker @@ -919,7 +936,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) ! Unpack recv buffer from all neighbors for current field ! call mpas_timer_start('unpacking_halo_exch') - !$acc parallel default(present) + !$acc parallel default(present) if(useGPUAwareMPI) !$acc loop gang collapse(2) do iHalo = 1, nHalos do j = 1, maxNRecvList @@ -949,11 +966,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) if (group % fields(i) % fieldType == MPAS_HALO_REAL) then select case (group % fields(i) % nDims) case (1) - !$acc update self(group % fields(i) % r1arr(:)) + !$acc update self(group % fields(i) % r1arr(:)) if (useGPUAwareMPI) case (2) - !$acc update self(group % fields(i) % r2arr(:,:)) + !$acc update self(group % fields(i) % r2arr(:,:)) if (useGPUAwareMPI) case (3) - !$acc update self(group % fields(i) % r3arr(:,:,:)) + !$acc update self(group % fields(i) % r3arr(:,:,:)) if (useGPUAwareMPI) end select end if end do @@ -962,11 +979,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr) if (group % fields(i) % fieldType == MPAS_HALO_REAL) then select case (group % fields(i) % nDims) case (1) - !$acc exit data delete(group % fields(i) % r1arr(:)) + !$acc exit data delete(group % fields(i) % r1arr(:)) if (useGPUAwareMPI) case (2) - !$acc exit data delete(group % fields(i) % r2arr(:,:)) + !$acc exit data delete(group % fields(i) % r2arr(:,:)) if (useGPUAwareMPI) case (3) - !$acc exit data delete(group % fields(i) % r3arr(:,:,:)) + !$acc exit data delete(group % fields(i) % r3arr(:,:,:)) if (useGPUAwareMPI) end select end if end do From 7b1ff2c9853c3f02d19e07cef110e2e60f0657ab Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Wed, 13 Aug 2025 09:27:06 -0600 Subject: [PATCH 26/30] u_2 and w_2 need to be copied out after dynamics + cleanup --- src/framework/mpas_halo.F | 68 +++++++++------------------------------ 1 file changed, 16 insertions(+), 52 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index ead5e42b35..034e222576 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -556,6 +556,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP use mpas_pool_routines, only : mpas_pool_get_array use mpas_log, only : mpas_log_write use mpas_timer, only : mpas_timer_start, mpas_timer_stop + use openacc ! Parameters @@ -636,6 +637,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP messageType=MPAS_LOG_CRIT) end if + ! Logic to return early if there no neighbors to send to if ( group% nGroupSendNeighbors <= 0 ) then return end if @@ -662,9 +664,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP if (group % groupRecvCounts(i) > 0) then bufstart = group % groupRecvOffsets(i) bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1 - - !!$acc update self(group % recvBuf(bufstart:bufend)) if(useGPUAwareMPI) - !TO DO: how do we determine appropriate type here? !$acc host_data use_device(group % recvBuf) if(useGPUAwareMPI) call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, & @@ -708,10 +707,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP case (1) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) - - !MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - !$acc enter data copyin(group % fields(i) % r1arr(:)) if(useGPUAwareMPI) - !MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + + if( useGPUAwareMPI ) then + call acc_attach(group % fields(i) % r1arr) + end if ! ! Pack send buffer for all neighbors for current field ! @@ -742,15 +741,15 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! ! Pack send buffer for all neighbors for current field - ! - + ! ! Use data regions for specificity and so the reference or attachment counters are easier to make sense of ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) - ! !$acc data copyin(group % fields(i) % r2arr(:,:)) - ! MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - !$acc enter data copyin(group % fields(i) % r2arr(:,:)) if (useGPUAwareMPI) - ! MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + + if( useGPUAwareMPI ) then + call acc_attach(group % fields(i) % r2arr) + end if + call mpas_timer_start('packing_halo_exch') ! Kernels is good enough, use default present to force a run-time error if programmer forgot something @@ -778,10 +777,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & - group % fields(i) % r3arr, group % fields(i) % timeLevel) - ! MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - !$acc enter data copyin(group % fields(i) % r3arr(:,:,:)) if (useGPUAwareMPI) - ! MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') + group % fields(i) % r3arr, group % fields(i) % timeLevel) + if( useGPUAwareMPI ) then + call acc_attach(group % fields(i) % r3arr) + end if ! ! Pack send buffer for all neighbors for current field @@ -820,8 +819,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP if (group % groupSendCounts(i) > 0) then bufstart = group % groupSendOffsets(i) bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1 - - !!$acc update self(group % sendBuf(bufstart:bufend)) if(.not. useGPUAwareMPI) !TO DO: how do we determine appropriate type here? !$acc host_data use_device(group % sendBuf) if(useGPUAwareMPI) call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, & @@ -873,10 +870,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP dim1 = compactHaloInfo(2) dim2 = compactHaloInfo(3) - bufstart = group % groupSendOffsets(iNeighbor) - bufend = group % groupSendOffsets(iNeighbor) + group % groupSendCounts(iNeighbor) - 1 - !!!$acc update device(group % recvBuf(bufstart:bufend)) if(.not. useGPUAwareMPI) - select case (group % fields(i) % nDims) ! @@ -960,37 +953,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP end if end do end do - - MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]') - do i = 1, group % nFields - if (group % fields(i) % fieldType == MPAS_HALO_REAL) then - select case (group % fields(i) % nDims) - case (1) - !$acc update self(group % fields(i) % r1arr(:)) if (useGPUAwareMPI) - case (2) - !$acc update self(group % fields(i) % r2arr(:,:)) if (useGPUAwareMPI) - case (3) - !$acc update self(group % fields(i) % r3arr(:,:,:)) if (useGPUAwareMPI) - end select - end if - end do - - do i = 1, group % nFields - if (group % fields(i) % fieldType == MPAS_HALO_REAL) then - select case (group % fields(i) % nDims) - case (1) - !$acc exit data delete(group % fields(i) % r1arr(:)) if (useGPUAwareMPI) - case (2) - !$acc exit data delete(group % fields(i) % r2arr(:,:)) if (useGPUAwareMPI) - case (3) - !$acc exit data delete(group % fields(i) % r3arr(:,:,:)) if (useGPUAwareMPI) - end select - end if - end do - ! For the present(group % recvBuf(:), group % sendBuf(:)) !$acc end data - MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]') ! ! Nullify array pointers - not necessary for correctness, but helpful when debugging From ef09e0b2a88bf3949863a117153423a3e4a4697d Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Wed, 13 Aug 2025 10:05:02 -0600 Subject: [PATCH 27/30] using attach in a directive instead of the acc_attach library call --- src/framework/mpas_halo.F | 28 ++++------------------------ 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 034e222576..8048661acb 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -17,15 +17,6 @@ !> communicating the halos of all fields in a group. ! !----------------------------------------------------------------------- - -#ifdef MPAS_OPENACC -#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X) -#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X) -#else -#define MPAS_ACC_TIMER_START(X) -#define MPAS_ACC_TIMER_STOP(X) -#endif - module mpas_halo implicit none @@ -556,8 +547,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP use mpas_pool_routines, only : mpas_pool_get_array use mpas_log, only : mpas_log_write use mpas_timer, only : mpas_timer_start, mpas_timer_stop - use openacc - ! Parameters #ifdef MPAS_USE_MPI_F08 @@ -607,7 +596,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP integer :: maxNRecvList integer, dimension(:,:,:), CONTIGUOUS pointer :: recvListSrc, recvListDst integer, dimension(:), CONTIGUOUS pointer :: unpackOffsets - if (present(iErr)) then @@ -618,7 +606,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP if (present(withGPUAwareMPI)) then useGPUAwareMPI = withGPUAwareMPI end if - ! ! Find this halo exhange group in the list of groups @@ -708,9 +695,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) - if( useGPUAwareMPI ) then - call acc_attach(group % fields(i) % r1arr) - end if + !$acc enter data attach(group % fields(i) % r1arr) if(useGPUAwareMPI) ! ! Pack send buffer for all neighbors for current field ! @@ -738,7 +723,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP case (2) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r2arr, timeLevel=group % fields(i) % timeLevel) - + + !$acc enter data attach(group % fields(i) % r2arr) if(useGPUAwareMPI) ! ! Pack send buffer for all neighbors for current field ! @@ -746,10 +732,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action' ! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:)) - if( useGPUAwareMPI ) then - call acc_attach(group % fields(i) % r2arr) - end if - call mpas_timer_start('packing_halo_exch') ! Kernels is good enough, use default present to force a run-time error if programmer forgot something @@ -778,10 +760,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r3arr, group % fields(i) % timeLevel) - if( useGPUAwareMPI ) then - call acc_attach(group % fields(i) % r3arr) - end if + !$acc enter data attach(group % fields(i) % r3arr) if(useGPUAwareMPI) ! ! Pack send buffer for all neighbors for current field ! From 3af6f9d705e920c55953d2bed108fb1d02ad6506 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Wed, 13 Aug 2025 10:45:13 -0600 Subject: [PATCH 28/30] Using attach clause in parallel region will also auto detach at end of region --- src/framework/mpas_halo.F | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index 8048661acb..e88dd5d021 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -695,12 +695,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel) - !$acc enter data attach(group % fields(i) % r1arr) if(useGPUAwareMPI) ! ! Pack send buffer for all neighbors for current field ! call mpas_timer_start('packing_halo_exch') - !$acc parallel default(present) if(useGPUAwareMPI) + !$acc parallel default(present) attach(group % fields(i) % r1arr) if(useGPUAwareMPI) !$acc loop gang collapse(2) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -724,7 +723,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r2arr, timeLevel=group % fields(i) % timeLevel) - !$acc enter data attach(group % fields(i) % r2arr) if(useGPUAwareMPI) ! ! Pack send buffer for all neighbors for current field ! @@ -735,7 +733,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP call mpas_timer_start('packing_halo_exch') ! Kernels is good enough, use default present to force a run-time error if programmer forgot something - !$acc parallel default(present) if(useGPUAwareMPI) + !$acc parallel default(present) attach(group % fields(i) % r2arr) if(useGPUAwareMPI) !$acc loop gang collapse(3) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -760,13 +758,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP case (3) call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), & group % fields(i) % r3arr, group % fields(i) % timeLevel) - - !$acc enter data attach(group % fields(i) % r3arr) if(useGPUAwareMPI) + ! ! Pack send buffer for all neighbors for current field ! call mpas_timer_start('packing_halo_exch') - !$acc parallel default(present) if(useGPUAwareMPI) + !$acc parallel default(present) attach(group % fields(i) % r3arr) if(useGPUAwareMPI) !$acc loop gang collapse(4) do iEndp = 1, nSendEndpts do iHalo = 1, nHalos @@ -860,7 +857,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! Unpack recv buffer from all neighbors for current field ! call mpas_timer_start('unpacking_halo_exch') - !$acc parallel default(present) if(useGPUAwareMPI) + !$acc parallel default(present) attach(group % fields(i) % r1arr) if(useGPUAwareMPI) !$acc loop gang do iHalo = 1, nHalos !$acc loop vector @@ -883,7 +880,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! Unpack recv buffer from all neighbors for current field ! call mpas_timer_start('unpacking_halo_exch') - !$acc parallel default(present) if(useGPUAwareMPI) + !$acc parallel default(present) attach(group % fields(i) % r2arr) if(useGPUAwareMPI) !$acc loop gang do iHalo = 1, nHalos !$acc loop worker @@ -909,7 +906,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! Unpack recv buffer from all neighbors for current field ! call mpas_timer_start('unpacking_halo_exch') - !$acc parallel default(present) if(useGPUAwareMPI) + !$acc parallel default(present) attach(group % fields(i) % r3arr) if(useGPUAwareMPI) !$acc loop gang collapse(2) do iHalo = 1, nHalos do j = 1, maxNRecvList From 00c92b0cad89b518816b42a1ae7f65d6612ed862 Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Wed, 13 Aug 2025 11:43:26 -0600 Subject: [PATCH 29/30] Reverting the indexing in loops and comment cleanup --- src/framework/mpas_halo.F | 39 ++++++++++++++++----------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/framework/mpas_halo.F b/src/framework/mpas_halo.F index e88dd5d021..8877d6330d 100644 --- a/src/framework/mpas_halo.F +++ b/src/framework/mpas_halo.F @@ -280,6 +280,7 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr) call refactor_lists(domain, groupName, iErr) + ! Logic to return early if there are no neighbors to send to if ( newGroup% nGroupSendNeighbors <=0 ) then return end if @@ -571,7 +572,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP ! Local variables integer :: i, bufstart, bufend - integer :: idxBuf, idxArr integer :: dim1, dim2 integer :: i1, i2, j, iNeighbor, iReq integer :: iHalo, iEndp @@ -624,7 +624,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP messageType=MPAS_LOG_CRIT) end if - ! Logic to return early if there no neighbors to send to + ! Logic to return early if there are no neighbors to send to if ( group% nGroupSendNeighbors <= 0 ) then return end if @@ -706,9 +706,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP !$acc loop vector do j = 1, maxNSendList if (j <= nSendLists(iHalo,iEndp)) then - idxBuf = packOffsets(iEndp) + sendListDst(j,iHalo,iEndp) - idxArr = sendListSrc(j,iHalo,iEndp) - group % sendBuf(idxBuf) = group % fields(i) % r1arr(idxArr) + group % sendBuf(packOffsets(iEndp) + sendListDst(j,iHalo,iEndp)) = & + group % fields(i) % r1arr(sendListSrc(j,iHalo,iEndp)) end if end do end do @@ -732,7 +731,6 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP call mpas_timer_start('packing_halo_exch') - ! Kernels is good enough, use default present to force a run-time error if programmer forgot something !$acc parallel default(present) attach(group % fields(i) % r2arr) if(useGPUAwareMPI) !$acc loop gang collapse(3) do iEndp = 1, nSendEndpts @@ -741,9 +739,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP !$acc loop vector do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then - idxBuf = packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1 - idxArr = sendListSrc(j,iHalo,iEndp) - group % sendBuf(idxBuf) = group % fields(i) % r2arr(i1,idxArr) + group % sendBuf(packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1) = & + group % fields(i) % r2arr(i1, sendListSrc(j,iHalo,iEndp)) end if end do end do @@ -772,10 +769,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP !$acc loop vector do i1 = 1, dim1 if (j <= nSendLists(iHalo,iEndp)) then - idxBuf = packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & - + dim1*(i2-1) + i1 - idxArr = sendListSrc(j,iHalo,iEndp) - group % sendBuf(idxBuf) = group % fields(i) % r3arr(i1,i2,idxArr) + group % sendBuf(packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) & + + dim1*(i2-1) + i1) = & + group % fields(i) % r3arr(i1, i2, sendListSrc(j,iHalo,iEndp)) end if end do end do @@ -863,9 +859,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP !$acc loop vector do j = 1, maxNRecvList if (j <= nRecvLists(iHalo,iEndp)) then - idxArr = recvListDst(j,iHalo,iEndp) - idxBuf = unpackOffsets(iEndp) + recvListSrc(j,iHalo,iEndp) - group % fields(i) % r1arr(idxArr) = group % recvBuf(idxBuf) + group % fields(i) % r1arr(recvListDst(j,iHalo,iEndp)) = & + group % recvBuf(unpackOffsets(iEndp) + recvListSrc(j,iHalo,iEndp)) end if end do end do @@ -888,9 +883,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP !$acc loop vector do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then - idxArr = recvListDst(j,iHalo,iEndp) - idxBuf = unpackOffsets(iEndp) + dim1 * (recvListSrc(j,iHalo,iEndp) - 1) + i1 - group % fields(i) % r2arr(i1, idxArr) = group % recvBuf(idxBuf) + group % fields(i) % r2arr(i1, recvListDst(j,iHalo,iEndp)) = & + group % recvBuf(unpackOffsets(iEndp) + dim1 * (recvListSrc(j,iHalo,iEndp) - 1) + i1) end if end do end do @@ -914,10 +908,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, withGPUAwareMP do i2 = 1, dim2 do i1 = 1, dim1 if (j <= nRecvLists(iHalo,iEndp)) then - idxArr = recvListDst(j,iHalo,iEndp) - idxBuf = unpackOffsets(iEndp) + dim1*dim2*(recvListSrc(j,iHalo,iEndp) - 1) & - + dim1*(i2-1) + i1 - group % fields(i) % r3arr(i1, i2, idxArr) = group % recvBuf(idxBuf) + group % fields(i) % r3arr(i1, i2, recvListDst(j,iHalo,iEndp)) = & + group % recvBuf(unpackOffsets(iEndp) + dim1*dim2*(recvListSrc(j,iHalo,iEndp) - 1) & + + dim1*(i2-1) + i1) end if end do end do From a8fda92ccc4288c6fb74f7cc33d0e71587a3e3ab Mon Sep 17 00:00:00 2001 From: Abishek Gopal Date: Wed, 13 Aug 2025 15:48:42 -0600 Subject: [PATCH 30/30] New namelist option to switch on or off GPU-Aware MPI Introducing a new namelist option under development, config_gpu_aware_mpi, which will control whether the OpenACC run of MPAS on GPUs will use GPU-aware MPI or do a device<->host update of variables around the call to a purely CPU- based halo exchange. Note: This feature is not available to use when config_halo_exch_method is set to 'mpas_dmpar' --- src/core_atmosphere/Registry.xml | 4 + .../dynamics/mpas_atm_time_integration.F | 123 +++++++++--------- src/core_atmosphere/mpas_atm_halos.F | 5 + 3 files changed, 73 insertions(+), 59 deletions(-) diff --git a/src/core_atmosphere/Registry.xml b/src/core_atmosphere/Registry.xml index 4281c40bba..1d9d80c89b 100644 --- a/src/core_atmosphere/Registry.xml +++ b/src/core_atmosphere/Registry.xml @@ -392,6 +392,10 @@ units="-" description="Method to use for exchanging halos" possible_values="`mpas_dmpar', `mpas_halo'"/> + diff --git a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F index b656d6b4dd..4c15c5169c 100644 --- a/src/core_atmosphere/dynamics/mpas_atm_time_integration.F +++ b/src/core_atmosphere/dynamics/mpas_atm_time_integration.F @@ -1793,6 +1793,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) logical, pointer :: config_scalar_advection logical, pointer :: config_positive_definite logical, pointer :: config_monotonic + logical, pointer :: config_gpu_aware_mpi character (len=StrKIND), pointer :: config_microp_scheme character (len=StrKIND), pointer :: config_convection_scheme @@ -1837,6 +1838,7 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_config(block % configs, 'config_scalar_advection', config_scalar_advection) call mpas_pool_get_config(block % configs, 'config_positive_definite', config_positive_definite) call mpas_pool_get_config(block % configs, 'config_monotonic', config_monotonic) + call mpas_pool_get_config(block % configs, 'config_gpu_aware_mpi', config_gpu_aware_mpi) call mpas_pool_get_config(block % configs, 'config_IAU_option', config_IAU_option) ! config variables for dynamics-transport splitting, WCS 18 November 2014 call mpas_pool_get_config(block % configs, 'config_split_dynamics_transport', config_split_dynamics_transport) @@ -1990,9 +1992,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'scalars', scalars_1, 1) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - !!$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) - call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p', .true.) - !!$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) + !$acc update self(theta_m,scalars_1,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:theta_m,scalars,pressure_p,rtheta_p', config_gpu_aware_mpi) + !$acc update device(theta_m,scalars_1,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_rk_integration_setup') @@ -2081,9 +2083,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'exner', exner) - !!$acc update self(exner) - call exchange_halo_group(domain, 'dynamics:exner', .true.) - !!$acc update device(exner) + !$acc update self(exner) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:exner', config_gpu_aware_mpi) + !$acc update device(exner) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -2164,9 +2166,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! tend_u MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(tend, 'u', tend_u) - !!$acc update self(tend_u) - call exchange_halo_group(domain, 'dynamics:tend_u', .true.) - !!$acc update device(tend_u) + !$acc update self(tend_u) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:tend_u', config_gpu_aware_mpi) + !$acc update device(tend_u) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('small_step_prep') @@ -2245,9 +2247,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rho_pp', rho_pp) - !!$acc update self(rho_pp) - call exchange_halo_group(domain, 'dynamics:rho_pp', .true.) - !!$acc update device(rho_pp) + !$acc update self(rho_pp) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:rho_pp', config_gpu_aware_mpi) + !$acc update device(rho_pp) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_advance_acoustic_step') @@ -2272,9 +2274,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) ! This is the only communications needed during the acoustic steps because we solve for u on all edges of owned cells MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - !!$acc update self(rtheta_pp) - call exchange_halo_group(domain, 'dynamics:rtheta_pp', .true.) - !!$acc update device(rtheta_pp) + !$acc update self(rtheta_pp) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:rtheta_pp', config_gpu_aware_mpi) + !$acc update device(rtheta_pp) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! complete update of horizontal momentum by including 3d divergence damping at the end of the acoustic step @@ -2300,9 +2302,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(diag, 'rw_p', rw_p) call mpas_pool_get_array(diag, 'rho_pp', rho_pp) call mpas_pool_get_array(diag, 'rtheta_pp', rtheta_pp) - !!$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) - call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp', .true.) - !!$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) + !$acc update self(rw_p,ru_p,rho_pp,rtheta_pp) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:rw_p,ru_p,rho_pp,rtheta_pp', config_gpu_aware_mpi) + !$acc update device(rw_p,ru_p,rho_pp,rtheta_pp) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_timer_start('atm_recover_large_step_variables') @@ -2373,14 +2375,14 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'u', u, 2) - !!$acc update self(u) + !$acc update self(u) if (.not. config_gpu_aware_mpi) ! u if (config_apply_lbcs) then - call exchange_halo_group(domain, 'dynamics:u_123', .true.) + call exchange_halo_group(domain, 'dynamics:u_123', config_gpu_aware_mpi) else - call exchange_halo_group(domain, 'dynamics:u_3', .true.) + call exchange_halo_group(domain, 'dynamics:u_3', config_gpu_aware_mpi) end if - !!$acc update device(u) + !$acc update device(u) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! scalar advection: RK3 scheme of Skamarock and Gassmann (2011). @@ -2389,15 +2391,15 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then call advance_scalars('scalars', domain, rk_step, rk_timestep, config_monotonic, config_positive_definite, & - config_time_integration_order, config_split_dynamics_transport, exchange_halo_group) + config_time_integration_order, config_split_dynamics_transport, config_gpu_aware_mpi, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary tendencies for regional_MPAS scalar transport MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !!$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars', .true.) - !!$acc update device(scalars_2) + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2454,22 +2456,22 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'w', w, 2) call mpas_pool_get_array(diag, 'pv_edge', pv_edge) call mpas_pool_get_array(diag, 'rho_edge', rho_edge) - !!$acc update self(w,pv_edge,rho_edge) + !$acc update self(w,pv_edge,rho_edge) if (.not. config_gpu_aware_mpi) if (config_scalar_advection .and. (.not. config_split_dynamics_transport) ) then ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2], scalars[1,2] ! call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !!$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars', .true.) - !!$acc update device(scalars_2) + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge,scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) else ! ! Communicate halos for w[1,2], pv_edge[1,2], rho_edge[1,2] ! - call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge', .true.) + call exchange_halo_group(domain, 'dynamics:w,pv_edge,rho_edge', config_gpu_aware_mpi) end if - !!$acc update device(w,pv_edge,rho_edge) + !$acc update device(w,pv_edge,rho_edge) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! set the zero-gradient condition on w for regional_MPAS @@ -2486,9 +2488,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! w halo values needs resetting after regional boundary update call mpas_pool_get_array(state, 'w', w, 2) - !!$acc update self(w) - call exchange_halo_group(domain, 'dynamics:w', .true.) - !!$acc update device(w) + !$acc update self(w) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:w', config_gpu_aware_mpi) + !$acc update device(w) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if ! end of regional_MPAS addition @@ -2504,9 +2506,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call mpas_pool_get_array(state, 'theta_m', theta_m, 2) call mpas_pool_get_array(diag, 'pressure_p', pressure_p) call mpas_pool_get_array(diag, 'rtheta_p', rtheta_p) - !!$acc update self(theta_m,pressure_p,rtheta_p) - call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p', .true.) - !!$acc update device(theta_m,pressure_p,rtheta_p) + !$acc update self(theta_m,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:theta_m,pressure_p,rtheta_p', config_gpu_aware_mpi) + !$acc update device(theta_m,pressure_p,rtheta_p) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') ! @@ -2568,16 +2570,16 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) call advance_scalars('scalars', domain, rk_step, rk_timestep, config_monotonic, config_positive_definite, & - config_time_integration_order, config_split_dynamics_transport, exchange_halo_group) + config_time_integration_order, config_split_dynamics_transport, config_gpu_aware_mpi, exchange_halo_group) if (config_apply_lbcs) then ! adjust boundary tendencies for regional_MPAS scalar transport MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') ! need to fill halo for horizontal filter call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !!$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars', .true.) - !!$acc update device(scalars_2) + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2606,9 +2608,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) if (rk_step < 3) then MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !!$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars', .true.) - !!$acc update device(scalars_2) + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') end if @@ -2737,9 +2739,9 @@ subroutine atm_srk3(domain, dt, itimestep, exchange_halo_group) MPAS_ACC_TIMER_START('atm_srk3: halo_exchanges + ACC_data_xfer') call mpas_pool_get_array(state, 'scalars', scalars_2, 2) - !!$acc update self(scalars_2) - call exchange_halo_group(domain, 'dynamics:scalars', .true.) - !!$acc update device(scalars_2) + !$acc update self(scalars_2) if (.not. config_gpu_aware_mpi) + call exchange_halo_group(domain, 'dynamics:scalars', config_gpu_aware_mpi) + !$acc update device(scalars_2) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_srk3: halo_exchanges + ACC_data_xfer') allocate(scalars_driving(num_scalars,nVertLevels,nCells+1)) @@ -2793,7 +2795,7 @@ end subroutine atm_srk3 ! !----------------------------------------------------------------------- subroutine advance_scalars(field_name, domain, rk_step, rk_timestep, config_monotonic, config_positive_definite, & - config_time_integration_order, config_split_dynamics_transport, exchange_halo_group) + config_time_integration_order, config_split_dynamics_transport, config_gpu_aware_mpi, exchange_halo_group) implicit none @@ -2806,6 +2808,7 @@ subroutine advance_scalars(field_name, domain, rk_step, rk_timestep, config_mono logical, intent(in) :: config_positive_definite integer, intent(in) :: config_time_integration_order logical, intent(in) :: config_split_dynamics_transport + logical, intent(in) :: config_gpu_aware_mpi procedure (halo_exchange_routine) :: exchange_halo_group ! Local variables @@ -2937,7 +2940,7 @@ subroutine advance_scalars(field_name, domain, rk_step, rk_timestep, config_mono edgeThreadStart(thread), edgeThreadEnd(thread), & cellSolveThreadStart(thread), cellSolveThreadEnd(thread), & scalar_old_arr, scalar_new_arr, s_max_arr, s_min_arr, wdtn_arr, & - flux_array, flux_upwind_tmp_arr, flux_tmp_arr, & + flux_array, flux_upwind_tmp_arr, flux_tmp_arr, config_gpu_aware_mpi, & exchange_halo_group, & advance_density=config_split_dynamics_transport, rho_zz_int=rho_zz_int) end if @@ -4816,7 +4819,7 @@ subroutine atm_advance_scalars_mono(field_name, block, tend, state, diag, mesh, cellStart, cellEnd, edgeStart, edgeEnd, & cellSolveStart, cellSolveEnd, & scalar_old, scalar_new, s_max, s_min, wdtn, flux_arr, & - flux_upwind_tmp, flux_tmp, exchange_halo_group, advance_density, rho_zz_int) + flux_upwind_tmp, flux_tmp, config_gpu_aware_mpi, exchange_halo_group, advance_density, rho_zz_int) implicit none @@ -4837,6 +4840,7 @@ subroutine atm_advance_scalars_mono(field_name, block, tend, state, diag, mesh, real (kind=RKIND), dimension(:,:), intent(inout) :: wdtn real (kind=RKIND), dimension(:,:), intent(inout) :: flux_arr real (kind=RKIND), dimension(:,:), intent(inout) :: flux_upwind_tmp, flux_tmp + logical, intent(in) :: config_gpu_aware_mpi procedure (halo_exchange_routine) :: exchange_halo_group logical, intent(in), optional :: advance_density real (kind=RKIND), dimension(:,:), intent(inout), optional :: rho_zz_int @@ -4915,7 +4919,7 @@ subroutine atm_advance_scalars_mono(field_name, block, tend, state, diag, mesh, edgesOnCell, edgesOnCell_sign, nEdgesOnCell, fnm, fnp, rdnw, nAdvCellsForEdge, & advCellsForEdge, adv_coefs, adv_coefs_3rd, scalar_old, scalar_new, s_max, s_min, & wdtn, scale_arr, flux_arr, flux_upwind_tmp, flux_tmp, & - bdyMaskCell, bdyMaskEdge, & + bdyMaskCell, bdyMaskEdge, config_gpu_aware_mpi, & exchange_halo_group, advance_density, rho_zz_int) call mpas_deallocate_scratch_field(scale) @@ -4963,7 +4967,7 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge edgesOnCell, edgesOnCell_sign, nEdgesOnCell, fnm, fnp, rdnw, nAdvCellsForEdge, & advCellsForEdge, adv_coefs, adv_coefs_3rd, scalar_old, scalar_new, s_max, s_min, & wdtn, scale_arr, flux_arr, flux_upwind_tmp, flux_tmp, & - bdyMaskCell, bdyMaskEdge, & + bdyMaskCell, bdyMaskEdge, config_gpu_aware_mpi, & exchange_halo_group, advance_density, rho_zz_int) use mpas_atm_dimensions, only : nVertLevels @@ -4979,6 +4983,7 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge real (kind=RKIND), intent(in) :: dt integer, intent(in) :: cellStart, cellEnd, edgeStart, edgeEnd integer, intent(in) :: cellSolveStart, cellSolveEnd + logical, intent(in) :: config_gpu_aware_mpi procedure (halo_exchange_routine) :: exchange_halo_group logical, intent(in), optional :: advance_density real (kind=RKIND), dimension(:,:), intent(inout), optional :: rho_zz_int @@ -5075,17 +5080,17 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge !$acc end parallel MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !!$acc update self(scalars_old) + !$acc update self(scalars_old) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER !$OMP MASTER - call exchange_halo_group(block % domain, 'dynamics:'//trim(field_name)//'_old', .true.) + call exchange_halo_group(block % domain, 'dynamics:'//trim(field_name)//'_old', config_gpu_aware_mpi) !$OMP END MASTER !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !!$acc update device(scalars_old) + !$acc update device(scalars_old) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') ! @@ -5482,17 +5487,17 @@ subroutine atm_advance_scalars_mono_work(field_name, block, state, nCells, nEdge ! MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !!$acc update self(scale_arr) + !$acc update self(scale_arr) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$OMP BARRIER !$OMP MASTER - call exchange_halo_group(block % domain, 'dynamics:scale', .true.) + call exchange_halo_group(block % domain, 'dynamics:scale', config_gpu_aware_mpi) !$OMP END MASTER !$OMP BARRIER MPAS_ACC_TIMER_START('atm_advance_scalars_mono [ACC_data_xfer]') - !!$acc update device(scale_arr) + !$acc update device(scale_arr) if (.not. config_gpu_aware_mpi) MPAS_ACC_TIMER_STOP('atm_advance_scalars_mono [ACC_data_xfer]') !$acc parallel diff --git a/src/core_atmosphere/mpas_atm_halos.F b/src/core_atmosphere/mpas_atm_halos.F index 955f9b5ea0..19c4a5be46 100644 --- a/src/core_atmosphere/mpas_atm_halos.F +++ b/src/core_atmosphere/mpas_atm_halos.F @@ -62,18 +62,23 @@ subroutine atm_build_halo_groups(domain, ierr) ! Local variables character(len=StrKIND), pointer :: config_halo_exch_method + logical, pointer :: config_gpu_aware_mpi ! ! Determine from the namelist option config_halo_exch_method which halo exchange method to employ ! call mpas_pool_get_config(domain % blocklist % configs, 'config_halo_exch_method', config_halo_exch_method) + call mpas_pool_get_config(domain % blocklist % configs, 'config_gpu_aware_mpi', config_gpu_aware_mpi) if (trim(config_halo_exch_method) == 'mpas_dmpar') then call mpas_log_write('') call mpas_log_write('*** Using ''mpas_dmpar'' routines for exchanging halos') call mpas_log_write('') + if (config_gpu_aware_mpi) then + call mpas_log_write('GPU-aware MPI is not presently supported with config_halo_exch_method = mpas_dmpar',MPAS_LOG_CRIT) + end if ! ! Set up halo exchange groups used during atmosphere core initialization !