Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Add cuf.device_address operation #122975

Merged
merged 1 commit into from
Jan 15, 2025

Conversation

clementval
Copy link
Contributor

Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to
#122802

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 14, 2025
@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to
#122802


Full diff: https://github.com/llvm/llvm-project/pull/122975.diff

5 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td (+12)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+52-24)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+3)
  • (modified) flang/test/Fir/CUDA/cuda-global-addr.mlir (+1)
  • (modified) flang/test/Fir/CUDA/cuda-launch.fir (+3-3)
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 6f886726b12834..a270e69b394104 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -335,4 +335,16 @@ def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
   }];
 }
 
+def cuf_DeviceAddressOp : cuf_Op<"device_address", []> {
+  let summary = "Get the device address from a host symbol";
+
+  let arguments = (ins SymbolRefAttr:$hostSymbol);
+
+  let assemblyFormat = [{
+    $hostSymbol attr-dict `->` type($addr)
+  }];
+
+  let results = (outs fir_ReferenceType:$addr);
+}
+
 #endif // FORTRAN_DIALECT_CUF_CUF_OPS
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index d61d9f63cb2949..e93bed37d39f78 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
   const fir::LLVMTypeConverter *typeConverter;
 };
 
-static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
-                                       mlir::ModuleOp mod, mlir::Location loc,
-                                       mlir::Value inputArg) {
-  fir::FirOpBuilder builder(rewriter, mod);
-  mlir::func::FuncOp callee =
-      fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
-  auto fTy = callee.getFunctionType();
-  mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
-  mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-  mlir::Value sourceLine =
-      fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
-  llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-      builder, loc, fTy, conv, sourceFile, sourceLine)};
-  auto call = rewriter.create<fir::CallOp>(loc, callee, args);
-  return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
-}
+struct CUFDeviceAddressOpConversion
+    : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
+                               const mlir::SymbolTable &symtab)
+      : OpRewritePattern(context), symTab{symtab} {}
+
+  mlir::LogicalResult
+  matchAndRewrite(cuf::DeviceAddressOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    if (auto global = symTab.lookup<fir::GlobalOp>(
+            op.getHostSymbol().getRootReference().getValue())) {
+      auto mod = op->getParentOfType<mlir::ModuleOp>();
+      mlir::Location loc = op.getLoc();
+      auto hostAddr = rewriter.create<fir::AddrOfOp>(
+          loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
+      fir::FirOpBuilder builder(rewriter, mod);
+      mlir::func::FuncOp callee =
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
+                                                                     builder);
+      auto fTy = callee.getFunctionType();
+      mlir::Value conv =
+          createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
+      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+      mlir::Value sourceLine =
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+          builder, loc, fTy, conv, sourceFile, sourceLine)};
+      auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+      mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
+                                         call->getResult(0));
+      rewriter.replaceOp(op, addr.getDefiningOp());
+      return success();
+    }
+    return failure();
+  }
+
+private:
+  const mlir::SymbolTable &symTab;
+};
 
 struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
               addrOfOp.getSymbol().getRootReference().getValue())) {
         if (cuf::isRegisteredDeviceGlobal(global)) {
           rewriter.setInsertionPointAfter(addrOfOp);
-          auto mod = op->getParentOfType<mlir::ModuleOp>();
-          mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
-                                                    addrOfOp.getResult());
+          mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
+              op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
           rewriter.startOpModification(op);
           op.getMemrefMutable().assign(devAddr);
           rewriter.finalizeOpModification(op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
       }
     }
     llvm::SmallVector<mlir::Value> args;
-    auto mod = op->getParentOfType<mlir::ModuleOp>();
     for (mlir::Value arg : op.getArgs()) {
       // If the argument is a global descriptor, make sure we pass the device
       // copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
             if (auto global = symTab.lookup<fir::GlobalOp>(
                     addrOfOp.getSymbol().getRootReference().getValue())) {
               if (cuf::isRegisteredDeviceGlobal(global)) {
-                arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
-                                          declareOp.getResult());
+                arg = rewriter
+                          .create<cuf::DeviceAddressOp>(op.getLoc(),
+                                                        addrOfOp.getType(),
+                                                        addrOfOp.getSymbol())
+                          .getResult();
               }
             }
           }
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
       patterns.getContext());
   patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
                                                &dl, &converter);
-  patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+  patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
+      patterns.getContext(), symtab);
 }
 
 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
                                            mlir::RewritePatternSet &patterns) {
-  patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
+  patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
+      patterns.getContext(), symtab);
 }
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 7203c33e7eb11f..5ed27f1be0a430 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -198,6 +198,7 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
 // CHECK-LABEL: func.func @_QPsub8()
 // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
 // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -222,6 +223,7 @@ func.func @_QPsub9() {
 // CHECK-LABEL: func.func @_QPsub9()
 // CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> 
 // CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
 // CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -380,6 +382,7 @@ func.func @_QPdevice_addr_conv() {
 }
 
 // CHECK-LABEL: func.func @_QPdevice_addr_conv()
+// CHECK: fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
 // CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
 // CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 94ee74736f6508..0ccd0c797fb6f5 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -26,6 +26,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
 }
 
 // CHECK-LABEL: func.func @_QQmain()
+// CHECK: fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
 // CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
 // CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index 1e19b3bea1296f..8432b9ec926e38 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -98,9 +98,9 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
 }
 
 // CHECK-LABEL: func.func @_QQmain()
+// CHECK: _FortranACUFSyncGlobalDescriptor
 // CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_ADDR:.*]] = fir.convert %[[ADDROF]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_ADDR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
 // CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
 // CHECK: gpu.launch_func  @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})  dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)

@clementval clementval merged commit a19919f into llvm:main Jan 15, 2025
11 checks passed
@clementval clementval deleted the cuf_get_addr branch January 15, 2025 01:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants