Skip to content

Commit 64b0908

Browse files
committed
Added customMapper error propagation. Updated test.
1 parent a38c84d commit 64b0908

File tree

5 files changed

+92
-67
lines changed

5 files changed

+92
-67
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8879,17 +8879,17 @@ static void emitOffloadingArraysAndArgs(
88798879
};
88808880

88818881
auto CustomMapperCB = [&](unsigned int I) {
8882-
llvm::Value *MFunc = nullptr;
8882+
llvm::Function *MFunc = nullptr;
88838883
if (CombinedInfo.Mappers[I]) {
88848884
Info.HasMapper = true;
88858885
MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(
88868886
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
88878887
}
88888888
return MFunc;
88898889
};
8890-
OMPBuilder.emitOffloadingArraysAndArgs(
8890+
cantFail(OMPBuilder.emitOffloadingArraysAndArgs(
88918891
AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB,
8892-
IsNonContiguous, ForEndCall, DeviceAddrCB);
8892+
IsNonContiguous, ForEndCall, DeviceAddrCB));
88938893
}
88948894

88958895
/// Check for inner distribute directive.
@@ -9082,26 +9082,25 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
90829082
return CombinedInfo;
90839083
};
90849084

9085-
auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) {
9085+
auto CustomMapperCB = [&](unsigned I) {
9086+
llvm::Function *MapperFunc = nullptr;
90869087
if (CombinedInfo.Mappers[I]) {
90879088
// Call the corresponding mapper function.
9088-
*MapperFunc = getOrCreateUserDefinedMapperFunc(
9089+
MapperFunc = getOrCreateUserDefinedMapperFunc(
90899090
cast<OMPDeclareMapperDecl>(CombinedInfo.Mappers[I]));
9090-
assert(*MapperFunc && "Expect a valid mapper function is available.");
9091-
return true;
9091+
assert(MapperFunc && "Expect a valid mapper function is available.");
90929092
}
9093-
return false;
9093+
return MapperFunc;
90949094
};
90959095

90969096
SmallString<64> TyStr;
90979097
llvm::raw_svector_ostream Out(TyStr);
90989098
CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out);
90999099
std::string Name = getName({"omp_mapper", TyStr, D->getName()});
91009100

9101-
llvm::Expected<llvm::Function *> NewFn = OMPBuilder.emitUserDefinedMapper(
9102-
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB);
9103-
assert(NewFn && "Unexpected error in emitUserDefinedMapper");
9104-
UDMMap.try_emplace(D, *NewFn);
9101+
llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper(
9102+
PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB));
9103+
UDMMap.try_emplace(D, NewFn);
91059104
if (CGF)
91069105
FunctionUDMMap[CGF->CurFn].push_back(D);
91079106
}
@@ -10074,7 +10073,7 @@ void CGOpenMPRuntime::emitTargetDataCalls(
1007410073
};
1007510074

