Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers #124746

Merged
merged 8 commits into from
Feb 18, 2025

Conversation

TIFitis
Copy link
Member

@TIFitis TIFitis commented Jan 28, 2025

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive.

Since both MLIR and Clang now support custom mappers, I've changed the respective function params to no longer be optional as well.

Depends on #121005

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen IR generation bugs: mangling, exceptions, etc. mlir:llvm mlir mlir:openmp flang:openmp clang:openmp OpenMP related changes to Clang offload labels Jan 28, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-offload
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Akash Banerjee (TIFitis)

Changes

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive.

Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well.

Depends on #121005


Patch is 55.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124746.diff

7 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+12-8)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+28-21)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+41-37)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+36-20)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+163-35)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+117)
  • (added) offload/test/offloading/fortran/target-custom-mapper.f90 (+46)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 30c3834de139c3..0a13581dcb1700 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -32,10 +32,12 @@
 #include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -8888,8 +8890,8 @@ static void emitOffloadingArraysAndArgs(
     return MFunc;
   };
   OMPBuilder.emitOffloadingArraysAndArgs(
-      AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
-      ForEndCall, DeviceAddrCB, CustomMapperCB);
+      AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
+      IsNonContiguous, ForEndCall, DeviceAddrCB);
 }
 
 /// Check for inner distribute directive.
@@ -9098,9 +9100,10 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
   CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
   std::string Name = getName({"omp_mapper", TyStr, D->getName()});
 
-  auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
-                                                 ElemTy, Name, CustomMapperCB);
-  UDMMap.try_emplace(D, NewFn);
+  llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
+      PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
+  assert(NewFn && "Unexpected error in emitUserDefinedMapper");
+  UDMMap.try_emplace(D, *NewFn);
   if (CGF)
     FunctionUDMMap[CGF->CurFn].push_back(D);
 }
@@ -10092,9 +10095,10 @@ void CGOpenMPRuntime::emitTargetDataCalls(
                           CGF.Builder.GetInsertPoint());
   llvm::OpenMPIRBuilder::LocationDescription OmpLoc(CodeGenIP);
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTargetData(
-          OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
-          /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc);
+      OMPBuilder.createTargetData(OmpLoc, AllocaIP, CodeGenIP, DeviceID,
+                                  IfCondVal, Info, GenMapInfoCB, CustomMapperCB,
+                                  /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB,
+                                  RTLoc);
   assert(AfterIP && "unexpected error creating target data");
   CGF.Builder.restoreIP(*AfterIP);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..4e80bff6db4553 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -22,6 +22,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
 #include "llvm/TargetParser/Triple.h"
 #include <forward_list>
 #include <map>
@@ -2355,6 +2356,7 @@ class OpenMPIRBuilder {
                                    CurInfo.NonContigInfo.Strides.end());
     }
   };
+  using MapInfosOrErrorTy = Expected<MapInfosTy &>;
 
   /// Callback function type for functions emitting the host fallback code that
   /// is executed when the kernel launch fails. It takes an insertion point as
@@ -2431,9 +2433,9 @@ class OpenMPIRBuilder {
   /// including base pointers, pointers, sizes, map types, user-defined mappers.
   void emitOffloadingArrays(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
-      TargetDataInfo &Info, bool IsNonContiguous = false,
-      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+      TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+      bool IsNonContiguous = false,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
 
   /// Allocates memory for and populates the arrays required for offloading
   /// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
@@ -2444,9 +2446,9 @@ class OpenMPIRBuilder {
   void emitOffloadingArraysAndArgs(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
       TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+      function_ref<Value *(unsigned int)> CustomMapperCB,
       bool IsNonContiguous = false, bool ForEndCall = false,
-      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
 
   /// Creates offloading entry for the provided entry ID \a ID, address \a
   /// Addr, size \a Size, and flags \a Flags.
@@ -2911,12 +2913,12 @@ class OpenMPIRBuilder {
   /// \param FuncName Optional param to specify mapper function name.
   /// \param CustomMapperCB Optional callback to generate code related to
   /// custom mappers.
-  Function *emitUserDefinedMapper(
-      function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
-                                llvm::Value *BeginArg)>
+  Expected<Function *> emitUserDefinedMapper(
+      function_ref<MapInfosOrErrorTy(
+          InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
           PrivAndGenMapInfoCB,
       llvm::Type *ElemTy, StringRef FuncName,
-      function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
+      function_ref<bool(unsigned int, Function **)> CustomMapperCB);
 
   /// Generator for '#omp target data'
   ///
@@ -2930,21 +2932,21 @@ class OpenMPIRBuilder {
   /// \param IfCond Value which corresponds to the if clause condition.
   /// \param Info Stores all information realted to the Target Data directive.
   /// \param GenMapInfoCB Callback that populates the MapInfos and returns.
+  /// \param CustomMapperCB Callback to generate code related to
+  /// custom mappers.
   /// \param BodyGenCB Optional Callback to generate the region code.
   /// \param DeviceAddrCB Optional callback to generate code related to
   /// use_device_ptr and use_device_addr.
-  /// \param CustomMapperCB Optional callback to generate code related to
-  /// custom mappers.
   InsertPointOrErrorTy createTargetData(
       const LocationDescription &Loc, InsertPointTy AllocaIP,
       InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
       TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+      function_ref<Value *(unsigned int)> CustomMapperCB,
       omp::RuntimeFunction *MapperFunc = nullptr,
       function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
                                         BodyGenTy BodyGenType)>
           BodyGenCB = nullptr,
       function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
       Value *SrcLocInfo = nullptr);
 
   using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2960,6 +2962,7 @@ class OpenMPIRBuilder {
   /// \param IsOffloadEntry whether it is an offload entry.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
+  /// \param Info Stores all information realted to the Target directive.
   /// \param EntryInfo The entry information about the function.
   /// \param NumTeams Number of teams specified in the num_teams clause.
   /// \param NumThreads Number of teams specified in the thread_limit clause.
@@ -2968,18 +2971,22 @@ class OpenMPIRBuilder {
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
   /// instructions for passed in target arguments where neccessary
+  /// \param CustomMapperCB Callback to generate code related to
+  /// custom mappers.
   /// \param Dependencies A vector of DependData objects that carry
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
-  InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
-      OpenMPIRBuilder::InsertPointTy AllocaIP,
-      OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
-      TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-      SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+  InsertPointOrErrorTy
+  createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
+               TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+               GenMapInfoCallbackTy GenMapInfoCB,
+               TargetBodyGenCallbackTy BodyGenCB,
+               TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+               function_ref<Value *(unsigned int)> CustomMapperCB,
+               SmallVector<DependData> Dependencies, bool HasNowait);
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0d8dbbe3a8a718..be53dbbf8addf3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -47,6 +47,7 @@
 #include "llvm/IR/Value.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Target/TargetMachine.h"
@@ -6480,12 +6481,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+    function_ref<Value *(unsigned int)> CustomMapperCB,
     omp::RuntimeFunction *MapperFunc,
     function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
                                       BodyGenTy BodyGenType)>
         BodyGenCB,
-    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -6511,8 +6512,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
                           InsertPointTy CodeGenIP) -> Error {
     MapInfo = &GenMapInfoCB(Builder.saveIP());
     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
-                         /*IsNonContiguous=*/true, DeviceAddrCB,
-                         CustomMapperCB);
+                         CustomMapperCB,
+                         /*IsNonContiguous=*/true, DeviceAddrCB);
 
     TargetDataRTArgs RTArgs;
     emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7304,22 +7305,24 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
 
 void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
-    TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
-    bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB) {
-  emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous,
-                       DeviceAddrCB, CustomMapperCB);
+    TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+    function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
+    bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
+  emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
+                       IsNonContiguous, DeviceAddrCB);
   emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
 }
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::TargetDataInfo &Info, Function *OutlinedFn,
                Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
                ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
-               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
-               bool HasNoWait = false) {
+               function_ref<Value *(unsigned int)> CustomMapperCB,
+               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
+               bool HasNoWait) {
   // Generate a function call to the host fallback implementation of the target
   // region. This is called by the host when no offload entry was generated for
   // the target region and when the offloading call fails at runtime.
@@ -7384,14 +7387,10 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return;
   }
 
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
-
   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
   OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
