diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index b8ee943754bd6..97d0dc5180291 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -8921,7 +8921,7 @@ static void emitOffloadingArraysAndArgs( }; auto CustomMapperCB = [&](unsigned int I) { - llvm::Value *MFunc = nullptr; + llvm::Function *MFunc = nullptr; if (CombinedInfo.Mappers[I]) { Info.HasMapper = true; MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc( @@ -8929,9 +8929,9 @@ static void emitOffloadingArraysAndArgs( } return MFunc; }; - OMPBuilder.emitOffloadingArraysAndArgs( - AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous, - ForEndCall, DeviceAddrCB, CustomMapperCB); + cantFail(OMPBuilder.emitOffloadingArraysAndArgs( + AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB, + IsNonContiguous, ForEndCall, DeviceAddrCB)); } /// Check for inner distribute directive. @@ -9124,15 +9124,15 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D, return CombinedInfo; }; - auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) { + auto CustomMapperCB = [&](unsigned I) { + llvm::Function *MapperFunc = nullptr; if (CombinedInfo.Mappers[I]) { // Call the corresponding mapper function. - *MapperFunc = getOrCreateUserDefinedMapperFunc( + MapperFunc = getOrCreateUserDefinedMapperFunc( cast(CombinedInfo.Mappers[I])); - assert(*MapperFunc && "Expect a valid mapper function is available."); - return true; + assert(MapperFunc && "Expect a valid mapper function is available."); } - return false; + return MapperFunc; }; SmallString<64> TyStr; @@ -9140,8 +9140,8 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D, CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out); std::string Name = getName({"omp_mapper", TyStr, D->getName()}); - auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB, - ElemTy, Name, CustomMapperCB); + llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper( + PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB)); UDMMap.try_emplace(D, NewFn); if (CGF) FunctionUDMMap[CGF->CurFn].push_back(D); @@ -10493,7 +10493,7 @@ void CGOpenMPRuntime::emitTargetDataCalls( }; auto CustomMapperCB = [&](unsigned int I) { - llvm::Value *MFunc = nullptr; + llvm::Function *MFunc = nullptr; if (CombinedInfo.Mappers[I]) { Info.HasMapper = true; MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc( @@ -10513,7 +10513,8 @@ void CGOpenMPRuntime::emitTargetDataCalls( llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(OMPBuilder.createTargetData( OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB, - /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc)); + CustomMapperCB, + /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc)); CGF.Builder.restoreIP(AfterIP); } diff --git a/flang/include/flang/Lower/OpenMP/Utils.h b/flang/include/flang/Lower/OpenMP/Utils.h index f2e378443e5f2..3943eb633b04e 100644 --- a/flang/include/flang/Lower/OpenMP/Utils.h +++ b/flang/include/flang/Lower/OpenMP/Utils.h @@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, llvm::ArrayRef members, mlir::ArrayAttr membersIndex, uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap = false); + bool partialMap = false, + mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()); void insertChildMapInfoIntoParent( Fortran::lower::AbstractConverter &converter, diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index a946d44c5d3b7..0707ebef1f0d9 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -969,8 +969,11 @@ void ClauseProcessor::processMapObjects( llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl &mapSyms) const { + llvm::SmallVectorImpl &mapSyms, + llvm::StringRef mapperIdNameRef) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::FlatSymbolRefAttr mapperId; + std::string mapperIdName = mapperIdNameRef.str(); for (const omp::Object &object : objects) { llvm::SmallVector bounds; @@ -1003,6 +1006,20 @@ void ClauseProcessor::processMapObjects( } } + if (!mapperIdName.empty()) { + if (mapperIdName == "default") { + auto &typeSpec = object.sym()->owner().IsDerivedType() + ? *object.sym()->owner().derivedTypeSpec() + : object.sym()->GetType()->derivedTypeSpec(); + mapperIdName = typeSpec.name().ToString() + ".default"; + mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope()); + } + assert(converter.getModuleOp().lookupSymbol(mapperIdName) && + "mapper not found"); + mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperIdName); + mapperIdName.clear(); + } // Explicit map captures are captured ByRef by default, // optimisation passes may alter this to ByCopy or other capture // types to optimise @@ -1016,7 +1033,8 @@ void ClauseProcessor::processMapObjects( static_cast< std::underlying_type_t>( mapTypeBits), - mlir::omp::VariableCaptureKind::ByRef, baseOp.getType()); + mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(), + /*partialMap=*/false, mapperId); if (parentObj.has_value()) { parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent( @@ -1047,6 +1065,7 @@ bool ClauseProcessor::processMap( const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t; llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + std::string mapperIdName; // If the map type is specified, then process it else Tofrom is the // default. Map::MapType type = mapType.value_or(Map::MapType::Tofrom); @@ -1090,13 +1109,17 @@ bool ClauseProcessor::processMap( "Support for iterator modifiers is not implemented yet"); } if (mappers) { - TODO(currentLocation, - "Support for mapper modifiers is not implemented yet"); + assert(mappers->size() == 1 && "more than one mapper"); + mapperIdName = mappers->front().v.id().symbol->name().ToString(); + if (mapperIdName != "default") + mapperIdName = converter.mangleName( + mapperIdName, mappers->front().v.id().symbol->owner()); } processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, - parentMemberIndices, result.mapVars, *ptrMapSyms); + parentMemberIndices, result.mapVars, *ptrMapSyms, + mapperIdName); }; bool clauseFound = findRepeatableClause(process); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 4d2825725c3d2..3ea8b72117186 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -176,7 +176,8 @@ class ClauseProcessor { llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl &mapSyms) const; + llvm::SmallVectorImpl &mapSyms, + llvm::StringRef mapperIdNameRef = "") const; lower::AbstractConverter &converter; semantics::SemanticsContext &semaCtx; diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b436b4ad1eb39..72884cef7cfa0 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -130,7 +130,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, llvm::ArrayRef members, mlir::ArrayAttr membersIndex, uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap) { + bool partialMap, mlir::FlatSymbolRefAttr mapperId) { if (auto boxTy = llvm::dyn_cast(baseAddr.getType())) { baseAddr = builder.create(loc, baseAddr); retTy = baseAddr.getType(); @@ -149,6 +149,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, mlir::omp::MapInfoOp op = builder.create( loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + mapperId, builder.getAttr(mapCaptureType), builder.getStringAttr(name), builder.getBoolAttr(partialMap)); return op; diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 9d18536df3af2..2501cd7c2919f 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -51,7 +51,8 @@ mlir::omp::MapInfoOp createMapInfoOp( mlir::Value varPtrPtr, std::string name, llvm::ArrayRef bounds, llvm::ArrayRef members, mlir::ArrayAttr membersIndex, uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, - mlir::Type retTy, bool partialMap = false) { + mlir::Type retTy, bool partialMap = false, + mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()) { if (auto boxTy = llvm::dyn_cast(baseAddr.getType())) { baseAddr = builder.create(loc, baseAddr); retTy = baseAddr.getType(); @@ -70,6 +71,7 @@ mlir::omp::MapInfoOp createMapInfoOp( mlir::omp::MapInfoOp op = builder.create( loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + mapperId, builder.getAttr(mapCaptureType), builder.getStringAttr(name), builder.getBoolAttr(partialMap)); diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 7828a466d265d..2ad54fed07602 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -184,6 +184,7 @@ class MapInfoFinalizationPass /*members=*/mlir::SmallVector{}, /*membersIndex=*/mlir::ArrayAttr{}, bounds, builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + /*mapperId*/ mlir::FlatSymbolRefAttr(), builder.getAttr( mlir::omp::VariableCaptureKind::ByRef), /*name=*/builder.getStringAttr(""), @@ -331,7 +332,8 @@ class MapInfoFinalizationPass builder.getIntegerAttr( builder.getIntegerType(64, false), getDescriptorMapType(op.getMapType().value_or(0), target)), - op.getMapCaptureTypeAttr(), op.getNameAttr(), + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(), + op.getNameAttr(), /*partial_map=*/builder.getBoolAttr(false)); op.replaceAllUsesWith(newDescParentMapOp.getResult()); op->erase(); @@ -629,6 +631,7 @@ class MapInfoFinalizationPass // /*members=*/mlir::ValueRange{}, // /*members_index=*/mlir::ArrayAttr{}, // /*bounds=*/bounds, op.getMapTypeAttr(), + // /*mapperId*/ mlir::FlatSymbolRefAttr(), // builder.getAttr( // mlir::omp::VariableCaptureKind::ByRef), // builder.getStringAttr(op.getNameAttr().strref() + "." + diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 963ae863c1fc5..97ea463a3c495 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass /*bounds=*/ValueRange{}, builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), mapTypeTo), + /*mapperId*/ mlir::FlatSymbolRefAttr(), builder.getAttr( omp::VariableCaptureKind::ByRef), StringAttr(), builder.getBoolAttr(false)); diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index 22867cae5a7a5..1985a62523d6b 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -936,9 +936,9 @@ func.func @omp_map_info_descriptor_type_conversion(%arg0 : !fir.ref>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr> {name = ""} // CHECK: %[[DESC_MAP:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, delete) capture(ByRef) members(%[[MEMBER_MAP]] : [0] : !llvm.ptr) -> !llvm.ptr {name = ""} %2 = omp.map.info var_ptr(%arg0 : !fir.ref>>, !fir.box>) map_clauses(always, delete) capture(ByRef) members(%1 : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = ""} - // CHECK: omp.target_exit_data map_entries(%[[DESC_MAP]] : !llvm.ptr) + // CHECK: omp.target_exit_data map_entries(%[[DESC_MAP]] : !llvm.ptr) omp.target_exit_data map_entries(%2 : !fir.ref>>) - return + return } // ----- @@ -956,8 +956,8 @@ func.func @omp_map_info_derived_type_explicit_member_conversion(%arg0 : !fir.ref %3 = fir.field_index real, !fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}> %4 = fir.coordinate_of %arg0, %3 : (!fir.ref,int:i32}>>, !fir.field) -> !fir.ref // CHECK: %[[MAP_MEMBER_2:.*]] = omp.map.info var_ptr(%[[GEP_2]] : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "dtype%real"} - %5 = omp.map.info var_ptr(%4 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "dtype%real"} - // CHECK: %[[MAP_PARENT:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<"_QFderived_type", (f32, array<10 x i32>, i32)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBER_1]], %[[MAP_MEMBER_2]] : [2], [0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "dtype", partial_map = true} + %5 = omp.map.info var_ptr(%4 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "dtype%real"} + // CHECK: %[[MAP_PARENT:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<"_QFderived_type", (f32, array<10 x i32>, i32)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBER_1]], %[[MAP_MEMBER_2]] : [2], [0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "dtype", partial_map = true} %6 = omp.map.info var_ptr(%arg0 : !fir.ref,int:i32}>>, !fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) members(%2, %5 : [2], [0] : !fir.ref, !fir.ref) -> !fir.ref,int:i32}>> {name = "dtype", partial_map = true} // CHECK: omp.target map_entries(%[[MAP_MEMBER_1]] -> %[[ARG_1:.*]], %[[MAP_MEMBER_2]] -> %[[ARG_2:.*]], %[[MAP_PARENT]] -> %[[ARG_3:.*]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) { omp.target map_entries(%2 -> %arg1, %5 -> %arg2, %6 -> %arg3 : !fir.ref, !fir.ref, !fir.ref,int:i32}>>) { @@ -1279,3 +1279,22 @@ func.func @map_nested_dtype_alloca_mem2(%arg0 : !fir.ref { +omp.declare_mapper @my_mapper : !fir.type<_QFdeclare_mapperTmy_type{data:i32}> { +// CHECK: ^bb0(%[[VAL_0:.*]]: !llvm.ptr): +^bb0(%0: !fir.ref>): +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32 + %1 = fir.field_index data, !fir.type<_QFdeclare_mapperTmy_type{data:i32}> +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> + %2 = fir.coordinate_of %0, %1 : (!fir.ref>, !fir.field) -> !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%[[VAL_4:.*]]"} + %3 = omp.map.info var_ptr(%2 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "var%data"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]] : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + %4 = omp.map.info var_ptr(%0 : !fir.ref>, !fir.type<_QFdeclare_mapperTmy_type{data:i32}>) map_clauses(tofrom) capture(ByRef) members(%3 : [0] : !fir.ref) -> !fir.ref> {name = "var", partial_map = true} +// CHECK: omp.declare_mapper.info map_entries(%[[VAL_5]], %[[VAL_3]] : !llvm.ptr, !llvm.ptr) + omp.declare_mapper.info map_entries(%4, %3 : !fir.ref>, !fir.ref) +// CHECK: } +} diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90 index f271233cff8fd..efb9f4b024112 100644 --- a/flang/test/Lower/OpenMP/declare-mapper.f90 +++ b/flang/test/Lower/OpenMP/declare-mapper.f90 @@ -1,8 +1,9 @@ ! This test checks lowering of OpenMP declare mapper Directive. -! XFAIL: * + ! RUN: split-file %s %t ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90 +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90 !--- omp-declare-mapper-1.f90 subroutine declare_mapper_1 @@ -40,9 +41,11 @@ subroutine declare_mapper_1 !CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index) !CHECK: %[[VAL_17:.*]] = arith.constant 1 : index !CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref>>> - !CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, !fir.box>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.ref>>> {name = "var%[[VAL_20:.*]](1:var%[[VAL_21:.*]])"} - !CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_19]] : [1] : !fir.ref>>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"} - !CHECK: omp.declare_mapper.info map_entries(%[[VAL_22]] : !fir.ref<[[MY_TYPE]]>) + !CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + !CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr>> {name = ""} + !CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, !fir.box>>) map_clauses(always, to) capture(ByRef) -> !fir.ref>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"} + !CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref>>>, !fir.llvm_ptr>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"} + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref>>>, !fir.llvm_ptr>>) !CHECK: } !$omp declare mapper (my_type :: var) map (var, var%values (1:var%num_vals)) end subroutine declare_mapper_1 @@ -77,7 +80,66 @@ subroutine declare_mapper_2 !CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"temp"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref>>}>> !CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>>}>>, !fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box>>}>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref>>}>> {name = "v%[[VAL_13:.*]]"} !CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_9]], %[[VAL_12]] : [3], [1] : !fir.ref>, !fir.ref>>}>>) -> !fir.ref<[[MY_TYPE]]> {name = "v", partial_map = true} - !CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]] : !fir.ref<[[MY_TYPE]]>) + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_9]], %[[VAL_12]] : !fir.ref<[[MY_TYPE]]>, !fir.ref>, !fir.ref>>}>>) !CHECK: } !$omp declare mapper (my_mapper : my_type2 :: v) map (v%arr) map (alloc : v%temp) end subroutine declare_mapper_2 + +!--- omp-declare-mapper-3.f90 +subroutine declare_mapper_3 + type my_type + integer :: num_vals + integer, allocatable :: values(:) + end type + + type my_type2 + type(my_type) :: my_type_var + real, dimension(250) :: arr + end type + + !CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER2:_QQFdeclare_mapper_3my_mapper2]] : [[MY_TYPE2:!fir\.type<_QFdeclare_mapper_3Tmy_type2\{my_type_var:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box>>}>,arr:!fir\.array<250xf32>}>]] { + !CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE2]]>): + !CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Ev"} : (!fir.ref<[[MY_TYPE2]]>) -> (!fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE2]]>) + !CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"my_type_var"} : (!fir.ref<[[MY_TYPE2]]>) -> !fir.ref<[[MY_TYPE:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box>>}>]]> + !CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) mapper(@[[MY_TYPE_MAPPER:_QQFdeclare_mapper_3my_mapper]]) map_clauses(tofrom) capture(ByRef) -> !fir.ref<[[MY_TYPE]]> {name = "v%[[VAL_4:.*]]"} + !CHECK: %[[VAL_5:.*]] = arith.constant 250 : index + !CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1> + !CHECK: %[[VAL_7:.*]] = hlfir.designate %[[VAL_1]]#0{"arr"} shape %[[VAL_6]] : (!fir.ref<[[MY_TYPE2]]>, !fir.shape<1>) -> !fir.ref> + !CHECK: %[[VAL_8:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_9:.*]] = arith.constant 0 : index + !CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_8]] : index + !CHECK: %[[VAL_11:.*]] = omp.map.bounds lower_bound(%[[VAL_9]] : index) upper_bound(%[[VAL_10]] : index) extent(%[[VAL_5]] : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_8]] : index) + !CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref>, !fir.array<250xf32>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_11]]) -> !fir.ref> {name = "v%[[VAL_13:.*]]"} + !CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE2]]>, [[MY_TYPE2]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_12]] : [0], [1] : !fir.ref<[[MY_TYPE]]>, !fir.ref>) -> !fir.ref<[[MY_TYPE2]]> {name = "v", partial_map = true} + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_3]], %[[VAL_12]] : !fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE]]>, !fir.ref>) + !CHECK: } + + !CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER]] : [[MY_TYPE]] { + !CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE]]>): + !CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Evar"} : (!fir.ref<[[MY_TYPE]]>) -> (!fir.ref<[[MY_TYPE]]>, !fir.ref<[[MY_TYPE]]>) + !CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"values"} {fortran_attrs = #fir.var_attrs} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref>>> + !CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref>>> + !CHECK: %[[VAL_4:.*]] = fir.box_addr %[[VAL_3]] : (!fir.box>>) -> !fir.heap> + !CHECK: %[[VAL_5:.*]] = arith.constant 0 : index + !CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_3]], %[[VAL_5]] : (!fir.box>>, index) -> (index, index, index) + !CHECK: %[[VAL_7:.*]] = arith.constant 0 : index + !CHECK: %[[VAL_8:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_9:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]], %[[VAL_6]]#0 : index + !CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"num_vals"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref + !CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref + !CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (i32) -> i64 + !CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (i64) -> index + !CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_14]], %[[VAL_6]]#0 : index + !CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index) + !CHECK: %[[VAL_17:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref>>> + !CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + !CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr>> {name = ""} + !CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, !fir.box>>) map_clauses(always, to) capture(ByRef) -> !fir.ref>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"} + !CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref>>>, !fir.llvm_ptr>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"} + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref>>>, !fir.llvm_ptr>>) + !CHECK: } + !$omp declare mapper (my_mapper : my_type :: var) map (var, var%values (1:var%num_vals)) + !$omp declare mapper (my_mapper2 : my_type2 :: v) map (mapper(my_mapper) : v%my_type_var) map (tofrom : v%arr) +end subroutine declare_mapper_3 diff --git a/flang/test/Lower/OpenMP/map-mapper.f90 b/flang/test/Lower/OpenMP/map-mapper.f90 new file mode 100644 index 0000000000000..0d8fe7344bfab --- /dev/null +++ b/flang/test/Lower/OpenMP/map-mapper.f90 @@ -0,0 +1,30 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s +program p + integer, parameter :: n = 256 + type t1 + integer :: x(256) + end type t1 + + !$omp declare mapper(xx : t1 :: nn) map(to: nn, nn%x) + !$omp declare mapper(t1 :: nn) map(from: nn) + + !CHECK-LABEL: omp.declare_mapper @_QQFt1.default : !fir.type<_QFTt1{x:!fir.array<256xi32>}> + !CHECK-LABEL: omp.declare_mapper @_QQFxx : !fir.type<_QFTt1{x:!fir.array<256xi32>}> + + type(t1) :: a, b + !CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFxx) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"} + !CHECK: omp.target map_entries(%[[MAP_A]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) { + !$omp target map(mapper(xx) : a) + do i = 1, n + a%x(i) = i + end do + !$omp end target + + !CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFt1.default) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "b"} + !CHECK: omp.target map_entries(%[[MAP_B]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) { + !$omp target map(mapper(default) : b) + do i = 1, n + b%x(i) = i + end do + !$omp end target +end program p diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index bddf9aedf9d9d..fc63acd36b05f 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2434,6 +2434,7 @@ class OpenMPIRBuilder { CurInfo.NonContigInfo.Strides.end()); } }; + using MapInfosOrErrorTy = Expected; /// Callback function type for functions emitting the host fallback code that /// is executed when the kernel launch fails. It takes an insertion point as @@ -2442,6 +2443,11 @@ class OpenMPIRBuilder { using EmitFallbackCallbackTy = function_ref; + // Callback function type for emitting and fetching user defined custom + // mappers. + using CustomMapperCallbackTy = + function_ref(unsigned int)>; + /// Generate a target region entry call and host fallback call. /// /// \param Loc The location at which the request originated and is fulfilled. @@ -2508,11 +2514,11 @@ class OpenMPIRBuilder { /// return nullptr by reference. Accepts a reference to a MapInfosTy object /// that contains information generated for mappable clauses, /// including base pointers, pointers, sizes, map types, user-defined mappers. - void emitOffloadingArrays( + Error emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, - TargetDataInfo &Info, bool IsNonContiguous = false, - function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr); + TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB, + bool IsNonContiguous = false, + function_ref DeviceAddrCB = nullptr); /// Allocates memory for and populates the arrays required for offloading /// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it @@ -2520,12 +2526,12 @@ class OpenMPIRBuilder { /// library. In essence, this function is a combination of /// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably /// be preferred by clients of OpenMPIRBuilder. - void emitOffloadingArraysAndArgs( + Error emitOffloadingArraysAndArgs( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info, TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, - bool IsNonContiguous = false, bool ForEndCall = false, - function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr); + CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false, + bool ForEndCall = false, + function_ref DeviceAddrCB = nullptr); /// Creates offloading entry for the provided entry ID \a ID, address \a /// Addr, size \a Size, and flags \a Flags. @@ -2993,12 +2999,12 @@ class OpenMPIRBuilder { /// \param FuncName Optional param to specify mapper function name. /// \param CustomMapperCB Optional callback to generate code related to /// custom mappers. - Function *emitUserDefinedMapper( - function_ref + Expected emitUserDefinedMapper( + function_ref PrivAndGenMapInfoCB, llvm::Type *ElemTy, StringRef FuncName, - function_ref CustomMapperCB = nullptr); + CustomMapperCallbackTy CustomMapperCB); /// Generator for '#omp target data' /// @@ -3012,21 +3018,21 @@ class OpenMPIRBuilder { /// \param IfCond Value which corresponds to the if clause condition. /// \param Info Stores all information realted to the Target Data directive. /// \param GenMapInfoCB Callback that populates the MapInfos and returns. + /// \param CustomMapperCB Callback to generate code related to + /// custom mappers. /// \param BodyGenCB Optional Callback to generate the region code. /// \param DeviceAddrCB Optional callback to generate code related to /// use_device_ptr and use_device_addr. - /// \param CustomMapperCB Optional callback to generate code related to - /// custom mappers. InsertPointOrErrorTy createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, + CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc = nullptr, function_ref BodyGenCB = nullptr, function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr, Value *SrcLocInfo = nullptr); using TargetBodyGenCallbackTy = function_ref &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, + CustomMapperCallbackTy CustomMapperCB, SmallVector Dependencies = {}, bool HasNowait = false); /// Returns __kmpc_for_static_init_* runtime function for the specified diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 2a1948a540e4e..9d3c3c580d6cc 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6375,7 +6375,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( int32_t MaxThreadsVal = Attrs.MaxThreads.front(); #if FIX_NUM_THREADS_ISSUE - //breaks 534.hpgmg + //breaks 534.hpgmg // If MaxThreads not set, select the maximum between the default workgroup // size and the MinThreads value. if (MaxThreadsVal < 0) @@ -6703,12 +6703,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, - omp::RuntimeFunction *MapperFunc, + CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc, function_ref BodyGenCB, - function_ref DeviceAddrCB, - function_ref CustomMapperCB, Value *SrcLocInfo) { + function_ref DeviceAddrCB, Value *SrcLocInfo) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -6732,9 +6731,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData( auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) -> Error { MapInfo = &GenMapInfoCB(Builder.saveIP()); - emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info, - /*IsNonContiguous=*/true, DeviceAddrCB, - CustomMapperCB); + if (Error Err = emitOffloadingArrays( + AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB, + /*IsNonContiguous=*/true, DeviceAddrCB)) + return Err; TargetDataRTArgs RTArgs; emitOffloadingArraysArgument(Builder, RTArgs, Info); @@ -7657,26 +7657,31 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( return Builder.saveIP(); } -void OpenMPIRBuilder::emitOffloadingArraysAndArgs( +Error OpenMPIRBuilder::emitOffloadingArraysAndArgs( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info, - TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous, - bool ForEndCall, function_ref DeviceAddrCB, - function_ref CustomMapperCB) { - emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous, - DeviceAddrCB, CustomMapperCB); + TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, + CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous, + bool ForEndCall, function_ref DeviceAddrCB) { + if (Error Err = + emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, + CustomMapperCB, IsNonContiguous, DeviceAddrCB)) + return Err; emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall); + return Error::success(); } static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::TargetDataInfo &Info, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, - SmallVector Dependencies = {}, - bool HasNoWait = false) { + OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB, + SmallVector Dependencies, + bool HasNoWait) { // Generate a function call to the host fallback implementation of the target // region. This is called by the host when no offload entry was generated for // the target region and when the offloading call fails at runtime. @@ -7747,15 +7752,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, auto &&EmitTargetCallThen = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(CodeGenIP); OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, CodeGenIP, Info, RTArgs, - MapInfo, /*IsNonContiguous=*/true, - /*ForEndCall=*/false); + if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs( + AllocaIP, CodeGenIP, Info, RTArgs, MapInfo, CustomMapperCB, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false)) + return Err; SmallVector NumTeamsC; for (auto [DefaultVal, RuntimeVal] : @@ -7857,13 +7860,15 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, - InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, + InsertPointTy CodeGenIP, TargetDataInfo &Info, + TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, - SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, + SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector Dependencies, bool HasNowait) { + CustomMapperCallbackTy CustomMapperCB, SmallVector Dependencies, + bool HasNowait) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -7877,16 +7882,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // and ArgAccessorFuncCB if (Error Err = emitTargetOutlinedFunction( *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) + OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB)) return Err; // If we are not on the target device, then we need to generate code // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond, - OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, - HasNowait); + emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs, + IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB, + CustomMapperCB, Dependencies, HasNowait); return Builder.saveIP(); } @@ -8211,12 +8216,11 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel( OffloadingArgs); } -Function *OpenMPIRBuilder::emitUserDefinedMapper( - function_ref +Expected OpenMPIRBuilder::emitUserDefinedMapper( + function_ref GenMapInfoCB, - Type *ElemTy, StringRef FuncName, - function_ref CustomMapperCB) { + Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) { SmallVector Params; Params.emplace_back(Builder.getPtrTy()); Params.emplace_back(Builder.getPtrTy()); @@ -8287,7 +8291,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( PtrPHI->addIncoming(PtrBegin, HeadBB); // Get map clause information. Fill up the arrays with all mapped variables. - MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn); + MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn); + if (!Info) + return Info.takeError(); // Call the runtime API __tgt_mapper_num_components to get the number of // pre-existing components. @@ -8299,20 +8305,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset())); // Fill up the runtime mapper handle for all components. - for (unsigned I = 0; I < Info.BasePointers.size(); ++I) { + for (unsigned I = 0; I < Info->BasePointers.size(); ++I) { Value *CurBaseArg = - Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy()); + Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy()); Value *CurBeginArg = - Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy()); - Value *CurSizeArg = Info.Sizes[I]; - Value *CurNameArg = Info.Names.size() - ? Info.Names[I] + Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy()); + Value *CurSizeArg = Info->Sizes[I]; + Value *CurNameArg = Info->Names.size() + ? Info->Names[I] : Constant::getNullValue(Builder.getPtrTy()); // Extract the MEMBER_OF field from the map type. Value *OriMapType = Builder.getInt64( static_cast>( - Info.Types[I])); + Info->Types[I])); Value *MemberMapType = Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize); @@ -8394,10 +8400,13 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg, CurSizeArg, CurMapType, CurNameArg}; - Function *ChildMapperFn = nullptr; - if (CustomMapperCB && CustomMapperCB(I, &ChildMapperFn)) { + + auto ChildMapperFn = CustomMapperCB(I); + if (!ChildMapperFn) + return ChildMapperFn.takeError(); + if (*ChildMapperFn) { // Call the corresponding mapper function. - Builder.CreateCall(ChildMapperFn, OffloadingArgs)->setDoesNotThrow(); + Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow(); } else { // Call the runtime API __tgt_push_mapper_component to fill up the runtime // data structure. @@ -8431,18 +8440,18 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( return MapperFn; } -void OpenMPIRBuilder::emitOffloadingArrays( +Error OpenMPIRBuilder::emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, - TargetDataInfo &Info, bool IsNonContiguous, - function_ref DeviceAddrCB, - function_ref CustomMapperCB) { + TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB, + bool IsNonContiguous, + function_ref DeviceAddrCB) { // Reset the array information. Info.clearArrayInfo(); Info.NumberOfPtrs = CombinedInfo.BasePointers.size(); if (Info.NumberOfPtrs == 0) - return; + return Error::success(); Builder.restoreIP(AllocaIP); // Detect if we have any capture size requiring runtime evaluation of the @@ -8606,9 +8615,13 @@ void OpenMPIRBuilder::emitOffloadingArrays( // Fill up the mapper array. unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0); Value *MFunc = ConstantPointerNull::get(PtrTy); - if (CustomMapperCB) - if (Value *CustomMFunc = CustomMapperCB(I)) - MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy); + + auto CustomMFunc = CustomMapperCB(I); + if (!CustomMFunc) + return CustomMFunc.takeError(); + if (*CustomMFunc) + MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy); + Value *MAddr = Builder.CreateInBoundsGEP( MappersArray->getAllocatedType(), MappersArray, {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)}); @@ -8618,8 +8631,9 @@ void OpenMPIRBuilder::emitOffloadingArrays( if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() || Info.NumberOfPtrs == 0) - return; + return Error::success(); emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info); + return Error::success(); } void OpenMPIRBuilder::emitBranch(BasicBlock *Target) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index a8c6e7d6a1ac6..d2c4cc7acaec4 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5945,6 +5945,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/false, /*SeparateBeginEndCalls=*/true); @@ -5956,7 +5957,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTargetData( Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); + /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc)); Builder.restoreIP(AfterIP); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6007,6 +6008,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/false, /*SeparateBeginEndCalls=*/true); @@ -6018,7 +6020,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTargetData( Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); + /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc)); Builder.restoreIP(AfterIP); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6091,6 +6093,7 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/true, /*SeparateBeginEndCalls=*/true); @@ -6127,9 +6130,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, TargetDataIP1, - OMPBuilder.createTargetData( - Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyCB)); + OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(), + Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, + CustomMapperCB, nullptr, BodyCB)); Builder.restoreIP(TargetDataIP1); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6155,9 +6159,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { }; ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, TargetDataIP2, - OMPBuilder.createTargetData( - Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyTargetCB)); + OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(), + Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, + CustomMapperCB, nullptr, BodyTargetCB)); Builder.restoreIP(TargetDataIP2); EXPECT_TRUE(CheckDevicePassBodyGen); @@ -6258,6 +6263,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; @@ -6271,9 +6281,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, + Builder.saveIP(), Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB, + CustomMapperCB, {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); @@ -6418,6 +6429,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6437,13 +6449,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB, + {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); @@ -6567,6 +6583,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { F->setName("func"); IRBuilder<> Builder(BB); + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](InsertPointTy, InsertPointTy CodeGenIP) -> InsertPointTy { Builder.restoreIP(CodeGenIP); @@ -6594,13 +6611,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, + Builder.saveIP(), Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB, + CustomMapperCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6682,6 +6703,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6698,13 +6720,16 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6799,6 +6824,7 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { llvm::Value *RaiseAlloca = nullptr; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6819,13 +6845,17 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB, + {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index e1eef30c0dd06..2d8e022190f62 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1023,6 +1023,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { OptionalAttr:$members_index, Variadic:$bounds, /* rank-0 to rank-{n-1} */ OptionalAttr:$map_type, + OptionalAttr:$mapper_id, OptionalAttr:$map_capture_type, OptionalAttr:$name, DefaultValuedAttr:$partial_map); @@ -1076,6 +1077,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { - 'map_type': OpenMP map type for this map capture, for example: from, to and always. It's a bitfield composed of the OpenMP runtime flags stored in OpenMPOffloadMappingFlags. + - 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to + specify a user defined mapper to be used for mapping. - 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla this can affect how the variable is lowered. - `name`: Holds the name of variable as specified in user clause (including bounds). @@ -1087,6 +1090,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { `var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)` oilist( `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)` + | `mapper` `(` $mapper_id `)` | `map_clauses` `(` custom($map_type) `)` | `capture` `(` custom($map_capture_type) `)` | `members` `(` $members `:` custom($members_index) `:` type($members) `)` diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 12e3c07669839..7888745dc6920 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -186,6 +186,32 @@ struct MapInfoOpConversion : public ConvertOpToLLVMPattern { } }; +struct DeclMapperOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + SmallVector newAttrs; + newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr()); + newAttrs.emplace_back( + curOp.getTypeAttrName(), + TypeAttr::get(converter->convertType(curOp.getType()))); + + auto newOp = rewriter.create( + curOp.getLoc(), TypeRange(), adaptor.getOperands(), newAttrs); + rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *this->getTypeConverter()))) + return failure(); + + rewriter.eraseOp(curOp); + return success(); + } +}; + template struct MultiRegionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -225,19 +251,21 @@ void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, const LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp< omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp, - omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp, - omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp, - omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp, - omp::YieldOp>([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp, + omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp, + omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp, + omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target.addDynamicallyLegalOp< - omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp, - omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp, - omp::OrderedRegionOp, omp::ParallelOp, omp::SectionOp, omp::SectionsOp, - omp::SimdOp, omp::SingleOp, omp::TargetDataOp, omp::TargetOp, - omp::TaskgroupOp, omp::TaskloopOp, omp::TaskOp, omp::TeamsOp, + omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp, + omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, + omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp, + omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp, + omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, + omp::TaskloopOp, omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) { return std::all_of(op->getRegions().begin(), op->getRegions().end(), [&](Region ®ion) { @@ -267,12 +295,13 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, [&](omp::MapBoundsType type) -> Type { return type; }); patterns.add< - AtomicReadOpConversion, MapInfoOpConversion, + AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion, MultiRegionOpConversion, MultiRegionOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, + RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index cd59593b2daf8..ed4e9231dd4f8 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1639,7 +1639,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar); } - } else { + + if (mapInfoOp.getMapperId() && + !SymbolTable::lookupNearestSymbolFrom( + mapInfoOp, mapInfoOp.getMapperIdAttr())) { + return emitError(op->getLoc(), "invalid mapper id"); + } + } else if (!isa(op)) { emitError(op->getLoc(), "map argument is not a map entry operation"); } } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3a29ffb2770db..3b9a5495a4071 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2987,13 +2987,23 @@ getRefPtrIfDeclareTarget(mlir::Value value, } namespace { +// Append customMappers information to existing MapInfosTy +struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy { + SmallVector Mappers; + + /// Append arrays in \a CurInfo. + void append(MapInfosTy &curInfo) { + Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end()); + llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo); + } +}; // A small helper structure to contain data gathered // for map lowering and coalese it into one area and // avoiding extra computations such as searches in the // llvm module for lowered mapped variables or checking // if something is declare target (and retrieving the // value) more than neccessary. -struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { +struct MapInfoData : MapInfosTy { llvm::SmallVector IsDeclareTarget; llvm::SmallVector IsAMember; // Identify if mapping was added by mapClause or use_device clauses. @@ -3012,7 +3022,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { OriginalValue.append(CurInfo.OriginalValue.begin(), CurInfo.OriginalValue.end()); BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end()); - llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo); + MapInfosTy::append(CurInfo); } }; @@ -3161,6 +3171,12 @@ static void collectMapDataFromMapOperands( mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None); + if (mapOp.getMapperId()) + mapData.Mappers.push_back( + SymbolTable::lookupNearestSymbolFrom( + mapOp, mapOp.getMapperIdAttr())); + else + mapData.Mappers.push_back(nullptr); mapData.IsAMapping.push_back(true); mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp)); } @@ -3205,6 +3221,7 @@ static void collectMapDataFromMapOperands( mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(devInfoTy); + mapData.Mappers.push_back(nullptr); mapData.IsAMapping.push_back(false); mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp)); } @@ -3475,11 +3492,12 @@ static bool checkIfPointerMap(omp::MapInfoOp mapOp) { // // This function borrows a lot from Clang's emitCombinedEntry function // inside of CGOpenMPRuntime.cpp -static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( - LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, TargetDirective targetDirective) { +static llvm::omp::OpenMPOffloadMappingFlags +mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, + MapInfosTy &combinedInfo, MapInfoData &mapData, + uint64_t mapDataIndex, TargetDirective targetDirective) { // Map the first segment of our structure const size_t parentIndex = combinedInfo.Types.size(); combinedInfo.Types.emplace_back( @@ -3489,6 +3507,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); @@ -3558,6 +3577,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3592,6 +3612,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataOverlapIdx]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataOverlapIdx]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3613,6 +3634,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3629,9 +3651,9 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // This function is intended to add explicit mappings of members static void processMapMembersWithParent( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, + llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, + llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirective targetDirective) { auto parentClause = @@ -3660,6 +3682,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); + combinedInfo.Mappers.emplace_back(nullptr); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3689,6 +3712,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); + combinedInfo.Mappers.emplace_back(nullptr); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); uint64_t basePointerIndex = @@ -3708,11 +3732,10 @@ static void processMapMembersWithParent( } } -static void -processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, - TargetDirective targetDirective, - int mapDataParentIdx = -1) { +static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, + MapInfosTy &combinedInfo, + TargetDirective targetDirective, + int mapDataParentIdx = -1) { // Declare Target Mappings are excluded from being marked as // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're // marked with OMP_MAP_PTR_AND_OBJ instead. @@ -3743,16 +3766,18 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]); combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]); combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]); combinedInfo.Types.emplace_back(mapFlag); combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]); } -static void processMapWithMembersOf( - LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, TargetDirective targetDirective) { +static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, + DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, + TargetDirective targetDirective) { auto parentClause = llvm::cast(mapData.MapClause[mapDataIndex]); @@ -3858,8 +3883,7 @@ createAlteredByCaptureMap(MapInfoData &mapData, // Generate all map related information and fill the combinedInfo. static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, + DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirective targetDirective) { // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can @@ -3899,6 +3923,79 @@ static void genMapInfos(llvm::IRBuilderBase &builder, } } +static llvm::Expected +emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef mapperFuncName); + +static llvm::Expected +getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast(op); + std::string mapperFuncName = + moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + + if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName)) + return lookupFunc; + + return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation, + mapperFuncName); +} + +static llvm::Expected +emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef mapperFuncName) { + auto declMapperOp = cast(op); + auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo(); + DataLayout dl = DataLayout(declMapperOp->getParentOfType()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType()); + SmallVector mapVars = declMapperInfoOp.getMapVars(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + // Fill up the arrays with all the mapped variables. + MapInfosTy combinedInfo; + auto genMapInfoCB = + [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI, + llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy { + builder.restoreIP(codeGenIP); + moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI); + moduleTranslation.mapBlock(&declMapperOp.getRegion().front(), + builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(), + /*ignoreArguments=*/true, + builder))) + return llvm::make_error(); + MapInfoData mapData; + collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, + builder); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData, + TargetDirective::None); + + // Drop the mapping that is no longer necessary so that the same region can + // be processed multiple times. + moduleTranslation.forgetMapping(declMapperOp.getRegion()); + return combinedInfo; + }; + + auto customMapperCB = [&](unsigned i) -> llvm::Expected { + if (!combinedInfo.Mappers[i]) + return nullptr; + return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, + moduleTranslation); + }; + + llvm::Expected newFn = ompBuilder->emitUserDefinedMapper( + genMapInfoCB, varType, mapperFuncName, customMapperCB); + if (!newFn) + return newFn.takeError(); + moduleTranslation.mapFunction(mapperFuncName, *newFn); + return *newFn; +} + static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -4011,9 +4108,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, builder, useDevicePtrVars, useDeviceAddrVars); // Fill up the arrays with all the mapped variables. - llvm::OpenMPIRBuilder::MapInfosTy combinedInfo; - auto genMapInfoCB = - [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { + MapInfosTy combinedInfo; + auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData, targetDirective); @@ -4057,6 +4153,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy; auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { + builder.restoreIP(codeGenIP); assert(isa(op) && "BodyGen requested for non TargetDataOp"); auto blockArgIface = cast(op); @@ -4065,8 +4162,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, case BodyGenTy::Priv: // Check if any device ptr/addr info is available if (!info.DevicePtrInfoMap.empty()) { - builder.restoreIP(codeGenIP); - mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, blockArgIface.getUseDeviceAddrBlockArgs(), useDeviceAddrVars, mapData, @@ -4101,7 +4196,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, case BodyGenTy::NoPriv: // If device info is available then region has already been generated if (info.DevicePtrInfoMap.empty()) { - builder.restoreIP(codeGenIP); // For device pass, if use_device_ptr(addr) mappings were present, // we need to link them here before codegen. if (ompBuilder->Config.IsTargetDevice.value_or(false)) { @@ -4127,17 +4221,28 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return builder.saveIP(); }; + auto customMapperCB = + [&](unsigned int i) -> llvm::Expected { + if (!combinedInfo.Mappers[i]) + return nullptr; + info.HasMapper = true; + return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, + moduleTranslation); + }; + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() { if (isa(op)) - return ompBuilder->createTargetData( - ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), - ifCond, info, genMapInfoCB, nullptr, bodyGenCB); - return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), - builder.getInt64(deviceID), ifCond, - info, genMapInfoCB, &RTLFn); + return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), + builder.getInt64(deviceID), ifCond, + info, genMapInfoCB, customMapperCB, + /*MapperFunc=*/nullptr, bodyGenCB, + /*DeviceAddrCB=*/nullptr); + return ompBuilder->createTargetData( + ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond, + info, genMapInfoCB, customMapperCB, &RTLFn); }(); if (failed(handleError(afterIP, *op))) @@ -4952,9 +5057,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, builder); - llvm::OpenMPIRBuilder::MapInfosTy combinedInfos; - auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) - -> llvm::OpenMPIRBuilder::MapInfosTy & { + MapInfosTy combinedInfos; + auto genMapInfoCB = + [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, targetDirective); @@ -5024,15 +5129,28 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder::TargetDataInfo info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + auto customMapperCB = + [&](unsigned int i) -> llvm::Expected { + if (!combinedInfos.Mappers[i]) + return nullptr; + info.HasMapper = true; + return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder, + moduleTranslation); + }; + llvm::Value *ifCond = nullptr; if (Value targetIfCond = targetOp.getIfExpr()) ifCond = moduleTranslation.lookupValue(targetIfCond); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, + ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB, - argAccessorCB, dds, targetOp.getNowait()); + argAccessorCB, customMapperCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); @@ -5272,12 +5390,15 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::TaskwaitOp op) { return convertOmpTaskwaitOp(op, builder, moduleTranslation); }) - .Case([](auto op) { // `yield` and `terminator` can be just omitted. The block structure // was created in the region that handles their parent operation. // `declare_reduction` will be used by reductions and is not // converted directly, skip it. + // `declare_mapper` and `declare_mapper.info` are handled whenever + // they are referred to through a `map` clause. // `critical.declare` is only used to declare names of critical // sections which will be used by `critical` ops and hence can be // ignored for lowering. The OpenMP IRBuilder will create unique diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index 6f1ed73e778b4..d69de998346b5 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -601,3 +601,16 @@ func.func @omp_taskloop(%arg0: index, %arg1 : memref) { } return } + +// ----- + +// CHECK-LABEL: omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> { +omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> { +^bb0(%arg0: !llvm.ptr): + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%data"} + %3 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%2 : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr) + omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) +} diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 0793ca274fea2..532eb9775a74f 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2843,3 +2843,13 @@ func.func @missing_workshare(%idx : index) { ^bb0(%arg0: !llvm.ptr): omp.terminator } + +// ----- +llvm.func @invalid_mapper(%0 : !llvm.ptr) { + %1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + // expected-error @below {{invalid mapper id}} + omp.target_data map_entries(%1 : !llvm.ptr) { + omp.terminator + } + llvm.return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index f00d3d3426631..e318afbebbf0c 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2546,13 +2546,13 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> () // CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[BOUNDS1:.*]] = omp.map.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64) - // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""} + // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""} %6 = llvm.mlir.constant(9 : index) : i64 %7 = llvm.mlir.constant(1 : index) : i64 %8 = llvm.mlir.constant(2 : index) : i64 %9 = llvm.mlir.constant(2 : index) : i64 %10 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64) - %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""} + %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""} // CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr) omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) { diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index 7f21095763a39..02b84ff66a0d3 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -485,3 +485,120 @@ llvm.func @_QPopenmp_target_data_update() { // CHECK: call void @__tgt_target_data_update_mapper(ptr @2, i64 -1, i32 1, ptr %[[BASEPTRS_VAL_2]], ptr %[[PTRS_VAL_2]], ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr null) // CHECK: ret void + +// ----- + +omp.declare_mapper @_QQFmy_testmy_mapper : !llvm.struct<"_QFmy_testTmy_type", (i32)> { +^bb0(%arg0: !llvm.ptr): + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)> + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%data"} + %3 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%2 : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) +} + +llvm.func @_QPopenmp_target_data_mapper() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<"_QFmy_testTmy_type", (i32)> {bindc_name = "a"} : (i64) -> !llvm.ptr + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) mapper(@_QQFmy_testmy_mapper) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "a"} + omp.target_data map_entries(%2 : !llvm.ptr) { + %3 = llvm.mlir.constant(10 : i32) : i32 + %4 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)> + llvm.store %3, %4 : i32, !llvm.ptr + omp.terminator + } + llvm.return +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 4] +// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK-LABEL: define void @_QPopenmp_target_data_mapper +// CHECK: %[[VAL_0:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_1:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_2:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_3:.*]] = alloca %[[VAL_4:.*]], i64 1, align 8 +// CHECK: br label %[[VAL_5:.*]] +// CHECK: entry: ; preds = %[[VAL_6:.*]] +// CHECK: %[[VAL_7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_7]], align 8 +// CHECK: %[[VAL_8:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_8]], align 8 +// CHECK: %[[VAL_9:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_2]], i64 0, i64 0 +// CHECK: store ptr @.omp_mapper._QQFmy_testmy_mapper, ptr %[[VAL_9]], align 8 +// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: %[[VAL_11:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: call void @__tgt_target_data_begin_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]]) +// CHECK: %[[VAL_12:.*]] = getelementptr %[[VAL_4]], ptr %[[VAL_3]], i32 0, i32 0 +// CHECK: store i32 10, ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_13:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: call void @__tgt_target_data_end_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_13]], ptr %[[VAL_14]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]]) +// CHECK: ret void + +// CHECK-LABEL: define internal void @.omp_mapper._QQFmy_testmy_mapper +// CHECK: entry: +// CHECK: %[[VAL_15:.*]] = udiv exact i64 %[[VAL_16:.*]], 4 +// CHECK: %[[VAL_17:.*]] = getelementptr %[[VAL_18:.*]], ptr %[[VAL_19:.*]], i64 %[[VAL_15]] +// CHECK: %[[VAL_20:.*]] = icmp sgt i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_21:.*]] = and i64 %[[VAL_22:.*]], 8 +// CHECK: %[[VAL_23:.*]] = icmp ne ptr %[[VAL_24:.*]], %[[VAL_19]] +// CHECK: %[[VAL_25:.*]] = and i64 %[[VAL_22]], 16 +// CHECK: %[[VAL_26:.*]] = icmp ne i64 %[[VAL_25]], 0 +// CHECK: %[[VAL_27:.*]] = and i1 %[[VAL_23]], %[[VAL_26]] +// CHECK: %[[VAL_28:.*]] = or i1 %[[VAL_20]], %[[VAL_27]] +// CHECK: %[[VAL_29:.*]] = icmp eq i64 %[[VAL_21]], 0 +// CHECK: %[[VAL_30:.*]] = and i1 %[[VAL_28]], %[[VAL_29]] +// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]] +// CHECK: .omp.array..init: ; preds = %[[VAL_33:.*]] +// CHECK: %[[VAL_34:.*]] = mul nuw i64 %[[VAL_15]], 4 +// CHECK: %[[VAL_35:.*]] = and i64 %[[VAL_22]], -4 +// CHECK: %[[VAL_36:.*]] = or i64 %[[VAL_35]], 512 +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37:.*]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_34]], i64 %[[VAL_36]], ptr %[[VAL_38:.*]]) +// CHECK: br label %[[VAL_32]] +// CHECK: omp.arraymap.head: ; preds = %[[VAL_31]], %[[VAL_33]] +// CHECK: %[[VAL_39:.*]] = icmp eq ptr %[[VAL_19]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_39]], label %[[VAL_40:.*]], label %[[VAL_41:.*]] +// CHECK: omp.arraymap.body: ; preds = %[[VAL_42:.*]], %[[VAL_32]] +// CHECK: %[[VAL_43:.*]] = phi ptr [ %[[VAL_19]], %[[VAL_32]] ], [ %[[VAL_44:.*]], %[[VAL_42]] ] +// CHECK: %[[VAL_45:.*]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 0, i32 0 +// CHECK: %[[VAL_46:.*]] = call i64 @__tgt_mapper_num_components(ptr %[[VAL_37]]) +// CHECK: %[[VAL_47:.*]] = shl i64 %[[VAL_46]], 48 +// CHECK: %[[VAL_48:.*]] = add nuw i64 3, %[[VAL_47]] +// CHECK: %[[VAL_49:.*]] = and i64 %[[VAL_22]], 3 +// CHECK: %[[VAL_50:.*]] = icmp eq i64 %[[VAL_49]], 0 +// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_52:.*]] +// CHECK: omp.type.alloc: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_53:.*]] = and i64 %[[VAL_48]], -4 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.alloc.else: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_54:.*]] = icmp eq i64 %[[VAL_49]], 1 +// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]] +// CHECK: omp.type.to: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_57:.*]] = and i64 %[[VAL_48]], -3 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.to.else: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_58:.*]] = icmp eq i64 %[[VAL_49]], 2 +// CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_42]] +// CHECK: omp.type.from: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_60:.*]] = and i64 %[[VAL_48]], -2 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.end: ; preds = %[[VAL_59]], %[[VAL_56]], %[[VAL_55]], %[[VAL_51]] +// CHECK: %[[VAL_61:.*]] = phi i64 [ %[[VAL_53]], %[[VAL_51]] ], [ %[[VAL_57]], %[[VAL_55]] ], [ %[[VAL_60]], %[[VAL_59]] ], [ %[[VAL_48]], %[[VAL_56]] ] +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_43]], ptr %[[VAL_45]], i64 4, i64 %[[VAL_61]], ptr @2) +// CHECK: %[[VAL_44]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 1 +// CHECK: %[[VAL_62:.*]] = icmp eq ptr %[[VAL_44]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_62]], label %[[VAL_63:.*]], label %[[VAL_41]] +// CHECK: omp.arraymap.exit: ; preds = %[[VAL_42]] +// CHECK: %[[VAL_64:.*]] = icmp sgt i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_65:.*]] = and i64 %[[VAL_22]], 8 +// CHECK: %[[VAL_66:.*]] = icmp ne i64 %[[VAL_65]], 0 +// CHECK: %[[VAL_67:.*]] = and i1 %[[VAL_64]], %[[VAL_66]] +// CHECK: br i1 %[[VAL_67]], label %[[VAL_68:.*]], label %[[VAL_40]] +// CHECK: .omp.array..del: ; preds = %[[VAL_63]] +// CHECK: %[[VAL_69:.*]] = mul nuw i64 %[[VAL_15]], 4 +// CHECK: %[[VAL_70:.*]] = and i64 %[[VAL_22]], -4 +// CHECK: %[[VAL_71:.*]] = or i64 %[[VAL_70]], 512 +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_69]], i64 %[[VAL_71]], ptr %[[VAL_38]]) +// CHECK: br label %[[VAL_40]] +// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]] +// CHECK: ret void