1007610075
auto CustomMapperCB = [&](unsigned int I) {
10077-
llvm::Value *MFunc = nullptr;
10076+
llvm::Function *MFunc = nullptr;
1007810077
if (CombinedInfo.Mappers[I]) {
1007910078
Info.HasMapper = true;
1008010079
MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc(

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,6 +2408,11 @@ class OpenMPIRBuilder {
24082408
using EmitFallbackCallbackTy =
24092409
function_ref<InsertPointOrErrorTy(InsertPointTy)>;
24102410

2411+
// Callback function type for emitting and fetching user defined custom
2412+
// mappers.
2413+
using CustomMapperCallbackTy =
2414+
function_ref<Expected<Function *>(unsigned int)>;
2415+
24112416
/// Generate a target region entry call and host fallback call.
24122417
///
24132418
/// \param Loc The location at which the request originated and is fulfilled.
@@ -2474,9 +2479,9 @@ class OpenMPIRBuilder {
24742479
/// return nullptr by reference. Accepts a reference to a MapInfosTy object
24752480
/// that contains information generated for mappable clauses,
24762481
/// including base pointers, pointers, sizes, map types, user-defined mappers.
2477-
void emitOffloadingArrays(
2482+
Error emitOffloadingArrays(
24782483
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
2479-
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
2484+
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
24802485
bool IsNonContiguous = false,
24812486
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24822487

@@ -2486,11 +2491,11 @@ class OpenMPIRBuilder {
24862491
/// library. In essence, this function is a combination of
24872492
/// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably
24882493
/// be preferred by clients of OpenMPIRBuilder.
2489-
void emitOffloadingArraysAndArgs(
2494+
Error emitOffloadingArraysAndArgs(
24902495
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
24912496
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
2492-
function_ref<Value *(unsigned int)> CustomMapperCB,
2493-
bool IsNonContiguous = false, bool ForEndCall = false,
2497+
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false,
2498+
bool ForEndCall = false,
24942499
function_ref<void(unsigned int, Value *)> DeviceAddrCB = nullptr);
24952500

24962501
/// Creates offloading entry for the provided entry ID \a ID, address \a
@@ -2956,7 +2961,7 @@ class OpenMPIRBuilder {
29562961
InsertPointTy CodeGenIP, llvm::Value *PtrPHI, llvm::Value *BeginArg)>
29572962
PrivAndGenMapInfoCB,
29582963
llvm::Type *ElemTy, StringRef FuncName,
2959-
function_ref<bool(unsigned int, Function **)> CustomMapperCB);
2964+
CustomMapperCallbackTy CustomMapperCB);
29602965

29612966
/// Generator for '#omp target data'
29622967
///
@@ -2979,7 +2984,7 @@ class OpenMPIRBuilder {
29792984
const LocationDescription &Loc, InsertPointTy AllocaIP,
29802985
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
29812986
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
2982-
function_ref<Value *(unsigned int)> CustomMapperCB,
2987+
CustomMapperCallbackTy CustomMapperCB,
29832988
omp::RuntimeFunction *MapperFunc = nullptr,
29842989
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
29852990
BodyGenTy BodyGenType)>
@@ -3028,7 +3033,7 @@ class OpenMPIRBuilder {
30283033
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
30293034
TargetBodyGenCallbackTy BodyGenCB,
30303035
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
3031-
function_ref<Value *(unsigned int)> CustomMapperCB,
3036+
CustomMapperCallbackTy CustomMapperCB,
30323037
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
30333038

30343039
/// Returns __kmpc_for_static_init_* runtime function for the specified

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6555,8 +6555,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65556555
const LocationDescription &Loc, InsertPointTy AllocaIP,
65566556
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
65576557
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6558-
function_ref<Value *(unsigned int)> CustomMapperCB,
6559-
omp::RuntimeFunction *MapperFunc,
6558+
CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc,
65606559
function_ref<InsertPointOrErrorTy(InsertPointTy CodeGenIP,
65616560
BodyGenTy BodyGenType)>
65626561
BodyGenCB,
@@ -6585,9 +6584,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData(
65856584
auto BeginThenGen = [&](InsertPointTy AllocaIP,
65866585
InsertPointTy CodeGenIP) -> Error {
65876586
MapInfo = &GenMapInfoCB(Builder.saveIP());
6588-
emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6589-
CustomMapperCB,
6590-
/*IsNonContiguous=*/true, DeviceAddrCB);
6587+
if (Error Err = emitOffloadingArrays(
6588+
AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB,
6589+
/*IsNonContiguous=*/true, DeviceAddrCB))
6590+
return Err;
65916591

65926592
TargetDataRTArgs RTArgs;
65936593
emitOffloadingArraysArgument(Builder, RTArgs, Info);
@@ -7486,14 +7486,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask(
74867486
return Builder.saveIP();
74877487
}
74887488

7489-
void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7489+
Error OpenMPIRBuilder::emitOffloadingArraysAndArgs(
74907490
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
74917491
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo,
7492-
function_ref<Value *(unsigned int)> CustomMapperCB, bool IsNonContiguous,
7492+
CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous,
74937493
bool ForEndCall, function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
7494-
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, CustomMapperCB,
7495-
IsNonContiguous, DeviceAddrCB);
7494+
if (Error Err =
7495+
emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info,
7496+
CustomMapperCB, IsNonContiguous, DeviceAddrCB))
7497+
return Err;
74967498
emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
7499+
return Error::success();
74977500
}
74987501

74997502
static void
@@ -7505,7 +7508,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
75057508
Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
75067509
SmallVectorImpl<Value *> &Args,
75077510
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7508-
function_ref<Value *(unsigned int)> CustomMapperCB,
7511+
OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
75097512
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies,
75107513
bool HasNoWait) {
75117514
// Generate a function call to the host fallback implementation of the target
@@ -7580,10 +7583,11 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
75807583
OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
75817584
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
75827585
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7583-
OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
7584-
RTArgs, MapInfo, CustomMapperCB,
7585-
/*IsNonContiguous=*/true,
7586-
/*ForEndCall=*/false);
7586+
if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs(
7587+
AllocaIP, Builder.saveIP(), Info, RTArgs, MapInfo, CustomMapperCB,
7588+
/*IsNonContiguous=*/true,
7589+
/*ForEndCall=*/false))
7590+
return Err;
75877591

75887592
SmallVector<Value *, 3> NumTeamsC;
75897593
for (auto [DefaultVal, RuntimeVal] :
@@ -7692,8 +7696,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
76927696
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
76937697
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
76947698
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7695-
function_ref<Value *(unsigned int)> CustomMapperCB,
7696-
SmallVector<DependData> Dependencies, bool HasNowait) {
7699+
CustomMapperCallbackTy CustomMapperCB, SmallVector<DependData> Dependencies,
7700+
bool HasNowait) {
76977701

76987702
if (!updateToLocation(Loc))
76997703
return InsertPointTy();
@@ -8045,8 +8049,7 @@ Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
80458049
function_ref<MapInfosOrErrorTy(InsertPointTy CodeGenIP, llvm::Value *PtrPHI,
80468050
llvm::Value *BeginArg)>
80478051
GenMapInfoCB,
8048-
Type *ElemTy, StringRef FuncName,
8049-
function_ref<bool(unsigned int, Function **)> CustomMapperCB) {
8052+
Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) {
80508053
SmallVector<Type *> Params;
80518054
Params.emplace_back(Builder.getPtrTy());
80528055
Params.emplace_back(Builder.getPtrTy());
@@ -8226,17 +8229,19 @@ Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
82268229

82278230
Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg,
82288231
CurSizeArg, CurMapType, CurNameArg};
8229-
Function *ChildMapperFn = nullptr;
8230-
if (CustomMapperCB && CustomMapperCB(I, &ChildMapperFn)) {
8232+
8233+
auto ChildMapperFn = CustomMapperCB(I);
8234+
if (!ChildMapperFn)
8235+
return ChildMapperFn.takeError();
8236+
if (*ChildMapperFn)
82318237
// Call the corresponding mapper function.
8232-
Builder.CreateCall(ChildMapperFn, OffloadingArgs)->setDoesNotThrow();
8233-
} else {
8238+
Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow();
8239+
else
82348240
// Call the runtime API __tgt_push_mapper_component to fill up the runtime
82358241
// data structure.
82368242
Builder.CreateCall(
82378243
getOrCreateRuntimeFunction(M, OMPRTL___tgt_push_mapper_component),
82388244
OffloadingArgs);
8239-
}
82408245
}
82418246

82428247
// Update the pointer to point to the next element that needs to be mapped,
@@ -8263,9 +8268,9 @@ Expected<Function *> OpenMPIRBuilder::emitUserDefinedMapper(
82638268
return MapperFn;
82648269
}
82658270

8266-
void OpenMPIRBuilder::emitOffloadingArrays(
8271+
Error OpenMPIRBuilder::emitOffloadingArrays(
82678272
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
8268-
TargetDataInfo &Info, function_ref<Value *(unsigned int)> CustomMapperCB,
8273+
TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB,
82698274
bool IsNonContiguous,
82708275
function_ref<void(unsigned int, Value *)> DeviceAddrCB) {
82718276

@@ -8274,7 +8279,7 @@ void OpenMPIRBuilder::emitOffloadingArrays(
82748279
Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
82758280

82768281
if (Info.NumberOfPtrs == 0)
8277-
return;
8282+
return Error::success();
82788283

82798284
Builder.restoreIP(AllocaIP);
82808285
// Detect if we have any capture size requiring runtime evaluation of the
@@ -8438,9 +8443,13 @@ void OpenMPIRBuilder::emitOffloadingArrays(
84388443
// Fill up the mapper array.
84398444
unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
84408445
Value *MFunc = ConstantPointerNull::get(PtrTy);
8441-
if (CustomMapperCB)
8442-
if (Value *CustomMFunc = CustomMapperCB(I))
8443-
MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
8446+
8447+
auto CustomMFunc = CustomMapperCB(I);
8448+
if (!CustomMFunc)
8449+
return CustomMFunc.takeError();
8450+
if (*CustomMFunc)
8451+
MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy);
8452+
84448453
Value *MAddr = Builder.CreateInBoundsGEP(
84458454
MappersArray->getAllocatedType(), MappersArray,
84468455
{Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
@@ -8450,8 +8459,9 @@ void OpenMPIRBuilder::emitOffloadingArrays(
84508459

84518460
if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
84528461
Info.NumberOfPtrs == 0)
8453-
return;
8462+
return Error::success();
84548463
emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
8464+
return Error::success();
84558465
}
84568466

84578467
void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3608,16 +3608,17 @@ emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
36083608
return combinedInfo;
36093609
};
36103610

3611-
auto customMapperCB = [&](unsigned i, llvm::Function **mapperFunc) {
3611+
auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
3612+
llvm::Function *mapperFunc = nullptr;
36123613
if (combinedInfo.Mappers[i]) {
36133614
// Call the corresponding mapper function.
36143615
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
36153616
combinedInfo.Mappers[i], builder, moduleTranslation);
3616-
assert(newFn && "Expect a valid mapper function is available");
3617-
*mapperFunc = *newFn;
3618-
return true;
3617+
if (!newFn)
3618+
return newFn.takeError();
3619+
mapperFunc = *newFn;
36193620
}
3620-
return false;
3621+
return mapperFunc;
36213622
};
36223623

36233624
llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
@@ -3840,13 +3841,15 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
38403841
return builder.saveIP();
38413842
};
38423843

3843-
auto customMapperCB = [&](unsigned int i) {
3844+
auto customMapperCB =
3845+
[&](unsigned int i) -> llvm::Expected<llvm::Function *> {
38443846
llvm::Function *mapperFunc = nullptr;
38453847
if (combinedInfo.Mappers[i]) {
38463848
info.HasMapper = true;
38473849
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
38483850
combinedInfo.Mappers[i], builder, moduleTranslation);
3849-
assert(newFn && "Expect a valid mapper function is available");
3851+
if (!newFn)
3852+
return newFn.takeError();
38503853
mapperFunc = *newFn;
38513854
}
38523855
return mapperFunc;
@@ -4551,13 +4554,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
45514554
/*RequiresDevicePointerInfo=*/false,
45524555
/*SeparateBeginEndCalls=*/true);
45534556

4554-
auto customMapperCB = [&](unsigned int i) {
4555-
llvm::Value *mapperFunc = nullptr;
4557+
auto customMapperCB =
4558+
[&](unsigned int i) -> llvm::Expected<llvm::Function *> {
4559+
llvm::Function *mapperFunc = nullptr;
45564560
if (combinedInfos.Mappers[i]) {
45574561
info.HasMapper = true;
45584562
llvm::Expected<llvm::Function *> newFn = getOrCreateUserDefinedMapperFunc(
45594563
combinedInfos.Mappers[i], builder, moduleTranslation);
4560-
assert(newFn && "Expect a valid mapper function is available");
4564+
if (!newFn)
4565+
return newFn.takeError();
45614566
mapperFunc = *newFn;
45624567
}
45634568
return mapperFunc;

offload/test/offloading/fortran/target-custom-mapper.f90

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,32 @@ program test_openmp_mapper
1010
integer :: data(n)
1111
end type mytype
1212

13+
type :: mytype2
14+
type(mytype) :: my_data
15+
end type mytype2
16+
1317
! Declare custom mappers for the derived type `mytype`
14-
!$omp declare mapper(my_mapper1 : mytype :: t) map(to: t%data)
15-
!$omp declare mapper(my_mapper2 : mytype :: t) map(mapper(my_mapper1): t%data)
18+
!$omp declare mapper(my_mapper1 : mytype :: t) map(to: t%data(1 : n))
19+
20+
! Declare custom mappers for the derived type `mytype2`
21+
!$omp declare mapper(my_mapper2 : mytype2 :: t) map(mapper(my_mapper1): t%my_data)
1622

17-
type(mytype) :: obj
23+
type(mytype2) :: obj
1824
integer :: i, sum_host, sum_device
1925

2026
! Initialize the host data
2127
do i = 1, n
22-
obj%data(i) = 1
28+
obj%my_data%data(i) = 1
2329
end do
2430

2531
! Compute the sum on the host for verification
26-
sum_host = sum(obj%data)
32+
sum_host = sum(obj%my_data%data)
2733

2834
! Offload computation to the device using the named mapper `my_mapper2`
2935
sum_device = 0
3036
!$omp target map(tofrom: sum_device) map(mapper(my_mapper2) : obj)
3137
do i = 1, n
32-
sum_device = sum_device + obj%data(i)
38+
sum_device = sum_device + obj%my_data%data(i)
3339
end do
3440
!$omp end target
3541

0 commit comments

Comments
 (0)