Skip to content

Commit f5a7287

Browse files
committed
Optimized packing and unpacking loops. Adding timers and other cleanup
1 parent ab470da commit f5a7287

File tree

1 file changed

+73
-69
lines changed

1 file changed

+73
-69
lines changed

src/framework/mpas_halo.F

Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717
!> communicating the halos of all fields in a group.
1818
!
1919
!-----------------------------------------------------------------------
20+
21+
#ifdef MPAS_OPENACC
22+
#define MPAS_ACC_TIMER_START(X) call mpas_timer_start(X)
23+
#define MPAS_ACC_TIMER_STOP(X) call mpas_timer_stop(X)
24+
#else
25+
#define MPAS_ACC_TIMER_START(X)
26+
#define MPAS_ACC_TIMER_STOP(X)
27+
#endif
28+
2029
module mpas_halo
2130

2231
implicit none
@@ -281,9 +290,8 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr)
281290
call refactor_lists(domain, groupName, iErr)
282291

283292
if ( newGroup% nGroupSendNeighbors <=0 ) then
284-
!call mpas_log_write('No send neighbors for halo exchange group '//trim(groupName))
285293
return
286-
end if
294+
end if
287295

288296

289297
! Always copy in the main data member first
@@ -547,7 +555,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
547555
use mpas_derived_types, only : domain_type, mpas_halo_group, MPAS_HALO_REAL, MPAS_LOG_CRIT
548556
use mpas_pool_routines, only : mpas_pool_get_array
549557
use mpas_log, only : mpas_log_write
550-
use mpas_kind_types, only : RKIND
558+
use mpas_timer, only : mpas_timer_start, mpas_timer_stop
559+
551560

552561
! Parameters
553562
#ifdef MPAS_USE_MPI_F08
@@ -595,12 +604,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
595604
integer :: maxNRecvList
596605
integer, dimension(:,:,:), CONTIGUOUS pointer :: recvListSrc, recvListDst
597606
integer, dimension(:), CONTIGUOUS pointer :: unpackOffsets
598-
real (kind=RKIND), dimension(:), pointer :: sendBufptr, recvBufptr
607+
599608

600609
if (present(iErr)) then
601610
iErr = 0
602611
end if
603612

613+
604614
!
605615
! Find this halo exhange group in the list of groups
606616
!
@@ -618,10 +628,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
618628
messageType=MPAS_LOG_CRIT)
619629
end if
620630

621-
if ( group% nGroupSendNeighbors <=0 ) then
622-
!call mpas_log_write('group has no halo exchanges: '//trim(groupName))
631+
if ( group% nGroupSendNeighbors <= 0 ) then
623632
return
624-
end if
633+
end if
634+
635+
call mpas_timer_start('full_halo_exch')
625636
!
626637
! Get the rank of this task and the MPI communicator to use from the first field in
627638
! the group; all fields should be using the same communicator, so this should not
@@ -634,11 +645,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
634645
#endif
635646
rank = group % fields(1) % compactHaloInfo(8)
636647

637-
sendBufptr => group % sendBuf
638-
recvBufptr => group % recvBuf
639-
640-
!!!$acc data present(group % recvBuf(:), group % sendBuf(:))
641-
!$acc data present(sendBufptr,recvBufptr)
648+
!$acc data present(group % recvBuf(:), group % sendBuf(:))
642649

643650
!
644651
! Initiate non-blocking MPI receives for all neighbors
@@ -648,12 +655,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
648655
bufstart = group % groupRecvOffsets(i)
649656
bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1
650657
!TO DO: how do we determine appropriate type here?
651-
! !$acc host_data use_device(group % recvBuf)
652-
! call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, &
653-
! group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, &
654-
! group % recvRequests(i), mpi_ierr)
655-
!$acc host_data use_device(recvBufptr)
656-
call MPI_Irecv(recvBufptr(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, &
658+
!$acc host_data use_device(group % recvBuf)
659+
call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, &
657660
group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, &
658661
group % recvRequests(i), mpi_ierr)
659662
!$acc end host_data
@@ -695,14 +698,18 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
695698
call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), &
696699
group % fields(i) % r1arr, timeLevel=group % fields(i) % timeLevel)
697700

