Skip to content

Commit 2b123e0

Browse files
authored
CUDA Stream support (#213)
* CUDA Stream support * Async lowering [WIP] * Fix lowering to moccuda * Convert to malloc/free * Fix non-async * Update LLVM * Fix build
1 parent d061557 commit 2b123e0

File tree

12 files changed

+405
-83
lines changed

12 files changed

+405
-83
lines changed

include/polygeist/PolygeistOps.td

+10
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,16 @@ def SubIndexOp : Polygeist_Op<"subindex", [
6565
}];
6666
}
6767

68+
69+
def StreamToTokenOp : Polygeist_Op<"stream2token", [
70+
NoSideEffect
71+
]> {
72+
let summary = "Extract an async stream from a cuda stream";
73+
74+
let arguments = (ins AnyType : $source);
75+
let results = (outs AnyType : $result);
76+
}
77+
6878
//===----------------------------------------------------------------------===//
6979
// Memref2PointerOp
7080
//===----------------------------------------------------------------------===//

lib/polygeist/Passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms
2323

2424
LINK_LIBS PUBLIC
2525
MLIRAffine
26+
MLIRAsync
2627
MLIRAffineUtils
2728
MLIRFunc
2829
MLIRFuncTransforms

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

+318-65
Large diffs are not rendered by default.

lib/polygeist/Passes/ParallelLower.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Analysis/CallGraph.h"
1515
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1616
#include "mlir/Dialect/Affine/IR/AffineOps.h"
17+
#include "mlir/Dialect/Async/IR/Async.h"
1718
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -274,6 +275,21 @@ void ParallelLower::runOnOperation() {
274275

275276
auto oneindex = builder.create<ConstantIndexOp>(loc, 1);
276277

278+
async::ExecuteOp asyncOp = nullptr;
279+
if (!llvm::empty(launchOp.asyncDependencies())) {
280+
SmallVector<Value> dependencies;
281+
for (auto v : launchOp.asyncDependencies()) {
282+
auto tok = v.getDefiningOp<polygeist::StreamToTokenOp>();
283+
dependencies.push_back(builder.create<polygeist::StreamToTokenOp>(
284+
tok.getLoc(), builder.getType<async::TokenType>(), tok.source()));
285+
}
286+
asyncOp = builder.create<mlir::async::ExecuteOp>(
287+
loc, /*results*/ TypeRange(), /*dependencies*/ dependencies,
288+
/*operands*/ ValueRange());
289+
Block *blockB = &asyncOp.body().front();
290+
builder.setInsertionPointToStart(blockB);
291+
}
292+
277293
auto block = builder.create<mlir::scf::ParallelOp>(
278294
loc, std::vector<Value>({zindex, zindex, zindex}),
279295
std::vector<Value>(

llvm-project

Submodule llvm-project updated 3107 files

tools/mlir-clang/Lib/CGCall.cc

+16-3
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,22 @@ ValueCategory MLIRScanner::CallHelper(
296296
val, idx)));
297297
}
298298
}
299-
auto op = builder.create<mlir::gpu::LaunchOp>(loc, blocks[0], blocks[1],
300-
blocks[2], threads[0],
301-
threads[1], threads[2]);
299+
mlir::Value stream = nullptr;
300+
SmallVector<mlir::Value, 1> asyncDependencies;
301+
if (3 < CU->getConfig()->getNumArgs() &&
302+
!isa<CXXDefaultArgExpr>(CU->getConfig()->getArg(3))) {
303+
stream = Visit(CU->getConfig()->getArg(3)).getValue(builder);
304+
stream = builder.create<polygeist::StreamToTokenOp>(
305+
loc, builder.getType<gpu::AsyncTokenType>(), stream);
306+
assert(stream);
307+
asyncDependencies.push_back(stream);
308+
}
309+
auto op = builder.create<mlir::gpu::LaunchOp>(
310+
loc, blocks[0], blocks[1], blocks[2], threads[0], threads[1],
311+
threads[2],
312+
/*dynamic shmem size*/ nullptr,
313+
/*token type*/ stream ? stream.getType() : nullptr,
314+
/*dependencies*/ asyncDependencies);
302315
auto oldpoint = builder.getInsertionPoint();
303316
auto *oldblock = builder.getInsertionBlock();
304317
builder.setInsertionPointToStart(&op.getRegion().front());

