@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
366
366
const fir::LLVMTypeConverter *typeConverter;
367
367
};
368
368
369
+ static mlir::Value genGetDeviceAddress (mlir::PatternRewriter &rewriter,
370
+ mlir::ModuleOp mod, mlir::Location loc,
371
+ mlir::Value inputArg) {
372
+ fir::FirOpBuilder builder (rewriter, mod);
373
+ mlir::func::FuncOp callee =
374
+ fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(loc, builder);
375
+ auto fTy = callee.getFunctionType ();
376
+ mlir::Value conv = createConvertOp (rewriter, loc, fTy .getInput (0 ), inputArg);
377
+ mlir::Value sourceFile = fir::factory::locationToFilename (builder, loc);
378
+ mlir::Value sourceLine =
379
+ fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
380
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
381
+ builder, loc, fTy , conv, sourceFile, sourceLine)};
382
+ auto call = rewriter.create <fir::CallOp>(loc, callee, args);
383
+ return createConvertOp (rewriter, loc, inputArg.getType (), call->getResult (0 ));
384
+ }
385
+
369
386
struct DeclareOpConversion : public mlir ::OpRewritePattern<fir::DeclareOp> {
370
387
using OpRewritePattern::OpRewritePattern;
371
388
@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
382
399
if (cuf::isRegisteredDeviceGlobal (global)) {
383
400
rewriter.setInsertionPointAfter (addrOfOp);
384
401
auto mod = op->getParentOfType <mlir::ModuleOp>();
385
- fir::FirOpBuilder builder (rewriter, mod);
386
- mlir::Location loc = op.getLoc ();
387
- mlir::func::FuncOp callee =
388
- fir::runtime::getRuntimeFunc<mkRTKey (CUFGetDeviceAddress)>(
389
- loc, builder);
390
- auto fTy = callee.getFunctionType ();
391
- mlir::Type toTy = fTy .getInput (0 );
392
- mlir::Value inputArg =
393
- createConvertOp (rewriter, loc, toTy, addrOfOp.getResult ());
394
- mlir::Value sourceFile =
395
- fir::factory::locationToFilename (builder, loc);
396
- mlir::Value sourceLine =
397
- fir::factory::locationToLineNo (builder, loc, fTy .getInput (2 ));
398
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments (
399
- builder, loc, fTy , inputArg, sourceFile, sourceLine)};
400
- auto call = rewriter.create <fir::CallOp>(loc, callee, args);
401
- mlir::Value cast = createConvertOp (
402
- rewriter, loc, op.getMemref ().getType (), call->getResult (0 ));
402
+ mlir::Value devAddr = genGetDeviceAddress (rewriter, mod, op.getLoc (),
403
+ addrOfOp.getResult ());
403
404
rewriter.startOpModification (op);
404
- op.getMemrefMutable ().assign (cast );
405
+ op.getMemrefMutable ().assign (devAddr );
405
406
rewriter.finalizeOpModification (op);
406
407
return success ();
407
408
}
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
771
772
loc, clusterDimsAttr.getZ ().getInt ());
772
773
}
773
774
}
775
+ llvm::SmallVector<mlir::Value> args;
776
+ auto mod = op->getParentOfType <mlir::ModuleOp>();
777
+ for (mlir::Value arg : op.getArgs ()) {
778
+ // If the argument is a global descriptor, make sure we pass the device
779
+ // copy of this descriptor and not the host one.
780
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType (arg.getType ()))) {
781
+ if (auto declareOp =
782
+ mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp ())) {
783
+ if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
784
+ declareOp.getMemref ().getDefiningOp ())) {
785
+ if (auto global = symTab.lookup <fir::GlobalOp>(
786
+ addrOfOp.getSymbol ().getRootReference ().getValue ())) {
787
+ if (cuf::isRegisteredDeviceGlobal (global)) {
788
+ arg = genGetDeviceAddress (rewriter, mod, op.getLoc (),
789
+ declareOp.getResult ());
790
+ }
791
+ }
792
+ }
793
+ }
794
+ }
795
+ args.push_back (arg);
796
+ }
797
+
774
798
auto gpuLaunchOp = rewriter.create <mlir::gpu::LaunchFuncOp>(
775
799
loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
776
- mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
777
- op.getArgs ());
800
+ mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
778
801
if (clusterDimX && clusterDimY && clusterDimZ) {
779
802
gpuLaunchOp.getClusterSizeXMutable ().assign (clusterDimX);
780
803
gpuLaunchOp.getClusterSizeYMutable ().assign (clusterDimY);
0 commit comments