Skip to content

Commit 788039f

Browse files
committed
Add new MPI_Op -> MPI_LOR
1 parent cf03a11 commit 788039f

File tree

4 files changed

+71
-6
lines changed

4 files changed

+71
-6
lines changed

src/mpi.f90

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module mpi
2121
real(8), parameter :: MPI_IN_PLACE = -1002
2222
integer, parameter :: MPI_SUM = -2300
2323
integer, parameter :: MPI_MAX = -2301
24+
integer, parameter :: MPI_LOR = -2302
2425
integer, parameter :: MPI_INFO_NULL = -2000
2526
integer, parameter :: MPI_STATUS_SIZE = 5
2627
integer :: MPI_STATUS_IGNORE = 0
@@ -99,6 +100,7 @@ module mpi
99100
module procedure MPI_Allreduce_1D_recv_proc
100101
module procedure MPI_Allreduce_1D_real_proc
101102
module procedure MPI_Allreduce_1D_int_proc
103+
module procedure MPI_Allreduce_scalar_logical_proc
102104
end interface
103105

104106
interface MPI_Gatherv
@@ -168,14 +170,16 @@ module mpi
168170
contains
169171

170172
integer(kind=MPI_HANDLE_KIND) function handle_mpi_op_f2c(op_f) result(c_op)
171-
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max
173+
use mpi_c_bindings, only: c_mpi_op_f2c, c_mpi_sum, c_mpi_max, c_mpi_lor
172174
integer, intent(in) :: op_f
173175
if (op_f == MPI_SUM) then
174176
c_op = c_mpi_sum
175177
else if (op_f == MPI_MAX) then
176178
c_op = c_MPI_MAX
179+
else if (op_f == MPI_LOR) then
180+
c_op = c_mpi_lor
177181
else
178-
c_op = c_mpi_op_f2c(op_f)
182+
c_op = c_mpi_op_f2c(op_f) ! For other operations, use the C binding
179183
end if
180184
end function
181185

@@ -795,6 +799,35 @@ subroutine MPI_Allreduce_1D_int_proc(sendbuf, recvbuf, count, datatype, op, comm
795799
end if
796800
end subroutine MPI_Allreduce_1D_int_proc
797801

802+
subroutine MPI_Allreduce_scalar_logical_proc(sendbuf, recvbuf, count, datatype, op, comm, ierror)
803+
use iso_c_binding, only: c_int, c_ptr, c_loc
804+
use mpi_c_bindings, only: c_mpi_allreduce, c_mpi_comm_f2c
805+
logical, intent(in), target :: sendbuf
806+
logical, intent(out), target :: recvbuf
807+
integer, intent(in) :: count, datatype, op, comm
808+
integer, intent(out), optional :: ierror
809+
type(c_ptr) :: sendbuf_ptr, recvbuf_ptr
810+
integer(kind=MPI_HANDLE_KIND) :: c_datatype, c_op, c_comm
811+
integer(c_int) :: local_ierr
812+
813+
sendbuf_ptr = c_loc(sendbuf)
814+
recvbuf_ptr = c_loc(recvbuf)
815+
c_datatype = handle_mpi_datatype_f2c(datatype)
816+
c_op = handle_mpi_op_f2c(op)
817+
818+
c_comm = handle_mpi_comm_f2c(comm)
819+
820+
local_ierr = c_mpi_allreduce(sendbuf_ptr, recvbuf_ptr, count, c_datatype, c_op, c_comm)
821+
822+
if (present(ierror)) then
823+
ierror = local_ierr
824+
else
825+
if (local_ierr /= MPI_SUCCESS) then
826+
print *, "MPI_Allreduce_1D_recv_proc failed with error code: ", local_ierr
827+
end if
828+
end if
829+
end subroutine MPI_Allreduce_scalar_logical_proc
830+
798831
function MPI_Wtime_proc() result(time)
799832
use mpi_c_bindings, only: c_mpi_wtime
800833
real(8) :: time

src/mpi_c_bindings.f90

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,21 @@ module mpi_c_bindings
1212
type(c_ptr), bind(C, name="c_MPI_STATUSES_IGNORE") :: c_mpi_statuses_ignore
1313
type(c_ptr), bind(C, name="c_MPI_IN_PLACE") :: c_mpi_in_place
1414
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INFO_NULL") :: c_mpi_info_null
15+
1516
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_DOUBLE") :: c_mpi_double
1617
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_FLOAT") :: c_mpi_float
1718
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_REAL") :: c_mpi_real
1819
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_INT") :: c_mpi_int
19-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
20-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
21-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
22-
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
2320
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOGICAL") :: c_mpi_logical
2421
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_CHARACTER") :: c_mpi_character
2522

23+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_SUM") :: c_mpi_sum
24+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_MAX") :: c_mpi_max
25+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_LOR") :: c_mpi_lor
26+
27+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_WORLD") :: c_mpi_comm_world
28+
integer(kind=MPI_HANDLE_KIND), bind(C, name="c_MPI_COMM_NULL") :: c_mpi_comm_null
29+
2630
interface
2731

2832
function c_mpi_comm_f2c(comm_f) bind(C, name="MPI_Comm_f2c")

src/mpi_constants.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ MPI_Op c_MPI_SUM = MPI_SUM;
2626

2727
MPI_Op c_MPI_MAX = MPI_MAX;
2828

29+
MPI_Op c_MPI_LOR = MPI_LOR;
30+
2931
// Communicators Declarations
3032

3133
MPI_Comm c_MPI_COMM_NULL = MPI_COMM_NULL;

tests/allreduce_lor.f90

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
program mre_mpi_lor_allreduce
2+
use mpi
3+
implicit none
4+
5+
integer :: ierr, rank, size
6+
logical :: local_flag, global_flag
7+
8+
call MPI_INIT(ierr)
9+
if (ierr /= MPI_SUCCESS) error stop "MPI_INIT failed"
10+
11+
call MPI_COMM_RANK(MPI_COMM_WORLD, rank, ierr)
12+
call MPI_COMM_SIZE(MPI_COMM_WORLD, size, ierr)
13+
14+
! Initialize the local flag: True if this is the 0th rank, False otherwise
15+
local_flag = (rank == 0)
16+
17+
! Perform logical OR reduction across all processes
18+
call MPI_ALLREDUCE(local_flag, global_flag, 1, MPI_LOGICAL, MPI_LOR, MPI_COMM_WORLD, ierr)
19+
if (global_flag /= .true.) error stop "MPI_ALLREDUCE failed"
20+
21+
print *, 'Rank', rank, ': global_flag =', global_flag
22+
23+
call MPI_FINALIZE(ierr)
24+
if (ierr /= MPI_SUCCESS) error stop "MPI_FINALIZE failed"
25+
26+
end program mre_mpi_lor_allreduce

0 commit comments

Comments
 (0)