+                                         RTArgs, MapInfo, CustomMapperCB,
                                          /*IsNonContiguous=*/true,
                                          /*ForEndCall=*/false);
 
@@ -7439,11 +7438,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
-    SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+    InsertPointTy CodeGenIP, TargetDataInfo &Info,
+    TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+    ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+    GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+    function_ref<Value *(unsigned int)> CustomMapperCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
 
   if (!updateToLocation(Loc))
@@ -7458,15 +7459,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
           *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          Inputs, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, Info, OutlinedFn, OutlinedFnID,
+                   NumTeams, NumThreads, Inputs, GenMapInfoCB, CustomMapperCB,
+                   Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
@@ -7791,9 +7793,9 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
       OffloadingArgs);
 }
 
-Function *OpenMPIRBuilder::emitUserDefinedMapper(
-    function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
-                              llvm::Value *BeginArg)>
+Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
+    function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
+                                   llvm::Value *BeginArg)>
         GenMapInfoCB,
     Type *ElemTy, StringRef FuncName,
     function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
@@ -7867,7 +7869,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
   PtrPHI->addIncoming(PtrBegin, HeadBB);
 
   // Get map clause information. Fill up the arrays with all mapped variables.
-  MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+  MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+  if (!Info)
+    return Info.takeError();
 
   // Call the runtime API __tgt_mapper_num_components to get the number of
   // pre-existing components.
@@ -7879,20 +7883,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
       Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
 
   // Fill up the runtime mapper handle for all components.
-  for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
+  for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
     Value *CurBaseArg =
-        Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy());
+        Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy());
     Value *CurBeginArg =
-        Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy());
-    Value *CurSizeArg = Info.Sizes[I];
-    Value *CurNameArg = Info.Names.size()
-                            ? Info.Names[I]
+        Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy());
+    Value *CurSizeArg = Info->Sizes[I];
+    Value *CurNameArg = Info->Names.size()
+                            ? Info->Names[I]
                             : Constant::getNullValue(Builder.getPtrTy());
 
     // Extract the MEMBER_OF field from the map type.
     Value *OriMapType = Builder.getInt64(
         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
-            Info.Types[I]));
+            Info->Types[I]));
     Value *MemberMapType =
         Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
 
@@ -8013,9 +8017,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
 
 void OpenMPIRBuilder::emitOffloadingArrays(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
-    TargetDataInfo &Info, bool IsNonContiguous,
-    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB) {
+    TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+    bool IsNonContiguous,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
 
   // Reset the array information.
   Info.clearArrayInfo();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d7ac1082491180..a33e1533dede43 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5876,6 +5876,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
     return CombinedInfo;
   };
 
+  auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
   llvm::OpenMPIRBuilder::TargetDataInfo Info(
       /*RequiresDevicePointerInfo=*/false,
       /*SeparateBeginEndCalls=*/true);
@@ -5885,7 +5886,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
   llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_begin_mapper;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
       Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
-      /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
+      /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -5937,6 +5938,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) {
     return CombinedInfo;
   };
 
+  aut...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-clang-codegen

Author: Akash Banerjee (TIFitis)

Changes

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive.

Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well.

Depends on #121005


Patch is 55.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124746.diff

7 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+12-8)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+28-21)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+41-37)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+36-20)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+163-35)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+117)
  • (added) offload/test/offloading/fortran/target-custom-mapper.f90 (+46)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 30c3834de139c3..0a13581dcb1700 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -32,10 +32,12 @@
 #include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -8888,8 +8890,8 @@ static void emitOffloadingArraysAndArgs(
     return MFunc;
   };
   OMPBuilder.emitOffloadingArraysAndArgs(
-      AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
-      ForEndCall, DeviceAddrCB, CustomMapperCB);
+      AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
+      IsNonContiguous, ForEndCall, DeviceAddrCB);
 }
 
 /// Check for inner distribute directive.
@@ -9098,9 +9100,10 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
   CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
   std::string Name = getName({"omp_mapper", TyStr, D->getName()});
 
-  auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
-                                                 ElemTy, Name, CustomMapperCB);
-  UDMMap.try_emplace(D, NewFn);
+  llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
+      PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
+  assert(NewFn && "Unexpected error in emitUserDefinedMapper");
+  UDMMap.try_emplace(D, *NewFn);
   if (CGF)
     FunctionUDMMap[CGF->CurFn].push_back(D);
 }
@@ -10092,9 +10095,10 @@ void CGOpenMPRuntime::emitTargetDataCalls(
                           CGF.Builder.GetInsertPoint());
   llvm::OpenMPIRBuilder::LocationDescription OmpLoc(CodeGenIP);
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTargetData(
-          OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
-          /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc);
+      OMPBuilder.createTargetData(OmpLoc, AllocaIP, CodeGenIP, DeviceID,
+                                  IfCondVal, Info, GenMapInfoCB, CustomMapperCB,
+                                  /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB,
+                                  RTLoc);
   assert(AfterIP && "unexpected error creating target data");
   CGF.Builder.restoreIP(*AfterIP);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..4e80bff6db4553 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -22,6 +22,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
 #include "llvm/TargetParser/Triple.h"
 #include <forward_list>
 #include <map>
@@ -2355,6 +2356,7 @@ class OpenMPIRBuilder {
                                    CurInfo.NonContigInfo.Strides.end());
     }
   };
+  using MapInfosOrErrorTy = Expected<MapInfosTy &>;
 
   /// Callback function type for functions emitting the host fallback code that
   /// is executed when the kernel launch fails. It takes an insertion point as
