Skip to content

Commit 5315551

Browse files
committed
checkpoints: acc pack + cuda aware mpi working
1 parent 2d1da49 commit 5315551

File tree

1 file changed

+74
-17
lines changed

1 file changed

+74
-17
lines changed

src/framework/mpas_halo.F

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ subroutine mpas_halo_exch_group_complete(domain, groupName, iErr)
280280

281281
call refactor_lists(domain, groupName, iErr)
282282

283+
if ( newGroup% nGroupSendNeighbors <=0 ) then
284+
!call mpas_log_write('No send neighbors for halo exchange group '//trim(groupName))
285+
return
286+
end if
287+
288+
283289
! Always copy in the main data member first
284290
!$acc enter data copyin(newGroup)
285291
! Then the data in the members of the type
@@ -541,6 +547,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
541547
use mpas_derived_types, only : domain_type, mpas_halo_group, MPAS_HALO_REAL, MPAS_LOG_CRIT
542548
use mpas_pool_routines, only : mpas_pool_get_array
543549
use mpas_log, only : mpas_log_write
550+
use mpas_kind_types, only : RKIND
544551

545552
! Parameters
546553
#ifdef MPAS_USE_MPI_F08
@@ -588,7 +595,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
588595
integer :: maxNRecvList
589596
integer, dimension(:,:,:), CONTIGUOUS pointer :: recvListSrc, recvListDst
590597
integer, dimension(:), CONTIGUOUS pointer :: unpackOffsets
591-
598+
real (kind=RKIND), dimension(:), pointer :: sendBufptr, recvBufptr
592599

593600
if (present(iErr)) then
594601
iErr = 0
@@ -611,6 +618,10 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
611618
messageType=MPAS_LOG_CRIT)
612619
end if
613620