698-
! !$acc data copyin(group % fields(i) % r1arr(:))
701+
MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]')
699702
!$acc enter data copyin(group % fields(i) % r1arr(:))
703+
MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]')
700704
!
701705
! Pack send buffer for all neighbors for current field
702706
!
703-
!$acc kernels default(present)
707+
call mpas_timer_start('packing_halo_exch')
708+
!$acc parallel default(present)
709+
!$acc loop gang collapse(2)
704710
do iEndp = 1, nSendEndpts
705711
do iHalo = 1, nHalos
712+
!$acc loop vector
706713
do j = 1, maxNSendList
707714
if (j <= nSendLists(iHalo,iEndp)) then
708715
idxBuf = packOffsets(iEndp) + sendListDst(j,iHalo,iEndp)
@@ -712,9 +719,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
712719
end do
713720
end do
714721
end do
715-
!$acc end kernels
716-
! !$acc end data
717-
!!$acc update device(group % sendBuf(:))
722+
!$acc end parallel
723+
call mpas_timer_stop('packing_halo_exch')
724+
718725
!
719726
! Packing code for 2-d real-valued fields
720727
!
@@ -725,18 +732,23 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
725732
!
726733
! Pack send buffer for all neighbors for current field
727734
!
728-
735+
729736
! Use data regions for specificity and so the reference or attachment counters are easier to make sense of
730737
! Present should also cause an attach action. OpenACC Spec2.7 Section 2.7.2 describes 'attach action'
731738
! !$acc data present(group) present(group % fields(i)) present(group % sendBuf(:), group % fields(i) % sendListSrc(:,:,:))
732739
! !$acc data copyin(group % fields(i) % r2arr(:,:))
740+
MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]')
733741
!$acc enter data copyin(group % fields(i) % r2arr(:,:))
742+
MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]')
734743

744+
call mpas_timer_start('packing_halo_exch')
735745
! Kernels is good enough, use default present to force a run-time error if programmer forgot something
736-
!$acc kernels default(present)
746+
!$acc parallel default(present)
747+
!$acc loop gang collapse(3)
737748
do iEndp = 1, nSendEndpts
738749
do iHalo = 1, nHalos
739750
do j = 1, maxNSendList
751+
!$acc loop vector
740752
do i1 = 1, dim1
741753
if (j <= nSendLists(iHalo,iEndp)) then
742754
idxBuf = packOffsets(iEndp) + dim1 * (sendListDst(j,iHalo,iEndp) - 1) + i1
@@ -747,27 +759,30 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
747759
end do
748760
end do
749761
end do
750-
!$acc end kernels
751-
! !$acc end data
752-
! !$acc end data
753-
!!$acc update device(group % sendBuf(:))
762+
!$acc end parallel
763+
call mpas_timer_stop('packing_halo_exch')
764+
754765
!
755766
! Packing code for 3-d real-valued fields
756767
!
757768
case (3)
758769
call mpas_pool_get_array(domain % blocklist % allFields, trim(group % fields(i) % fieldName), &
759770
group % fields(i) % r3arr, group % fields(i) % timeLevel)
760-
! !$acc data copyin(group % fields(i) % r3arr(:,:,:))
771+
MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]')
761772
!$acc enter data copyin(group % fields(i) % r3arr(:,:,:))
773+
MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]')
762774

763775
!
764776
! Pack send buffer for all neighbors for current field
765777
!
766-
!$acc kernels default(present)
778+
call mpas_timer_start('packing_halo_exch')
779+
!$acc parallel default(present)
780+
!$acc loop gang collapse(4)
767781
do iEndp = 1, nSendEndpts
768782
do iHalo = 1, nHalos
769783
do j = 1, maxNSendList
770784
do i2 = 1, dim2
785+
!$acc loop vector
771786
do i1 = 1, dim1
772787
if (j <= nSendLists(iHalo,iEndp)) then
773788
idxBuf = packOffsets(iEndp) + dim1*dim2*(sendListDst(j,iHalo,iEndp) - 1) &
@@ -780,9 +795,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
780795
end do
781796
end do
782797
end do
783-
!$acc end kernels
784-
! !$acc end data
785-
!!$acc update device(group % sendBuf(:))
798+
!$acc end parallel
799+
call mpas_timer_stop('packing_halo_exch')
786800