@@ -2431,9 +2433,9 @@ class OpenMPIRBuilder {
   /// including base pointers, pointers, sizes, map types, user-defined mappers.
   void emitOffloadingArrays(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
-      TargetDataInfo &Info, bool IsNonContiguous = false,
-      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+      TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+      bool IsNonContiguous = false,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
 
   /// Allocates memory for and populates the arrays required for offloading
   /// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
@@ -2444,9 +2446,9 @@ class OpenMPIRBuilder {
   void emitOffloadingArraysAndArgs(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
       TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+      function_ref<Value *(unsigned int)> CustomMapperCB,
       bool IsNonContiguous = false, bool ForEndCall = false,
-      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
 
   /// Creates offloading entry for the provided entry ID \a ID, address \a
   /// Addr, size \a Size, and flags \a Flags.
@@ -2911,12 +2913,12 @@ class OpenMPIRBuilder {
   /// \param FuncName Optional param to specify mapper function name.
   /// \param CustomMapperCB Optional callback to generate code related to
   /// custom mappers.
-  Function *emitUserDefinedMapper(
-      function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
-                                llvm::Value *BeginArg)>
+  Expected<Function *> emitUserDefinedMapper(
+      function_ref<MapInfosOrErrorTy(
+          InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
           PrivAndGenMapInfoCB,
       llvm::Type *ElemTy, StringRef FuncName,
-      function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
+      function_ref<bool(unsigned int, Function **)> CustomMapperCB);
 
   /// Generator for '#omp target data'
   ///
@@ -2930,21 +2932,21 @@ class OpenMPIRBuilder {
   /// \param IfCond Value which corresponds to the if clause condition.
   /// \param Info Stores all information realted to the Target Data directive.
   /// \param GenMapInfoCB Callback that populates the MapInfos and returns.
+  /// \param CustomMapperCB Callback to generate code related to
+  /// custom mappers.
   /// \param BodyGenCB Optional Callback to generate the region code.
   /// \param DeviceAddrCB Optional callback to generate code related to
   /// use_device_ptr and use_device_addr.
-  /// \param CustomMapperCB Optional callback to generate code related to
-  /// custom mappers.
   InsertPointOrErrorTy createTargetData(
       const LocationDescription &Loc, InsertPointTy AllocaIP,
       InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
       TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+      function_ref<Value *(unsigned int)> CustomMapperCB,
       omp::RuntimeFunction *MapperFunc = nullptr,
       function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
                                         BodyGenTy BodyGenType)>
           BodyGenCB = nullptr,
       function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
       Value *SrcLocInfo = nullptr);
 
   using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2960,6 +2962,7 @@ class OpenMPIRBuilder {
   /// \param IsOffloadEntry whether it is an offload entry.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
+  /// \param Info Stores all information realted to the Target directive.
   /// \param EntryInfo The entry information about the function.
   /// \param NumTeams Number of teams specified in the num_teams clause.
   /// \param NumThreads Number of teams specified in the thread_limit clause.
@@ -2968,18 +2971,22 @@ class OpenMPIRBuilder {
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
   /// instructions for passed in target arguments where neccessary
+  /// \param CustomMapperCB Callback to generate code related to
+  /// custom mappers.
   /// \param Dependencies A vector of DependData objects that carry
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
-  InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
-      OpenMPIRBuilder::InsertPointTy AllocaIP,
-      OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
-      TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-      SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+  InsertPointOrErrorTy
+  createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
+               TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+               GenMapInfoCallbackTy GenMapInfoCB,
+               TargetBodyGenCallbackTy BodyGenCB,
+               TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+               function_ref<Value *(unsigned int)> CustomMapperCB,
+               SmallVector<DependData> Dependencies, bool HasNowait);
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0d8dbbe3a8a718..be53dbbf8addf3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -47,6 +47,7 @@
 #include "llvm/IR/Value.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Target/TargetMachine.h"
@@ -6480,12 +6481,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+    function_ref<Value *(unsigned int)> CustomMapperCB,
     omp::RuntimeFunction *MapperFunc,
     function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
                                       BodyGenTy BodyGenType)>
         BodyGenCB,
-    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -6511,8 +6512,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
                           InsertPointTy CodeGenIP) -> Error {
     MapInfo = &GenMapInfoCB(Builder.saveIP());
     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
-                         /*IsNonContiguous=*/true, DeviceAddrCB,
-                         CustomMapperCB);
+                         CustomMapperCB,
+                         /*IsNonContiguous=*/true, DeviceAddrCB);
 
     TargetDataRTArgs RTArgs;
     emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7304,22 +7305,24 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
 
 void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
-    TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
-    bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB) {
-  emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous,
-                       DeviceAddrCB, CustomMapperCB);
+    TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+    function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
+    bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
+  emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
+                       IsNonContiguous, DeviceAddrCB);
   emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
 }
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::TargetDataInfo &Info, Function *OutlinedFn,
                Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
                ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
-               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
-               bool HasNoWait = false) {
+               function_ref<Value *(unsigned int)> CustomMapperCB,
+               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
+               bool HasNoWait) {
   // Generate a function call to the host fallback implementation of the target
   // region. This is called by the host when no offload entry was generated for
   // the target region and when the offloading call fails at runtime.
@@ -7384,14 +7387,10 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return;
   }
 
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
-
   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
   OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
+                                         RTArgs, MapInfo, CustomMapperCB,
                                          /*IsNonContiguous=*/true,
                                          /*ForEndCall=*/false);
 
@@ -7439,11 +7438,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
-    SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+    InsertPointTy CodeGenIP, TargetDataInfo &Info,
+    TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+    ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+    GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+    function_ref<Value *(unsigned int)> CustomMapperCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
 
   if (!updateToLocation(Loc))
@@ -7458,15 +7459,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
           *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          Inputs, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, Info, OutlinedFn, OutlinedFnID,
+                   NumTeams, NumThreads, Inputs, GenMapInfoCB, CustomMapperCB,
+                   Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
@@ -7791,9 +7793,9 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
       OffloadingArgs);
 }
 
-Function *OpenMPIRBuilder::emitUserDefinedMapper(
-    function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
-                              llvm::Value *BeginArg)>
+Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
+    function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
+                                   llvm::Value *BeginArg)>
         GenMapInfoCB,
     Type *ElemTy, StringRef FuncName,
     function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
@@ -7867,7 +7869,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
   PtrPHI->addIncoming(PtrBegin, HeadBB);
 
   // Get map clause information. Fill up the arrays with all mapped variables.
-  MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+  MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+  if (!Info)
+    return Info.takeError();
 
   // Call the runtime API __tgt_mapper_num_components to get the number of
   // pre-existing components.
@@ -7879,20 +7883,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
       Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
 
   // Fill up the runtime mapper handle for all components.
-  for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
+  for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
     Value *CurBaseArg =
-        Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy());
+        Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy());
     Value *CurBeginArg =
-        Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy());
-    Value *CurSizeArg = Info.Sizes[I];
-    Value *CurNameArg = Info.Names.size()
-                            ? Info.Names[I]
+        Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy());
+    Value *CurSizeArg = Info->Sizes[I];
+    Value *CurNameArg = Info->Names.size()
+                            ? Info->Names[I]
                             : Constant::getNullValue(Builder.getPtrTy());
 
     // Extract the MEMBER_OF field from the map type.
     Value *OriMapType = Builder.getInt64(
         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
-            Info.Types[I]));
+            Info->Types[I]));
     Value *MemberMapType =
         Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
 
