Skip to content

Commit 785a5b4

Browse files
authored
[MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers (#124746)
This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive. Since both MLIR and Clang now support custom mappers, I've changed the respective function params to no longer be optional as well. Depends on #121005
1 parent d6ab12c commit 785a5b4

File tree

7 files changed

+478
-138
lines changed

7 files changed

+478
-138
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

+14-13
Original file line numberDiff line numberDiff line change
@@ -8879,17 +8879,17 @@ static void emitOffloadingArraysAndArgs(
88798879
};
88808880

88818881
auto CustomMapperCB = [&](unsigned int I) {
8882-
llvm::Value *MFunc = nullptr;
8882+
llvm::Function *MFunc = nullptr;
88838883
if (CombinedInfo.Mappers[I]) {
88848884
Info.HasMapper = true;
88858885
MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
88868886
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
88878887
}
88888888
return MFunc;
88898889
};
8890-
OMPBuilder.emitOffloadingArraysAndArgs(
8891-
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
8892-
ForEndCall, DeviceAddrCB, CustomMapperCB);
8890+
cantFail(OMPBuilder.emitOffloadingArraysAndArgs(
8891+
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
8892+
IsNonContiguous, ForEndCall, DeviceAddrCB));
88938893
}
88948894

88958895
/// Check for inner distribute directive.
@@ -9082,24 +9082,24 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
90829082
return CombinedInfo;
90839083
};
90849084

9085-
auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) {
9085+
auto CustomMapperCB = [&](unsigned I) {
9086+
llvm::Function *MapperFunc = nullptr;
90869087
if (CombinedInfo.Mappers[I]) {
90879088
// Call the corresponding mapper function.
9088-
*MapperFunc = getOrCreateUserDefinedMapperFunc(
9089+
MapperFunc = getOrCreateUserDefinedMapperFunc(
90899090
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
9090-
assert(*MapperFunc && "Expect a valid mapper function is available.");
9091-
return true;
9091+
assert(MapperFunc && "Expect a valid mapper function is available.");
90929092
}
9093-
return false;
9093+
return MapperFunc;
90949094
};
90959095

90969096
SmallString<64> TyStr;
90979097
llvm::raw_svector_ostream Out(TyStr);
90989098
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
90999099
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
91009100

9101-
auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
9102-
ElemTy, Name, CustomMapperCB);
9101+
llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper(
9102+
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB));
91039103
UDMMap.try_emplace(D, NewFn);
91049104
if (CGF)
91059105
FunctionUDMMap[CGF->CurFn].push_back(D);
@@ -10073,7 +10073,7 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1007310073
};
1007410074

1007510075
auto CustomMapperCB = [&](unsigned int I) {
10076-
llvm::Value *MFunc = nullptr;
10076+
llvm::Function *MFunc = nullptr;
1007710077
if (CombinedInfo.Mappers[I]) {
1007810078
Info.HasMapper = true;
1007910079
MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
@@ -10093,7 +10093,8 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1009310093
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
1009410094
cantFail(OMPBuilder.createTargetData(
1009510095
OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
10096-
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc));
10096+
CustomMapperCB,
10097+
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc));
1009710098
CGF.Builder.restoreIP(AfterIP);
1009810099
}
1009910100

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

+26-16
Original file line numberDiff line numberDiff line change
@@ -2399,6 +2399,7 @@ class OpenMPIRBuilder {
23992399
CurInfo.NonContigInfo.Strides.end());
24002400
}
24012401
};
2402+
using MapInfosOrErrorTy = Expected<MapInfosTy &>;
24022403

