Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers #124746

Merged
merged 8 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8879,17 +8879,17 @@ static void emitOffloadingArraysAndArgs(
};

auto CustomMapperCB = [&](unsigned int I) {
llvm::Value *MFunc = nullptr;
llvm::Function *MFunc = nullptr;
if (CombinedInfo.Mappers[I]) {
Info.HasMapper = true;
MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
}
return MFunc;
};
OMPBuilder.emitOffloadingArraysAndArgs(
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous,
ForEndCall, DeviceAddrCB, CustomMapperCB);
cantFail(OMPBuilder.emitOffloadingArraysAndArgs(
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
IsNonContiguous, ForEndCall, DeviceAddrCB));
}

/// Check for inner distribute directive.
Expand Down Expand Up @@ -9082,24 +9082,24 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
return CombinedInfo;
};

auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) {
auto CustomMapperCB = [&](unsigned I) {
llvm::Function *MapperFunc = nullptr;
if (CombinedInfo.Mappers[I]) {
// Call the corresponding mapper function.
*MapperFunc = getOrCreateUserDefinedMapperFunc(
MapperFunc = getOrCreateUserDefinedMapperFunc(
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
assert(*MapperFunc && "Expect a valid mapper function is available.");
return true;
assert(MapperFunc && "Expect a valid mapper function is available.");
}
return false;
return MapperFunc;
};

SmallString<64> TyStr;
llvm::raw_svector_ostream Out(TyStr);
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
std::string Name = getName({"omp_mapper", TyStr, D->getName()});

auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB,
ElemTy, Name, CustomMapperCB);
llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper(
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB));
UDMMap.try_emplace(D, NewFn);
if (CGF)
FunctionUDMMap[CGF->CurFn].push_back(D);
Expand Down Expand Up @@ -10073,7 +10073,7 @@ void CGOpenMPRuntime::emitTargetDataCalls(
};

auto CustomMapperCB = [&](unsigned int I) {
llvm::Value *MFunc = nullptr;
llvm::Function *MFunc = nullptr;
if (CombinedInfo.Mappers[I]) {
Info.HasMapper = true;
MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
Expand All @@ -10093,7 +10093,8 @@ void CGOpenMPRuntime::emitTargetDataCalls(
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
cantFail(OMPBuilder.createTargetData(
OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB,
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc));
CustomMapperCB,
/*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc));
CGF.Builder.restoreIP(AfterIP);
}

Expand Down
42 changes: 26 additions & 16 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2399,6 +2399,7 @@ class OpenMPIRBuilder {
CurInfo.NonContigInfo.Strides.end());
}
};
using MapInfosOrErrorTy = Expected<MapInfosTy &>;

/// Callback function type for functions emitting the host fallback code that
/// is executed when the kernel launch fails. It takes an insertion point as
Expand All @@ -2407,6 +2408,11 @@ class OpenMPIRBuilder {
using EmitFallbackCallbackTy =
function_ref<InsertPointOrErrorTy(InsertPointTy)>;

// Callback function type for emitting and fetching user defined custom
// mappers.
using CustomMapperCallbackTy =
function_ref<Expected<Function *>(unsigned int)>;

/// Generate a target region entry call and host fallback call.
///
/// \param Loc The location at which the request originated and is fulfilled.
Expand Down Expand Up @@ -2473,24 +2479,24 @@ class OpenMPIRBuilder {
/// return nullptr by reference. Accepts a reference to a MapInfosTy object
/// that contains information generated for mappable clauses,
/// including base pointers, pointers, sizes, map types, user-defined mappers.
void emitOffloadingArrays(
Error emitOffloadingArrays(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
TargetDataInfo &Info, bool IsNonContiguous = false,
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
bool IsNonContiguous = false,
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);

/// Allocates memory for and populates the arrays required for offloading
/// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it
/// emits their base addresses as arguments to be passed to the runtime
/// library. In essence, this function is a combination of
/// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably
/// be preferred by clients of OpenMPIRBuilder.
void emitOffloadingArraysAndArgs(
Error emitOffloadingArraysAndArgs(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
bool IsNonContiguous = false, bool ForEndCall = false,
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr);
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false,
bool ForEndCall = false,
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);

/// Creates offloading entry for the provided entry ID \a ID, address \a
/// Addr, size \a Size, and flags \a Flags.
Expand Down Expand Up @@ -2950,12 +2956,12 @@ class OpenMPIRBuilder {
/// \param FuncName Optional param to specify mapper function name.
/// \param CustomMapperCB Optional callback to generate code related to
/// custom mappers.
Function *emitUserDefinedMapper(
function_ref<MapInfosTy &(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
llvm::Value *BeginArg)>
Expected<Function *> emitUserDefinedMapper(
function_ref<MapInfosOrErrorTy(
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
PrivAndGenMapInfoCB,
llvm::Type *ElemTy, StringRef FuncName,
function_ref<bool(unsigned int, Function **)> CustomMapperCB = nullptr);
CustomMapperCallbackTy CustomMapperCB);

/// Generator for '#omp target data'
///
Expand All @@ -2969,21 +2975,21 @@ class OpenMPIRBuilder {
/// \param IfCond Value which corresponds to the if clause condition.
/// \param Info Stores all information realted to the Target Data directive.
/// \param GenMapInfoCB Callback that populates the MapInfos and returns.
/// \param CustomMapperCB Callback to generate code related to
/// custom mappers.
/// \param BodyGenCB Optional Callback to generate the region code.
/// \param DeviceAddrCB Optional callback to generate code related to
/// use_device_ptr and use_device_addr.
/// \param CustomMapperCB Optional callback to generate code related to
/// custom mappers.
InsertPointOrErrorTy createTargetData(
const LocationDescription &Loc, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
CustomMapperCallbackTy CustomMapperCB,
omp::RuntimeFunction *MapperFunc = nullptr,
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
BodyGenTy BodyGenType)>
BodyGenCB = nullptr,
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr,
function_ref<Value *(unsigned int)> CustomMapperCB = nullptr,
Value *SrcLocInfo = nullptr);

using TargetBodyGenCallbackTy = function_ref<InsertPointOrErrorTy(
Expand All @@ -2999,6 +3005,7 @@ class OpenMPIRBuilder {
/// \param IsOffloadEntry whether it is an offload entry.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param Info Stores all information realted to the Target directive.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default attributes, including
/// numbers of threads and teams to launch the kernel with.
Expand All @@ -3010,20 +3017,23 @@ class OpenMPIRBuilder {
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
/// \param CustomMapperCB Callback to generate code related to
/// custom mappers.
/// \param Dependencies A vector of DependData objects that carry
/// dependency information as passed in the depend clause
/// \param HasNowait Whether the target construct has a `nowait` clause or
/// not.
InsertPointOrErrorTy createTarget(
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP, TargetDataInfo &Info,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
CustomMapperCallbackTy CustomMapperCB,
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);

/// Returns __kmpc_for_static_init_* runtime function for the specified
Expand Down
Loading