Skip to content

Commit 2b5ef6c

Browse files
TIFitisronlieb
authored andcommitted
[MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers (llvm#124746)
This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive. Since both MLIR and Clang now support custom mappers, I've changed the respective function params to no longer be optional as well. Depends on llvm#121005
1 parent 8ea3392 commit 2b5ef6c

File tree

6 files changed

+434
-141
lines changed

6 files changed

+434
-141
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

+14-13
Original file line numberDiff line numberDiff line change
@@ -8921,17 +8921,17 @@ static void emitOffloadingArraysAndArgs(
89218921
};
89228922

89238923
auto CustomMapperCB = [&](unsigned int I) {
8924-
llvm::Value *MFunc = nullptr;
8924+
llvm::Function *MFunc = nullptr;
89258925
if (CombinedInfo.Mappers[I]) {
89268926
Info.HasMapper = true;
89278927
MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
89288928
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
89298929
}
89308930
return MFunc;
89318931
};
8932-
OMPBuilder.emitOffloadingArraysAndArgs(
8933-
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
8934-
ForEndCall, DeviceAddrCB, CustomMapperCB);
8932+
cantFail(OMPBuilder.emitOffloadingArraysAndArgs(
8933+
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
8934+
IsNonContiguous, ForEndCall, DeviceAddrCB));
89358935
}
89368936

89378937
/// Check for inner distribute directive.
@@ -9124,24 +9124,24 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
91249124
return CombinedInfo;
91259125
};
91269126

9127-
auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) {
9127+
auto CustomMapperCB = [&](unsigned I) {
9128+
llvm::Function *MapperFunc = nullptr;
91289129
if (CombinedInfo.Mappers[I]) {
91299130
// Call the corresponding mapper function.
9130-
*MapperFunc = getOrCreateUserDefinedMapperFunc(
9131+
MapperFunc = getOrCreateUserDefinedMapperFunc(
91319132
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
9132-
assert(*MapperFunc && "Expect a valid mapper function is available.");
9133-
return true;
9133+
assert(MapperFunc && "Expect a valid mapper function is available.");
91349134
}
9135-
return false;
9135+
return MapperFunc;
91369136
};
91379137

91389138
SmallString<64> TyStr;
91399139
llvm::raw_svector_ostream Out(TyStr);
91409140
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
91419141
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
91429142

9143-
auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
9144-
ElemTy, Name, CustomMapperCB);
9143+
llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper(
9144+
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB));
91459145
UDMMap.try_emplace(D, NewFn);
91469146
if (CGF)
91479147
FunctionUDMMap[CGF->CurFn].push_back(D);
@@ -10493,7 +10493,7 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1049310493
};
1049410494

1049510495
auto CustomMapperCB = [&](unsigned int I) {
10496-
llvm::Value *MFunc = nullptr;
10496+
llvm::Function *MFunc = nullptr;
1049710497
if (CombinedInfo.Mappers[I]) {
1049810498
Info.HasMapper = true;
1049910499
MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
@@ -10513,7 +10513,8 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1051310513
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
1051410514
cantFail(OMPBuilder.createTargetData(
1051510515
OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
10516-
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc));
10516+
CustomMapperCB,
10517+
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc));
1051710518
CGF.Builder.restoreIP(AfterIP);
1051810519
}
1051910520

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

+26-16
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,7 @@ class OpenMPIRBuilder {
24342434
CurInfo.NonContigInfo.Strides.end());
24352435
}
24362436
};
2437+
using MapInfosOrErrorTy = Expected<MapInfosTy &>;
24372438