24032404
/// Callback function type for functions emitting the host fallback code that
24042405
/// is executed when the kernel launch fails. It takes an insertion point as
@@ -2407,6 +2408,11 @@ class OpenMPIRBuilder {
24072408
using EmitFallbackCallbackTy =
24082409
function_ref<InsertPointOrErrorTy(InsertPointTy)>;
24092410

2411+
// Callback function type for emitting and fetching user defined custom
2412+
// mappers.
2413+
using CustomMapperCallbackTy =
2414+
function_ref<Expected<Function *>(unsigned int)>;
2415+
24102416
/// Generate a target region entry call and host fallback call.
24112417
///
24122418
/// \param Loc The location at which the request originated and is fulfilled.
@@ -2473,24 +2479,24 @@ class OpenMPIRBuilder {
24732479
/// return nullptr by reference. Accepts a reference to a MapInfosTy object
24742480
/// that contains information generated for mappable clauses,
24752481
/// including base pointers, pointers, sizes, map types, user-defined mappers.
2476-
void emitOffloadingArrays(
2482+
Error emitOffloadingArrays(
24772483
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
2478-
TargetDataInfo &Info, bool IsNonContiguous = false,
2479-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2480-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2484+
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
2485+
bool IsNonContiguous = false,
2486+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24812487

24822488
/// Allocates memory for and populates the arrays required for offloading
24832489
/// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
24842490
/// emits their base addresses as arguments to be passed to the runtime
24852491
/// library. In essence, this function is a combination of
24862492
/// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably
24872493
/// be preferred by clients of OpenMPIRBuilder.
2488-
void emitOffloadingArraysAndArgs(
2494+
Error emitOffloadingArraysAndArgs(
24892495
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
24902496
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
2491-
bool IsNonContiguous = false, bool ForEndCall = false,
2492-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2493-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2497+
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false,
2498+
bool ForEndCall = false,
2499+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24942500

24952501
/// Creates offloading entry for the provided entry ID \a ID, address \a
24962502
/// Addr, size \a Size, and flags \a Flags.
@@ -2950,12 +2956,12 @@ class OpenMPIRBuilder {
29502956
/// \param FuncName Optional param to specify mapper function name.
29512957
/// \param CustomMapperCB Optional callback to generate code related to
29522958
/// custom mappers.
2953-
Function *emitUserDefinedMapper(
2954-
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
2955-
llvm::Value *BeginArg)>
2959+
Expected<Function *> emitUserDefinedMapper(
2960+
function_ref<MapInfosOrErrorTy(
2961+
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
29562962
PrivAndGenMapInfoCB,
29572963
llvm::Type *ElemTy, StringRef FuncName,
2958-
function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
2964+
CustomMapperCallbackTy CustomMapperCB);
29592965

29602966
/// Generator for '#omp target data'
29612967
///
@@ -2969,21 +2975,21 @@ class OpenMPIRBuilder {
29692975
/// \param IfCond Value which corresponds to the if clause condition.
29702976
/// \param Info Stores all information realted to the Target Data directive.
29712977
/// \param GenMapInfoCB Callback that populates the MapInfos and returns.
2978+
/// \param CustomMapperCB Callback to generate code related to
2979+
/// custom mappers.
29722980
/// \param BodyGenCB Optional Callback to generate the region code.
29732981
/// \param DeviceAddrCB Optional callback to generate code related to
29742982
/// use_device_ptr and use_device_addr.
2975-
/// \param CustomMapperCB Optional callback to generate code related to
2976-
/// custom mappers.
29772983
InsertPointOrErrorTy createTargetData(
29782984
const LocationDescription &Loc, InsertPointTy AllocaIP,
29792985
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
29802986
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
2987+
CustomMapperCallbackTy CustomMapperCB,
29812988
omp::RuntimeFunction *MapperFunc = nullptr,
29822989
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
29832990
BodyGenTy BodyGenType)>
29842991
BodyGenCB = nullptr,
29852992
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2986-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
29872993
Value *SrcLocInfo = nullptr);
29882994

29892995
using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -2999,6 +3005,7 @@ class OpenMPIRBuilder {
29993005
/// \param IsOffloadEntry whether it is an offload entry.
30003006
/// \param CodeGenIP The insertion point where the call to the outlined
30013007
/// function should be emitted.
3008+
/// \param Info Stores all information realted to the Target directive.
30023009
/// \param EntryInfo The entry information about the function.
30033010
/// \param DefaultAttrs Structure containing the default attributes, including
30043011
/// numbers of threads and teams to launch the kernel with.
@@ -3010,20 +3017,23 @@ class OpenMPIRBuilder {
30103017
/// \param BodyGenCB Callback that will generate the region code.
30113018
/// \param ArgAccessorFuncCB Callback that will generate accessors
30123019
/// instructions for passed in target arguments where neccessary
3020+
/// \param CustomMapperCB Callback to generate code related to
3021+
/// custom mappers.
30133022
/// \param Dependencies A vector of DependData objects that carry
30143023
/// dependency information as passed in the depend clause
30153024
/// \param HasNowait Whether the target construct has a `nowait` clause or
30163025
/// not.
30173026
InsertPointOrErrorTy createTarget(
30183027
const LocationDescription &Loc, bool IsOffloadEntry,
30193028
OpenMPIRBuilder::InsertPointTy AllocaIP,
3020-
OpenMPIRBuilder::InsertPointTy CodeGenIP,
3029+
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
30213030
TargetRegionEntryInfo &EntryInfo,
30223031
const TargetKernelDefaultAttrs &DefaultAttrs,
30233032
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
30243033
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30253034
TargetBodyGenCallbackTy BodyGenCB,
30263035
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
3036+
CustomMapperCallbackTy CustomMapperCB,
30273037
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
30283038

30293039
/// Returns __kmpc_for_static_init_* runtime function for the specified

0 commit comments

Comments
 (0)