tools/mlir-clang/Lib/clang-mlir.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,7 @@ mlir::Attribute MLIRScanner::InitializeValueByInitListExpr(mlir::Value toInit,
718718
return mlir::DenseElementsAttr();
719719
if (auto mt = toInit.getType().dyn_cast<MemRefType>()) {
720720
return DenseElementsAttr::getFromRawBuffer(
721-
RankedTensorType::get(mt.getShape(), mt.getElementType()), attrs,
722-
false);
721+
RankedTensorType::get(mt.getShape(), mt.getElementType()), attrs);
723722
}
724723
return mlir::DenseElementsAttr();
725724
} else {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-clang %s --cuda-gpu-arch=sm_60 -nocudalib -nocudainc %resourcedir --function=* -S | FileCheck %s
2+
3+
#include "Inputs/cuda.h"
4+
5+
__device__ void something(int* array, int n);
6+
7+
// Type your code here, or load an example.
8+
__global__ void square(int *array, int n) {
9+
something(array, n);
10+
}
11+
12+
void run(cudaStream_t stream1, int *array, int n) {
13+
square<<< 10, 20, 0, stream1>>> (array, n) ;
14+
}
15+
16+
// CHECK: func.func @_Z3runP10cudaStreamPii(%arg0: !llvm.ptr<struct<()>>, %arg1: memref<?xi32>, %arg2: i32) attributes {llvm.linkage = #llvm.linkage<external>} {
17+
// CHECK-NEXT: %c10 = arith.constant 10 : index
18+
// CHECK-NEXT: %c1 = arith.constant 1 : index
19+
// CHECK-NEXT: %c20 = arith.constant 20 : index
20+
// CHECK-NEXT: %0 = "polygeist.stream2token"(%arg0) : (!llvm.ptr<struct<()>>) -> !gpu.async.token
21+
// CHECK-NEXT: %1 = gpu.launch async [%0] blocks(%arg3, %arg4, %arg5) in (%arg9 = %c10, %arg10 = %c1, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c20, %arg13 = %c1, %arg14 = %c1) {
22+
// CHECK-NEXT: func.call @_Z21__device_stub__squarePii(%arg1, %arg2) : (memref<?xi32>, i32) -> ()
23+
// CHECK-NEXT: gpu.terminator
24+
// CHECK-NEXT: }
25+
// CHECK-NEXT: return
26+
// CHECK-NEXT: }

tools/mlir-clang/Test/Verification/whiletofor.c

+8-8
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ void whiletofor() {
2222

2323
// TODO redundant for elim
2424
// CHECK: func @whiletofor()
25-
// CHECK-NEXT: %c7_i32 = arith.constant 7 : i32
26-
// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32
27-
// CHECK-NEXT: %c20_i32 = arith.constant 20 : i32
28-
// CHECK-NEXT: %c2_i32 = arith.constant 2 : i32
29-
// CHECK-NEXT: %c3_i32 = arith.constant 3 : i32
30-
// CHECK-NEXT: %c1 = arith.constant 1 : index
31-
// CHECK-NEXT: %c0 = arith.constant 0 : index
32-
// CHECK-NEXT: %c100 = arith.constant 100 : index
25+
// CHECK-DAG: %c7_i32 = arith.constant 7 : i32
26+
// CHECK-DAG: %c0_i32 = arith.constant 0 : i32
27+
// CHECK-DAG: %c20_i32 = arith.constant 20 : i32
28+
// CHECK-DAG: %c2_i32 = arith.constant 2 : i32
29+
// CHECK-DAG: %c3_i32 = arith.constant 3 : i32
30+
// CHECK-DAG: %c1 = arith.constant 1 : index
31+
// CHECK-DAG: %c0 = arith.constant 0 : index
32+
// CHECK-DAG: %c100 = arith.constant 100 : index
3333
// CHECK-NEXT: %0 = memref.alloca() : memref<100x100xi32>
3434
// CHECK-NEXT: %1 = scf.for %arg0 = %c0 to %c100 step %c1 iter_args(%arg1 = %c7_i32) -> (i32) {
3535
// CHECK-NEXT: %3 = arith.index_cast %arg1 : i32 to index

tools/mlir-clang/Test/canonicalization.c

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
// CHECK-LABEL: func @matrix_power(
77
// CHECK: %[[VAL_0:.*]]: memref<20x20xi32>, %[[VAL_1:.*]]: memref<20xi32>, %[[VAL_2:.*]]: memref<20xi32>, %[[VAL_3:.*]]: memref<20xi32>)
8-
// CHECK-NEXT: %c1 = arith.constant 1 : index
9-
// CHECK-NEXT: %c20 = arith.constant 20 : index
10-
// CHECK-NEXT: %c0 = arith.constant 0 : index
11-
// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32
8+
// CHECK-DAG: %c1 = arith.constant 1 : index
9+
// CHECK-DAG: %c20 = arith.constant 20 : index
10+
// CHECK-DAG: %c0 = arith.constant 0 : index
11+
// CHECK-DAG: %c-1_i32 = arith.constant -1 : i32
1212
// CHECK-NEXT: scf.for %arg4 = %c1 to %c20 step %c1 {
1313
// CHECK-NEXT: %0 = arith.index_cast %arg4 : index to i32
1414
// CHECK-NEXT: %1 = arith.addi %0, %c-1_i32 : i32

tools/mlir-clang/mlir-clang.cc

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
2727
#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
2828
#include "mlir/Dialect/Affine/Passes.h"
29+
#include "mlir/Dialect/Async/IR/Async.h"
2930
#include "mlir/Dialect/DLTI/DLTI.h"
3031
#include "mlir/Dialect/GPU/GPUDialect.h"
3132
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -418,6 +419,7 @@ int main(int argc, char **argv) {
418419
context.getOrLoadDialect<func::FuncDialect>();
419420
context.getOrLoadDialect<DLTIDialect>();
420421
context.getOrLoadDialect<mlir::scf::SCFDialect>();
422+
context.getOrLoadDialect<mlir::async::AsyncDialect>();
421423
context.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
422424
context.getOrLoadDialect<mlir::NVVM::NVVMDialect>();
423425
context.getOrLoadDialect<mlir::gpu::GPUDialect>();

tools/polygeist-opt/polygeist-opt.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/Passes.h"
1515
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1616
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17+
#include "mlir/Dialect/Async/IR/Async.h"
1718
#include "mlir/Dialect/DLTI/DLTI.h"
1819
#include "mlir/Dialect/Func/IR/FuncOps.h"
1920
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -48,6 +49,7 @@ int main(int argc, char **argv) {
4849
registry.insert<mlir::AffineDialect>();
4950
registry.insert<mlir::LLVM::LLVMDialect>();
5051
registry.insert<mlir::memref::MemRefDialect>();
52+
registry.insert<mlir::async::AsyncDialect>();
5153
registry.insert<mlir::func::FuncDialect>();
5254
registry.insert<mlir::arith::ArithmeticDialect>();
5355
registry.insert<mlir::scf::SCFDialect>();

0 commit comments

Comments
 (0)