From 6820bf72070bb6451064a83440ab8ceac5a675a8 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Thu, 20 Feb 2025 12:54:56 +0000 Subject: [PATCH 1/5] Fix flang/test/Lower/OpenMP/declare-mapper.f90 --- flang/test/Lower/OpenMP/declare-mapper.f90 | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90 index f271233cff8fd..6f3ae9f3074dc 100644 --- a/flang/test/Lower/OpenMP/declare-mapper.f90 +++ b/flang/test/Lower/OpenMP/declare-mapper.f90 @@ -1,5 +1,5 @@ ! This test checks lowering of OpenMP declare mapper Directive. -! XFAIL: * + ! RUN: split-file %s %t ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90 @@ -40,9 +40,11 @@ subroutine declare_mapper_1 !CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index) !CHECK: %[[VAL_17:.*]] = arith.constant 1 : index !CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref>>> - !CHECK: %[[VAL_19:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, !fir.box>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.ref>>> {name = "var%[[VAL_20:.*]](1:var%[[VAL_21:.*]])"} - !CHECK: %[[VAL_22:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_19]] : [1] : !fir.ref>>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"} - !CHECK: omp.declare_mapper.info map_entries(%[[VAL_22]] : !fir.ref<[[MY_TYPE]]>) + !CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + !CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr>> {name = ""} + !CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, !fir.box>>) map_clauses(always, to) capture(ByRef) -> !fir.ref>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"} + !CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref>>>, !fir.llvm_ptr>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"} + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref>>>, !fir.llvm_ptr>>) !CHECK: } !$omp declare mapper (my_type :: var) map (var, var%values (1:var%num_vals)) end subroutine declare_mapper_1 @@ -77,7 +79,7 @@ subroutine declare_mapper_2 !CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"temp"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref>>}>> !CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_11]] : !fir.ref>>}>>, !fir.type<_QFdeclare_mapper_2Tmy_type{num_vals:i32,values:!fir.box>>}>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> !fir.ref>>}>> {name = "v%[[VAL_13:.*]]"} !CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_9]], %[[VAL_12]] : [3], [1] : !fir.ref>, !fir.ref>>}>>) -> !fir.ref<[[MY_TYPE]]> {name = "v", partial_map = true} - !CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]] : !fir.ref<[[MY_TYPE]]>) + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_9]], %[[VAL_12]] : !fir.ref<[[MY_TYPE]]>, !fir.ref>, !fir.ref>>}>>) !CHECK: } !$omp declare mapper (my_mapper : my_type2 :: v) map (v%arr) map (alloc : v%temp) end subroutine declare_mapper_2 From 8d8470ff4fbf18868e19945737a9784fbd8cc095 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 18 Feb 2025 17:27:48 +0000 Subject: [PATCH 2/5] [MLIR][OpenMP] Add OMP Mapper field to MapInfoOp (#120994) This patch adds the mapper field to the omp.map.info op. Depends on #117046. --- flang/include/flang/Lower/OpenMP/Utils.h | 3 ++- flang/lib/Lower/OpenMP/Utils.cpp | 3 ++- flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp | 4 +++- flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp | 5 ++++- .../lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp | 1 + mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 4 ++++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 8 +++++++- mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++++++++++ mlir/test/Dialect/OpenMP/ops.mlir | 4 ++-- 9 files changed, 35 insertions(+), 7 deletions(-) diff --git a/flang/include/flang/Lower/OpenMP/Utils.h b/flang/include/flang/Lower/OpenMP/Utils.h index f2e378443e5f2..3943eb633b04e 100644 --- a/flang/include/flang/Lower/OpenMP/Utils.h +++ b/flang/include/flang/Lower/OpenMP/Utils.h @@ -116,7 +116,8 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, llvm::ArrayRef members, mlir::ArrayAttr membersIndex, uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap = false); + bool partialMap = false, + mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()); void insertChildMapInfoIntoParent( Fortran::lower::AbstractConverter &converter, diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b436b4ad1eb39..72884cef7cfa0 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -130,7 +130,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, llvm::ArrayRef members, mlir::ArrayAttr membersIndex, uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy, - bool partialMap) { + bool partialMap, mlir::FlatSymbolRefAttr mapperId) { if (auto boxTy = llvm::dyn_cast(baseAddr.getType())) { baseAddr = builder.create(loc, baseAddr); retTy = baseAddr.getType(); @@ -149,6 +149,7 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc, mlir::omp::MapInfoOp op = builder.create( loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + mapperId, builder.getAttr(mapCaptureType), builder.getStringAttr(name), builder.getBoolAttr(partialMap)); return op; diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 9d18536df3af2..2501cd7c2919f 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -51,7 +51,8 @@ mlir::omp::MapInfoOp createMapInfoOp( mlir::Value varPtrPtr, std::string name, llvm::ArrayRef bounds, llvm::ArrayRef members, mlir::ArrayAttr membersIndex, uint64_t mapType, mlir::omp::VariableCaptureKind mapCaptureType, - mlir::Type retTy, bool partialMap = false) { + mlir::Type retTy, bool partialMap = false, + mlir::FlatSymbolRefAttr mapperId = mlir::FlatSymbolRefAttr()) { if (auto boxTy = llvm::dyn_cast(baseAddr.getType())) { baseAddr = builder.create(loc, baseAddr); retTy = baseAddr.getType(); @@ -70,6 +71,7 @@ mlir::omp::MapInfoOp createMapInfoOp( mlir::omp::MapInfoOp op = builder.create( loc, retTy, baseAddr, varType, varPtrPtr, members, membersIndex, bounds, builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + mapperId, builder.getAttr(mapCaptureType), builder.getStringAttr(name), builder.getBoolAttr(partialMap)); diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp index 7828a466d265d..2ad54fed07602 100644 --- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp +++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp @@ -184,6 +184,7 @@ class MapInfoFinalizationPass /*members=*/mlir::SmallVector{}, /*membersIndex=*/mlir::ArrayAttr{}, bounds, builder.getIntegerAttr(builder.getIntegerType(64, false), mapType), + /*mapperId*/ mlir::FlatSymbolRefAttr(), builder.getAttr( mlir::omp::VariableCaptureKind::ByRef), /*name=*/builder.getStringAttr(""), @@ -331,7 +332,8 @@ class MapInfoFinalizationPass builder.getIntegerAttr( builder.getIntegerType(64, false), getDescriptorMapType(op.getMapType().value_or(0), target)), - op.getMapCaptureTypeAttr(), op.getNameAttr(), + /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getMapCaptureTypeAttr(), + op.getNameAttr(), /*partial_map=*/builder.getBoolAttr(false)); op.replaceAllUsesWith(newDescParentMapOp.getResult()); op->erase(); @@ -629,6 +631,7 @@ class MapInfoFinalizationPass // /*members=*/mlir::ValueRange{}, // /*members_index=*/mlir::ArrayAttr{}, // /*bounds=*/bounds, op.getMapTypeAttr(), + // /*mapperId*/ mlir::FlatSymbolRefAttr(), // builder.getAttr( // mlir::omp::VariableCaptureKind::ByRef), // builder.getStringAttr(op.getNameAttr().strref() + "." + diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp index 963ae863c1fc5..97ea463a3c495 100644 --- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp +++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp @@ -91,6 +91,7 @@ class MapsForPrivatizedSymbolsPass /*bounds=*/ValueRange{}, builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false), mapTypeTo), + /*mapperId*/ mlir::FlatSymbolRefAttr(), builder.getAttr( omp::VariableCaptureKind::ByRef), StringAttr(), builder.getBoolAttr(false)); diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index e1eef30c0dd06..2d8e022190f62 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1023,6 +1023,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { OptionalAttr:$members_index, Variadic:$bounds, /* rank-0 to rank-{n-1} */ OptionalAttr:$map_type, + OptionalAttr:$mapper_id, OptionalAttr:$map_capture_type, OptionalAttr:$name, DefaultValuedAttr:$partial_map); @@ -1076,6 +1077,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { - 'map_type': OpenMP map type for this map capture, for example: from, to and always. It's a bitfield composed of the OpenMP runtime flags stored in OpenMPOffloadMappingFlags. + - 'mapper_id': OpenMP mapper map type modifier for this map capture. It's used to + specify a user defined mapper to be used for mapping. - 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla this can affect how the variable is lowered. - `name`: Holds the name of variable as specified in user clause (including bounds). @@ -1087,6 +1090,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> { `var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)` oilist( `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)` + | `mapper` `(` $mapper_id `)` | `map_clauses` `(` custom($map_type) `)` | `capture` `(` custom($map_capture_type) `)` | `members` `(` $members `:` custom($members_index) `:` type($members) `)` diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index cd59593b2daf8..ed4e9231dd4f8 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1639,7 +1639,13 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar); } - } else { + + if (mapInfoOp.getMapperId() && + !SymbolTable::lookupNearestSymbolFrom( + mapInfoOp, mapInfoOp.getMapperIdAttr())) { + return emitError(op->getLoc(), "invalid mapper id"); + } + } else if (!isa(op)) { emitError(op->getLoc(), "map argument is not a map entry operation"); } } diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 0793ca274fea2..532eb9775a74f 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -2843,3 +2843,13 @@ func.func @missing_workshare(%idx : index) { ^bb0(%arg0: !llvm.ptr): omp.terminator } + +// ----- +llvm.func @invalid_mapper(%0 : !llvm.ptr) { + %1 = omp.map.info var_ptr(%0 : !llvm.ptr, !llvm.struct<"my_type", (i32)>) mapper(@my_mapper) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""} + // expected-error @below {{invalid mapper id}} + omp.target_data map_entries(%1 : !llvm.ptr) { + omp.terminator + } + llvm.return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index f00d3d3426631..e318afbebbf0c 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -2546,13 +2546,13 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> () // CHECK: %[[C_12:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[C_13:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK: %[[BOUNDS1:.*]] = omp.map.bounds lower_bound(%[[C_11]] : i64) upper_bound(%[[C_10]] : i64) stride(%[[C_12]] : i64) start_idx(%[[C_13]] : i64) - // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""} + // CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""} %6 = llvm.mlir.constant(9 : index) : i64 %7 = llvm.mlir.constant(1 : index) : i64 %8 = llvm.mlir.constant(2 : index) : i64 %9 = llvm.mlir.constant(2 : index) : i64 %10 = omp.map.bounds lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64) - %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""} + %mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>) mapper(@my_mapper) map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""} // CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr) omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) { From f1d138fb9060a1f5ad591fc550f981d52ae65f1e Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 18 Feb 2025 17:40:25 +0000 Subject: [PATCH 3/5] [MLIR][OpenMP] Add Lowering support for OpenMP custom mappers in map clause (#121001) Add Lowering support for OpenMP mapper field in mapInfoOp. NOTE: This patch only supports explicit mapper lowering. I'll add a separate PR soon which handles implicit default mapper recognition. Depends on #120994. --- flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 33 ++++++++++-- flang/lib/Lower/OpenMP/ClauseProcessor.h | 3 +- flang/test/Lower/OpenMP/declare-mapper.f90 | 60 ++++++++++++++++++++++ flang/test/Lower/OpenMP/map-mapper.f90 | 30 +++++++++++ 4 files changed, 120 insertions(+), 6 deletions(-) create mode 100644 flang/test/Lower/OpenMP/map-mapper.f90 diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index a946d44c5d3b7..0707ebef1f0d9 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -969,8 +969,11 @@ void ClauseProcessor::processMapObjects( llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl &mapSyms) const { + llvm::SmallVectorImpl &mapSyms, + llvm::StringRef mapperIdNameRef) const { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::FlatSymbolRefAttr mapperId; + std::string mapperIdName = mapperIdNameRef.str(); for (const omp::Object &object : objects) { llvm::SmallVector bounds; @@ -1003,6 +1006,20 @@ void ClauseProcessor::processMapObjects( } } + if (!mapperIdName.empty()) { + if (mapperIdName == "default") { + auto &typeSpec = object.sym()->owner().IsDerivedType() + ? *object.sym()->owner().derivedTypeSpec() + : object.sym()->GetType()->derivedTypeSpec(); + mapperIdName = typeSpec.name().ToString() + ".default"; + mapperIdName = converter.mangleName(mapperIdName, *typeSpec.GetScope()); + } + assert(converter.getModuleOp().lookupSymbol(mapperIdName) && + "mapper not found"); + mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(), + mapperIdName); + mapperIdName.clear(); + } // Explicit map captures are captured ByRef by default, // optimisation passes may alter this to ByCopy or other capture // types to optimise @@ -1016,7 +1033,8 @@ void ClauseProcessor::processMapObjects( static_cast< std::underlying_type_t>( mapTypeBits), - mlir::omp::VariableCaptureKind::ByRef, baseOp.getType()); + mlir::omp::VariableCaptureKind::ByRef, baseOp.getType(), + /*partialMap=*/false, mapperId); if (parentObj.has_value()) { parentMemberIndices[parentObj.value()].addChildIndexAndMapToParent( @@ -1047,6 +1065,7 @@ bool ClauseProcessor::processMap( const auto &[mapType, typeMods, mappers, iterator, objects] = clause.t; llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + std::string mapperIdName; // If the map type is specified, then process it else Tofrom is the // default. Map::MapType type = mapType.value_or(Map::MapType::Tofrom); @@ -1090,13 +1109,17 @@ bool ClauseProcessor::processMap( "Support for iterator modifiers is not implemented yet"); } if (mappers) { - TODO(currentLocation, - "Support for mapper modifiers is not implemented yet"); + assert(mappers->size() == 1 && "more than one mapper"); + mapperIdName = mappers->front().v.id().symbol->name().ToString(); + if (mapperIdName != "default") + mapperIdName = converter.mangleName( + mapperIdName, mappers->front().v.id().symbol->owner()); } processMapObjects(stmtCtx, clauseLocation, std::get(clause.t), mapTypeBits, - parentMemberIndices, result.mapVars, *ptrMapSyms); + parentMemberIndices, result.mapVars, *ptrMapSyms, + mapperIdName); }; bool clauseFound = findRepeatableClause(process); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 4d2825725c3d2..3ea8b72117186 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -176,7 +176,8 @@ class ClauseProcessor { llvm::omp::OpenMPOffloadMappingFlags mapTypeBits, std::map &parentMemberIndices, llvm::SmallVectorImpl &mapVars, - llvm::SmallVectorImpl &mapSyms) const; + llvm::SmallVectorImpl &mapSyms, + llvm::StringRef mapperIdNameRef = "") const; lower::AbstractConverter &converter; semantics::SemanticsContext &semaCtx; diff --git a/flang/test/Lower/OpenMP/declare-mapper.f90 b/flang/test/Lower/OpenMP/declare-mapper.f90 index 6f3ae9f3074dc..efb9f4b024112 100644 --- a/flang/test/Lower/OpenMP/declare-mapper.f90 +++ b/flang/test/Lower/OpenMP/declare-mapper.f90 @@ -3,6 +3,7 @@ ! RUN: split-file %s %t ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-1.f90 -o - | FileCheck %t/omp-declare-mapper-1.f90 ! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-2.f90 -o - | FileCheck %t/omp-declare-mapper-2.f90 +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90 !--- omp-declare-mapper-1.f90 subroutine declare_mapper_1 @@ -83,3 +84,62 @@ subroutine declare_mapper_2 !CHECK: } !$omp declare mapper (my_mapper : my_type2 :: v) map (v%arr) map (alloc : v%temp) end subroutine declare_mapper_2 + +!--- omp-declare-mapper-3.f90 +subroutine declare_mapper_3 + type my_type + integer :: num_vals + integer, allocatable :: values(:) + end type + + type my_type2 + type(my_type) :: my_type_var + real, dimension(250) :: arr + end type + + !CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER2:_QQFdeclare_mapper_3my_mapper2]] : [[MY_TYPE2:!fir\.type<_QFdeclare_mapper_3Tmy_type2\{my_type_var:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box>>}>,arr:!fir\.array<250xf32>}>]] { + !CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE2]]>): + !CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Ev"} : (!fir.ref<[[MY_TYPE2]]>) -> (!fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE2]]>) + !CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"my_type_var"} : (!fir.ref<[[MY_TYPE2]]>) -> !fir.ref<[[MY_TYPE:!fir\.type<_QFdeclare_mapper_3Tmy_type\{num_vals:i32,values:!fir\.box>>}>]]> + !CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) mapper(@[[MY_TYPE_MAPPER:_QQFdeclare_mapper_3my_mapper]]) map_clauses(tofrom) capture(ByRef) -> !fir.ref<[[MY_TYPE]]> {name = "v%[[VAL_4:.*]]"} + !CHECK: %[[VAL_5:.*]] = arith.constant 250 : index + !CHECK: %[[VAL_6:.*]] = fir.shape %[[VAL_5]] : (index) -> !fir.shape<1> + !CHECK: %[[VAL_7:.*]] = hlfir.designate %[[VAL_1]]#0{"arr"} shape %[[VAL_6]] : (!fir.ref<[[MY_TYPE2]]>, !fir.shape<1>) -> !fir.ref> + !CHECK: %[[VAL_8:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_9:.*]] = arith.constant 0 : index + !CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_8]] : index + !CHECK: %[[VAL_11:.*]] = omp.map.bounds lower_bound(%[[VAL_9]] : index) upper_bound(%[[VAL_10]] : index) extent(%[[VAL_5]] : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_8]] : index) + !CHECK: %[[VAL_12:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref>, !fir.array<250xf32>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_11]]) -> !fir.ref> {name = "v%[[VAL_13:.*]]"} + !CHECK: %[[VAL_14:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE2]]>, [[MY_TYPE2]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_12]] : [0], [1] : !fir.ref<[[MY_TYPE]]>, !fir.ref>) -> !fir.ref<[[MY_TYPE2]]> {name = "v", partial_map = true} + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_14]], %[[VAL_3]], %[[VAL_12]] : !fir.ref<[[MY_TYPE2]]>, !fir.ref<[[MY_TYPE]]>, !fir.ref>) + !CHECK: } + + !CHECK: omp.declare_mapper @[[MY_TYPE_MAPPER]] : [[MY_TYPE]] { + !CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<[[MY_TYPE]]>): + !CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdeclare_mapper_3Evar"} : (!fir.ref<[[MY_TYPE]]>) -> (!fir.ref<[[MY_TYPE]]>, !fir.ref<[[MY_TYPE]]>) + !CHECK: %[[VAL_2:.*]] = hlfir.designate %[[VAL_1]]#0{"values"} {fortran_attrs = #fir.var_attrs} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref>>> + !CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref>>> + !CHECK: %[[VAL_4:.*]] = fir.box_addr %[[VAL_3]] : (!fir.box>>) -> !fir.heap> + !CHECK: %[[VAL_5:.*]] = arith.constant 0 : index + !CHECK: %[[VAL_6:.*]]:3 = fir.box_dims %[[VAL_3]], %[[VAL_5]] : (!fir.box>>, index) -> (index, index, index) + !CHECK: %[[VAL_7:.*]] = arith.constant 0 : index + !CHECK: %[[VAL_8:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_9:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_9]], %[[VAL_6]]#0 : index + !CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_1]]#0{"num_vals"} : (!fir.ref<[[MY_TYPE]]>) -> !fir.ref + !CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_11]] : !fir.ref + !CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_12]] : (i32) -> i64 + !CHECK: %[[VAL_14:.*]] = fir.convert %[[VAL_13]] : (i64) -> index + !CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_14]], %[[VAL_6]]#0 : index + !CHECK: %[[VAL_16:.*]] = omp.map.bounds lower_bound(%[[VAL_10]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]]#1 : index) stride(%[[VAL_8]] : index) start_idx(%[[VAL_6]]#0 : index) + !CHECK: %[[VAL_17:.*]] = arith.constant 1 : index + !CHECK: %[[VAL_18:.*]] = fir.coordinate_of %[[VAL_1]]#0, %[[VAL_17]] : (!fir.ref<[[MY_TYPE]]>, index) -> !fir.ref>>> + !CHECK: %[[VAL_19:.*]] = fir.box_offset %[[VAL_18]] base_addr : (!fir.ref>>>) -> !fir.llvm_ptr>> + !CHECK: %[[VAL_20:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, i32) var_ptr_ptr(%[[VAL_19]] : !fir.llvm_ptr>>) map_clauses(tofrom) capture(ByRef) bounds(%[[VAL_16]]) -> !fir.llvm_ptr>> {name = ""} + !CHECK: %[[VAL_21:.*]] = omp.map.info var_ptr(%[[VAL_18]] : !fir.ref>>>, !fir.box>>) map_clauses(always, to) capture(ByRef) -> !fir.ref>>> {name = "var%[[VAL_22:.*]](1:var%[[VAL_23:.*]])"} + !CHECK: %[[VAL_24:.*]] = omp.map.info var_ptr(%[[VAL_1]]#1 : !fir.ref<[[MY_TYPE]]>, [[MY_TYPE]]) map_clauses(tofrom) capture(ByRef) members(%[[VAL_21]], %[[VAL_20]] : [1], [1, 0] : !fir.ref>>>, !fir.llvm_ptr>>) -> !fir.ref<[[MY_TYPE]]> {name = "var"} + !CHECK: omp.declare_mapper.info map_entries(%[[VAL_24]], %[[VAL_21]], %[[VAL_20]] : !fir.ref<[[MY_TYPE]]>, !fir.ref>>>, !fir.llvm_ptr>>) + !CHECK: } + !$omp declare mapper (my_mapper : my_type :: var) map (var, var%values (1:var%num_vals)) + !$omp declare mapper (my_mapper2 : my_type2 :: v) map (mapper(my_mapper) : v%my_type_var) map (tofrom : v%arr) +end subroutine declare_mapper_3 diff --git a/flang/test/Lower/OpenMP/map-mapper.f90 b/flang/test/Lower/OpenMP/map-mapper.f90 new file mode 100644 index 0000000000000..0d8fe7344bfab --- /dev/null +++ b/flang/test/Lower/OpenMP/map-mapper.f90 @@ -0,0 +1,30 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %s -o - | FileCheck %s +program p + integer, parameter :: n = 256 + type t1 + integer :: x(256) + end type t1 + + !$omp declare mapper(xx : t1 :: nn) map(to: nn, nn%x) + !$omp declare mapper(t1 :: nn) map(from: nn) + + !CHECK-LABEL: omp.declare_mapper @_QQFt1.default : !fir.type<_QFTt1{x:!fir.array<256xi32>}> + !CHECK-LABEL: omp.declare_mapper @_QQFxx : !fir.type<_QFTt1{x:!fir.array<256xi32>}> + + type(t1) :: a, b + !CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFxx) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "a"} + !CHECK: omp.target map_entries(%[[MAP_A]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) { + !$omp target map(mapper(xx) : a) + do i = 1, n + a%x(i) = i + end do + !$omp end target + + !CHECK: %[[MAP_B:.*]] = omp.map.info var_ptr(%{{.*}} : {{.*}}, {{.*}}) mapper(@_QQFt1.default) map_clauses(tofrom) capture(ByRef) -> {{.*}} {name = "b"} + !CHECK: omp.target map_entries(%[[MAP_B]] -> %{{.*}}, %{{.*}} -> %{{.*}} : {{.*}}, {{.*}}) { + !$omp target map(mapper(default) : b) + do i = 1, n + b%x(i) = i + end do + !$omp end target +end program p From 9c68917c24e1219552540879d13dd021e1cb23e5 Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 18 Feb 2025 17:47:32 +0000 Subject: [PATCH 4/5] [MLIR][OpenMP] Add conversion support from FIR to LLVM Dialect for OMP DeclareMapper (#121005) Add conversion support from FIR to LLVM Dialect for OMP DeclareMapper. Depends on #121001 --- .../Fir/convert-to-llvm-openmp-and-fir.fir | 27 +++++++-- .../Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp | 55 ++++++++++++++----- .../OpenMPToLLVM/convert-to-llvmir.mlir | 13 +++++ 3 files changed, 78 insertions(+), 17 deletions(-) diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir index 22867cae5a7a5..1985a62523d6b 100644 --- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir +++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir @@ -936,9 +936,9 @@ func.func @omp_map_info_descriptor_type_conversion(%arg0 : !fir.ref>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr> {name = ""} // CHECK: %[[DESC_MAP:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(always, delete) capture(ByRef) members(%[[MEMBER_MAP]] : [0] : !llvm.ptr) -> !llvm.ptr {name = ""} %2 = omp.map.info var_ptr(%arg0 : !fir.ref>>, !fir.box>) map_clauses(always, delete) capture(ByRef) members(%1 : [0] : !fir.llvm_ptr>) -> !fir.ref>> {name = ""} - // CHECK: omp.target_exit_data map_entries(%[[DESC_MAP]] : !llvm.ptr) + // CHECK: omp.target_exit_data map_entries(%[[DESC_MAP]] : !llvm.ptr) omp.target_exit_data map_entries(%2 : !fir.ref>>) - return + return } // ----- @@ -956,8 +956,8 @@ func.func @omp_map_info_derived_type_explicit_member_conversion(%arg0 : !fir.ref %3 = fir.field_index real, !fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}> %4 = fir.coordinate_of %arg0, %3 : (!fir.ref,int:i32}>>, !fir.field) -> !fir.ref // CHECK: %[[MAP_MEMBER_2:.*]] = omp.map.info var_ptr(%[[GEP_2]] : !llvm.ptr, f32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "dtype%real"} - %5 = omp.map.info var_ptr(%4 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "dtype%real"} - // CHECK: %[[MAP_PARENT:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<"_QFderived_type", (f32, array<10 x i32>, i32)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBER_1]], %[[MAP_MEMBER_2]] : [2], [0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "dtype", partial_map = true} + %5 = omp.map.info var_ptr(%4 : !fir.ref, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "dtype%real"} + // CHECK: %[[MAP_PARENT:.*]] = omp.map.info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.struct<"_QFderived_type", (f32, array<10 x i32>, i32)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBER_1]], %[[MAP_MEMBER_2]] : [2], [0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "dtype", partial_map = true} %6 = omp.map.info var_ptr(%arg0 : !fir.ref,int:i32}>>, !fir.type<_QFderived_type{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) members(%2, %5 : [2], [0] : !fir.ref, !fir.ref) -> !fir.ref,int:i32}>> {name = "dtype", partial_map = true} // CHECK: omp.target map_entries(%[[MAP_MEMBER_1]] -> %[[ARG_1:.*]], %[[MAP_MEMBER_2]] -> %[[ARG_2:.*]], %[[MAP_PARENT]] -> %[[ARG_3:.*]] : !llvm.ptr, !llvm.ptr, !llvm.ptr) { omp.target map_entries(%2 -> %arg1, %5 -> %arg2, %6 -> %arg3 : !fir.ref, !fir.ref, !fir.ref,int:i32}>>) { @@ -1279,3 +1279,22 @@ func.func @map_nested_dtype_alloca_mem2(%arg0 : !fir.ref { +omp.declare_mapper @my_mapper : !fir.type<_QFdeclare_mapperTmy_type{data:i32}> { +// CHECK: ^bb0(%[[VAL_0:.*]]: !llvm.ptr): +^bb0(%0: !fir.ref>): +// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(0 : i32) : i32 + %1 = fir.field_index data, !fir.type<_QFdeclare_mapperTmy_type{data:i32}> +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> + %2 = fir.coordinate_of %0, %1 : (!fir.ref>, !fir.field) -> !fir.ref +// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%[[VAL_4:.*]]"} + %3 = omp.map.info var_ptr(%2 : !fir.ref, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref {name = "var%data"} +// CHECK: %[[VAL_5:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]] : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + %4 = omp.map.info var_ptr(%0 : !fir.ref>, !fir.type<_QFdeclare_mapperTmy_type{data:i32}>) map_clauses(tofrom) capture(ByRef) members(%3 : [0] : !fir.ref) -> !fir.ref> {name = "var", partial_map = true} +// CHECK: omp.declare_mapper.info map_entries(%[[VAL_5]], %[[VAL_3]] : !llvm.ptr, !llvm.ptr) + omp.declare_mapper.info map_entries(%4, %3 : !fir.ref>, !fir.ref) +// CHECK: } +} diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 12e3c07669839..7888745dc6920 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -186,6 +186,32 @@ struct MapInfoOpConversion : public ConvertOpToLLVMPattern { } }; +struct DeclMapperOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter(); + SmallVector newAttrs; + newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr()); + newAttrs.emplace_back( + curOp.getTypeAttrName(), + TypeAttr::get(converter->convertType(curOp.getType()))); + + auto newOp = rewriter.create( + curOp.getLoc(), TypeRange(), adaptor.getOperands(), newAttrs); + rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), + *this->getTypeConverter()))) + return failure(); + + rewriter.eraseOp(curOp); + return success(); + } +}; + template struct MultiRegionOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -225,19 +251,21 @@ void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, const LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp< omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp, - omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp, - omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp, - omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp, - omp::YieldOp>([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp, + omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp, + omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp, + omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target.addDynamicallyLegalOp< - omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareReductionOp, - omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, omp::MasterOp, - omp::OrderedRegionOp, omp::ParallelOp, omp::SectionOp, omp::SectionsOp, - omp::SimdOp, omp::SingleOp, omp::TargetDataOp, omp::TargetOp, - omp::TaskgroupOp, omp::TaskloopOp, omp::TaskOp, omp::TeamsOp, + omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp, + omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp, + omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp, + omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp, + omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp, + omp::TaskloopOp, omp::TaskOp, omp::TeamsOp, omp::WsloopOp>([&](Operation *op) { return std::all_of(op->getRegions().begin(), op->getRegions().end(), [&](Region ®ion) { @@ -267,12 +295,13 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, [&](omp::MapBoundsType type) -> Type { return type; }); patterns.add< - AtomicReadOpConversion, MapInfoOpConversion, + AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion, MultiRegionOpConversion, MultiRegionOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, + RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir index 6f1ed73e778b4..d69de998346b5 100644 --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -601,3 +601,16 @@ func.func @omp_taskloop(%arg0: index, %arg1 : memref) { } return } + +// ----- + +// CHECK-LABEL: omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> { +omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> { +^bb0(%arg0: !llvm.ptr): + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)> + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%data"} + %3 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%2 : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr) + omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) +} From 1b9e2a8a9d32793709b71ea963f6e663f89423dc Mon Sep 17 00:00:00 2001 From: Akash Banerjee Date: Tue, 18 Feb 2025 17:55:48 +0000 Subject: [PATCH 5/5] [MLIR][OpenMP] Add LLVM translation support for OpenMP UserDefinedMappers (#124746) This patch adds OpenMPToLLVMIRTranslation support for the OpenMP Declare Mapper directive. Since both MLIR and Clang now support custom mappers, I've changed the respective function params to no longer be optional as well. Depends on #121005 --- clang/lib/CodeGen/CGOpenMPRuntime.cpp | 27 +-- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 42 ++-- llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 122 ++++++----- .../Frontend/OpenMPIRBuilderTest.cpp | 66 ++++-- .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 201 ++++++++++++++---- mlir/test/Target/LLVMIR/omptarget-llvm.mlir | 117 ++++++++++ 6 files changed, 434 insertions(+), 141 deletions(-) diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index b8ee943754bd6..97d0dc5180291 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -8921,7 +8921,7 @@ static void emitOffloadingArraysAndArgs( }; auto CustomMapperCB = [&](unsigned int I) { - llvm::Value *MFunc = nullptr; + llvm::Function *MFunc = nullptr; if (CombinedInfo.Mappers[I]) { Info.HasMapper = true; MFunc = CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc( @@ -8929,9 +8929,9 @@ static void emitOffloadingArraysAndArgs( } return MFunc; }; - OMPBuilder.emitOffloadingArraysAndArgs( - AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, IsNonContiguous, - ForEndCall, DeviceAddrCB, CustomMapperCB); + cantFail(OMPBuilder.emitOffloadingArraysAndArgs( + AllocaIP, CodeGenIP, Info, Info.RTArgs, CombinedInfo, CustomMapperCB, + IsNonContiguous, ForEndCall, DeviceAddrCB)); } /// Check for inner distribute directive. @@ -9124,15 +9124,15 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D, return CombinedInfo; }; - auto CustomMapperCB = [&](unsigned I, llvm::Function **MapperFunc) { + auto CustomMapperCB = [&](unsigned I) { + llvm::Function *MapperFunc = nullptr; if (CombinedInfo.Mappers[I]) { // Call the corresponding mapper function. - *MapperFunc = getOrCreateUserDefinedMapperFunc( + MapperFunc = getOrCreateUserDefinedMapperFunc( cast(CombinedInfo.Mappers[I])); - assert(*MapperFunc && "Expect a valid mapper function is available."); - return true; + assert(MapperFunc && "Expect a valid mapper function is available."); } - return false; + return MapperFunc; }; SmallString<64> TyStr; @@ -9140,8 +9140,8 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D, CGM.getCXXABI().getMangleContext().mangleCanonicalTypeName(Ty, Out); std::string Name = getName({"omp_mapper", TyStr, D->getName()}); - auto *NewFn = OMPBuilder.emitUserDefinedMapper(PrivatizeAndGenMapInfoCB, - ElemTy, Name, CustomMapperCB); + llvm::Function *NewFn = cantFail(OMPBuilder.emitUserDefinedMapper( + PrivatizeAndGenMapInfoCB, ElemTy, Name, CustomMapperCB)); UDMMap.try_emplace(D, NewFn); if (CGF) FunctionUDMMap[CGF->CurFn].push_back(D); @@ -10493,7 +10493,7 @@ void CGOpenMPRuntime::emitTargetDataCalls( }; auto CustomMapperCB = [&](unsigned int I) { - llvm::Value *MFunc = nullptr; + llvm::Function *MFunc = nullptr; if (CombinedInfo.Mappers[I]) { Info.HasMapper = true; MFunc = CGF.CGM.getOpenMPRuntime().getOrCreateUserDefinedMapperFunc( @@ -10513,7 +10513,8 @@ void CGOpenMPRuntime::emitTargetDataCalls( llvm::OpenMPIRBuilder::InsertPointTy AfterIP = cantFail(OMPBuilder.createTargetData( OmpLoc, AllocaIP, CodeGenIP, DeviceID, IfCondVal, Info, GenMapInfoCB, - /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, CustomMapperCB, RTLoc)); + CustomMapperCB, + /*MapperFunc=*/nullptr, BodyCB, DeviceAddrCB, RTLoc)); CGF.Builder.restoreIP(AfterIP); } diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index bddf9aedf9d9d..fc63acd36b05f 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -2434,6 +2434,7 @@ class OpenMPIRBuilder { CurInfo.NonContigInfo.Strides.end()); } }; + using MapInfosOrErrorTy = Expected; /// Callback function type for functions emitting the host fallback code that /// is executed when the kernel launch fails. It takes an insertion point as @@ -2442,6 +2443,11 @@ class OpenMPIRBuilder { using EmitFallbackCallbackTy = function_ref; + // Callback function type for emitting and fetching user defined custom + // mappers. + using CustomMapperCallbackTy = + function_ref(unsigned int)>; + /// Generate a target region entry call and host fallback call. /// /// \param Loc The location at which the request originated and is fulfilled. @@ -2508,11 +2514,11 @@ class OpenMPIRBuilder { /// return nullptr by reference. Accepts a reference to a MapInfosTy object /// that contains information generated for mappable clauses, /// including base pointers, pointers, sizes, map types, user-defined mappers. - void emitOffloadingArrays( + Error emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, - TargetDataInfo &Info, bool IsNonContiguous = false, - function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr); + TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB, + bool IsNonContiguous = false, + function_ref DeviceAddrCB = nullptr); /// Allocates memory for and populates the arrays required for offloading /// (offload_{baseptrs|ptrs|mappers|sizes|maptypes|mapnames}). Then, it @@ -2520,12 +2526,12 @@ class OpenMPIRBuilder { /// library. In essence, this function is a combination of /// emitOffloadingArrays and emitOffloadingArraysArgument and should arguably /// be preferred by clients of OpenMPIRBuilder. - void emitOffloadingArraysAndArgs( + Error emitOffloadingArraysAndArgs( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info, TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, - bool IsNonContiguous = false, bool ForEndCall = false, - function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr); + CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous = false, + bool ForEndCall = false, + function_ref DeviceAddrCB = nullptr); /// Creates offloading entry for the provided entry ID \a ID, address \a /// Addr, size \a Size, and flags \a Flags. @@ -2993,12 +2999,12 @@ class OpenMPIRBuilder { /// \param FuncName Optional param to specify mapper function name. /// \param CustomMapperCB Optional callback to generate code related to /// custom mappers. - Function *emitUserDefinedMapper( - function_ref + Expected emitUserDefinedMapper( + function_ref PrivAndGenMapInfoCB, llvm::Type *ElemTy, StringRef FuncName, - function_ref CustomMapperCB = nullptr); + CustomMapperCallbackTy CustomMapperCB); /// Generator for '#omp target data' /// @@ -3012,21 +3018,21 @@ class OpenMPIRBuilder { /// \param IfCond Value which corresponds to the if clause condition. /// \param Info Stores all information realted to the Target Data directive. /// \param GenMapInfoCB Callback that populates the MapInfos and returns. + /// \param CustomMapperCB Callback to generate code related to + /// custom mappers. /// \param BodyGenCB Optional Callback to generate the region code. /// \param DeviceAddrCB Optional callback to generate code related to /// use_device_ptr and use_device_addr. - /// \param CustomMapperCB Optional callback to generate code related to - /// custom mappers. InsertPointOrErrorTy createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, + CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc = nullptr, function_ref BodyGenCB = nullptr, function_ref DeviceAddrCB = nullptr, - function_ref CustomMapperCB = nullptr, Value *SrcLocInfo = nullptr); using TargetBodyGenCallbackTy = function_ref &Inputs, GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB, TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, + CustomMapperCallbackTy CustomMapperCB, SmallVector Dependencies = {}, bool HasNowait = false); /// Returns __kmpc_for_static_init_* runtime function for the specified diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 2a1948a540e4e..9d3c3c580d6cc 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -6375,7 +6375,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit( int32_t MaxThreadsVal = Attrs.MaxThreads.front(); #if FIX_NUM_THREADS_ISSUE - //breaks 534.hpgmg + //breaks 534.hpgmg // If MaxThreads not set, select the maximum between the default workgroup // size and the MinThreads value. if (MaxThreadsVal < 0) @@ -6703,12 +6703,11 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData( const LocationDescription &Loc, InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond, TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB, - omp::RuntimeFunction *MapperFunc, + CustomMapperCallbackTy CustomMapperCB, omp::RuntimeFunction *MapperFunc, function_ref BodyGenCB, - function_ref DeviceAddrCB, - function_ref CustomMapperCB, Value *SrcLocInfo) { + function_ref DeviceAddrCB, Value *SrcLocInfo) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -6732,9 +6731,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTargetData( auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) -> Error { MapInfo = &GenMapInfoCB(Builder.saveIP()); - emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info, - /*IsNonContiguous=*/true, DeviceAddrCB, - CustomMapperCB); + if (Error Err = emitOffloadingArrays( + AllocaIP, Builder.saveIP(), *MapInfo, Info, CustomMapperCB, + /*IsNonContiguous=*/true, DeviceAddrCB)) + return Err; TargetDataRTArgs RTArgs; emitOffloadingArraysArgument(Builder, RTArgs, Info); @@ -7657,26 +7657,31 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( return Builder.saveIP(); } -void OpenMPIRBuilder::emitOffloadingArraysAndArgs( +Error OpenMPIRBuilder::emitOffloadingArraysAndArgs( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info, - TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous, - bool ForEndCall, function_ref DeviceAddrCB, - function_ref CustomMapperCB) { - emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, IsNonContiguous, - DeviceAddrCB, CustomMapperCB); + TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, + CustomMapperCallbackTy CustomMapperCB, bool IsNonContiguous, + bool ForEndCall, function_ref DeviceAddrCB) { + if (Error Err = + emitOffloadingArrays(AllocaIP, CodeGenIP, CombinedInfo, Info, + CustomMapperCB, IsNonContiguous, DeviceAddrCB)) + return Err; emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall); + return Error::success(); } static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointTy AllocaIP, + OpenMPIRBuilder::TargetDataInfo &Info, const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs, const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID, SmallVectorImpl &Args, OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB, - SmallVector Dependencies = {}, - bool HasNoWait = false) { + OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB, + SmallVector Dependencies, + bool HasNoWait) { // Generate a function call to the host fallback implementation of the target // region. This is called by the host when no offload entry was generated for // the target region and when the offloading call fails at runtime. @@ -7747,15 +7752,13 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, auto &&EmitTargetCallThen = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error { - OpenMPIRBuilder::TargetDataInfo Info( - /*RequiresDevicePointerInfo=*/false, - /*SeparateBeginEndCalls=*/true); - OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(CodeGenIP); OpenMPIRBuilder::TargetDataRTArgs RTArgs; - OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, CodeGenIP, Info, RTArgs, - MapInfo, /*IsNonContiguous=*/true, - /*ForEndCall=*/false); + if (Error Err = OMPBuilder.emitOffloadingArraysAndArgs( + AllocaIP, CodeGenIP, Info, RTArgs, MapInfo, CustomMapperCB, + /*IsNonContiguous=*/true, + /*ForEndCall=*/false)) + return Err; SmallVector NumTeamsC; for (auto [DefaultVal, RuntimeVal] : @@ -7857,13 +7860,15 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP, - InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, + InsertPointTy CodeGenIP, TargetDataInfo &Info, + TargetRegionEntryInfo &EntryInfo, const TargetKernelDefaultAttrs &DefaultAttrs, const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond, - SmallVectorImpl &Args, GenMapInfoCallbackTy GenMapInfoCB, + SmallVectorImpl &Inputs, GenMapInfoCallbackTy GenMapInfoCB, OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc, OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB, - SmallVector Dependencies, bool HasNowait) { + CustomMapperCallbackTy CustomMapperCB, SmallVector Dependencies, + bool HasNowait) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -7877,16 +7882,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget( // and ArgAccessorFuncCB if (Error Err = emitTargetOutlinedFunction( *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn, - OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB)) + OutlinedFnID, Inputs, CBFunc, ArgAccessorFuncCB)) return Err; // If we are not on the target device, then we need to generate code // to make a remote call (offload) to the previously outlined function // that represents the target region. Do that now. if (!Config.isTargetDevice()) - emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond, - OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies, - HasNowait); + emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs, + IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB, + CustomMapperCB, Dependencies, HasNowait); return Builder.saveIP(); } @@ -8211,12 +8216,11 @@ void OpenMPIRBuilder::emitUDMapperArrayInitOrDel( OffloadingArgs); } -Function *OpenMPIRBuilder::emitUserDefinedMapper( - function_ref +Expected OpenMPIRBuilder::emitUserDefinedMapper( + function_ref GenMapInfoCB, - Type *ElemTy, StringRef FuncName, - function_ref CustomMapperCB) { + Type *ElemTy, StringRef FuncName, CustomMapperCallbackTy CustomMapperCB) { SmallVector Params; Params.emplace_back(Builder.getPtrTy()); Params.emplace_back(Builder.getPtrTy()); @@ -8287,7 +8291,9 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( PtrPHI->addIncoming(PtrBegin, HeadBB); // Get map clause information. Fill up the arrays with all mapped variables. - MapInfosTy &Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn); + MapInfosOrErrorTy Info = GenMapInfoCB(Builder.saveIP(), PtrPHI, BeginIn); + if (!Info) + return Info.takeError(); // Call the runtime API __tgt_mapper_num_components to get the number of // pre-existing components. @@ -8299,20 +8305,20 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( Builder.CreateShl(PreviousSize, Builder.getInt64(getFlagMemberOffset())); // Fill up the runtime mapper handle for all components. - for (unsigned I = 0; I < Info.BasePointers.size(); ++I) { + for (unsigned I = 0; I < Info->BasePointers.size(); ++I) { Value *CurBaseArg = - Builder.CreateBitCast(Info.BasePointers[I], Builder.getPtrTy()); + Builder.CreateBitCast(Info->BasePointers[I], Builder.getPtrTy()); Value *CurBeginArg = - Builder.CreateBitCast(Info.Pointers[I], Builder.getPtrTy()); - Value *CurSizeArg = Info.Sizes[I]; - Value *CurNameArg = Info.Names.size() - ? Info.Names[I] + Builder.CreateBitCast(Info->Pointers[I], Builder.getPtrTy()); + Value *CurSizeArg = Info->Sizes[I]; + Value *CurNameArg = Info->Names.size() + ? Info->Names[I] : Constant::getNullValue(Builder.getPtrTy()); // Extract the MEMBER_OF field from the map type. Value *OriMapType = Builder.getInt64( static_cast>( - Info.Types[I])); + Info->Types[I])); Value *MemberMapType = Builder.CreateNUWAdd(OriMapType, ShiftedPreviousSize); @@ -8394,10 +8400,13 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( Value *OffloadingArgs[] = {MapperHandle, CurBaseArg, CurBeginArg, CurSizeArg, CurMapType, CurNameArg}; - Function *ChildMapperFn = nullptr; - if (CustomMapperCB && CustomMapperCB(I, &ChildMapperFn)) { + + auto ChildMapperFn = CustomMapperCB(I); + if (!ChildMapperFn) + return ChildMapperFn.takeError(); + if (*ChildMapperFn) { // Call the corresponding mapper function. - Builder.CreateCall(ChildMapperFn, OffloadingArgs)->setDoesNotThrow(); + Builder.CreateCall(*ChildMapperFn, OffloadingArgs)->setDoesNotThrow(); } else { // Call the runtime API __tgt_push_mapper_component to fill up the runtime // data structure. @@ -8431,18 +8440,18 @@ Function *OpenMPIRBuilder::emitUserDefinedMapper( return MapperFn; } -void OpenMPIRBuilder::emitOffloadingArrays( +Error OpenMPIRBuilder::emitOffloadingArrays( InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo, - TargetDataInfo &Info, bool IsNonContiguous, - function_ref DeviceAddrCB, - function_ref CustomMapperCB) { + TargetDataInfo &Info, CustomMapperCallbackTy CustomMapperCB, + bool IsNonContiguous, + function_ref DeviceAddrCB) { // Reset the array information. Info.clearArrayInfo(); Info.NumberOfPtrs = CombinedInfo.BasePointers.size(); if (Info.NumberOfPtrs == 0) - return; + return Error::success(); Builder.restoreIP(AllocaIP); // Detect if we have any capture size requiring runtime evaluation of the @@ -8606,9 +8615,13 @@ void OpenMPIRBuilder::emitOffloadingArrays( // Fill up the mapper array. unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0); Value *MFunc = ConstantPointerNull::get(PtrTy); - if (CustomMapperCB) - if (Value *CustomMFunc = CustomMapperCB(I)) - MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy); + + auto CustomMFunc = CustomMapperCB(I); + if (!CustomMFunc) + return CustomMFunc.takeError(); + if (*CustomMFunc) + MFunc = Builder.CreatePointerCast(*CustomMFunc, PtrTy); + Value *MAddr = Builder.CreateInBoundsGEP( MappersArray->getAllocatedType(), MappersArray, {Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)}); @@ -8618,8 +8631,9 @@ void OpenMPIRBuilder::emitOffloadingArrays( if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() || Info.NumberOfPtrs == 0) - return; + return Error::success(); emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info); + return Error::success(); } void OpenMPIRBuilder::emitBranch(BasicBlock *Target) { diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index a8c6e7d6a1ac6..d2c4cc7acaec4 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5945,6 +5945,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/false, /*SeparateBeginEndCalls=*/true); @@ -5956,7 +5957,7 @@ TEST_F(OpenMPIRBuilderTest, TargetEnterData) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTargetData( Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); + /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc)); Builder.restoreIP(AfterIP); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6007,6 +6008,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/false, /*SeparateBeginEndCalls=*/true); @@ -6018,7 +6020,7 @@ TEST_F(OpenMPIRBuilderTest, TargetExitData) { OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTargetData( Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, &RTLFunc)); + /* IfCond= */ nullptr, Info, GenMapInfoCB, CustomMapperCB, &RTLFunc)); Builder.restoreIP(AfterIP); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6091,6 +6093,7 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { return CombinedInfo; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; llvm::OpenMPIRBuilder::TargetDataInfo Info( /*RequiresDevicePointerInfo=*/true, /*SeparateBeginEndCalls=*/true); @@ -6127,9 +6130,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, TargetDataIP1, - OMPBuilder.createTargetData( - Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyCB)); + OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(), + Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, + CustomMapperCB, nullptr, BodyCB)); Builder.restoreIP(TargetDataIP1); CallInst *TargetDataCall = dyn_cast(&BB->back()); @@ -6155,9 +6159,10 @@ TEST_F(OpenMPIRBuilderTest, TargetDataRegion) { }; ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, TargetDataIP2, - OMPBuilder.createTargetData( - Loc, AllocaIP, Builder.saveIP(), Builder.getInt64(DeviceID), - /* IfCond= */ nullptr, Info, GenMapInfoCB, nullptr, BodyTargetCB)); + OMPBuilder.createTargetData(Loc, AllocaIP, Builder.saveIP(), + Builder.getInt64(DeviceID), + /* IfCond= */ nullptr, Info, GenMapInfoCB, + CustomMapperCB, nullptr, BodyTargetCB)); Builder.restoreIP(TargetDataIP2); EXPECT_TRUE(CheckDevicePassBodyGen); @@ -6258,6 +6263,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17); OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL}); OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs; @@ -6271,9 +6281,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) { ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, + Builder.saveIP(), Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB, + CustomMapperCB, {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); @@ -6418,6 +6429,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6437,13 +6449,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB, + {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); @@ -6567,6 +6583,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { F->setName("func"); IRBuilder<> Builder(BB); + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](InsertPointTy, InsertPointTy CodeGenIP) -> InsertPointTy { Builder.restoreIP(CodeGenIP); @@ -6594,13 +6611,17 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; RuntimeAttrs.LoopTripCount = Builder.getInt64(1000); + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), - Builder.saveIP(), EntryInfo, DefaultAttrs, + Builder.saveIP(), Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, Inputs, - GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB)); + GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB, + CustomMapperCB)); Builder.restoreIP(AfterIP); OMPBuilder.finalize(); @@ -6682,6 +6703,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { return CombinedInfos; }; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6698,13 +6720,16 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB)); Builder.restoreIP(AfterIP); Builder.CreateRetVoid(); @@ -6799,6 +6824,7 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { llvm::Value *RaiseAlloca = nullptr; + auto CustomMapperCB = [&](unsigned int I) { return nullptr; }; auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP, OpenMPIRBuilder::InsertPointTy CodeGenIP) -> OpenMPIRBuilder::InsertPointTy { @@ -6819,13 +6845,17 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) { OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; + llvm::OpenMPIRBuilder::TargetDataInfo Info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); ASSERT_EXPECTED_INIT( OpenMPIRBuilder::InsertPointTy, AfterIP, OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, - EntryInfo, DefaultAttrs, RuntimeAttrs, + Info, EntryInfo, DefaultAttrs, RuntimeAttrs, /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB, - BodyGenCB, SimpleArgAccessorCB)); + BodyGenCB, SimpleArgAccessorCB, CustomMapperCB, + {}, false)); EXPECT_EQ(DL, Builder.getCurrentDebugLocation()); Builder.restoreIP(AfterIP); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3a29ffb2770db..3b9a5495a4071 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2987,13 +2987,23 @@ getRefPtrIfDeclareTarget(mlir::Value value, } namespace { +// Append customMappers information to existing MapInfosTy +struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy { + SmallVector Mappers; + + /// Append arrays in \a CurInfo. + void append(MapInfosTy &curInfo) { + Mappers.append(curInfo.Mappers.begin(), curInfo.Mappers.end()); + llvm::OpenMPIRBuilder::MapInfosTy::append(curInfo); + } +}; // A small helper structure to contain data gathered // for map lowering and coalese it into one area and // avoiding extra computations such as searches in the // llvm module for lowered mapped variables or checking // if something is declare target (and retrieving the // value) more than neccessary. -struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { +struct MapInfoData : MapInfosTy { llvm::SmallVector IsDeclareTarget; llvm::SmallVector IsAMember; // Identify if mapping was added by mapClause or use_device clauses. @@ -3012,7 +3022,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy { OriginalValue.append(CurInfo.OriginalValue.begin(), CurInfo.OriginalValue.end()); BaseType.append(CurInfo.BaseType.begin(), CurInfo.BaseType.end()); - llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo); + MapInfosTy::append(CurInfo); } }; @@ -3161,6 +3171,12 @@ static void collectMapDataFromMapOperands( mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None); + if (mapOp.getMapperId()) + mapData.Mappers.push_back( + SymbolTable::lookupNearestSymbolFrom( + mapOp, mapOp.getMapperIdAttr())); + else + mapData.Mappers.push_back(nullptr); mapData.IsAMapping.push_back(true); mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp)); } @@ -3205,6 +3221,7 @@ static void collectMapDataFromMapOperands( mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(devInfoTy); + mapData.Mappers.push_back(nullptr); mapData.IsAMapping.push_back(false); mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp)); } @@ -3475,11 +3492,12 @@ static bool checkIfPointerMap(omp::MapInfoOp mapOp) { // // This function borrows a lot from Clang's emitCombinedEntry function // inside of CGOpenMPRuntime.cpp -static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( - LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, TargetDirective targetDirective) { +static llvm::omp::OpenMPOffloadMappingFlags +mapParentWithMembers(LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, + MapInfosTy &combinedInfo, MapInfoData &mapData, + uint64_t mapDataIndex, TargetDirective targetDirective) { // Map the first segment of our structure const size_t parentIndex = combinedInfo.Types.size(); combinedInfo.Types.emplace_back( @@ -3489,6 +3507,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]); @@ -3558,6 +3577,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3592,6 +3612,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataOverlapIdx]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataOverlapIdx]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3613,6 +3634,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( mapData.DevicePointers[mapDataIndex]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIndex]); combinedInfo.Names.emplace_back(LLVM::createMappingInformation( mapData.MapClause[mapDataIndex]->getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3629,9 +3651,9 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // This function is intended to add explicit mappings of members static void processMapMembersWithParent( LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, + llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, + llvm::omp::OpenMPOffloadMappingFlags memberOfFlag, TargetDirective targetDirective) { auto parentClause = @@ -3660,6 +3682,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); + combinedInfo.Mappers.emplace_back(nullptr); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); combinedInfo.BasePointers.emplace_back( @@ -3689,6 +3712,7 @@ static void processMapMembersWithParent( combinedInfo.Types.emplace_back(mapFlag); combinedInfo.DevicePointers.emplace_back( llvm::OpenMPIRBuilder::DeviceInfoTy::None); + combinedInfo.Mappers.emplace_back(nullptr); combinedInfo.Names.emplace_back( LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder)); uint64_t basePointerIndex = @@ -3708,11 +3732,10 @@ static void processMapMembersWithParent( } } -static void -processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, - TargetDirective targetDirective, - int mapDataParentIdx = -1) { +static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, + MapInfosTy &combinedInfo, + TargetDirective targetDirective, + int mapDataParentIdx = -1) { // Declare Target Mappings are excluded from being marked as // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're // marked with OMP_MAP_PTR_AND_OBJ instead. @@ -3743,16 +3766,18 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx, combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIdx]); combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[mapDataIdx]); + combinedInfo.Mappers.emplace_back(mapData.Mappers[mapDataIdx]); combinedInfo.Names.emplace_back(mapData.Names[mapDataIdx]); combinedInfo.Types.emplace_back(mapFlag); combinedInfo.Sizes.emplace_back(mapData.Sizes[mapDataIdx]); } -static void processMapWithMembersOf( - LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData, - uint64_t mapDataIndex, TargetDirective targetDirective) { +static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation, + llvm::IRBuilderBase &builder, + llvm::OpenMPIRBuilder &ompBuilder, + DataLayout &dl, MapInfosTy &combinedInfo, + MapInfoData &mapData, uint64_t mapDataIndex, + TargetDirective targetDirective) { auto parentClause = llvm::cast(mapData.MapClause[mapDataIndex]); @@ -3858,8 +3883,7 @@ createAlteredByCaptureMap(MapInfoData &mapData, // Generate all map related information and fill the combinedInfo. static void genMapInfos(llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, - DataLayout &dl, - llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, + DataLayout &dl, MapInfosTy &combinedInfo, MapInfoData &mapData, TargetDirective targetDirective) { // We wish to modify some of the methods in which arguments are // passed based on their capture type by the target region, this can @@ -3899,6 +3923,79 @@ static void genMapInfos(llvm::IRBuilderBase &builder, } } +static llvm::Expected +emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef mapperFuncName); + +static llvm::Expected +getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto declMapperOp = cast(op); + std::string mapperFuncName = + moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName( + {"omp_mapper", declMapperOp.getSymName()}); + + if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName)) + return lookupFunc; + + return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation, + mapperFuncName); +} + +static llvm::Expected +emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::StringRef mapperFuncName) { + auto declMapperOp = cast(op); + auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo(); + DataLayout dl = DataLayout(declMapperOp->getParentOfType()); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::Type *varType = moduleTranslation.convertType(declMapperOp.getType()); + SmallVector mapVars = declMapperInfoOp.getMapVars(); + + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; + + // Fill up the arrays with all the mapped variables. + MapInfosTy combinedInfo; + auto genMapInfoCB = + [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI, + llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy { + builder.restoreIP(codeGenIP); + moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI); + moduleTranslation.mapBlock(&declMapperOp.getRegion().front(), + builder.GetInsertBlock()); + if (failed(moduleTranslation.convertBlock(declMapperOp.getRegion().front(), + /*ignoreArguments=*/true, + builder))) + return llvm::make_error(); + MapInfoData mapData; + collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, + builder); + genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData, + TargetDirective::None); + + // Drop the mapping that is no longer necessary so that the same region can + // be processed multiple times. + moduleTranslation.forgetMapping(declMapperOp.getRegion()); + return combinedInfo; + }; + + auto customMapperCB = [&](unsigned i) -> llvm::Expected { + if (!combinedInfo.Mappers[i]) + return nullptr; + return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, + moduleTranslation); + }; + + llvm::Expected newFn = ompBuilder->emitUserDefinedMapper( + genMapInfoCB, varType, mapperFuncName, customMapperCB); + if (!newFn) + return newFn.takeError(); + moduleTranslation.mapFunction(mapperFuncName, *newFn); + return *newFn; +} + static LogicalResult convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -4011,9 +4108,8 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, builder, useDevicePtrVars, useDeviceAddrVars); // Fill up the arrays with all the mapped variables. - llvm::OpenMPIRBuilder::MapInfosTy combinedInfo; - auto genMapInfoCB = - [&](InsertPointTy codeGenIP) -> llvm::OpenMPIRBuilder::MapInfosTy & { + MapInfosTy combinedInfo; + auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); genMapInfos(builder, moduleTranslation, DL, combinedInfo, mapData, targetDirective); @@ -4057,6 +4153,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy; auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType) -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy { + builder.restoreIP(codeGenIP); assert(isa(op) && "BodyGen requested for non TargetDataOp"); auto blockArgIface = cast(op); @@ -4065,8 +4162,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, case BodyGenTy::Priv: // Check if any device ptr/addr info is available if (!info.DevicePtrInfoMap.empty()) { - builder.restoreIP(codeGenIP); - mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address, blockArgIface.getUseDeviceAddrBlockArgs(), useDeviceAddrVars, mapData, @@ -4101,7 +4196,6 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, case BodyGenTy::NoPriv: // If device info is available then region has already been generated if (info.DevicePtrInfoMap.empty()) { - builder.restoreIP(codeGenIP); // For device pass, if use_device_ptr(addr) mappings were present, // we need to link them here before codegen. if (ompBuilder->Config.IsTargetDevice.value_or(false)) { @@ -4127,17 +4221,28 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return builder.saveIP(); }; + auto customMapperCB = + [&](unsigned int i) -> llvm::Expected { + if (!combinedInfo.Mappers[i]) + return nullptr; + info.HasMapper = true; + return getOrCreateUserDefinedMapperFunc(combinedInfo.Mappers[i], builder, + moduleTranslation); + }; + llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() { if (isa(op)) - return ompBuilder->createTargetData( - ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), - ifCond, info, genMapInfoCB, nullptr, bodyGenCB); - return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), - builder.getInt64(deviceID), ifCond, - info, genMapInfoCB, &RTLFn); + return ompBuilder->createTargetData(ompLoc, allocaIP, builder.saveIP(), + builder.getInt64(deviceID), ifCond, + info, genMapInfoCB, customMapperCB, + /*MapperFunc=*/nullptr, bodyGenCB, + /*DeviceAddrCB=*/nullptr); + return ompBuilder->createTargetData( + ompLoc, allocaIP, builder.saveIP(), builder.getInt64(deviceID), ifCond, + info, genMapInfoCB, customMapperCB, &RTLFn); }(); if (failed(handleError(afterIP, *op))) @@ -4952,9 +5057,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl, builder); - llvm::OpenMPIRBuilder::MapInfosTy combinedInfos; - auto genMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) - -> llvm::OpenMPIRBuilder::MapInfosTy & { + MapInfosTy combinedInfos; + auto genMapInfoCB = + [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & { builder.restoreIP(codeGenIP); genMapInfos(builder, moduleTranslation, dl, combinedInfos, mapData, targetDirective); @@ -5024,15 +5129,28 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder::TargetDataInfo info( + /*RequiresDevicePointerInfo=*/false, + /*SeparateBeginEndCalls=*/true); + + auto customMapperCB = + [&](unsigned int i) -> llvm::Expected { + if (!combinedInfos.Mappers[i]) + return nullptr; + info.HasMapper = true; + return getOrCreateUserDefinedMapperFunc(combinedInfos.Mappers[i], builder, + moduleTranslation); + }; + llvm::Value *ifCond = nullptr; if (Value targetIfCond = targetOp.getIfExpr()) ifCond = moduleTranslation.lookupValue(targetIfCond); llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = moduleTranslation.getOpenMPBuilder()->createTarget( - ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo, + ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), info, entryInfo, defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB, - argAccessorCB, dds, targetOp.getNowait()); + argAccessorCB, customMapperCB, dds, targetOp.getNowait()); if (failed(handleError(afterIP, opInst))) return failure(); @@ -5272,12 +5390,15 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, .Case([&](omp::TaskwaitOp op) { return convertOmpTaskwaitOp(op, builder, moduleTranslation); }) - .Case([](auto op) { // `yield` and `terminator` can be just omitted. The block structure // was created in the region that handles their parent operation. // `declare_reduction` will be used by reductions and is not // converted directly, skip it. + // `declare_mapper` and `declare_mapper.info` are handled whenever + // they are referred to through a `map` clause. // `critical.declare` is only used to declare names of critical // sections which will be used by `critical` ops and hence can be // ignored for lowering. The OpenMP IRBuilder will create unique diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir index 7f21095763a39..02b84ff66a0d3 100644 --- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir @@ -485,3 +485,120 @@ llvm.func @_QPopenmp_target_data_update() { // CHECK: call void @__tgt_target_data_update_mapper(ptr @2, i64 -1, i32 1, ptr %[[BASEPTRS_VAL_2]], ptr %[[PTRS_VAL_2]], ptr @{{.*}}, ptr @{{.*}}, ptr @{{.*}}, ptr null) // CHECK: ret void + +// ----- + +omp.declare_mapper @_QQFmy_testmy_mapper : !llvm.struct<"_QFmy_testTmy_type", (i32)> { +^bb0(%arg0: !llvm.ptr): + %0 = llvm.mlir.constant(0 : i32) : i32 + %1 = llvm.getelementptr %arg0[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)> + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "var%data"} + %3 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) map_clauses(tofrom) capture(ByRef) members(%2 : [0] : !llvm.ptr) -> !llvm.ptr {name = "var", partial_map = true} + omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr) +} + +llvm.func @_QPopenmp_target_data_mapper() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x !llvm.struct<"_QFmy_testTmy_type", (i32)> {bindc_name = "a"} : (i64) -> !llvm.ptr + %2 = omp.map.info var_ptr(%1 : !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)>) mapper(@_QQFmy_testmy_mapper) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "a"} + omp.target_data map_entries(%2 : !llvm.ptr) { + %3 = llvm.mlir.constant(10 : i32) : i32 + %4 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"_QFmy_testTmy_type", (i32)> + llvm.store %3, %4 : i32, !llvm.ptr + omp.terminator + } + llvm.return +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 4] +// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK-LABEL: define void @_QPopenmp_target_data_mapper +// CHECK: %[[VAL_0:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_1:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_2:.*]] = alloca [1 x ptr], align 8 +// CHECK: %[[VAL_3:.*]] = alloca %[[VAL_4:.*]], i64 1, align 8 +// CHECK: br label %[[VAL_5:.*]] +// CHECK: entry: ; preds = %[[VAL_6:.*]] +// CHECK: %[[VAL_7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_7]], align 8 +// CHECK: %[[VAL_8:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: store ptr %[[VAL_3]], ptr %[[VAL_8]], align 8 +// CHECK: %[[VAL_9:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_2]], i64 0, i64 0 +// CHECK: store ptr @.omp_mapper._QQFmy_testmy_mapper, ptr %[[VAL_9]], align 8 +// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: %[[VAL_11:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: call void @__tgt_target_data_begin_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]]) +// CHECK: %[[VAL_12:.*]] = getelementptr %[[VAL_4]], ptr %[[VAL_3]], i32 0, i32 0 +// CHECK: store i32 10, ptr %[[VAL_12]], align 4 +// CHECK: %[[VAL_13:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0 +// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0 +// CHECK: call void @__tgt_target_data_end_mapper(ptr @4, i64 -1, i32 1, ptr %[[VAL_13]], ptr %[[VAL_14]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr %[[VAL_2]]) +// CHECK: ret void + +// CHECK-LABEL: define internal void @.omp_mapper._QQFmy_testmy_mapper +// CHECK: entry: +// CHECK: %[[VAL_15:.*]] = udiv exact i64 %[[VAL_16:.*]], 4 +// CHECK: %[[VAL_17:.*]] = getelementptr %[[VAL_18:.*]], ptr %[[VAL_19:.*]], i64 %[[VAL_15]] +// CHECK: %[[VAL_20:.*]] = icmp sgt i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_21:.*]] = and i64 %[[VAL_22:.*]], 8 +// CHECK: %[[VAL_23:.*]] = icmp ne ptr %[[VAL_24:.*]], %[[VAL_19]] +// CHECK: %[[VAL_25:.*]] = and i64 %[[VAL_22]], 16 +// CHECK: %[[VAL_26:.*]] = icmp ne i64 %[[VAL_25]], 0 +// CHECK: %[[VAL_27:.*]] = and i1 %[[VAL_23]], %[[VAL_26]] +// CHECK: %[[VAL_28:.*]] = or i1 %[[VAL_20]], %[[VAL_27]] +// CHECK: %[[VAL_29:.*]] = icmp eq i64 %[[VAL_21]], 0 +// CHECK: %[[VAL_30:.*]] = and i1 %[[VAL_28]], %[[VAL_29]] +// CHECK: br i1 %[[VAL_30]], label %[[VAL_31:.*]], label %[[VAL_32:.*]] +// CHECK: .omp.array..init: ; preds = %[[VAL_33:.*]] +// CHECK: %[[VAL_34:.*]] = mul nuw i64 %[[VAL_15]], 4 +// CHECK: %[[VAL_35:.*]] = and i64 %[[VAL_22]], -4 +// CHECK: %[[VAL_36:.*]] = or i64 %[[VAL_35]], 512 +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37:.*]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_34]], i64 %[[VAL_36]], ptr %[[VAL_38:.*]]) +// CHECK: br label %[[VAL_32]] +// CHECK: omp.arraymap.head: ; preds = %[[VAL_31]], %[[VAL_33]] +// CHECK: %[[VAL_39:.*]] = icmp eq ptr %[[VAL_19]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_39]], label %[[VAL_40:.*]], label %[[VAL_41:.*]] +// CHECK: omp.arraymap.body: ; preds = %[[VAL_42:.*]], %[[VAL_32]] +// CHECK: %[[VAL_43:.*]] = phi ptr [ %[[VAL_19]], %[[VAL_32]] ], [ %[[VAL_44:.*]], %[[VAL_42]] ] +// CHECK: %[[VAL_45:.*]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 0, i32 0 +// CHECK: %[[VAL_46:.*]] = call i64 @__tgt_mapper_num_components(ptr %[[VAL_37]]) +// CHECK: %[[VAL_47:.*]] = shl i64 %[[VAL_46]], 48 +// CHECK: %[[VAL_48:.*]] = add nuw i64 3, %[[VAL_47]] +// CHECK: %[[VAL_49:.*]] = and i64 %[[VAL_22]], 3 +// CHECK: %[[VAL_50:.*]] = icmp eq i64 %[[VAL_49]], 0 +// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_52:.*]] +// CHECK: omp.type.alloc: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_53:.*]] = and i64 %[[VAL_48]], -4 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.alloc.else: ; preds = %[[VAL_41]] +// CHECK: %[[VAL_54:.*]] = icmp eq i64 %[[VAL_49]], 1 +// CHECK: br i1 %[[VAL_54]], label %[[VAL_55:.*]], label %[[VAL_56:.*]] +// CHECK: omp.type.to: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_57:.*]] = and i64 %[[VAL_48]], -3 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.to.else: ; preds = %[[VAL_52]] +// CHECK: %[[VAL_58:.*]] = icmp eq i64 %[[VAL_49]], 2 +// CHECK: br i1 %[[VAL_58]], label %[[VAL_59:.*]], label %[[VAL_42]] +// CHECK: omp.type.from: ; preds = %[[VAL_56]] +// CHECK: %[[VAL_60:.*]] = and i64 %[[VAL_48]], -2 +// CHECK: br label %[[VAL_42]] +// CHECK: omp.type.end: ; preds = %[[VAL_59]], %[[VAL_56]], %[[VAL_55]], %[[VAL_51]] +// CHECK: %[[VAL_61:.*]] = phi i64 [ %[[VAL_53]], %[[VAL_51]] ], [ %[[VAL_57]], %[[VAL_55]] ], [ %[[VAL_60]], %[[VAL_59]] ], [ %[[VAL_48]], %[[VAL_56]] ] +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_43]], ptr %[[VAL_45]], i64 4, i64 %[[VAL_61]], ptr @2) +// CHECK: %[[VAL_44]] = getelementptr %[[VAL_18]], ptr %[[VAL_43]], i32 1 +// CHECK: %[[VAL_62:.*]] = icmp eq ptr %[[VAL_44]], %[[VAL_17]] +// CHECK: br i1 %[[VAL_62]], label %[[VAL_63:.*]], label %[[VAL_41]] +// CHECK: omp.arraymap.exit: ; preds = %[[VAL_42]] +// CHECK: %[[VAL_64:.*]] = icmp sgt i64 %[[VAL_15]], 1 +// CHECK: %[[VAL_65:.*]] = and i64 %[[VAL_22]], 8 +// CHECK: %[[VAL_66:.*]] = icmp ne i64 %[[VAL_65]], 0 +// CHECK: %[[VAL_67:.*]] = and i1 %[[VAL_64]], %[[VAL_66]] +// CHECK: br i1 %[[VAL_67]], label %[[VAL_68:.*]], label %[[VAL_40]] +// CHECK: .omp.array..del: ; preds = %[[VAL_63]] +// CHECK: %[[VAL_69:.*]] = mul nuw i64 %[[VAL_15]], 4 +// CHECK: %[[VAL_70:.*]] = and i64 %[[VAL_22]], -4 +// CHECK: %[[VAL_71:.*]] = or i64 %[[VAL_70]], 512 +// CHECK: call void @__tgt_push_mapper_component(ptr %[[VAL_37]], ptr %[[VAL_24]], ptr %[[VAL_19]], i64 %[[VAL_69]], i64 %[[VAL_71]], ptr %[[VAL_38]]) +// CHECK: br label %[[VAL_40]] +// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]] +// CHECK: ret void