787801
end select
788802
end if
@@ -796,12 +810,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
796810
bufstart = group % groupSendOffsets(i)
797811
bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1
798812
!TO DO: how do we determine appropriate type here?
799-
! !$acc host_data use_device(group % sendBuf)
800-
! call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, &
801-
! group % groupSendNeighbors(i), rank, comm, &
802-
! group % sendRequests(i), mpi_ierr)
803-
!$acc host_data use_device(sendBufptr)
804-
call MPI_Isend(sendBufptr(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, &
813+
!$acc host_data use_device(group % sendBuf)
814+
call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, &
805815
group % groupSendNeighbors(i), rank, comm, &
806816
group % sendRequests(i), mpi_ierr)
807817
!$acc end host_data
@@ -859,10 +869,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
859869
!
860870
! Unpack recv buffer from all neighbors for current field
861871
!
862-
!!$acc update host(group % recvBuf(:))
863-
!!$acc wait
864-
!$acc kernels default(present)
872+
call mpas_timer_start('unpacking_halo_exch')
873+
!$acc parallel default(present)
874+
!$acc loop gang
865875
do iHalo = 1, nHalos
876+
!$acc loop vector
866877
do j = 1, maxNRecvList
867878
if (j <= nRecvLists(iHalo,iEndp)) then
868879
idxArr = recvListDst(j,iHalo,iEndp)
@@ -871,8 +882,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
871882
end if
872883
end do
873884
end do
874-
!$acc end kernels
875-
!!$acc exit data copyout(group % fields(i) % r1arr(:))
885+
!$acc end parallel
886+
call mpas_timer_stop('unpacking_halo_exch')
876887

877888
!
878889
! Unpacking code for 2-d real-valued fields
@@ -881,11 +892,13 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
881892
!
882893
! Unpack recv buffer from all neighbors for current field
883894
!
884-
!!$acc update host(group % recvBuf(:))
885-
!!$acc wait
886-
!$acc kernels default(present)
895+
call mpas_timer_start('unpacking_halo_exch')
896+
!$acc parallel default(present)
897+
!$acc loop gang
887898
do iHalo = 1, nHalos
899+
!$acc loop worker
888900
do j = 1, maxNRecvList
901+
!$acc loop vector
889902
do i1 = 1, dim1
890903
if (j <= nRecvLists(iHalo,iEndp)) then
891904
idxArr = recvListDst(j,iHalo,iEndp)
@@ -895,8 +908,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
895908
end do
896909
end do
897910
end do
898-
!$acc end kernels
899-
!!$acc exit data copyout(group % fields(i) % r2arr(:,:))
911+
!$acc end parallel
912+
call mpas_timer_stop('unpacking_halo_exch')
900913

901914
!
902915
! Unpacking code for 3-d real-valued fields
@@ -905,11 +918,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
905918
!
906919
! Unpack recv buffer from all neighbors for current field
907920
!
908-
!!$acc update host(group % recvBuf(:))
909-
!!$acc wait
910-
!$acc kernels default(present)
921+
call mpas_timer_start('unpacking_halo_exch')
922+
!$acc parallel default(present)
923+
!$acc loop gang collapse(2)
911924
do iHalo = 1, nHalos
912925
do j = 1, maxNRecvList
926+
!$acc loop vector collapse(2)
913927
do i2 = 1, dim2
914928
do i1 = 1, dim1
915929
if (j <= nRecvLists(iHalo,iEndp)) then
@@ -922,14 +936,15 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
922936
end do
923937
end do
924938
end do
925-
!$acc end kernels
926-
!!$acc exit data copyout(group % fields(i) % r3arr(:,:,:))
939+
!$acc end parallel
940+
call mpas_timer_stop('unpacking_halo_exch')
927941

928942
end select
929943
end if
930944
end do
931945
end do
932946

947+
MPAS_ACC_TIMER_START('halo_exch [ACC_data_xfer]')
933948
do i = 1, group % nFields
934949
if (group % fields(i) % fieldType == MPAS_HALO_REAL) then
935950
select case (group % fields(i) % nDims)
@@ -958,20 +973,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
958973

959974
! For the present(group % recvBuf(:), group % sendBuf(:))
960975
!$acc end data
961-
! !$acc wait
962-
! do i = 1, group % nFields
963-
! if (group % fields(i) % fieldType == MPAS_HALO_REAL) then
964-
! select case (group % fields(i) % nDims)
965-
! case (1)
966-
! !$acc exit data copyout(group % fields(i) % r1arr(:))
967-
! case (2)
968-
! !$acc exit data copyout(group % fields(i) % r2arr(:,:))
969-
! case (3)
970-
! !$acc exit data copyout(group % fields(i) % r3arr(:,:,:))
971-
! end select
972-
! end if
973-
! end do
974-
! !$acc wait
976+
MPAS_ACC_TIMER_STOP('halo_exch [ACC_data_xfer]')
975977

976978
!
977979
! Nullify array pointers - not necessary for correctness, but helpful when debugging
@@ -992,6 +994,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
992994
!
993995
call MPI_Waitall(group % nGroupSendNeighbors, group % sendRequests, MPI_STATUSES_IGNORE, mpi_ierr)
994996

997+
call mpas_timer_stop('full_halo_exch')
998+
995999
end subroutine mpas_halo_exch_group_full_halo_exch
9961000

9971001

0 commit comments

Comments
 (0)