-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Add cuf.device_address operation #122975
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesIntroduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to Full diff: https://github.com/llvm/llvm-project/pull/122975.diff 5 Files Affected:
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 6f886726b12834..a270e69b394104 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -335,4 +335,16 @@ def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
}];
}
+def cuf_DeviceAddressOp : cuf_Op<"device_address", []> {
+ let summary = "Get the device address from a host symbol";
+
+ let arguments = (ins SymbolRefAttr:$hostSymbol);
+
+ let assemblyFormat = [{
+ $hostSymbol attr-dict `->` type($addr)
+ }];
+
+ let results = (outs fir_ReferenceType:$addr);
+}
+
#endif // FORTRAN_DIALECT_CUF_CUF_OPS
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index d61d9f63cb2949..e93bed37d39f78 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};
-static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
- mlir::ModuleOp mod, mlir::Location loc,
- mlir::Value inputArg) {
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::func::FuncOp callee =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
- auto fTy = callee.getFunctionType();
- mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, conv, sourceFile, sourceLine)};
- auto call = rewriter.create<fir::CallOp>(loc, callee, args);
- return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
-}
+struct CUFDeviceAddressOpConversion
+ : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symtab)
+ : OpRewritePattern(context), symTab{symtab} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::DeviceAddressOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (auto global = symTab.lookup<fir::GlobalOp>(
+ op.getHostSymbol().getRootReference().getValue())) {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ mlir::Location loc = op.getLoc();
+ auto hostAddr = rewriter.create<fir::AddrOfOp>(
+ loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
+ builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value conv =
+ createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, conv, sourceFile, sourceLine)};
+ auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+ mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
+ call->getResult(0));
+ rewriter.replaceOp(op, addr.getDefiningOp());
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ const mlir::SymbolTable &symTab;
+};
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
addrOfOp.getSymbol().getRootReference().getValue())) {
if (cuf::isRegisteredDeviceGlobal(global)) {
rewriter.setInsertionPointAfter(addrOfOp);
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
- addrOfOp.getResult());
+ mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
+ op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
rewriter.startOpModification(op);
op.getMemrefMutable().assign(devAddr);
rewriter.finalizeOpModification(op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
}
}
llvm::SmallVector<mlir::Value> args;
- auto mod = op->getParentOfType<mlir::ModuleOp>();
for (mlir::Value arg : op.getArgs()) {
// If the argument is a global descriptor, make sure we pass the device
// copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
if (auto global = symTab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (cuf::isRegisteredDeviceGlobal(global)) {
- arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
- declareOp.getResult());
+ arg = rewriter
+ .create<cuf::DeviceAddressOp>(op.getLoc(),
+ addrOfOp.getType(),
+ addrOfOp.getSymbol())
+ .getResult();
}
}
}
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
- patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+ patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
+ patterns.getContext(), symtab);
}
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns) {
- patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
+ patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
+ patterns.getContext(), symtab);
}
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 7203c33e7eb11f..5ed27f1be0a430 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -198,6 +198,7 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
// CHECK-LABEL: func.func @_QPsub8()
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -222,6 +223,7 @@ func.func @_QPsub9() {
// CHECK-LABEL: func.func @_QPsub9()
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -380,6 +382,7 @@ func.func @_QPdevice_addr_conv() {
}
// CHECK-LABEL: func.func @_QPdevice_addr_conv()
+// CHECK: fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 94ee74736f6508..0ccd0c797fb6f5 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -26,6 +26,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
}
// CHECK-LABEL: func.func @_QQmain()
+// CHECK: fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index 1e19b3bea1296f..8432b9ec926e38 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -98,9 +98,9 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
}
// CHECK-LABEL: func.func @_QQmain()
+// CHECK: _FortranACUFSyncGlobalDescriptor
// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_ADDR:.*]] = fir.convert %[[ADDROF]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_ADDR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
// CHECK: gpu.launch_func @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
|
wangzpgi
approved these changes
Jan 14, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to
#122802