@@ -8013,9 +8017,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
 
 void OpenMPIRBuilder::emitOffloadingArrays(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
-    TargetDataInfo &Info, bool IsNonContiguous,
-    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB) {
+    TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+    bool IsNonContiguous,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
 
   // Reset the array information.
   Info.clearArrayInfo();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d7ac1082491180..a33e1533dede43 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5876,6 +5876,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
     return CombinedInfo;
   };
 
+  auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
   llvm::OpenMPIRBuilder::TargetDataInfo Info(
       /*RequiresDevicePointerInfo=*/false,
       /*SeparateBeginEndCalls=*/true);
@@ -5885,7 +5886,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
   llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_begin_mapper;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
       Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
-      /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
+      /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -5937,6 +5938,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) {
     return CombinedInfo;
   };
 
+  aut...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 28, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Akash Banerjee (TIFitis)

Changes

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive.

Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well.

Depends on #121005


Patch is 55.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124746.diff

7 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+12-8)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+28-21)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+41-37)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+36-20)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+163-35)
  • (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+117)
  • (added) offload/test/offloading/fortran/target-custom-mapper.f90 (+46)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 30c3834de139c3..0a13581dcb1700 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -32,10 +32,12 @@
 #include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/AtomicOrdering.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -8888,8 +8890,8 @@ static void emitOffloadingArraysAndArgs(
     return MFunc;
   };
   OMPBuilder.emitOffloadingArraysAndArgs(
-      AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
-      ForEndCall, DeviceAddrCB, CustomMapperCB);
+      AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
+      IsNonContiguous, ForEndCall, DeviceAddrCB);
 }
 
 /// Check for inner distribute directive.
@@ -9098,9 +9100,10 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
   CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
   std::string Name = getName({"omp_mapper", TyStr, D->getName()});
 
-  auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
-                                                 ElemTy, Name, CustomMapperCB);
-  UDMMap.try_emplace(D, NewFn);
+  llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
+      PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
+  assert(NewFn && "Unexpected error in emitUserDefinedMapper");
+  UDMMap.try_emplace(D, *NewFn);
   if (CGF)
     FunctionUDMMap[CGF->CurFn].push_back(D);
 }
@@ -10092,9 +10095,10 @@ void CGOpenMPRuntime::emitTargetDataCalls(
                           CGF.Builder.GetInsertPoint());
   llvm::OpenMPIRBuilder::LocationDescription OmpLoc(CodeGenIP);
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTargetData(
-          OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
-          /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc);
+      OMPBuilder.createTargetData(OmpLoc, AllocaIP, CodeGenIP, DeviceID,
+                                  IfCondVal, Info, GenMapInfoCB, CustomMapperCB,
+                                  /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB,
+                                  RTLoc);
   assert(AfterIP && "unexpected error creating target data");
   CGF.Builder.restoreIP(*AfterIP);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..4e80bff6db4553 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -22,6 +22,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/Allocator.h"
+#include "llvm/Support/Error.h"
 #include "llvm/TargetParser/Triple.h"
 #include <forward_list>
 #include <map>
@@ -2355,6 +2356,7 @@ class OpenMPIRBuilder {
                                    CurInfo.NonContigInfo.Strides.end());
     }
   };
+  using MapInfosOrErrorTy = Expected<MapInfosTy &>;
 
   /// Callback function type for functions emitting the host fallback code that
   /// is executed when the kernel launch fails. It takes an insertion point as
@@ -2431,9 +2433,9 @@ class OpenMPIRBuilder {
   /// including base pointers, pointers, sizes, map types, user-defined mappers.
   void emitOffloadingArrays(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
-      TargetDataInfo &Info, bool IsNonContiguous = false,
-      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+      TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+      bool IsNonContiguous = false,
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
 
   /// Allocates memory for and populates the arrays required for offloading
   /// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
@@ -2444,9 +2446,9 @@ class OpenMPIRBuilder {
   void emitOffloadingArraysAndArgs(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
       TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+      function_ref<Value *(unsigned int)> CustomMapperCB,
       bool IsNonContiguous = false, bool ForEndCall = false,
-      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
+      function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
 
   /// Creates offloading entry for the provided entry ID \a ID, address \a
   /// Addr, size \a Size, and flags \a Flags.
@@ -2911,12 +2913,12 @@ class OpenMPIRBuilder {
   /// \param FuncName Optional param to specify mapper function name.
   /// \param CustomMapperCB Optional callback to generate code related to
   /// custom mappers.
-  Function *emitUserDefinedMapper(
-      function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
-                                llvm::Value *BeginArg)>
+  Expected<Function *> emitUserDefinedMapper(
+      function_ref<MapInfosOrErrorTy(
+          InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
           PrivAndGenMapInfoCB,
       llvm::Type *ElemTy, StringRef FuncName,
-      function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
+      function_ref<bool(unsigned int, Function **)> CustomMapperCB);
 
   /// Generator for '#omp target data'
   ///
@@ -2930,21 +2932,21 @@ class OpenMPIRBuilder {
   /// \param IfCond Value which corresponds to the if clause condition.
   /// \param Info Stores all information realted to the Target Data directive.
   /// \param GenMapInfoCB Callback that populates the MapInfos and returns.
+  /// \param CustomMapperCB Callback to generate code related to
+  /// custom mappers.
   /// \param BodyGenCB Optional Callback to generate the region code.
   /// \param DeviceAddrCB Optional callback to generate code related to
   /// use_device_ptr and use_device_addr.
-  /// \param CustomMapperCB Optional callback to generate code related to
-  /// custom mappers.
   InsertPointOrErrorTy createTargetData(
       const LocationDescription &Loc, InsertPointTy AllocaIP,
       InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
       TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+      function_ref<Value *(unsigned int)> CustomMapperCB,
       omp::RuntimeFunction *MapperFunc = nullptr,
       function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
                                         BodyGenTy BodyGenType)>
           BodyGenCB = nullptr,
       function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
-      function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
       Value *SrcLocInfo = nullptr);
 
   using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2960,6 +2962,7 @@ class OpenMPIRBuilder {
   /// \param IsOffloadEntry whether it is an offload entry.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
+  /// \param Info Stores all information realted to the Target directive.
   /// \param EntryInfo The entry information about the function.
   /// \param NumTeams Number of teams specified in the num_teams clause.
   /// \param NumThreads Number of teams specified in the thread_limit clause.
@@ -2968,18 +2971,22 @@ class OpenMPIRBuilder {
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
   /// instructions for passed in target arguments where neccessary
+  /// \param CustomMapperCB Callback to generate code related to
+  /// custom mappers.
   /// \param Dependencies A vector of DependData objects that carry
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
-  InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
-      OpenMPIRBuilder::InsertPointTy AllocaIP,
-      OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
-      TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-      SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+  InsertPointOrErrorTy
+  createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
+               TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+               GenMapInfoCallbackTy GenMapInfoCB,
+               TargetBodyGenCallbackTy BodyGenCB,
+               TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+               function_ref<Value *(unsigned int)> CustomMapperCB,
+               SmallVector<DependData> Dependencies, bool HasNowait);
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0d8dbbe3a8a718..be53dbbf8addf3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -47,6 +47,7 @@
 #include "llvm/IR/Value.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Target/TargetMachine.h"
@@ -6480,12 +6481,12 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
     TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
+    function_ref<Value *(unsigned int)> CustomMapperCB,
     omp::RuntimeFunction *MapperFunc,
     function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
                                       BodyGenTy BodyGenType)>
         BodyGenCB,
-    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB, Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -6511,8 +6512,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
                           InsertPointTy CodeGenIP) -> Error {
     MapInfo = &GenMapInfoCB(Builder.saveIP());
     emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
-                         /*IsNonContiguous=*/true, DeviceAddrCB,
-                         CustomMapperCB);
+                         CustomMapperCB,
+                         /*IsNonContiguous=*/true, DeviceAddrCB);
 
     TargetDataRTArgs RTArgs;
     emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7304,22 +7305,24 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
 
 void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
-    TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
-    bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB) {
-  emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous,
-                       DeviceAddrCB, CustomMapperCB);
+    TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
+    function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
+    bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
+  emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
+                       IsNonContiguous, DeviceAddrCB);
   emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
 }
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::TargetDataInfo &Info, Function *OutlinedFn,
                Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
                ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
-               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
-               bool HasNoWait = false) {
+               function_ref<Value *(unsigned int)> CustomMapperCB,
+               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
+               bool HasNoWait) {
   // Generate a function call to the host fallback implementation of the target
   // region. This is called by the host when no offload entry was generated for
   // the target region and when the offloading call fails at runtime.
@@ -7384,14 +7387,10 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return;
   }
 
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
-
   OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
   OpenMPIRBuilder::TargetDataRTArgs RTArgs;
   OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
