Skip to content

Commit cf491db

Browse files
authored
[CIR][CUDA] Generate kernel calls (#1348)
Now we could generate calls to `__global__` functions. Most work is already done in AST. It rewrites `fn<<<2, 2>>>()` to something like `__cudaPushCallConfiguration(dim3(2, 1, 1), dim3(2, 1, 1), 0, nullptr)`, which returns a bool. We calls the device stub as a normal function when the call returns true.
1 parent 0d2a01f commit cf491db

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,23 @@ void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
169169
else
170170
emitDeviceStubBodyLegacy(cgf, fn, args);
171171
}
172+
173+
RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf,
174+
const CUDAKernelCallExpr *expr,
175+
ReturnValueSlot retValue) {
176+
auto builder = cgm.getBuilder();
177+
mlir::Location loc =
178+
cgf.currSrcLoc ? cgf.currSrcLoc.value() : builder.getUnknownLoc();
179+
180+
cgf.emitIfOnBoolExpr(
181+
expr->getConfig(),
182+
[&](mlir::OpBuilder &b, mlir::Location l) {
183+
CIRGenCallee callee = cgf.emitCallee(expr->getCallee());
184+
cgf.emitCall(expr->getCallee()->getType(), callee, expr, retValue);
185+
b.create<cir::YieldOp>(loc);
186+
},
187+
loc, [](mlir::OpBuilder &b, mlir::Location l) {},
188+
std::optional<mlir::Location>());
189+
190+
return RValue::get(nullptr);
191+
}

clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace clang::CIRGen {
2323
class CIRGenFunction;
2424
class CIRGenModule;
2525
class FunctionArgList;
26+
class RValue;
27+
class ReturnValueSlot;
2628

2729
class CIRGenCUDARuntime {
2830
protected:
@@ -40,6 +42,10 @@ class CIRGenCUDARuntime {
4042

4143
virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn,
4244
FunctionArgList &args);
45+
46+
virtual RValue emitCUDAKernelCallExpr(CIRGenFunction &cgf,
47+
const CUDAKernelCallExpr *expr,
48+
ReturnValueSlot retValue);
4349
};
4450

4551
} // namespace clang::CIRGen

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,10 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) {
530530

531531
auto CalleePtr = emitFunctionDeclPointer(CGM, GD);
532532

533-
assert(!CGM.getLangOpts().CUDA && "NYI");
533+
// For HIP, the device stub should be converted to handle.
534+
if (CGM.getLangOpts().HIP && !CGM.getLangOpts().CUDAIsDevice &&
535+
FD->hasAttr<CUDAGlobalAttr>())
536+
llvm_unreachable("NYI");
534537

535538
return CIRGenCallee::forDirect(CalleePtr, GD);
536539
}
@@ -1405,7 +1408,9 @@ RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *E,
14051408
if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
14061409
return emitCXXMemberCallExpr(CE, ReturnValue);
14071410

1408-
assert(!dyn_cast<CUDAKernelCallExpr>(E) && "CUDA NYI");
1411+
if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E))
1412+
return CGM.getCUDARuntime().emitCUDAKernelCallExpr(*this, CE, ReturnValue);
1413+
14091414
if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(E))
14101415
if (const CXXMethodDecl *MD =
14111416
dyn_cast_or_null<CXXMethodDecl>(CE->getCalleeDecl()))

clang/test/CIR/CodeGen/CUDA/simple.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,18 @@ __global__ void global_fn(int a) {}
3131
// CIR-HOST: cir.call @__cudaPopCallConfiguration
3232
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
3333
// CIR-HOST: cir.call @cudaLaunchKernel
34+
35+
int main() {
36+
global_fn<<<1, 1>>>(1);
37+
}
38+
// CIR-DEVICE-NOT: cir.func @main()
39+
40+
// CIR-HOST: cir.func @main()
41+
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
42+
// CIR-HOST: cir.call @_ZN4dim3C1Ejjj
43+
// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__cudaPushCallConfiguration
44+
// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast(int_to_bool, [[Push]]
45+
// CIR-HOST: cir.if [[ConfigOK]] {
46+
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
47+
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
48+
// CIR-HOST: }

0 commit comments

Comments
 (0)