1414 * Copyright (c) 2016 Research Organization for Information Science
1515 * and Technology (RIST). All rights reserved.
1616 * Copyright (c) 2017 IBM Corporation. All rights reserved.
17+ * Copyright (c) 2025 Triad National Security, LLC. All rights reserved.
1718 * $COPYRIGHT$
1819 *
1920 * Additional copyrights may follow
3435#include "coll_base_topo.h"
3536#include "coll_base_util.h"
3637
38+ /*
39+ * if a > b return a- b otherwise 0
40+ */
41+ static inline size_t
42+ rectify_diff (size_t a , size_t b )
43+ {
44+ return a > b ? a - b : 0 ;
45+ }
46+
3747int
3848ompi_coll_base_bcast_intra_generic ( void * buffer ,
3949 size_t original_count ,
@@ -811,8 +821,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
811821 if (vrank & mask ) {
812822 int parent = (rank - mask + comm_size ) % comm_size ;
813823 /* Compute an upper bound on recv block size */
814- recv_count = count - vrank * scatter_count ;
815- if (recv_count < = 0 ) {
824+ recv_count = rectify_diff ( count , ( size_t )( vrank * scatter_count )) ;
825+ if (recv_count = = 0 ) {
816826 curr_count = 0 ;
817827 } else {
818828 /* Recv data from parent */
@@ -832,7 +842,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
832842 mask >>= 1 ;
833843 while (mask > 0 ) {
834844 if (vrank + mask < comm_size ) {
835- send_count = curr_count - scatter_count * mask ;
845+ send_count = rectify_diff ( curr_count , ( size_t )( scatter_count * mask )) ;
836846 if (send_count > 0 ) {
837847 int child = (rank + mask ) % comm_size ;
838848 err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -850,10 +860,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
850860 * Allgather by recursive doubling
851861 * Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
852862 */
853- size_t rem_count = count - vrank * scatter_count ;
863+ size_t rem_count = rectify_diff ( count , ( size_t )( vrank * scatter_count )) ;
854864 curr_count = (scatter_count < rem_count ) ? scatter_count : rem_count ;
855- if (curr_count < 0 )
856- curr_count = 0 ;
857865
858866 mask = 0x1 ;
859867 while (mask < comm_size ) {
@@ -866,9 +874,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
866874 if (vremote < comm_size ) {
867875 ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent ;
868876 ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent ;
869- recv_count = count - vremote_tree_root * scatter_count ;
870- if (recv_count < 0 )
871- recv_count = 0 ;
877+ recv_count = rectify_diff (count , (size_t )(vremote_tree_root * scatter_count ));
872878 err = ompi_coll_base_sendrecv ((char * )buf + send_offset ,
873879 curr_count , datatype , remote ,
874880 MCA_COLL_BASE_TAG_BCAST ,
@@ -877,7 +883,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
877883 MCA_COLL_BASE_TAG_BCAST ,
878884 comm , & status , rank );
879885 if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
880- recv_count = (int )(status ._ucount / datatype_size );
886+ recv_count = (size_t )(status ._ucount / datatype_size );
881887 curr_count += recv_count ;
882888 }
883889
@@ -913,7 +919,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
913919 MCA_COLL_BASE_TAG_BCAST ,
914920 comm , & status ));
915921 if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
916- recv_count = (int )(status ._ucount / datatype_size );
922+ recv_count = (size_t )(status ._ucount / datatype_size );
917923 curr_count += recv_count ;
918924 }
919925 }
@@ -988,8 +994,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
988994 if (vrank & mask ) {
989995 int parent = (rank - mask + comm_size ) % comm_size ;
990996 /* Compute an upper bound on recv block size */
991- recv_count = count - vrank * scatter_count ;
992- if (recv_count <= 0 ) {
997+ recv_count = rectify_diff ( count , ( size_t )( vrank * scatter_count )) ;
998+ if (0 == recv_count ) {
993999 curr_count = 0 ;
9941000 } else {
9951001 /* Recv data from parent */
@@ -1009,7 +1015,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10091015 mask >>= 1 ;
10101016 while (mask > 0 ) {
10111017 if (vrank + mask < comm_size ) {
1012- send_count = curr_count - scatter_count * mask ;
1018+ send_count = rectify_diff ( curr_count , ( size_t )( scatter_count * mask )) ;
10131019 if (send_count > 0 ) {
10141020 int child = (rank + mask ) % comm_size ;
10151021 err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -1023,33 +1029,41 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10231029 mask >>= 1 ;
10241030 }
10251031
1026- /* Allgather by a ring algorithm */
1032+ /* Allgather by a ring algorithm, using only unsigned types */
10271033 int left = (rank - 1 + comm_size ) % comm_size ;
10281034 int right = (rank + 1 ) % comm_size ;
1035+
1036+ /* The block we will send/recv in each step */
10291037 int send_block = vrank ;
10301038 int recv_block = (vrank - 1 + comm_size ) % comm_size ;
10311039
1032- for (int i = 1 ; i < comm_size ; i ++ ) {
1033- recv_count = (scatter_count < count - recv_block * scatter_count ) ?
1034- scatter_count : count - recv_block * scatter_count ;
1035- if (recv_count < 0 )
1036- recv_count = 0 ;
1037- ptrdiff_t recv_offset = recv_block * scatter_count * extent ;
1038-
1039- send_count = (scatter_count < count - send_block * scatter_count ) ?
1040- scatter_count : count - send_block * scatter_count ;
1041- if (send_count < 0 )
1042- send_count = 0 ;
1043- ptrdiff_t send_offset = send_block * scatter_count * extent ;
1044-
1045- err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1040+ for (int i = 1 ; i < comm_size ; ++ i ) {
1041+ /* how many elements remain in recv_block? */
1042+ size_t recv_offset_elems = recv_block * scatter_count ;
1043+ size_t recv_remaining = rectify_diff (count , recv_offset_elems );
1044+ recv_count = (recv_remaining < scatter_count ) ?
1045+ recv_remaining : scatter_count ;
1046+ size_t recv_offset = recv_offset_elems * extent ;
1047+
1048+ /* same logic for send */
1049+ size_t send_offset_elems = send_block * scatter_count ;
1050+ size_t send_remaining = rectify_diff (count , send_offset_elems );
1051+ send_count = (send_remaining < scatter_count ) ?
1052+ send_remaining : scatter_count ;
1053+ size_t send_offset = send_offset_elems * extent ;
1054+
1055+ err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
10461056 datatype , right , MCA_COLL_BASE_TAG_BCAST ,
1047- (char * )buf + recv_offset , recv_count ,
1057+ (char * )buf + recv_offset , recv_count ,
10481058 datatype , left , MCA_COLL_BASE_TAG_BCAST ,
10491059 comm , MPI_STATUS_IGNORE , rank );
1050- if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
1060+ if (MPI_SUCCESS != err ) {
1061+ goto cleanup_and_return ;
1062+ }
1063+
1064+ /* rotate blocks */
10511065 send_block = recv_block ;
1052- recv_block = (recv_block - 1 + comm_size ) % comm_size ;
1066+ recv_block = (recv_block + comm_size - 1 ) % comm_size ;
10531067 }
10541068
10551069cleanup_and_return :
0 commit comments