From efc81857be83d80883aba43835b697ad28fd35e6 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 30 Jul 2024 13:49:08 -0700 Subject: [PATCH 1/6] initial rust mpi support --- enzyme/Enzyme/ActivityAnalysis.cpp | 2 ++ enzyme/Enzyme/AdjointGenerator.h | 5 +++-- enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp | 7 ++++--- enzyme/Enzyme/Utils.cpp | 5 +++-- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index e69145815ee..1e828629cb9 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -111,6 +111,8 @@ static const StringSet<> InactiveGlobals = { "jl_small_typeof", "ompi_request_null", "ompi_mpi_double", + "RSMPI_DOUBLE", + "RSMPI_FLOAT", "ompi_mpi_comm_world", "__cxa_thread_atexit_impl", "stderr", diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 19a832dd8e8..1b9fe136b6c 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -145,9 +145,10 @@ class AdjointGenerator : public llvm::InstVisitor { C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { + auto name = GV->getName(); + if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") { return ConstantInt::get(intType, 8, false); - } else if (GV->getName() == "ompi_mpi_float") { + } else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") { return ConstantInt::get(intType, 4, false); } } diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index bf481c13b93..21fbe612daa 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -4785,11 +4785,12 @@ void TypeAnalyzer::visitCallBase(CallBase &call) { C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { + auto name = GV->getName(); + if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") { buf.insert({0}, Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { + } else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") { buf.insert({0}, Type::getFloatTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_cxx_bool") { + } else if (name == "ompi_mpi_cxx_bool") { buf.insert({0}, BaseType::Integer); } } else if (auto CI = dyn_cast(C)) { diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 1e87bf2a9a0..7f29078b1af 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2344,9 +2344,10 @@ bool writesToMemoryReadBy(llvm::AAResults &AA, llvm::TargetLibraryInfo &TLI, C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_double") { + auto name = GV->getName(); + if (name == "ompi_mpi_double" || name == "RSMPI_DOUBLE") { type = ConcreteType(Type::getDoubleTy(C->getContext())); - } else if (GV->getName() == "ompi_mpi_float") { + } else if (name == "ompi_mpi_float" || name == "RSMPI_FLOAT") { type = ConcreteType(Type::getFloatTy(C->getContext())); } } From bd94ef8e8a1d6d662d520fef06975e0ea48c65da Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 30 Jul 2024 15:41:44 -0700 Subject: [PATCH 2/6] add rust mlir mpi support --- enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index bf9b67f83f8..b9d99844253 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -36,6 +36,7 @@ static const char *KnownInactiveFunctionsContains[] = { static const std::set InactiveGlobals = { "ompi_request_null", "ompi_mpi_double", "ompi_mpi_comm_world", "stderr", "stdout", "stdin", "_ZSt3cin", "_ZSt4cout", "_ZSt5wcout", "_ZSt4cerr", + "RSMPI_DOUBLE", "RSMPI_FLOAT", "_ZTVNSt7__cxx1115basic_stringbufIcSt11char_traitsIcESaIcEEE", "_ZTVSt15basic_streambufIcSt11char_traitsIcEE", "_ZTVSt9basic_iosIcSt11char_traitsIcEE", From 92e8308d1d209d0877fefcc54bb4cc8c3b832d8c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 30 Jul 2024 16:22:03 -0700 Subject: [PATCH 3/6] add mpi rust type test --- enzyme/test/Enzyme/ReverseMode/mpi_rust.ll | 115 +++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 enzyme/test/Enzyme/ReverseMode/mpi_rust.ll diff --git a/enzyme/test/Enzyme/ReverseMode/mpi_rust.ll b/enzyme/test/Enzyme/ReverseMode/mpi_rust.ll new file mode 100644 index 00000000000..75d060921f3 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/mpi_rust.ll @@ -0,0 +1,115 @@ +; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %loadEnzyme -enzyme -opaque-pointers=1 -S | FileCheck %s; fi +; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -opaque-pointers=1 -S | FileCheck %s; fi + +; ModuleID = 'enzyme-repro.ll' +source_filename = "dot_enzyme.3df87ea89a38df43-cgu.0" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@RSMPI_DOUBLE = external local_unnamed_addr global ptr +@RSMPI_COMM_WORLD = external local_unnamed_addr global ptr +@RSMPI_COMM_SELF = external local_unnamed_addr global ptr + +; Function Attrs: noinline nonlazybind sanitize_hwaddress uwtable +define hidden noundef "enzyme_type"="{[-1]:Float@double}" double @_ZN10dot_enzyme12dot_parallel17h7dfcd86d9e8c176bE(ptr noalias nocapture noundef readonly align 8 dereferenceable(16) "enzyme_type"="{[-1]:Pointer}" %0, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %1, i64 noundef "enzyme_type"="{[-1]:Integer}" %2, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %3, i64 noundef "enzyme_type"="{[-1]:Integer}" %4, ptr noundef "enzyme_type"="{[0]:Pointer}" %5) unnamed_addr #1 personality ptr @rust_eh_personality { + %7 = alloca double, align 8 + %8 = alloca double, align 8 + call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %8) + %9 = alloca double, align 8 + store double 1.000, ptr %8, align 8 + call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %7) + store double 0.000000e+00, ptr %7, align 8 + tail call void @llvm.experimental.noalias.scope.decl(metadata !7) + %10 = load ptr, ptr @RSMPI_DOUBLE, align 8, !noalias !10, !noundef !13 + %11 = load i64, ptr %0, align 8, !range !14, !alias.scope !15, !noalias !18, !noundef !13 + switch i64 %11, label %12 [ + i64 0, label %20 + i64 1, label %13 + i64 2, label %14 + i64 3, label %16 + i64 4, label %18 + ] + +12: ; preds = %6 + unreachable + +13: ; preds = %6 + br label %20 + +14: ; preds = %6 + %15 = getelementptr inbounds { i64, ptr }, ptr %0, i64 0, i32 1 + br label %20 + +16: ; preds = %6 + %17 = getelementptr inbounds { i64, ptr }, ptr %0, i64 0, i32 1 + br label %20 + +18: ; preds = %6 + %19 = getelementptr inbounds { i64, ptr }, ptr %0, i64 0, i32 1 + br label %20 + +20: ; preds = %18, %16, %14, %13, %6 + %21 = phi ptr [ %19, %18 ], [ %17, %16 ], [ %15, %14 ], [ @RSMPI_COMM_WORLD, %13 ], [ @RSMPI_COMM_SELF, %6 ] + %22 = load ptr, ptr %21, align 8, !noalias !18, !noundef !13 + %23 = alloca i32, align 4 + ;%23 = call noundef i32 @MPI_Allreduce(ptr noundef nonnull %8, ptr noundef nonnull %7, i32 noundef 1, ptr noundef %10, ptr noundef %5, ptr noundef %22), !noalias !7 + %24 = load double, ptr %7, align 8, !noundef !13 + call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %7) + call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %8) + ret double %24 +} + +; Function Attrs: nonlazybind sanitize_hwaddress uwtable +declare noundef i32 @MPI_Allreduce(ptr noundef, ptr noundef, i32 noundef, ptr noundef, ptr noundef, ptr noundef) unnamed_addr #2 + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) +declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3 + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) +declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3 + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) +declare void @llvm.experimental.noalias.scope.decl(metadata) #4 + +; Function Attrs: nounwind nonlazybind sanitize_hwaddress uwtable +declare noundef i32 @rust_eh_personality(i32 noundef, i32 noundef, i64, ptr noundef, ptr noundef) unnamed_addr #5 + +declare double @__enzyme_autodiff(...) + +define double @enzyme_opt_helper_0(ptr %0, ptr %1, i64 %2, ptr %3, i64 %4, ptr %5) { + %7 = call double (...) @__enzyme_autodiff(ptr @_ZN10dot_enzyme12dot_parallel17h7dfcd86d9e8c176bE, metadata !"enzyme_const", ptr %0, metadata !"enzyme_dup", ptr %1, ptr %1, metadata !"enzyme_const", i64 %2, metadata !"enzyme_dup", ptr %3, ptr %3, metadata !"enzyme_const", i64 %4, metadata !"enzyme_const", ptr %5) + ret double %7 +} + +attributes #0 = { noinline nounwind nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" } +attributes #1 = { noinline nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" } +attributes #2 = { nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" } +attributes #3 = { nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) } +attributes #4 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) } +attributes #5 = { nounwind nonlazybind sanitize_hwaddress uwtable "probe-stack"="inline-asm" "target-cpu"="x86-64" } + +!llvm.module.flags = !{!0, !1, !2, !3, !4, !5} +!llvm.ident = !{!6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6, !6} +!llvm.dbg.cu = !{} + +!0 = !{i32 8, !"PIC Level", i32 2} +!1 = !{i32 7, !"PIE Level", i32 2} +!2 = !{i32 2, !"RtLibUseGOT", i32 1} +!3 = !{i32 1, !"LTOPostLink", i32 1} +!4 = !{i32 2, !"Dwarf Version", i32 4} +!5 = !{i32 2, !"Debug Info Version", i32 3} +!6 = !{!"rustc version 1.77.0-nightly (ecb2f9cdf 2024-07-30)"} +!7 = !{!8} +!8 = distinct !{!8, !9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E: argument 0"} +!9 = distinct !{!9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E"} +!10 = !{!8, !11, !12} +!11 = distinct !{!11, !9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E: argument 1"} +!12 = distinct !{!12, !9, !"_ZN3mpi10collective23CommunicatorCollectives15all_reduce_into17h5bd43ff3d0a82648E: argument 2"} +!13 = !{} +!14 = !{i64 0, i64 5} +!15 = !{!16, !8} +!16 = distinct !{!16, !17, !"_ZN69_$LT$mpi..topology..SimpleCommunicator$u20$as$u20$mpi..raw..AsRaw$GT$6as_raw17h5ddd9d255d268465E: argument 0"} +!17 = distinct !{!17, !"_ZN69_$LT$mpi..topology..SimpleCommunicator$u20$as$u20$mpi..raw..AsRaw$GT$6as_raw17h5ddd9d255d268465E"} +!18 = !{!11, !12} + +; CHECK: define internal void @diffe_ZN10dot_enzyme12dot_parallel17h7dfcd86d9e8c176bE(ptr noalias nocapture noundef readonly align 8 dereferenceable(16) "enzyme_type"="{[-1]:Pointer}" %0, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %1, ptr nocapture align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %"'", i64 noundef "enzyme_type"="{[-1]:Integer}" %2, ptr noalias nocapture noundef nonnull readonly align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %3, ptr nocapture align 8 "enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@double}" %"'1", i64 noundef "enzyme_type"="{[-1]:Integer}" %4, ptr noundef "enzyme_type"="{[0]:Pointer}" %5, double %differeturn) From 946dc9f403b252be43eb76bb14d7a9eb93c7b636 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Fri, 9 Aug 2024 11:07:06 -0600 Subject: [PATCH 4/6] add RSMPI_SUM --- enzyme/Enzyme/ActivityAnalysis.cpp | 1 + enzyme/Enzyme/CallDerivatives.cpp | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 1e828629cb9..d3337a04694 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -113,6 +113,7 @@ static const StringSet<> InactiveGlobals = { "ompi_mpi_double", "RSMPI_DOUBLE", "RSMPI_FLOAT", + "RSMPI_SUM", "ompi_mpi_comm_world", "__cxa_thread_atexit_impl", "stderr", diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index b8154840988..b495f53c8de 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -1150,7 +1150,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_op_sum") { + if (GV->getName() == "ompi_mpi_op_sum" || + GV->getName() == "RSMPI_SUM") { isSum = true; } } @@ -1391,7 +1392,8 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, C = CE->getOperand(0); } if (auto GV = dyn_cast(C)) { - if (GV->getName() == "ompi_mpi_op_sum") { + if (GV->getName() == "ompi_mpi_op_sum" || + GV->getName() == "RSMPI_SUM") { isSum = true; } } From cb2d7396ec9494b04cc0890f9f1ae2d76da0be54 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Sun, 11 Aug 2024 07:54:24 -0600 Subject: [PATCH 5/6] RSMPI_SUM is a global load --- enzyme/Enzyme/CallDerivatives.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index b495f53c8de..813bfed0622 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -1404,6 +1404,11 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, } } } + if (auto LI = dyn_cast(orig_op)) { + if (auto GV = dyn_cast(LI->getPointerOperand())) + if (GV->getName() == "RSMPI_SUM") + isSum = true; + } if (!isSum) { std::string s; llvm::raw_string_ostream ss(s); From 49d60149dd771f0df76e93842cdb8a568715e443 Mon Sep 17 00:00:00 2001 From: Jed Brown Date: Sun, 11 Aug 2024 23:56:07 -0600 Subject: [PATCH 6/6] RSMPI_COMM_WORLD and RSMPI_COMM_SELF are inactive --- enzyme/Enzyme/ActivityAnalysis.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index d3337a04694..7da4d01e303 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -114,6 +114,8 @@ static const StringSet<> InactiveGlobals = { "RSMPI_DOUBLE", "RSMPI_FLOAT", "RSMPI_SUM", + "RSMPI_COMM_WORLD", + "RSMPI_COMM_SELF", "ompi_mpi_comm_world", "__cxa_thread_atexit_impl", "stderr",