24382439
/// Callback function type for functions emitting the host fallback code that
24392440
/// is executed when the kernel launch fails. It takes an insertion point as
@@ -2442,6 +2443,11 @@ class OpenMPIRBuilder {
24422443
using EmitFallbackCallbackTy =
24432444
function_ref<InsertPointOrErrorTy(InsertPointTy)>;
24442445

2446+
// Callback function type for emitting and fetching user defined custom
2447+
// mappers.
2448+
using CustomMapperCallbackTy =
2449+
function_ref<Expected<Function *>(unsigned int)>;
2450+
24452451
/// Generate a target region entry call and host fallback call.
24462452
///
24472453
/// \param Loc The location at which the request originated and is fulfilled.
@@ -2508,24 +2514,24 @@ class OpenMPIRBuilder {
25082514
/// return nullptr by reference. Accepts a reference to a MapInfosTy object
25092515
/// that contains information generated for mappable clauses,
25102516
/// including base pointers, pointers, sizes, map types, user-defined mappers.
2511-
void emitOffloadingArrays(
2517+
Error emitOffloadingArrays(
25122518
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
2513-
TargetDataInfo &Info, bool IsNonContiguous = false,
2514-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2515-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2519+
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
2520+
bool IsNonContiguous = false,
2521+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
25162522

25172523
/// Allocates memory for and populates the arrays required for offloading
25182524
/// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
25192525
/// emits their base addresses as arguments to be passed to the runtime
25202526
/// library. In essence, this function is a combination of
25212527
/// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably
25222528
/// be preferred by clients of OpenMPIRBuilder.
2523-
void emitOffloadingArraysAndArgs(
2529+
Error emitOffloadingArraysAndArgs(
25242530
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
25252531
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
2526-
bool IsNonContiguous = false, bool ForEndCall = false,
2527-
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
2528-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
2532+
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false,
2533+
bool ForEndCall = false,
2534+
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
25292535

25302536
/// Creates offloading entry for the provided entry ID \a ID, address \a
25312537
/// Addr, size \a Size, and flags \a Flags.
@@ -2993,12 +2999,12 @@ class OpenMPIRBuilder {
29932999
/// \param FuncName Optional param to specify mapper function name.
29943000
/// \param CustomMapperCB Optional callback to generate code related to
29953001
/// custom mappers.
2996-
Function *emitUserDefinedMapper(
2997-
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
2998-
llvm::Value *BeginArg)>
3002+
Expected<Function *> emitUserDefinedMapper(
3003+
function_ref<MapInfosOrErrorTy(
3004+
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
29993005
PrivAndGenMapInfoCB,
30003006
llvm::Type *ElemTy, StringRef FuncName,
3001-
function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
3007+
CustomMapperCallbackTy CustomMapperCB);
30023008

30033009
/// Generator for '#omp target data'
30043010
///
@@ -3012,21 +3018,21 @@ class OpenMPIRBuilder {
30123018
/// \param IfCond Value which corresponds to the if clause condition.
30133019
/// \param Info Stores all information realted to the Target Data directive.
30143020
/// \param GenMapInfoCB Callback that populates the MapInfos and returns.
3021+
/// \param CustomMapperCB Callback to generate code related to
3022+
/// custom mappers.
30153023
/// \param BodyGenCB Optional Callback to generate the region code.
30163024
/// \param DeviceAddrCB Optional callback to generate code related to
30173025
/// use_device_ptr and use_device_addr.
3018-
/// \param CustomMapperCB Optional callback to generate code related to
3019-
/// custom mappers.
30203026
InsertPointOrErrorTy createTargetData(
30213027
const LocationDescription &Loc, InsertPointTy AllocaIP,
30223028
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
30233029
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
3030+
CustomMapperCallbackTy CustomMapperCB,
30243031
omp::RuntimeFunction *MapperFunc = nullptr,
30253032
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
30263033
BodyGenTy BodyGenType)>
30273034
BodyGenCB = nullptr,
30283035
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
3029-
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
30303036
Value *SrcLocInfo = nullptr);
30313037

30323038
using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
@@ -3042,6 +3048,7 @@ class OpenMPIRBuilder {
30423048
/// \param IsOffloadEntry whether it is an offload entry.
30433049
/// \param CodeGenIP The insertion point where the call to the outlined
30443050
/// function should be emitted.
3051+
/// \param Info Stores all information realted to the Target directive.
30453052
/// \param EntryInfo The entry information about the function.
30463053
/// \param DefaultAttrs Structure containing the default attributes, including
30473054
/// numbers of threads and teams to launch the kernel with.
@@ -3053,20 +3060,23 @@ class OpenMPIRBuilder {
30533060
/// \param BodyGenCB Callback that will generate the region code.
30543061
/// \param ArgAccessorFuncCB Callback that will generate accessors
30553062
/// instructions for passed in target arguments where neccessary
3063+
/// \param CustomMapperCB Callback to generate code related to
3064+
/// custom mappers.
30563065
/// \param Dependencies A vector of DependData objects that carry
30573066
/// dependency information as passed in the depend clause
30583067
/// \param HasNowait Whether the target construct has a `nowait` clause or
30593068
/// not.
30603069
InsertPointOrErrorTy createTarget(
30613070
const LocationDescription &Loc, bool IsOffloadEntry,
30623071
OpenMPIRBuilder::InsertPointTy AllocaIP,
3063-
OpenMPIRBuilder::InsertPointTy CodeGenIP,
3072+
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
30643073
TargetRegionEntryInfo &EntryInfo,
30653074
const TargetKernelDefaultAttrs &DefaultAttrs,
30663075
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
30673076
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30683077
TargetBodyGenCallbackTy BodyGenCB,
30693078
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
3079+
CustomMapperCallbackTy CustomMapperCB,
30703080
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
30713081

30723082
/// Returns __kmpc_for_static_init_* runtime function for the specified

0 commit comments

Comments
 (0)