diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index 9f7db25a15bec..c2289985e9519 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -8879,7 +8879,7 @@ static void emitOffloadingArraysAndArgs( }; auto CustomMapperCB = [&](unsigned int I) { - llvm::Value *MFunc = nullptr; + llvm::Function *MFunc = nullptr; if (CombinedInfo.Mappers[I]) { Info.HasMapper = true; MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc( @@ -8887,9 +8887,9 @@ static void emitOffloadingArraysAndArgs( } return MFunc; }; - OMPBuilder.emitOffloadingArraysAndArgs( - AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous, - ForEndCall, DeviceAddrCB, CustomMapperCB); + cantFail(OMPBuilder.emitOffloadingArraysAndArgs( + AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB, + IsNonContiguous, ForEndCall, DeviceAddrCB)); } /// Check for inner distribute directive. @@ -9082,15 +9082,15 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D, return CombinedInfo; }; - auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) { + auto CustomMapperCB = [&](unsigned I) { + llvm::Function *MapperFunc = nullptr; if (CombinedInfo.Mappers[I]) { // Call the corresponding mapper function. - *MapperFunc = getOrCreateUserDefinedMapperFunc( + MapperFunc = getOrCreateUserDefinedMapperFunc( cast(CombinedInfo.Mappers[I])); - assert(*MapperFunc && "Expect a valid mapper function is available."); - return true; + assert(MapperFunc && "Expect a valid mapper function is available."); } - return false; + return MapperFunc; }; SmallString<64> TyStr; @@ -9098,8 +9098,8 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D, CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out); std::string Name = getName({"omp_mapper", TyStr, D->getName()}); - auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB, - ElemTy, Name, CustomMapperCB); + llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper( + PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB)); UDMMap.try_emplace(D, NewFn); if (CGF) FunctionUDMMap[CGF->CurFn].push_back(D); @@ -10073,7 +10073,7 @@ void CGOpenMPRuntime::emitTargetDataCalls( }; auto CustomMapperCB = [&](unsigned int I) { - llvm::Value *MFunc = nullptr; + llvm::Function *MFunc = nullptr; if (CombinedInfo.Mappers[I]) { Info.HasMapper = true; MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc( @@ -10093,7 +10093,8 @@ void CGOpenMPRuntime::emitTargetDataCalls( llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(OMPBuilder.createTargetData( OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB, - /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc)); + CustomMapperCB, + /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc)); CGF.Builder.restoreIP(AfterIP); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index d25077cae63e4..33b3d7bad4a71 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2399,6 +2399,7 @@ class OpenMPIRBuilder { CurInfo.NonContigInfo.Strides.end()); } }; + using MapInfosOrErrorTy = Expected; /// Callback function type for functions emitting the host fallback code that /// is executed when the kernel launch fails. It takes an insertion point as @@ -2407,6 +2408,11 @@ class OpenMPIRBuilder { using EmitFallbackCallbackTy = function_ref; + // Callback function type for emitting and fetching user defined custom + // mappers. + using CustomMapperCallbackTy = + function_ref(unsigned int)>; + /// Generate a target region entry call and host fallback call. /// /// \param Loc The location at which the request originated and is fulfilled. @@ -2473,11 +2479,11 @@ class OpenMPIRBuilder { /// return nullptr by reference. Accepts a reference to a MapInfosTy object /// that contains information generated for mappable clauses, /// including base pointers, pointers, sizes, map types, user-defined mappers. - void emitOffloadingArrays( + Error emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, - TargetDataInfo &Info, bool IsNonContiguous = false, - function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr); + TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB, + bool IsNonContiguous = false, + function_ref DeviceAddrCB = nullptr); /// Allocates memory for and populates the arrays required for offloading /// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it @@ -2485,12 +2491,12 @@ class OpenMPIRBuilder { /// library. In essence, this function is a combination of /// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably /// be preferred by clients of OpenMPIRBuilder. - void emitOffloadingArraysAndArgs( + Error emitOffloadingArraysAndArgs( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info, TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, - bool IsNonContiguous = false, bool ForEndCall = false, - function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr); + CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false, + bool ForEndCall = false, + function_ref DeviceAddrCB = nullptr); /// Creates offloading entry for the provided entry ID \a ID, address \a /// Addr, size \a Size, and flags \a Flags. @@ -2950,12 +2956,12 @@ class OpenMPIRBuilder { /// \param FuncName Optional param to specify mapper function name. /// \param CustomMapperCB Optional callback to generate code related to /// custom mappers. - Function *emitUserDefinedMapper( - function_ref + Expected emitUserDefinedMapper( + function_ref PrivAndGenMapInfoCB, llvm::Type *ElemTy, StringRef FuncName, - function_ref CustomMapperCB = nullptr); + CustomMapperCallbackTy CustomMapperCB); /// Generator for '#omp target data' /// @@ -2969,21 +2975,21 @@ class OpenMPIRBuilder { /// \param IfCond Value which corresponds to the if clause condition. /// \param Info Stores all information realted to the Target Data directive. /// \param GenMapInfoCB Callback that populates the MapInfos and returns. + /// \param CustomMapperCB Callback to generate code related to + /// custom mappers. /// \param BodyGenCB Optional Callback to generate the region code. /// \param DeviceAddrCB Optional callback to generate code related to /// use_device_ptr and use_device_addr. - /// \param CustomMapperCB Optional callback to generate code related to - /// custom mappers. InsertPointOrErrorTy createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, + CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc = nullptr, function_ref BodyGenCB = nullptr, function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr, Value *SrcLocInfo = nullptr); using TargetBodyGenCallbackTy = function_ref &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, + CustomMapperCallbackTy CustomMapperCB, SmallVector Dependencies = {}, bool HasNowait = false); /// Returns __kmpc_for_static_init_* runtime function for the specified diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 7ba23b0bd377e..18bc82fc827f7 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6555,12 +6555,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, - omp::RuntimeFunction *MapperFunc, + CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc, function_ref BodyGenCB, - function_ref DeviceAddrCB, - function_ref CustomMapperCB, Value *SrcLocInfo) { + function_ref DeviceAddrCB, Value *SrcLocInfo) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -6585,9 +6584,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData( auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) -> Error { MapInfo = &GenMapInfoCB(Builder.saveIP()); - emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info, - /*IsNonContiguous=*/true, DeviceAddrCB, - CustomMapperCB); + if (Error Err = emitOffloadingArrays( + AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB, + /*IsNonContiguous=*/true, DeviceAddrCB)) + return Err; TargetDataRTArgs RTArgs; emitOffloadingArraysArgument(Builder, RTArgs, Info); @@ -7486,26 +7486,31 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( return Builder.saveIP(); } -void OpenMPIRBuilder::emitOffloadingArraysAndArgs( +Error OpenMPIRBuilder::emitOffloadingArraysAndArgs( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info, - TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous, - bool ForEndCall, function_ref DeviceAddrCB, - function_ref CustomMapperCB) { - emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous, - DeviceAddrCB, CustomMapperCB); + TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, + CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous, + bool ForEndCall, function_ref DeviceAddrCB) { + if (Error Err = + emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, + CustomMapperCB, IsNonContiguous, DeviceAddrCB)) + return Err; emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall); + return Error::success(); } static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::TargetDataInfo &Info, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, - SmallVector Dependencies = {}, - bool HasNoWait = false) { + OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB, + SmallVector Dependencies, + bool HasNoWait) { // Generate a function call to the host fallback implementation of the target // region. This is called by the host when no offload entry was generated for // the target region and when the offloading call fails at runtime. @@ -7576,16 +7581,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, auto &&EmitTargetCallThen = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP()); OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info, - RTArgs, MapInfo, - /*IsNonContiguous=*/true, - /*ForEndCall=*/false); + if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs( + AllocaIP, Builder.saveIP(), Info, RTArgs, MapInfo, CustomMapperCB, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false)) + return Err; SmallVector NumTeamsC; for (auto [DefaultVal, RuntimeVal] : @@ -7687,13 +7689,15 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, - InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, + InsertPointTy CodeGenIP, TargetDataInfo &Info, + TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, - SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, + SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector Dependencies, bool HasNowait) { + CustomMapperCallbackTy CustomMapperCB, SmallVector Dependencies, + bool HasNowait) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -7707,16 +7711,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // and ArgAccessorFuncCB if (Error Err = emitTargetOutlinedFunction( *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) + OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB)) return Err; // If we are not on the target device, then we need to generate code // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond, - OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, - HasNowait); + emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs, + IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB, + CustomMapperCB, Dependencies, HasNowait); return Builder.saveIP(); } @@ -8041,12 +8045,11 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel( OffloadingArgs); } -Function *OpenMPIRBuilder::emitUserDefinedMapper( - function_ref +Expected OpenMPIRBuilder::emitUserDefinedMapper( + function_ref GenMapInfoCB, - Type *ElemTy, StringRef FuncName, - function_ref CustomMapperCB) { + Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) { SmallVector Params; Params.emplace_back(Builder.getPtrTy()); Params.emplace_back(Builder.getPtrTy()); @@ -8117,7 +8120,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( PtrPHI->addIncoming(PtrBegin, HeadBB); // Get map clause information. Fill up the arrays with all mapped variables. - MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn); + MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn); + if (!Info) + return Info.takeError(); // Call the runtime API __tgt_mapper_num_components to get the number of // pre-existing components. @@ -8129,20 +8134,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset())); // Fill up the runtime mapper handle for all components. - for (unsigned I = 0; I < Info.BasePointers.size(); ++I) { + for (unsigned I = 0; I < Info->BasePointers.size(); ++I) { Value *CurBaseArg = - Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy()); + Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy()); Value *CurBeginArg = - Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy()); - Value *CurSizeArg = Info.Sizes[I]; - Value *CurNameArg = Info.Names.size() - ? Info.Names[I] + Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy()); + Value *CurSizeArg = Info->Sizes[I]; + Value *CurNameArg = Info->Names.size() + ? Info->Names[I] : Constant::getNullValue(Builder.getPtrTy()); // Extract the MEMBER_OF field from the map type. Value *OriMapType = Builder.getInt64( static_cast>( - Info.Types[I])); + Info->Types[I])); Value *MemberMapType = Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize); @@ -8224,10 +8229,13 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg, CurSizeArg, CurMapType, CurNameArg}; - Function *ChildMapperFn = nullptr; - if (CustomMapperCB && CustomMapperCB(I, &ChildMapperFn)) { + + auto ChildMapperFn = CustomMapperCB(I); + if (!ChildMapperFn) + return ChildMapperFn.takeError(); + if (*ChildMapperFn) { // Call the corresponding mapper function. - Builder.CreateCall(ChildMapperFn, OffloadingArgs)->setDoesNotThrow(); + Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow(); } else { // Call the runtime API __tgt_push_mapper_component to fill up the runtime // data structure. @@ -8261,18 +8269,18 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( return MapperFn; } -void OpenMPIRBuilder::emitOffloadingArrays( +Error OpenMPIRBuilder::emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, - TargetDataInfo &Info, bool IsNonContiguous, - function_ref DeviceAddrCB, - function_ref CustomMapperCB) { + TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB, + bool IsNonContiguous, + function_ref DeviceAddrCB) { // Reset the array information. Info.clearArrayInfo(); Info.NumberOfPtrs = CombinedInfo.BasePointers.size(); if (Info.NumberOfPtrs == 0) - return; + return Error::success(); Builder.restoreIP(AllocaIP); // Detect if we have any capture size requiring runtime evaluation of the @@ -8436,9 +8444,13 @@ void OpenMPIRBuilder::emitOffloadingArrays( // Fill up the mapper array. unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0); Value *MFunc = ConstantPointerNull::get(PtrTy); - if (CustomMapperCB) - if (Value *CustomMFunc = CustomMapperCB(I)) - MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy); + + auto CustomMFunc = CustomMapperCB(I); + if (!CustomMFunc) + return CustomMFunc.takeError(); + if (*CustomMFunc) + MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy); + Value *MAddr = Builder.CreateInBoundsGEP( MappersArray->getAllocatedType(), MappersArray, {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)}); @@ -8448,8 +8460,9 @@ void OpenMPIRBuilder::emitOffloadingArrays( if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() || Info.NumberOfPtrs == 0) - return; + return Error::success(); emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info); + return Error::success(); } void OpenMPIRBuilder::emitBranch(BasicBlock *Target) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index 83c8f7e932b2b..a1ea7849d7c0c 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5928,6 +5928,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/false, /*SeparateBeginEndCalls=*/true); @@ -5939,7 +5940,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTargetData( Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); + /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc)); Builder.restoreIP(AfterIP); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -5990,6 +5991,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/false, /*SeparateBeginEndCalls=*/true); @@ -6001,7 +6003,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTargetData( Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); + /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc)); Builder.restoreIP(AfterIP); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6074,6 +6076,7 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/true, /*SeparateBeginEndCalls=*/true); @@ -6110,9 +6113,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, TargetDataIP1, - OMPBuilder.createTargetData( - Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyCB)); + OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(), + Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, + CustomMapperCB, nullptr, BodyCB)); Builder.restoreIP(TargetDataIP1); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6138,9 +6142,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { }; ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, TargetDataIP2, - OMPBuilder.createTargetData( - Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyTargetCB)); + OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(), + Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, + CustomMapperCB, nullptr, BodyTargetCB)); Builder.restoreIP(TargetDataIP2); EXPECT_TRUE(CheckDevicePassBodyGen); @@ -6241,6 +6246,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; @@ -6254,9 +6264,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, + Builder.saveIP(), Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB, + CustomMapperCB, {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); @@ -6400,6 +6411,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6419,13 +6431,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB, + {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); @@ -6549,6 +6565,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { F->setName("func"); IRBuilder<> Builder(BB); + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](InsertPointTy, InsertPointTy CodeGenIP) -> InsertPointTy { Builder.restoreIP(CodeGenIP); @@ -6576,13 +6593,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, + Builder.saveIP(), Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB, + CustomMapperCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6663,6 +6684,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6679,13 +6701,16 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6779,6 +6804,7 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { llvm::Value *RaiseAlloca = nullptr; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6799,13 +6825,17 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB, + {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 51a3cbdbb5e7f..b9d88a68410ee 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2809,13 +2809,23 @@ getRefPtrIfDeclareTarget(mlir::Value value, } namespace { +// Append customMappers information to existing MapInfosTy +struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy { + SmallVector Mappers; + + /// Append arrays in \a CurInfo. + void append(MapInfosTy &curInfo) { + Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end()); + llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo); + } +}; // A small helper structure to contain data gathered // for map lowering and coalese it into one area and // avoiding extra computations such as searches in the // llvm module for lowered mapped variables or checking // if something is declare target (and retrieving the // value) more than neccessary. -struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { +struct MapInfoData : MapInfosTy { llvm::SmallVector IsDeclareTarget; llvm::SmallVector IsAMember; // Identify if mapping was added by mapClause or use_device clauses. @@ -2834,7 +2844,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { OriginalValue.append(CurInfo.OriginalValue.begin(), CurInfo.OriginalValue.end()); BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end()); - llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo); + MapInfosTy::append(CurInfo); } }; } // namespace @@ -2955,6 +2965,12 @@ static void collectMapDataFromMapOperands( mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None); + if (mapOp.getMapperId()) + mapData.Mappers.push_back( + SymbolTable::lookupNearestSymbolFrom( + mapOp, mapOp.getMapperIdAttr())); + else + mapData.Mappers.push_back(nullptr); mapData.IsAMapping.push_back(true); mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp)); } @@ -2999,6 +3015,7 @@ static void collectMapDataFromMapOperands( mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(devInfoTy); + mapData.Mappers.push_back(nullptr); mapData.IsAMapping.push_back(false); mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp)); } @@ -3164,9 +3181,8 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation, // inside of CGOpenMPRuntime.cpp static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, bool isTargetParams) { + llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) { // Map the first segment of our structure combinedInfo.Types.emplace_back( isTargetParams @@ -3174,6 +3190,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); @@ -3237,6 +3254,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); + combinedInfo.Mappers.emplace_back(nullptr); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); @@ -3270,9 +3288,9 @@ static bool checkIfPointerMap(omp::MapInfoOp mapOp) { // This function is intended to add explicit mappings of members static void processMapMembersWithParent( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) { + llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, + llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) { auto parentClause = llvm::cast(mapData.MapClause[mapDataIndex]); @@ -3300,6 +3318,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); + combinedInfo.Mappers.emplace_back(nullptr); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3322,6 +3341,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[memberDataIdx]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); uint64_t basePointerIndex = @@ -3341,10 +3361,9 @@ static void processMapMembersWithParent( } } -static void -processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, - bool isTargetParams, int mapDataParentIdx = -1) { +static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, + MapInfosTy &combinedInfo, bool isTargetParams, + int mapDataParentIdx = -1) { // Declare Target Mappings are excluded from being marked as // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're // marked with OMP_MAP_PTR_AND_OBJ instead. @@ -3374,16 +3393,18 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]); combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]); combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]); combinedInfo.Types.emplace_back(mapFlag); combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]); } -static void processMapWithMembersOf( - LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, bool isTargetParams) { +static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, + DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, + bool isTargetParams) { auto parentClause = llvm::cast(mapData.MapClause[mapDataIndex]); @@ -3488,8 +3509,7 @@ createAlteredByCaptureMap(MapInfoData &mapData, // Generate all map related information and fill the combinedInfo. static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, + DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, bool isTargetParams = false) { // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can @@ -3529,6 +3549,78 @@ static void genMapInfos(llvm::IRBuilderBase &builder, } } +static llvm::Expected +emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef mapperFuncName); + +static llvm::Expected +getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast(op); + std::string mapperFuncName = + moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + + if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName)) + return lookupFunc; + + return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation, + mapperFuncName); +} + +static llvm::Expected +emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef mapperFuncName) { + auto declMapperOp = cast(op); + auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo(); + DataLayout dl = DataLayout(declMapperOp->getParentOfType()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType()); + SmallVector mapVars = declMapperInfoOp.getMapVars(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + // Fill up the arrays with all the mapped variables. + MapInfosTy combinedInfo; + auto genMapInfoCB = + [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI, + llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy { + builder.restoreIP(codeGenIP); + moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI); + moduleTranslation.mapBlock(&declMapperOp.getRegion().front(), + builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(), + /*ignoreArguments=*/true, + builder))) + return llvm::make_error(); + MapInfoData mapData; + collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, + builder); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData); + + // Drop the mapping that is no longer necessary so that the same region can + // be processed multiple times. + moduleTranslation.forgetMapping(declMapperOp.getRegion()); + return combinedInfo; + }; + + auto customMapperCB = [&](unsigned i) -> llvm::Expected { + if (!combinedInfo.Mappers[i]) + return nullptr; + return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, + moduleTranslation); + }; + + llvm::Expected newFn = ompBuilder->emitUserDefinedMapper( + genMapInfoCB, varType, mapperFuncName, customMapperCB); + if (!newFn) + return newFn.takeError(); + moduleTranslation.mapFunction(mapperFuncName, *newFn); + return *newFn; +} + static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -3640,9 +3732,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, builder, useDevicePtrVars, useDeviceAddrVars); // Fill up the arrays with all the mapped variables. - llvm::OpenMPIRBuilder::MapInfosTy combinedInfo; - auto genMapInfoCB = - [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { + MapInfosTy combinedInfo; + auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData); return combinedInfo; @@ -3685,6 +3776,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy; auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { + builder.restoreIP(codeGenIP); assert(isa(op) && "BodyGen requested for non TargetDataOp"); auto blockArgIface = cast(op); @@ -3693,8 +3785,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, case BodyGenTy::Priv: // Check if any device ptr/addr info is available if (!info.DevicePtrInfoMap.empty()) { - builder.restoreIP(codeGenIP); - mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, blockArgIface.getUseDeviceAddrBlockArgs(), useDeviceAddrVars, mapData, @@ -3724,7 +3814,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, case BodyGenTy::NoPriv: // If device info is available then region has already been generated if (info.DevicePtrInfoMap.empty()) { - builder.restoreIP(codeGenIP); // For device pass, if use_device_ptr(addr) mappings were present, // we need to link them here before codegen. if (ompBuilder->Config.IsTargetDevice.value_or(false)) { @@ -3745,17 +3834,28 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return builder.saveIP(); }; + auto customMapperCB = + [&](unsigned int i) -> llvm::Expected { + if (!combinedInfo.Mappers[i]) + return nullptr; + info.HasMapper = true; + return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, + moduleTranslation); + }; + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() { if (isa(op)) - return ompBuilder->createTargetData( - ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), - ifCond, info, genMapInfoCB, nullptr, bodyGenCB); - return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), - builder.getInt64(deviceID), ifCond, - info, genMapInfoCB, &RTLFn); + return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), + builder.getInt64(deviceID), ifCond, + info, genMapInfoCB, customMapperCB, + /*MapperFunc=*/nullptr, bodyGenCB, + /*DeviceAddrCB=*/nullptr); + return ompBuilder->createTargetData( + ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond, + info, genMapInfoCB, customMapperCB, &RTLFn); }(); if (failed(handleError(afterIP, *op))) @@ -4367,9 +4467,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, builder); - llvm::OpenMPIRBuilder::MapInfosTy combinedInfos; - auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) - -> llvm::OpenMPIRBuilder::MapInfosTy & { + MapInfosTy combinedInfos; + auto genMapInfoCB = + [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, true); return combinedInfos; @@ -4438,15 +4538,28 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder::TargetDataInfo info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + auto customMapperCB = + [&](unsigned int i) -> llvm::Expected { + if (!combinedInfos.Mappers[i]) + return nullptr; + info.HasMapper = true; + return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder, + moduleTranslation); + }; + llvm::Value *ifCond = nullptr; if (Value targetIfCond = targetOp.getIfExpr()) ifCond = moduleTranslation.lookupValue(targetIfCond); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, + ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB, - argAccessorCB, dds, targetOp.getNowait()); + argAccessorCB, customMapperCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); @@ -4673,12 +4786,15 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::TaskwaitOp op) { return convertOmpTaskwaitOp(op, builder, moduleTranslation); }) - .Case([](auto op) { // `yield` and `terminator` can be just omitted. The block structure // was created in the region that handles their parent operation. // `declare_reduction` will be used by reductions and is not // converted directly, skip it. + // `declare_mapper` and `declare_mapper.info` are handled whenever they + // are referred to through a `map` clause. // `critical.declare` is only used to declare names of critical // sections which will be used by `critical` ops and hence can be // ignored for lowering. The OpenMP IRBuilder will create unique diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index 7f21095763a39..02b84ff66a0d3 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -485,3 +485,120 @@ llvm.func @_QPopenmp_target_data_update() { // CHECK: call void @__tgt_target_data_update_mapper(ptr @2, i64 -1, i32 1, ptr %[[BASEPTRS_VAL_2]], ptr %[[PTRS_VAL_2]], ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr null) // CHECK: ret void + +// ----- + +omp.declare_mapper @_QQFmy_testmy_mapper : !llvm.struct<"_QFmy_testTmy_type", (i32)> { +^bb0(%arg0: !llvm.ptr): + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)> + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%data"} + %3 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%2 : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) +} + +llvm.func @_QPopenmp_target_data_mapper() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<"_QFmy_testTmy_type", (i32)> {bindc_name = "a"} : (i64) -> !llvm.ptr + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) mapper(@_QQFmy_testmy_mapper) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "a"} + omp.target_data map_entries(%2 : !llvm.ptr) { + %3 = llvm.mlir.constant(10 : i32) : i32 + %4 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)> + llvm.store %3, %4 : i32, !llvm.ptr + omp.terminator + } + llvm.return +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 4] +// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK-LABEL: define void @_QPopenmp_target_data_mapper +// CHECK: %[[VAL_0:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_1:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_2:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_3:.*]] = alloca %[[VAL_4:.*]], i64 1, align 8 +// CHECK: br label %[[VAL_5:.*]] +// CHECK: entry: ; preds = %[[VAL_6:.*]] +// CHECK: %[[VAL_7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_7]], align 8 +// CHECK: %[[VAL_8:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_8]], align 8 +// CHECK: %[[VAL_9:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_2]], i64 0, i64 0 +// CHECK: store ptr @.omp_mapper._QQFmy_testmy_mapper, ptr %[[VAL_9]], align 8 +// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: %[[VAL_11:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: call void @__tgt_target_data_begin_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]]) +// CHECK: %[[VAL_12:.*]] = getelementptr %[[VAL_4]], ptr %[[VAL_3]], i32 0, i32 0 +// CHECK: store i32 10, ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_13:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: call void @__tgt_target_data_end_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_13]], ptr %[[VAL_14]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]]) +// CHECK: ret void + +// CHECK-LABEL: define internal void @.omp_mapper._QQFmy_testmy_mapper +// CHECK: entry: +// CHECK: %[[VAL_15:.*]] = udiv exact i64 %[[VAL_16:.*]], 4 +// CHECK: %[[VAL_17:.*]] = getelementptr %[[VAL_18:.*]], ptr %[[VAL_19:.*]], i64 %[[VAL_15]] +// CHECK: %[[VAL_20:.*]] = icmp sgt i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_21:.*]] = and i64 %[[VAL_22:.*]], 8 +// CHECK: %[[VAL_23:.*]] = icmp ne ptr %[[VAL_24:.*]], %[[VAL_19]] +// CHECK: %[[VAL_25:.*]] = and i64 %[[VAL_22]], 16 +// CHECK: %[[VAL_26:.*]] = icmp ne i64 %[[VAL_25]], 0 +// CHECK: %[[VAL_27:.*]] = and i1 %[[VAL_23]], %[[VAL_26]] +// CHECK: %[[VAL_28:.*]] = or i1 %[[VAL_20]], %[[VAL_27]] +// CHECK: %[[VAL_29:.*]] = icmp eq i64 %[[VAL_21]], 0 +// CHECK: %[[VAL_30:.*]] = and i1 %[[VAL_28]], %[[VAL_29]] +// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]] +// CHECK: .omp.array..init: ; preds = %[[VAL_33:.*]] +// CHECK: %[[VAL_34:.*]] = mul nuw i64 %[[VAL_15]], 4 +// CHECK: %[[VAL_35:.*]] = and i64 %[[VAL_22]], -4 +// CHECK: %[[VAL_36:.*]] = or i64 %[[VAL_35]], 512 +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37:.*]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_34]], i64 %[[VAL_36]], ptr %[[VAL_38:.*]]) +// CHECK: br label %[[VAL_32]] +// CHECK: omp.arraymap.head: ; preds = %[[VAL_31]], %[[VAL_33]] +// CHECK: %[[VAL_39:.*]] = icmp eq ptr %[[VAL_19]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_39]], label %[[VAL_40:.*]], label %[[VAL_41:.*]] +// CHECK: omp.arraymap.body: ; preds = %[[VAL_42:.*]], %[[VAL_32]] +// CHECK: %[[VAL_43:.*]] = phi ptr [ %[[VAL_19]], %[[VAL_32]] ], [ %[[VAL_44:.*]], %[[VAL_42]] ] +// CHECK: %[[VAL_45:.*]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 0, i32 0 +// CHECK: %[[VAL_46:.*]] = call i64 @__tgt_mapper_num_components(ptr %[[VAL_37]]) +// CHECK: %[[VAL_47:.*]] = shl i64 %[[VAL_46]], 48 +// CHECK: %[[VAL_48:.*]] = add nuw i64 3, %[[VAL_47]] +// CHECK: %[[VAL_49:.*]] = and i64 %[[VAL_22]], 3 +// CHECK: %[[VAL_50:.*]] = icmp eq i64 %[[VAL_49]], 0 +// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_52:.*]] +// CHECK: omp.type.alloc: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_53:.*]] = and i64 %[[VAL_48]], -4 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.alloc.else: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_54:.*]] = icmp eq i64 %[[VAL_49]], 1 +// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]] +// CHECK: omp.type.to: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_57:.*]] = and i64 %[[VAL_48]], -3 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.to.else: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_58:.*]] = icmp eq i64 %[[VAL_49]], 2 +// CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_42]] +// CHECK: omp.type.from: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_60:.*]] = and i64 %[[VAL_48]], -2 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.end: ; preds = %[[VAL_59]], %[[VAL_56]], %[[VAL_55]], %[[VAL_51]] +// CHECK: %[[VAL_61:.*]] = phi i64 [ %[[VAL_53]], %[[VAL_51]] ], [ %[[VAL_57]], %[[VAL_55]] ], [ %[[VAL_60]], %[[VAL_59]] ], [ %[[VAL_48]], %[[VAL_56]] ] +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_43]], ptr %[[VAL_45]], i64 4, i64 %[[VAL_61]], ptr @2) +// CHECK: %[[VAL_44]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 1 +// CHECK: %[[VAL_62:.*]] = icmp eq ptr %[[VAL_44]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_62]], label %[[VAL_63:.*]], label %[[VAL_41]] +// CHECK: omp.arraymap.exit: ; preds = %[[VAL_42]] +// CHECK: %[[VAL_64:.*]] = icmp sgt i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_65:.*]] = and i64 %[[VAL_22]], 8 +// CHECK: %[[VAL_66:.*]] = icmp ne i64 %[[VAL_65]], 0 +// CHECK: %[[VAL_67:.*]] = and i1 %[[VAL_64]], %[[VAL_66]] +// CHECK: br i1 %[[VAL_67]], label %[[VAL_68:.*]], label %[[VAL_40]] +// CHECK: .omp.array..del: ; preds = %[[VAL_63]] +// CHECK: %[[VAL_69:.*]] = mul nuw i64 %[[VAL_15]], 4 +// CHECK: %[[VAL_70:.*]] = and i64 %[[VAL_22]], -4 +// CHECK: %[[VAL_71:.*]] = or i64 %[[VAL_70]], 512 +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_69]], i64 %[[VAL_71]], ptr %[[VAL_38]]) +// CHECK: br label %[[VAL_40]] +// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]] +// CHECK: ret void diff --git a/offload/test/offloading/fortran/target-custom-mapper.f90 b/offload/test/offloading/fortran/target-custom-mapper.f90 new file mode 100644 index 0000000000000..9c527861c87b5 --- /dev/null +++ b/offload/test/offloading/fortran/target-custom-mapper.f90 @@ -0,0 +1,53 @@ +! Offloading test checking lowering of arrays with dynamic extents. +! REQUIRES: flang, amdgpu + +! RUN: %libomptarget-compile-fortran-run-and-check-generic + +program test_openmp_mapper + implicit none + integer, parameter :: n = 1024 + type :: mytype + integer :: data(n) + end type mytype + + type :: mytype2 + type(mytype) :: my_data + end type mytype2 + + ! Declare custom mappers for the derived type `mytype` + !$omp declare mapper(my_mapper1 : mytype :: t) map(to: t%data(1 : n)) + + ! Declare custom mappers for the derived type `mytype2` + !$omp declare mapper(my_mapper2 : mytype2 :: t) map(mapper(my_mapper1): t%my_data) + + type(mytype2) :: obj + integer :: i, sum_host, sum_device + + ! Initialize the host data + do i = 1, n + obj%my_data%data(i) = 1 + end do + + ! Compute the sum on the host for verification + sum_host = sum(obj%my_data%data) + + ! Offload computation to the device using the named mapper `my_mapper2` + sum_device = 0 + !$omp target map(tofrom: sum_device) map(mapper(my_mapper2) : obj) + do i = 1, n + sum_device = sum_device + obj%my_data%data(i) + end do + !$omp end target + + ! Check results + print *, "Sum on host: ", sum_host + print *, "Sum on device: ", sum_device + + if (sum_device == sum_host) then + print *, "Test passed!" + else + print *, "Test failed!" + end if + end program test_openmp_mapper + +! CHECK: Test passed!