621+
if ( group% nGroupSendNeighbors <=0 ) then
622+
!call mpas_log_write('group has no halo exchanges: '//trim(groupName))
623+
return
624+
end if
614625
!
615626
! Get the rank of this task and the MPI communicator to use from the first field in
616627
! the group; all fields should be using the same communicator, so this should not
@@ -623,7 +634,11 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
623634
#endif
624635
rank = group % fields(1) % compactHaloInfo(8)
625636

626-
!$acc data present(group % recvBuf(:), group % sendBuf(:))
637+
sendBufptr => group % sendBuf
638+
recvBufptr => group % recvBuf
639+
640+
!!!$acc data present(group % recvBuf(:), group % sendBuf(:))
641+
!$acc data present(sendBufptr,recvBufptr)
627642

628643
!
629644
! Initiate non-blocking MPI receives for all neighbors
@@ -633,8 +648,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
633648
bufstart = group % groupRecvOffsets(i)
634649
bufend = group % groupRecvOffsets(i) + group % groupRecvCounts(i) - 1
635650
!TO DO: how do we determine appropriate type here?
636-
!$acc host_data use_device(group % recvBuf)
637-
call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, &
651+
! !$acc host_data use_device(group % recvBuf)
652+
! call MPI_Irecv(group % recvBuf(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, &
653+
! group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, &
654+
! group % recvRequests(i), mpi_ierr)
655+
!$acc host_data use_device(recvBufptr)
656+
call MPI_Irecv(recvBufptr(bufstart:bufend), group % groupRecvCounts(i), MPI_REALKIND, &
638657
group % groupRecvNeighbors(i), group % groupRecvNeighbors(i), comm, &
639658
group % recvRequests(i), mpi_ierr)
640659
!$acc end host_data
@@ -695,7 +714,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
695714
end do
696715
!$acc end kernels
697716
! !$acc end data
698-
717+
!!$acc update device(group % sendBuf(:))
699718
!
700719
! Packing code for 2-d real-valued fields
701720
!
@@ -731,7 +750,7 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
731750
!$acc end kernels
732751
! !$acc end data
733752
! !$acc end data
734-
753+
!!$acc update device(group % sendBuf(:))
735754
!
736755
! Packing code for 3-d real-valued fields
737756
!
@@ -763,11 +782,25 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
763782
end do
764783
!$acc end kernels
765784
! !$acc end data
785+
!!$acc update device(group % sendBuf(:))
766786

767787
end select
768788
end if
769789
end do
770790

791+
do i = 1, group % nFields
792+
if (group % fields(i) % fieldType == MPAS_HALO_REAL) then
793+
select case (group % fields(i) % nDims)
794+
case (1)
795+
!$acc exit data delete(group % fields(i) % r1arr(:))
796+
case (2)
797+
!$acc exit data delete(group % fields(i) % r2arr(:,:))
798+
case (3)
799+
!$acc exit data delete(group % fields(i) % r3arr(:,:,:))
800+
end select
801+
end if
802+
end do
803+
771804
!
772805
! Initiate non-blocking sends to all neighbors
773806
!
@@ -776,8 +809,12 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
776809
bufstart = group % groupSendOffsets(i)
777810
bufend = group % groupSendOffsets(i) + group % groupSendCounts(i) - 1
778811
!TO DO: how do we determine appropriate type here?
779-
!$acc host_data use_device(group % sendBuf)
780-
call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, &
812+
! !$acc host_data use_device(group % sendBuf)
813+
! call MPI_Isend(group % sendBuf(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, &
814+
! group % groupSendNeighbors(i), rank, comm, &
815+
! group % sendRequests(i), mpi_ierr)
816+
!$acc host_data use_device(sendBufptr)
817+
call MPI_Isend(sendBufptr(bufstart:bufend), group % groupSendCounts(i), MPI_REALKIND, &
781818
group % groupSendNeighbors(i), rank, comm, &
782819
group % sendRequests(i), mpi_ierr)
783820
!$acc end host_data
@@ -835,7 +872,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
835872
!
836873
! Unpack recv buffer from all neighbors for current field
837874
!
838-
!$acc kernels default(present)
875+
!$acc update host(group % recvBuf(:))
876+
!$acc wait
877+
!!$acc kernels default(present)
839878
do iHalo = 1, nHalos
840879
do j = 1, maxNRecvList
841880
if (j <= nRecvLists(iHalo,iEndp)) then
@@ -845,8 +884,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
845884
end if
846885
end do
847886
end do
848-
!$acc end kernels
849-
!$acc exit data copyout(group % fields(i) % r1arr(:))
887+
!!$acc end kernels
888+
!!$acc exit data copyout(group % fields(i) % r1arr(:))
850889

851890
!
852891
! Unpacking code for 2-d real-valued fields
@@ -855,7 +894,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
855894
!
856895
! Unpack recv buffer from all neighbors for current field
857896
!
858-
!$acc kernels default(present)
897+
!$acc update host(group % recvBuf(:))
898+
!$acc wait
899+
!!$acc kernels default(present)
859900
do iHalo = 1, nHalos
860901
do j = 1, maxNRecvList
861902
do i1 = 1, dim1
@@ -867,8 +908,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
867908
end do
868909
end do
869910
end do
870-
!$acc end kernels
871-
!$acc exit data copyout(group % fields(i) % r2arr(:,:))
911+
!!$acc end kernels
912+
!!$acc exit data copyout(group % fields(i) % r2arr(:,:))
872913

873914
!
874915
! Unpacking code for 3-d real-valued fields
@@ -877,7 +918,9 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
877918
!
878919
! Unpack recv buffer from all neighbors for current field
879920
!
880-
!$acc kernels default(present)
921+
!$acc update host(group % recvBuf(:))
922+
!$acc wait
923+
!!$acc kernels default(present)
881924
do iHalo = 1, nHalos
882925
do j = 1, maxNRecvList
883926
do i2 = 1, dim2
@@ -892,8 +935,8 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
892935
end do
893936
end do
894937
end do
895-
!$acc end kernels
896-
!$acc exit data copyout(group % fields(i) % r3arr(:,:,:))
938+
!!$acc end kernels
939+
!!$acc exit data copyout(group % fields(i) % r3arr(:,:,:))
897940

898941
end select
899942
end if
@@ -903,6 +946,20 @@ subroutine mpas_halo_exch_group_full_halo_exch(domain, groupName, iErr)
903946
! For the present(group % recvBuf(:), group % sendBuf(:))
904947
!$acc end data
905948

949+
! do i = 1, group % nFields
950+
! if (group % fields(i) % fieldType == MPAS_HALO_REAL) then
951+
! select case (group % fields(i) % nDims)
952+
! case (1)
953+
! !$acc exit data copyout(group % fields(i) % r1arr(:))
954+
! case (2)
955+
! !$acc exit data copyout(group % fields(i) % r2arr(:,:))
956+
! case (3)
957+
! !$acc exit data copyout(group % fields(i) % r3arr(:,:,:))
958+
! end select
959+
! end if
960+
! end do
961+
962+
906963
!
907964
! Nullify array pointers - not necessary for correctness, but helpful when debugging
908965
! to not leave pointers to what might later be incorrect targets

0 commit comments

Comments
 (0)