+                                         RTArgs, MapInfo, CustomMapperCB,
                                          /*IsNonContiguous=*/true,
                                          /*ForEndCall=*/false);
 
@@ -7439,11 +7438,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
-    SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+    InsertPointTy CodeGenIP, TargetDataInfo &Info,
+    TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+    ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+    GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+    function_ref<Value *(unsigned int)> CustomMapperCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
 
   if (!updateToLocation(Loc))
@@ -7458,15 +7459,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
           *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          Inputs, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, Info, OutlinedFn, OutlinedFnID,
+                   NumTeams, NumThreads, Inputs, GenMapInfoCB, CustomMapperCB,
+                   Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
@@ -7791,9 +7793,9 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel(
       OffloadingArgs);
 }
 
-Function *OpenMPIRBuilder::emitUserDefinedMapper(
-    function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
-                              llvm::Value *BeginArg)>
+Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
+    function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
+                                   llvm::Value *BeginArg)>
         GenMapInfoCB,
     Type *ElemTy, StringRef FuncName,
     function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
@@ -7867,7 +7869,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
   PtrPHI->addIncoming(PtrBegin, HeadBB);
 
   // Get map clause information. Fill up the arrays with all mapped variables.
-  MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+  MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn);
+  if (!Info)
+    return Info.takeError();
 
   // Call the runtime API __tgt_mapper_num_components to get the number of
   // pre-existing components.
@@ -7879,20 +7883,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
       Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset()));
 
   // Fill up the runtime mapper handle for all components.
-  for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
+  for (unsigned I = 0; I < Info->BasePointers.size(); ++I) {
     Value *CurBaseArg =
-        Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy());
+        Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy());
     Value *CurBeginArg =
-        Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy());
-    Value *CurSizeArg = Info.Sizes[I];
-    Value *CurNameArg = Info.Names.size()
-                            ? Info.Names[I]
+        Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy());
+    Value *CurSizeArg = Info->Sizes[I];
+    Value *CurNameArg = Info->Names.size()
+                            ? Info->Names[I]
                             : Constant::getNullValue(Builder.getPtrTy());
 
     // Extract the MEMBER_OF field from the map type.
     Value *OriMapType = Builder.getInt64(
         static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
-            Info.Types[I]));
+            Info->Types[I]));
     Value *MemberMapType =
         Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize);
 
@@ -8013,9 +8017,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper(
 
 void OpenMPIRBuilder::emitOffloadingArrays(
     InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
-    TargetDataInfo &Info, bool IsNonContiguous,
-    function_ref<void(unsigned int, Value *)> DeviceAddrCB,
-    function_ref<Value *(unsigned int)> CustomMapperCB) {
+    TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
+    bool IsNonContiguous,
+    function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
 
   // Reset the array information.
   Info.clearArrayInfo();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d7ac1082491180..a33e1533dede43 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5876,6 +5876,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
     return CombinedInfo;
   };
 
+  auto CustomMapperCB = [&](unsigned int I) { return nullptr; };
   llvm::OpenMPIRBuilder::TargetDataInfo Info(
       /*RequiresDevicePointerInfo=*/false,
       /*SeparateBeginEndCalls=*/true);
@@ -5885,7 +5886,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) {
   llvm::omp::RuntimeFunction RTLFunc = OMPRTL___tgt_target_data_begin_mapper;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTargetData(
       Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID),
-      /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc);
+      /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -5937,6 +5938,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) {
     return CombinedInfo;
   };
 
+  aut...
[truncated]

Comment on lines +2713 to +2814
struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
SmallVector<Operation *, 4> Mappers;
Copy link
Member

@ergawy ergawy Jan 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add Mappers to MapInfoData and only use MapInfoData through out the whole file?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it might be possible but probably not a good idea. MapInfoData is mean't to be used when coalescing the mapping data from the various map and bound ops.

llvm::OpenMPIRBuilder::MapInfosTy is used across MLIR, Clang and OMPIRBuilder for target codegen.

Copy link
Member

@ergawy ergawy Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand we cannot merge all of this into llvm::OpenMPIRBuilder::MapInfosTy. I am talking about the new struct MapInfosTy that sits in between llvm::OpenMPIRBuilder::MapInfosTy and MapInfoData. I think we can make this new struct and MapInfoData one struct and use it in this file. I gave it a try and should be fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can do that, but I don't think we should. Semantically they are used in different stages for different purposes. We don't stand to gain much from merging these two types, and using the same type might only lead to confusion in the future.

If anything, using the two different types should have a slightly less memory footprint, as we don't need any of the mapInfoData specific members during CodeGen in OMPIRBuilder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We gain better readability in this already complex file. I am not fully on board with memory footprint argument since I don't think a module will have that many instances of that struct that need to be created. Not a blocker from my side though since other reviewers are fine with it.

llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::DenseMap<const Operation *, llvm::Function *> userDefMapperMap;
auto iter = userDefMapperMap.find(declMapperOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

userDefMapperMap is empty at this point. Is it supposed to managed globally outside emitUserDefinedMapper (for the whole mlir::ModuleOp for example)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the mapping is intended to be at the Module scope, across multiple emitUserDefinedMapper calls within the same Module.

Copy link
Member

@ergawy ergawy Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but you get a new instance of userDefMapperMap with every invocation of getOrCreateUserDefinedMapperFunc and you directly search it for the user defined maper of declMapperOp. So nothing is actually cached. Did I miss something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I intended to make userDefMapperMap a static variable. Fixed now.

auto customMapperCB = [&](unsigned i, llvm::Function **mapperFunc) {
if (combinedInfo.Mappers[i]) {
// Call the corresponding mapper function.
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me undestand how emitUserDefinedMapper and getOrCreateUserDefinedMapperFunc work together? It seems they are mutually recursive but the recusion is not happening because combinedInfo.Mappers[i] is false. Can you explain the sequence of events here a bit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getOrCreateUserDefinedMapperFunc is the interface function for fetching/creating mapperFuncs. It does a lookup for the mapperFunc and emits one if not already present through the emitUserDefinedMapper function.

A declare mapper may refer to another mapper in it's mapping scheme, as such emitUserDefinedMapper may again make calls to getOrCreateUserDefinedMapperFunc. But I believe code flow disallows cycles, so we shouldn't go into any cyclic recursions.

Copy link
Member

@ergawy ergawy Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A declare mapper may refer to another mapper in it's mapping scheme ....

Can you please add a test that activates this recursion to better understand how it works? I believe the 2 tests added in the PR don't achieve that. I am just a bit uncomfortable doing this without being sure how it works.

In general, instead of the recursion, we can try to come up with a worklist algorithm that builds up the list of mapper functions that need to be generated. The MLIR guide advises against using recursion generally: https://mlir.llvm.org/getting_started/DeveloperGuide/.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the offloading test to trigger this recursion.

Worklist for something like this might be an overkill. Clang implements declare mapper in a similar recursive manner. Also, cyclic recursions are not possible in legal code.

Here's what the generated llvm IR looks like for the new test:

define void @_QQmain() {
  %.offload_baseptrs = alloca [2 x ptr], align 8
  %.offload_ptrs = alloca [2 x ptr], align 8
  %.offload_mappers = alloca [2 x ptr], align 8
  %1 = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, align 8
  %2 = alloca i32, i64 1, align 4
  %3 = alloca i32, i64 1, align 4
  %4 = alloca i32, i64 1, align 4
  br label %5

5:                                                ; preds = %9, %0
  %6 = phi i32 [ %18, %9 ], [ 1, %0 ]
  %7 = phi i64 [ %19, %9 ], [ 1024, %0 ]
  %8 = icmp sgt i64 %7, 0
  br i1 %8, label %9, label %20

9:                                                ; preds = %5
  store i32 %6, ptr %4, align 4
  %10 = load i32, ptr %4, align 4
  %11 = sext i32 %10 to i64
  %12 = sub nsw i64 %11, 1
  %13 = mul nsw i64 %12, 1
  %14 = mul nsw i64 %13, 1
  %15 = add nsw i64 %14, 0
  %16 = getelementptr i32, ptr @_QFEobj, i64 %15
  store i32 1, ptr %16, align 4
  %17 = load i32, ptr %4, align 4
  %18 = add nsw i32 %17, 1
  %19 = sub i64 %7, 1
  br label %5

20:                                               ; preds = %5
  store i32 %6, ptr %4, align 4
  store { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] } { ptr @_QFEobj, i64 ptrtoint (ptr getelementptr (i32, ptr null, i32 1) to i64), i32 20240719, i8 1, i8 9, i8 0, i8 0, [1 x [3 x i64]] [[3 x i64] [i64 1, i64 1024, i64 ptrtoint (ptr getelementptr (i32, ptr null, i32 1) to i64)]] }, ptr %1, align 8
  %21 = call i32 @_FortranASumInteger4(ptr %1, ptr @_QQclXf46b0d060c890540d012b521bc3468aa, i32 21, i32 0, ptr null)
  store i32 %21, ptr %2, align 4
  store i32 0, ptr %3, align 4
  %22 = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
  store ptr %3, ptr %22, align 8
  %23 = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 0
  store ptr %3, ptr %23, align 8
  %24 = getelementptr inbounds [2 x ptr], ptr %.offload_mappers, i64 0, i64 0
  store ptr null, ptr %24, align 8
  %25 = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
  store ptr @_QFEobj, ptr %25, align 8
  %26 = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 1
  store ptr @_QFEobj, ptr %26, align 8
  %27 = getelementptr inbounds [2 x ptr], ptr %.offload_mappers, i64 0, i64 1
  store ptr @.omp_mapper._QQFmy_mapper2, ptr %27, align 8
  %28 = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
  %29 = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 0
  call void @__tgt_target_data_begin_mapper(ptr @7, i64 -1, i32 2, ptr %28, ptr %29, ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %.offload_mappers)
  br label %omp.data.region

omp.data.region3:                                 ; preds = %omp.data.region1
  store i32 %43, ptr %4, align 4
  br label %omp.region.cont

omp.data.region2:                                 ; preds = %omp.data.region1
  store i32 %43, ptr %4, align 4
  %30 = load i32, ptr %3, align 4
  %31 = load i32, ptr %4, align 4
  %32 = sext i32 %31 to i64
  %33 = sub nsw i64 %32, 1
  %34 = mul nsw i64 %33, 1
  %35 = mul nsw i64 %34, 1
  %36 = add nsw i64 %35, 0
  %37 = getelementptr i32, ptr @_QFEobj, i64 %36
  %38 = load i32, ptr %37, align 4
  %39 = add i32 %30, %38
  store i32 %39, ptr %3, align 4
  %40 = load i32, ptr %4, align 4
  %41 = add nsw i32 %40, 1
  %42 = sub i64 %44, 1
  br label %omp.data.region1

omp.data.region1:                                 ; preds = %omp.data.region2, %omp.data.region
  %43 = phi i32 [ %41, %omp.data.region2 ], [ 1, %omp.data.region ]
  %44 = phi i64 [ %42, %omp.data.region2 ], [ 1024, %omp.data.region ]
  %45 = icmp sgt i64 %44, 0
  br i1 %45, label %omp.data.region2, label %omp.data.region3

omp.data.region:                                  ; preds = %20
  br label %omp.data.region1

omp.region.cont:                                  ; preds = %omp.data.region3
  %46 = getelementptr inbounds [2 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
  %47 = getelementptr inbounds [2 x ptr], ptr %.offload_ptrs, i32 0, i32 0
  call void @__tgt_target_data_end_mapper(ptr @7, i64 -1, i32 2, ptr %46, ptr %47, ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %.offload_mappers)
  %48 = call ptr @_FortranAioBeginExternalListOutput(i32 6, ptr @_QQclXf46b0d060c890540d012b521bc3468aa, i32 32)
  %49 = call i1 @_FortranAioOutputAscii(ptr %48, ptr @_QQclX53756D206F6E20686F73743A20202020, i64 16)
  %50 = load i32, ptr %2, align 4
  %51 = call i1 @_FortranAioOutputInteger32(ptr %48, i32 %50)
  %52 = call i32 @_FortranAioEndIoStatement(ptr %48)
  %53 = call ptr @_FortranAioBeginExternalListOutput(i32 6, ptr @_QQclXf46b0d060c890540d012b521bc3468aa, i32 33)
  %54 = call i1 @_FortranAioOutputAscii(ptr %53, ptr @_QQclX53756D206F6E206465766963653A2020, i64 16)
  %55 = load i32, ptr %3, align 4
  %56 = call i1 @_FortranAioOutputInteger32(ptr %53, i32 %55)
  %57 = call i32 @_FortranAioEndIoStatement(ptr %53)
  %58 = load i32, ptr %3, align 4
  %59 = load i32, ptr %2, align 4
  %60 = icmp eq i32 %58, %59
  br i1 %60, label %61, label %65

61:                                               ; preds = %omp.region.cont
  %62 = call ptr @_FortranAioBeginExternalListOutput(i32 6, ptr @_QQclXf46b0d060c890540d012b521bc3468aa, i32 36)
  %63 = call i1 @_FortranAioOutputAscii(ptr %62, ptr @_QQclX546573742070617373656421, i64 12)
  %64 = call i32 @_FortranAioEndIoStatement(ptr %62)
  br label %69

65:                                               ; preds = %omp.region.cont
  %66 = call ptr @_FortranAioBeginExternalListOutput(i32 6, ptr @_QQclXf46b0d060c890540d012b521bc3468aa, i32 38)
  %67 = call i1 @_FortranAioOutputAscii(ptr %66, ptr @_QQclX54657374206661696C656421, i64 12)
  %68 = call i32 @_FortranAioEndIoStatement(ptr %66)
  br label %69

69:                                               ; preds = %61, %65
  ret void
}

; Function Attrs: noinline nounwind
define internal void @.omp_mapper._QQFmy_mapper2(ptr noundef %0, ptr noundef %1, ptr noundef %2, i64 noundef %3, i64 noundef %4, ptr noundef %5) #0 {
entry:
  %6 = udiv exact i64 %3, 4096
  %7 = getelementptr %_QFTmytype, ptr %2, i64 %6
  %omp.arrayinit.isarray = icmp sgt i64 %6, 1
  %8 = and i64 %4, 8
  %9 = icmp ne ptr %1, %2
  %10 = and i64 %4, 16
  %11 = icmp ne i64 %10, 0
  %12 = and i1 %9, %11
  %13 = or i1 %omp.arrayinit.isarray, %12
  %.omp.array..init..delete = icmp eq i64 %8, 0
  %14 = and i1 %13, %.omp.array..init..delete
  br i1 %14, label %.omp.array..init, label %omp.arraymap.head

.omp.array..init:                                 ; preds = %entry
  %15 = mul nuw i64 %6, 4096
  %16 = and i64 %4, -4
  %17 = or i64 %16, 512
  call void @__tgt_push_mapper_component(ptr %0, ptr %1, ptr %2, i64 %15, i64 %17, ptr %5)
  br label %omp.arraymap.head

omp.arraymap.head:                                ; preds = %.omp.array..init, %entry
  %omp.arraymap.isempty = icmp eq ptr %2, %7
  br i1 %omp.arraymap.isempty, label %omp.done, label %omp.arraymap.body

omp.arraymap.body:                                ; preds = %omp.type.end, %omp.arraymap.head
  %omp.arraymap.ptrcurrent = phi ptr [ %2, %omp.arraymap.head ], [ %omp.arraymap.next, %omp.type.end ]
  %18 = getelementptr %_QFTmytype, ptr %omp.arraymap.ptrcurrent, i32 0, i32 0
  %array_offset = getelementptr inbounds [1024 x i32], ptr %18, i64 0, i64 0
  %19 = call i64 @__tgt_mapper_num_components(ptr %0)
  %20 = shl i64 %19, 48
  %21 = add nuw i64 2, %20
  %22 = and i64 %4, 3
  %23 = icmp eq i64 %22, 0
  br i1 %23, label %omp.type.alloc, label %omp.type.alloc.else

omp.type.alloc:                                   ; preds = %omp.arraymap.body
  %24 = and i64 %21, -4
  br label %omp.type.end

omp.type.alloc.else:                              ; preds = %omp.arraymap.body
  %25 = icmp eq i64 %22, 1
  br i1 %25, label %omp.type.to, label %omp.type.to.else

omp.type.to:                                      ; preds = %omp.type.alloc.else
  %26 = and i64 %21, -3
  br label %omp.type.end

omp.type.to.else:                                 ; preds = %omp.type.alloc.else
  %27 = icmp eq i64 %22, 2
  br i1 %27, label %omp.type.from, label %omp.type.end

omp.type.from:                                    ; preds = %omp.type.to.else
  %28 = and i64 %21, -2
  br label %omp.type.end

omp.type.end:                                     ; preds = %omp.type.from, %omp.type.to.else, %omp.type.to, %omp.type.alloc
  %omp.maptype = phi i64 [ %24, %omp.type.alloc ], [ %26, %omp.type.to ], [ %28, %omp.type.from ], [ %21, %omp.type.to.else ]
  call void @.omp_mapper._QQFmy_mapper1(ptr %0, ptr %omp.arraymap.ptrcurrent, ptr %array_offset, i64 4096, i64 %omp.maptype, ptr @3) #1
  %omp.arraymap.next = getelementptr %_QFTmytype, ptr %omp.arraymap.ptrcurrent, i32 1
  %omp.arraymap.isdone = icmp eq ptr %omp.arraymap.next, %7
  br i1 %omp.arraymap.isdone, label %omp.arraymap.exit, label %omp.arraymap.body

omp.arraymap.exit:                                ; preds = %omp.type.end
  %omp.arrayinit.isarray1 = icmp sgt i64 %6, 1
  %29 = and i64 %4, 8
  %.omp.array..del..delete = icmp ne i64 %29, 0
  %30 = and i1 %omp.arrayinit.isarray1, %.omp.array..del..delete
  br i1 %30, label %.omp.array..del, label %omp.done

.omp.array..del:                                  ; preds = %omp.arraymap.exit
  %31 = mul nuw i64 %6, 4096
  %32 = and i64 %4, -4
  %33 = or i64 %32, 512
  call void @__tgt_push_mapper_component(ptr %0, ptr %1, ptr %2, i64 %31, i64 %33, ptr %5)
  br label %omp.done

omp.done:                                         ; preds = %.omp.array..del, %omp.arraymap.exit, %omp.arraymap.head
  ret void
}

; Function Attrs: noinline nounwind
define internal void @.omp_mapper._QQFmy_mapper1(ptr noundef %0, ptr noundef %1, ptr noundef %2, i64 noundef %3, i64 noundef %4, ptr noundef %5) #0 {
entry:
  %6 = udiv exact i64 %3, 4096
  %7 = getelementptr %_QFTmytype, ptr %2, i64 %6
  %omp.arrayinit.isarray = icmp sgt i64 %6, 1
  %8 = and i64 %4, 8
  %9 = icmp ne ptr %1, %2
  %10 = and i64 %4, 16
  %11 = icmp ne i64 %10, 0
  %12 = and i1 %9, %11
  %13 = or i1 %omp.arrayinit.isarray, %12
  %.omp.array..init..delete = icmp eq i64 %8, 0
  %14 = and i1 %13, %.omp.array..init..delete
  br i1 %14, label %.omp.array..init, label %omp.arraymap.head

.omp.array..init:                                 ; preds = %entry
  %15 = mul nuw i64 %6, 4096
  %16 = and i64 %4, -4
  %17 = or i64 %16, 512
  call void @__tgt_push_mapper_component(ptr %0, ptr %1, ptr %2, i64 %15, i64 %17, ptr %5)
  br label %omp.arraymap.head

omp.arraymap.head:                                ; preds = %.omp.array..init, %entry
  %omp.arraymap.isempty = icmp eq ptr %2, %7
  br i1 %omp.arraymap.isempty, label %omp.done, label %omp.arraymap.body

omp.arraymap.body:                                ; preds = %omp.type.end, %omp.arraymap.head
  %omp.arraymap.ptrcurrent = phi ptr [ %2, %omp.arraymap.head ], [ %omp.arraymap.next, %omp.type.end ]
  %18 = getelementptr %_QFTmytype, ptr %omp.arraymap.ptrcurrent, i32 0, i32 0
  %array_offset = getelementptr inbounds [1024 x i32], ptr %18, i64 0, i64 0
  %19 = call i64 @__tgt_mapper_num_components(ptr %0)
  %20 = shl i64 %19, 48
  %21 = add nuw i64 1, %20
  %22 = and i64 %4, 3
  %23 = icmp eq i64 %22, 0
  br i1 %23, label %omp.type.alloc, label %omp.type.alloc.else

omp.type.alloc:                                   ; preds = %omp.arraymap.body
  %24 = and i64 %21, -4
  br label %omp.type.end

omp.type.alloc.else:                              ; preds = %omp.arraymap.body
  %25 = icmp eq i64 %22, 1
  br i1 %25, label %omp.type.to, label %omp.type.to.else

omp.type.to:                                      ; preds = %omp.type.alloc.else
  %26 = and i64 %21, -3
  br label %omp.type.end

omp.type.to.else:                                 ; preds = %omp.type.alloc.else
  %27 = icmp eq i64 %22, 2
  br i1 %27, label %omp.type.from, label %omp.type.end

omp.type.from:                                    ; preds = %omp.type.to.else
  %28 = and i64 %21, -2
  br label %omp.type.end

omp.type.end:                                     ; preds = %omp.type.from, %omp.type.to.else, %omp.type.to, %omp.type.alloc
  %omp.maptype = phi i64 [ %24, %omp.type.alloc ], [ %26, %omp.type.to ], [ %28, %omp.type.from ], [ %21, %omp.type.to.else ]
  call void @__tgt_push_mapper_component(ptr %0, ptr %omp.arraymap.ptrcurrent, ptr %array_offset, i64 4096, i64 %omp.maptype, ptr @5)
  %omp.arraymap.next = getelementptr %_QFTmytype, ptr %omp.arraymap.ptrcurrent, i32 1
  %omp.arraymap.isdone = icmp eq ptr %omp.arraymap.next, %7
  br i1 %omp.arraymap.isdone, label %omp.arraymap.exit, label %omp.arraymap.body

omp.arraymap.exit:                                ; preds = %omp.type.end
  %omp.arrayinit.isarray1 = icmp sgt i64 %6, 1
  %29 = and i64 %4, 8
  %.omp.array..del..delete = icmp ne i64 %29, 0
  %30 = and i1 %omp.arrayinit.isarray1, %.omp.array..del..delete
  br i1 %30, label %.omp.array..del, label %omp.done

.omp.array..del:                                  ; preds = %omp.arraymap.exit
  %31 = mul nuw i64 %6, 4096
  %32 = and i64 %4, -4
  %33 = or i64 %32, 512
  call void @__tgt_push_mapper_component(ptr %0, ptr %1, ptr %2, i64 %31, i64 %33, ptr %5)
  br label %omp.done

omp.done:                                         ; preds = %.omp.array..del, %omp.arraymap.exit, %omp.arraymap.head
  ret void
}

@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch from 908669a to 18037d1 Compare January 31, 2025 10:35
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 31, 2025
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch 2 times, most recently from 8fc9843 to 35e6331 Compare January 31, 2025 10:53
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_dialect branch from 8a90d20 to 5fe7a97 Compare January 31, 2025 11:27
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch from 35e6331 to 982315c Compare January 31, 2025 11:31
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_dialect branch from 5fe7a97 to 3302435 Compare February 7, 2025 17:42
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch from 982315c to 6d74d0f Compare February 7, 2025 18:07
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_dialect branch from af3dcbf to 3ddc972 Compare February 11, 2025 16:52
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch from 8a0192d to c4b7bc7 Compare February 12, 2025 17:14
Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again Akash. I have a couple of small code simplification suggestions and nits, but otherwise LGTM.

@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_dialect branch from 3ddc972 to 55e38d7 Compare February 18, 2025 16:33
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch from 4f2fa9e to 2730303 Compare February 18, 2025 16:35
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_dialect branch from 55e38d7 to ffd2012 Compare February 18, 2025 17:46
Base automatically changed from users/akash/mapper_llvm_dialect to main February 18, 2025 17:47
…pers

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive.

Since both MLIR and Clang now support custom mappers, I've made the relative params required instead of optional as well.

Depends on #121005
@TIFitis TIFitis force-pushed the users/akash/mapper_llvm_lower branch from 2730303 to b51ad10 Compare February 18, 2025 17:55
@TIFitis TIFitis merged commit 785a5b4 into main Feb 18, 2025
5 of 7 checks passed
@TIFitis TIFitis deleted the users/akash/mapper_llvm_lower branch February 18, 2025 17:55
wldfngrs pushed a commit to wldfngrs/llvm-project that referenced this pull request Feb 19, 2025
…pers (llvm#124746)

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare
Mapper directive.

Since both MLIR and Clang now support custom mappers, I've changed the
respective function params to no longer be optional as well.

Depends on llvm#121005
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Feb 20, 2025
land and revert:
785a5b4 [MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers (llvm#124746)
d6ab12c [MLIR][OpenMP] Add conversion support from FIR to LLVM Dialect for OMP DeclareMapper (llvm#121005)
886b2ed [MLIR][OpenMP] Add Lowering support for OpenMP custom mappers in map clause (llvm#121001)
ee17955 [MLIR][OpenMP] Add OMP Mapper field to MapInfoOp (llvm#120994)
TIFitis added a commit to TIFitis/llvm-project that referenced this pull request Feb 20, 2025
…pers (llvm#124746)

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare
Mapper directive.

Since both MLIR and Clang now support custom mappers, I've changed the
respective function params to no longer be optional as well.

Depends on llvm#121005
searlmc1 pushed a commit to ROCm/llvm-project that referenced this pull request Feb 22, 2025
…pers (llvm#124746)

This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare
Mapper directive.

Since both MLIR and Clang now support custom mappers, I've changed the
respective function params to no longer be optional as well.

Depends on llvm#121005
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir offload
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants