diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d82a7af8..b5f7be5f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,20 @@ include_directories( ${PROJECT_SOURCE_DIR}/include ) +if(TPP_DIR) + message(STATUS "Using TPP_DIR in: ${TPP_DIR}") + add_definitions("-DTPP_ENABLED") + include_directories(${TPP_DIR}/include) + include_directories(${TPP_DIR}/build/include) + link_directories(${TPP_DIR}/build) + link_directories(${TPP_DIR}/build/lib) + set(TPP_AVAILABLE_LIBS + TPPCheckDialect TPPCheckToLoops TPPGPU TPPIR TPPLinalgToFunc TPPLinalgToXSMM TPPPassBundles + TPPPerfDialect TPPPerfToFunc TPPPerfToLoop TPPPipeline TPPRunner TPPTestLib TPPTransforms + TPPTransformsUtils TPPXsmmDialect tpp_xsmm_runner_utils TPPXsmmToFunc xsmm + ) +endif() + # The paths are added in the subfolders using the gc_add_path() function. # These lists are also used by tests. set(GC_LIB_SOURCES CACHE INTERNAL "The graph_compiler library source paths") diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index aaea602b6..678f4969c 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -46,4 +46,31 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { "vector::VectorDialect"]; } +def GCGPUPipeline: Pass<"gc-gpu-pipeline"> { + let summary = "All-in-one pipeline for GC for GPU"; + let dependentDialects = ["onednn_graph::OneDNNGraphDialect", + "tensor::TensorDialect", + "memref::MemRefDialect", + "linalg::LinalgDialect", + "linalgx::LinalgxDialect", + "LLVM::LLVMDialect", + "scf::SCFDialect", + "bufferization::BufferizationDialect", + "omp::OpenMPDialect", + "gpu::GPUDialect", + "xegpu::XeGPUDialect", + "math::MathDialect", + "vector::VectorDialect"]; + let options = [ + Option<"kTile", "k-tile", "int64_t", + /*default=*/"32", + "GEMM tile size for reduction dimension.">, + Option<"stages", "stages", "int64_t", + /*default=*/"1", + "Number of cooperative prefetch stages.">, + ListOption<"dpasTile", "dpas-tile", "int64_t", + "DPAS register block sizes MxNxK">, + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 7be337566..610052ad9 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -25,4 +25,8 @@ add_mlir_library(GCPasses MLIROneDNNGraph ) +if(TPP_DIR) + target_link_libraries(GCPasses PRIVATE ${TPP_AVAILABLE_LIBS}) +endif() + set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCPasses) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 6e5151e9e..66f273007 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -13,17 +13,23 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/InitAllPasses.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#ifdef TPP_ENABLED +#include "TPP/Passes.h" +#endif + #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" #include "gc/Dialect/Linalgx/LinalgxDialect.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" @@ -143,7 +149,19 @@ void populateCPUPipeline(mlir::PassManager &pm) { populateLLVMPasses(pm); } +#ifdef TPP_ENABLED +void populateGPUPipeline(mlir::PassManager &pm, + tpp::LinalgToXeGPUOptions options) { + // middle-end, arith/math/vector dialects + populateVectorPasses(pm); + // back-end, arith/math/vector/memref dialects + populateBufferizationPasses(pm); + pm.addNestedPass(tpp::createLinalgToXeGPU(options)); +} +#endif + #define GEN_PASS_DEF_GCCPUPIPELINE +#define GEN_PASS_DEF_GCGPUPIPELINE #include "gc/Transforms/Passes.h.inc" namespace { @@ -162,5 +180,24 @@ class GCCPUPipeline : public impl::GCCPUPipelineBase { } }; +class GCGPUPipeline : public impl::GCGPUPipelineBase { +public: + friend struct PassHelper; + using impl::GCGPUPipelineBase::GCGPUPipelineBase; + void runOnOperation() final { + auto op = getOperation(); +#ifdef TPP_ENABLED + PassManager pm{op->getContext()}; + tpp::LinalgToXeGPUOptions options{kTile, stages, dpasTile}; + populateGPUPipeline(pm, options); + if (failed(pm.run(op))) + signalPassFailure(); +#else + op->emitError() << "No TPP passes.\n"; + signalPassFailure(); +#endif + } +}; + } // namespace } // namespace mlir::gc diff --git a/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir new file mode 100644 index 000000000..ac02c18a8 --- /dev/null +++ b/test/mlir/test/gc/Transforms/Pipeline/gpu.mlir @@ -0,0 +1,93 @@ +// RUN: gc-opt %s -o=/dev/null 2>&1 +// gc-opt --gc-gpu-pipeline="dpas-tile=8,16,16 k-tile=16" -canonicalize %s | FileCheck %s + +func.func @matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + linalg.matmul ins(%arg0, %arg1 : memref<8x16xf16>, memref<16x16xf16>) + outs(%arg2 : memref<8x16xf32>) + return +} + +// func.func @matmul(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { +// %c1024 = arith.constant 1024 : index +// %c16 = arith.constant 16 : index +// %c0 = arith.constant 0 : index +// %0 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %1 = xegpu.update_nd_offset %0, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %3 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %4 = xegpu.update_nd_offset %3, [0, 0] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %6 = xegpu.update_nd_offset %5, [0, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %7:3 = scf.for %arg3 = %c0 to %c16 step %c16 iter_args(%arg4 = %2, %arg5 = %4, %arg6 = %6) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr>) { +// %8 = arith.remui %arg3, %c1024 : index +// %9 = arith.cmpi eq, %8, %c0 : index +// scf.if %9 { +// gpu.barrier +// } +// %10 = xegpu.load_nd %arg5 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 1 : i64}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> -> vector<8x8x2xf16> +// %11 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 0 : i64}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> -> vector<8x16x2xf16> +// %12 = xegpu.update_nd_offset %arg5, [0, 16] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %13 = xegpu.update_nd_offset %arg6, [16, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %12 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %13 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %14 = xegpu.dpas %10, %11, %arg4 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> +// scf.yield %14, %12, %13 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// } +// xegpu.store_nd %7#0, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// return +// } + +func.func @mlp(%arg0: tensor<8x16xf16>, %arg1: tensor<16x16xf16>, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x16xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<8x16xf32>) -> tensor<8x16xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<8x16xf16>, tensor<16x16xf16>) + outs(%1 : tensor<8x16xf32>) -> tensor<8x16xf32> + %3 = tensor.empty() : tensor<8x16xf32> + %4 = linalg.add ins(%arg2, %2 : tensor<8x16xf32>, tensor<8x16xf32>) outs(%3 : tensor<8x16xf32>) -> tensor<8x16xf32> + return %4 : tensor<8x16xf32> +} + +// func.func @mlp(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>, %arg3: memref<8x16xf32>) { +// %c1024 = arith.constant 1024 : index +// %c16 = arith.constant 16 : index +// %c0 = arith.constant 0 : index +// %cst = arith.constant 0.000000e+00 : f32 +// %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x16xf32> +// linalg.fill ins(%cst : f32) outs(%alloc : memref<8x16xf32>) +// %0 = xegpu.create_nd_tdesc %alloc[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %1 = xegpu.update_nd_offset %0, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %3 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %4 = xegpu.update_nd_offset %3, [0, 0] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %5 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %6 = xegpu.update_nd_offset %5, [0, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %7:3 = scf.for %arg4 = %c0 to %c16 step %c16 iter_args(%arg5 = %2, %arg6 = %4, %arg7 = %6) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr>) { +// %17 = arith.remui %arg4, %c1024 : index +// %18 = arith.cmpi eq, %17, %c0 : index +// scf.if %18 { +// gpu.barrier +// } +// %19 = xegpu.load_nd %arg6 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 1 : i64}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> -> vector<8x8x2xf16> +// %20 = xegpu.load_nd %arg7 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint, vnni_axis = 0 : i64}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> -> vector<8x16x2xf16> +// %21 = xegpu.update_nd_offset %arg6, [0, 16] : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// %22 = xegpu.update_nd_offset %arg7, [16, 0] : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %21 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr> +// xegpu.prefetch_nd %22 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// %23 = xegpu.dpas %19, %20, %arg5 : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32> +// scf.yield %23, %21, %22 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr>, !xegpu.tensor_desc<16x16xf16, #xegpu.tdesc_attr> +// } +// xegpu.store_nd %7#0, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %8 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %9 = xegpu.update_nd_offset %8, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %10 = xegpu.load_nd %9 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %11 = xegpu.create_nd_tdesc %alloc[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %12 = xegpu.update_nd_offset %11, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %13 = xegpu.load_nd %12 : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> -> vector<8x16xf32> +// %14 = arith.addf %10, %13 : vector<8x16xf32> +// %15 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// %16 = xegpu.update_nd_offset %15, [0, 0] : !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// xegpu.store_nd %14, %16 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint, l3_hint = #xegpu.cache_hint}> : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr> +// memref.dealloc %alloc : memref<8x16xf32> +// return +// } \ No newline at end of file