From 9b9eb57d42c1475d9b84fe77130e57b5eaae4075 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Thu, 20 Jun 2024 16:11:41 +0800 Subject: [PATCH 01/10] base --- include/gc/Transforms/Passes.td | 12 + lib/gc/Transforms/CMakeLists.txt | 1 + .../SplitComputeIntensivePatterns.cpp | 456 ++++++++++++++++++ .../split-compute-intensive-patterns.mlir | 27 ++ 4 files changed, 496 insertions(+) create mode 100644 lib/gc/Transforms/SplitComputeIntensivePatterns.cpp create mode 100644 test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index aaea602b6..b733baad5 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -46,4 +46,16 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { "vector::VectorDialect"]; } +def SplitComputeIntensivePatterns : Pass<"split-compute-intensive-patterns"> { + let summary = "Split matmul patterns"; + let description = [{ + Split matmul patterns' weights into several parts, number of which aligns + with the number of target machine's numa node. + }]; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + "mlir::tensor::TensorDialect", + "mlir::arith::ArithDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 7be337566..cbee61a60 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp Pipeline.cpp TileNamed.cpp + SplitComputeIntensivePatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp new file mode 100644 index 000000000..528ca5f54 --- /dev/null +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -0,0 +1,456 @@ +/******************************************************************************* + * Copyright 2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Location.h" + +#include "gc/Transforms/Passes.h" + +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_SPLITCOMPUTEINTENSIVEPATTERNS +#include "gc/Transforms/Passes.h.inc" +} // namespace gc + +size_t NUM_OF_NUMA = 3; +size_t SUPPORTED_RANK = 2; + +void printValueType(Value value) { + if (!value) { + llvm::outs() << "Invalid value\n"; + return; + } + + Type type = value.getType(); + type.print(llvm::outs()); + llvm::outs() << "\n"; +} + +void getSplitedTensors(SmallVector& outputs, Location& loc, Value tensor, int64_t target_dim, PatternRewriter &rewriter) { + if (auto definingOp = tensor.getDefiningOp()) { + std::cout << "tensor operation name: " << definingOp->getName().getStringRef().str() << std::endl; + } else { + std::cout << "tensor does not have a defining operation." << std::endl; + } + + + auto Type = tensor.getType().cast(); + int64_t rank = Type.getRank(); + if (!Type || Type.getRank() != SUPPORTED_RANK) { + return; + } + + int64_t M = Type.getDimSize(0); + int64_t N = Type.getDimSize(1); + std::cout << "M: " << M << ", N: " << N << std::endl; + bool has_tail = target_dim == 1 ? N % NUM_OF_NUMA != 0 : M % NUM_OF_NUMA != 0; + int64_t split_length = target_dim == 1 ? (N + NUM_OF_NUMA - 1) / NUM_OF_NUMA : (M + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + // Split the weight tensor into NUM_OF_NUMA parts + auto splitEvenType = target_dim == 1 + ? RankedTensorType::get({M, split_length}, Type.getElementType()) + : RankedTensorType::get({split_length, N}, Type.getElementType()); + auto splitTailType = splitEvenType; + if (has_tail) splitTailType = target_dim == 1 + ? RankedTensorType::get({M, int64_t(N % split_length)}, Type.getElementType()) + : RankedTensorType::get({int64_t(M % split_length), N}, Type.getElementType()); + for (auto split_idx : llvm::seq(0, NUM_OF_NUMA)) { + SmallVector sizes; + SmallVector offsets; + SmallVector strides(rank, rewriter.getIndexAttr(1)); + for (auto i : llvm::seq(0, rank)) { + sizes.push_back(rewriter.getIndexAttr((split_idx == (NUM_OF_NUMA-1)) ? splitTailType.getShape()[i] : splitEvenType.getShape()[i])); + offsets.push_back(rewriter.getIndexAttr((split_idx == 0 || i != target_dim) ? 0 : splitEvenType.getShape()[i] * split_idx)); + } + Value res = rewriter.create( + loc, split_idx == (NUM_OF_NUMA-1) ? splitTailType : splitEvenType, tensor, offsets, sizes, strides)->getResult(0); + auto res_type = res.getType().cast(); + std::cout << split_idx << ", res_type M: " << res_type.getDimSize(0) << ", N: " << res_type.getDimSize(1) << std::endl; + outputs.push_back(res); + std::cout << outputs.size() << std::endl; + } +} + +void SplitMMonN(SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { + /*Split on N axis*/ + std::cout << "split on N" << std::endl; + int64_t M = inputs[0].getType().cast().getDimSize(0); + int64_t N = inputs[1].getType().cast().getDimSize(1); + int64_t K = inputs[0].getType().cast().getDimSize(1); + SmallVector splited_weights; + getSplitedTensors(splited_weights, loc, inputs[1], /*target_dim*/1, rewriter); + if (splited_weights.size() != NUM_OF_NUMA) return; + + for (Value weight : splited_weights) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + std::cout << "weight.getType().cast().getDimSize(1): " << weight.getType().cast().getDimSize(1) << std::endl; + Value empty = rewriter.create( + loc, ArrayRef {M, weight.getType().cast().getDimSize(1)}, resultTy.getElementType()); + Value tensor = + rewriter.create(loc, zero, empty).getResult(0); + outputs.push_back(rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor)->getResult(0)); + } +} + +void SplitMMonK(SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { + /*Split on K axis*/ + std::cout << "split on K" << std::endl; + int64_t M = inputs[0].getType().cast().getDimSize(0); + int64_t N = inputs[1].getType().cast().getDimSize(1); + int64_t K = inputs[0].getType().cast().getDimSize(1); + SmallVector splited_data, splited_weights; + getSplitedTensors(splited_data, loc, inputs[0], /*target_dim*/1, rewriter); + std::cout << "splited_data size: " << splited_data.size() << std::endl; + if (splited_data.size() != NUM_OF_NUMA) return; + getSplitedTensors(splited_weights, loc, inputs[1], /*target_dim*/0, rewriter); + std::cout << "splited_weights size: " << splited_weights.size() << std::endl; + if (splited_weights.size() != NUM_OF_NUMA) return; + + for (auto [data, weight] : + llvm::zip_equal(splited_data, splited_weights)) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType()); + Value tensor = + rewriter.create(loc, zero, empty).getResult(0); + outputs.push_back(rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/tensor.getType().cast(), + /*inputs=*/ValueRange{data, weight}, + /*outputs=*/tensor)->getResult(0)); + } +} + +bool isSupportedPostOp(Operation *op) { + // Check if the operation is a linalg operation + if (!isa(op)) + return false; + + // Get the inputs and outputs of the linalg operation + bool ismax = isa(op); + bool isadd = isa(op); + bool ismul = isa(op); + return ismax || isadd || ismul; +} + +// Helper function to get all post ops following the given operation +void getUnOps(Operation *op, SmallVectorImpl &postOps) { + for (auto user : op->getUsers()) { + if (isSupportedPostOp(user)) postOps.push_back(user); + // Recursively search for unary ops + getUnOps(user, postOps); + } +} + +template +void duplicateBinary(SmallVector& outputs,std::vector>& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + for (int i = 0; i < NUM_OF_NUMA; ++i) { + TensorType type = inputs[i][0].getType().cast(); + Value Empty = rewriter.create( + loc, type.getShape(), type.getElementType()); + auto tmpOp = rewriter.create(loc, inputs[i], ValueRange {Empty}); + for (auto result : tmpOp->getResults()) { + outputs.push_back(result); + } + } +} + +void deleteOperation(Operation *op) { + // Step 1: Ensure the operation exists + if (!op) + return; + + // Step 2: Check each operand of the operation + for (auto operand : op->getOperands()) { + if (!operand) continue; + if (operand.use_empty()) continue; // Skip if operand has no uses + + // If the operand is an operation and is either emptyOp or fillOp + if (auto definingOp = operand.getDefiningOp()) { + if (isa(definingOp) || isa(definingOp)) { + llvm::outs() << "is empty \n"; + // Recursively delete the operand operation if it has only one use + if (definingOp->hasOneUse()) { + deleteOperation(definingOp); + } + } + } + } + + // Step 3: Disconnect the operation from its operands and users + op->dropAllUses(); + op->dropAllReferences(); + + // Step 4: Erase the operation from its parent block + op->erase(); +} + +Value addN(Value& initTensor, SmallVector& ins, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { + llvm::outs() << "start addN \n"; + // Create indexing maps (for input tensors and output tensor) + int num_of_args = int(ins.size()) + 1; + MLIRContext *context = rewriter.getContext(); + SmallVector indexingMaps(num_of_args, + AffineMap::getMultiDimIdentityMap(resultTy.getRank(), context)); + llvm::outs() << "created affinemap \n"; + // Create iterator types (parallel for all dimensions) + // ArrayRef iteratorTypes(resultTy.getRank(), "parallel"); + SmallVector iteratorTypes(resultTy.getRank(), utils::IteratorType::parallel); + llvm::outs() << "created IteratorType \n"; + // Create the linalg.generic op + auto genericOp = rewriter.create( + loc, resultTy, ValueRange{ins}, ValueRange{initTensor}, + indexingMaps, iteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + // Define the body of the linalg.generic operation (elementwise addition) + Value sum = nestedBuilder.create(nestedLoc, args[0], args[1]); + for (auto i = 2; i < num_of_args - 1; ++i) + sum = nestedBuilder.create(nestedLoc, sum, args[i]); // Add more if more inputs + nestedBuilder.create(nestedLoc, sum); + }); + + // Mark the output as the result of the function (for demonstration purposes) + return genericOp.getResults().front();; +} + +LogicalResult splitSingleMM(linalg::MatmulOp& op, + PatternRewriter &rewriter) { + SmallVector postOps; + getUnOps(op, postOps); + auto loc = op->getLoc(); + auto resultTy = dyn_cast(op->getResultTypes().front()); + auto input_operands = op.getInputs(); + SmallVector input_tensors; + for (Value operand : input_operands) { + if (!operand.getType().isa()) { + continue; + } + input_tensors.push_back(operand); + } + + int64_t M = input_tensors[0].getType().cast().getDimSize(0); + int64_t N = input_tensors[1].getType().cast().getDimSize(1); + int64_t K = input_tensors[0].getType().cast().getDimSize(1); + std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; + + int64_t target_dim = N / K >= 2 ? 1 : 0; + SmallVector splites_res; + if (target_dim == 1) { + SplitMMonN(splites_res, input_tensors, resultTy, loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) return failure(); + SmallVector Outputs = splites_res; + auto lastInput = op->getResult(0); + for (auto postOp : postOps) { + llvm::outs() << "Operation name: " << postOp->getName().getStringRef() << "\n"; + auto opInputs = postOp->getOperands().drop_back(); + llvm::outs() << "inputs: " << opInputs.size() << "\n"; + auto opOutputs = postOp->getResults(); + llvm::outs() << "outputs: " << opOutputs.size() << "\n"; + + std::vector> Inputs; + for (auto input : opInputs) { + if (auto definingOp = input.getDefiningOp()) { + std::cout << "Input operation name: " << definingOp->getName().getStringRef().str() << std::endl; + } else { + std::cout << "Input does not have a defining operation." << std::endl; + } + if (input == lastInput) { + std::cout << "enter mm output" << std::endl; + for (size_t i = 0; i < NUM_OF_NUMA; ++i) { + SmallVector innerVector; + innerVector.push_back(Outputs[0]); + Inputs.push_back(innerVector); + Outputs.erase(Outputs.begin()); + llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; + } + } else { + llvm::outs() << "doesnot match anything \n"; + SmallVector splited_inputs; + getSplitedTensors(splited_inputs, loc, input, /*target_dim*/1, rewriter); + llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + i++; + } + llvm::outs() << "split input done \n"; + } + } + if (auto postOpType = llvm::dyn_cast(postOp)) + duplicateBinary(Outputs, Inputs, resultTy, loc, rewriter); + else if (auto postOpType = llvm::dyn_cast(postOp)) + duplicateBinary(Outputs, Inputs, resultTy, loc, rewriter); + else if (auto postOpType = llvm::dyn_cast(postOp)) + duplicateBinary(Outputs, Inputs, resultTy, loc, rewriter); + llvm::outs() << "post op creation and deletion done \n"; + lastInput = postOp->getResult(0); + } + // Concatenate the two halves back together on N axis + auto newop = rewriter.create( + loc, target_dim, Outputs); + llvm::outs() << "created concat \n"; + auto replaced_op = postOps.size() ? postOps.back() : op; + if (postOps.size() > 1) { + postOps.pop_back(); + deleteOperation(op); + for (auto &deleteOp : postOps) + deleteOperation(deleteOp); + } + rewriter.replaceOp(replaced_op, newop); + } else { + SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) return failure(); + // Add the two halves back together + // Create linalg.map operation + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType()); + Value initTensor = + rewriter.create(loc, zero, empty).getResult(0); + auto newop = addN(initTensor, splites_res, resultTy, loc, rewriter); + // Replace the original operation with the new linalg.map operation + rewriter.replaceOp(op, newop); + } + return success(); +} + +LogicalResult splitSingleMMwithUnary(linalg::MatmulOp& op, + PatternRewriter &rewriter) { + auto loc = op->getLoc(); + auto resultTy = dyn_cast(op->getResultTypes().front()); + auto input_operands = op.getInputs(); + SmallVector input_tensors; + for (Value operand : input_operands) { + if (!operand.getType().isa()) { + continue; + } + input_tensors.push_back(operand); + } + + int64_t M = input_tensors[0].getType().cast().getDimSize(0); + int64_t N = input_tensors[1].getType().cast().getDimSize(1); + int64_t K = input_tensors[0].getType().cast().getDimSize(1); + std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; + + int64_t target_dim = N / K >= 2 ? 1 : 0; + SmallVector splites_res; + if (target_dim == 1) { + SplitMMonN(splites_res, input_tensors, resultTy, loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) return failure(); + + + // Concatenate the two halves back together on N axis + auto newop = rewriter.create( + loc, target_dim, splites_res); + rewriter.replaceOp(op, newop); + } else { + SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) return failure(); + // Add the two halves back together + // Create linalg.map operation + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType()); + Value initTensor = + rewriter.create(loc, zero, empty).getResult(0); + auto newop = rewriter.create( + loc, resultTy, splites_res, ValueRange{initTensor}); + + // Replace the original operation with the new linalg.map operation + rewriter.replaceOp(op, newop); + } + return success(); +} + +LogicalResult splitMLP(linalg::MatmulOp& op, + PatternRewriter &rewriter) { + return success(); +} + +class SplitComputeIntensivePatternsRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::MatmulOp op, + PatternRewriter &rewriter) const final { + // Check if the operation has already been processed + if (op->hasAttr("splited")) + return failure(); + // Ensure the operation is followed by relu and another matmul. + // auto nextOp = op->getNextNode(); + // while (isa(nextOp) || isa(nextOp)) nextOp = nextOp->getNextNode(); + // std::cout << !isa(nextOp) << std::endl; + // need to break when encounters computational op + // if (!nextOp || !isa(nextOp)) + return splitSingleMM(op, rewriter); + // auto reluOp = cast(nextOp); + // auto nextNextOp = reluOp->getNextNode(); + // while (isa(nextNextOp) || isa(nextNextOp)) nextNextOp = nextNextOp->getNextNode(); + // // need to break when encounters binary op + // if (!nextNextOp || !isa(nextNextOp)) + // return splitSingleMMwithUnary(op, rewriter); + // auto nextMatmulOp = cast(nextNextOp); + // return splitMLP(op, rewriter); + } +}; + +namespace gc { +class SplitComputeIntensivePatterns + : public impl::SplitComputeIntensivePatternsBase { +public: + using impl::SplitComputeIntensivePatternsBase< + SplitComputeIntensivePatterns>::SplitComputeIntensivePatternsBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + SmallVector ops; + getOperation()->walk([&](Operation *op) { + if (isa(op)) + ops.push_back(op); + }); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + bool erased; + std::cout << "ops.size(): " << ops.size() << std::endl; + if (failed(applyOpPatternsAndFold(ops, patternSet, + config, /*changed=*/nullptr, &erased))) + signalPassFailure(); + return; + } +}; + +} // namespace gc +} // namespace mlir diff --git a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir new file mode 100644 index 000000000..25f3bf685 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir @@ -0,0 +1,27 @@ +// RUN: gc-opt %s --split-compute-intensive-patterns | FileCheck %s + +func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> { +%cst = arith.constant 0.000000e+00 : bf16 +%0 = tensor.empty() : tensor<128x64xbf16> +%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> +%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> +%3 = tensor.empty() : tensor<128x64xbf16> +%broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] +%4 = tensor.empty() : tensor<128x64xbf16> +%5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> +%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> +%6 = tensor.empty() : tensor<128x64xbf16> +%7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> +%cst_1 = arith.constant 0.000000e+00 : bf16 +%8 = tensor.empty() : tensor<128x256xbf16> +%9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> +%10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> +%11 = tensor.empty() : tensor<128x256xbf16> +%broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] +%12 = tensor.empty() : tensor<128x256xbf16> +%13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> +%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> +%14 = tensor.empty() : tensor<128x256xbf16> +%15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> +return %15 : tensor<128x256xbf16> +} From c03deb521ad4d696c746ee59fb08cf4803faa224 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Fri, 21 Jun 2024 16:43:48 +0800 Subject: [PATCH 02/10] add broadcast, constant support --- .../SplitComputeIntensivePatterns.cpp | 180 +++++++++++++----- .../split-compute-intensive-patterns.mlir | 43 ++--- 2 files changed, 154 insertions(+), 69 deletions(-) diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp index 528ca5f54..f5688e78b 100644 --- a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -35,7 +35,7 @@ namespace gc { #include "gc/Transforms/Passes.h.inc" } // namespace gc -size_t NUM_OF_NUMA = 3; +size_t NUM_OF_NUMA = 2; size_t SUPPORTED_RANK = 2; void printValueType(Value value) { @@ -49,33 +49,32 @@ void printValueType(Value value) { llvm::outs() << "\n"; } -void getSplitedTensors(SmallVector& outputs, Location& loc, Value tensor, int64_t target_dim, PatternRewriter &rewriter) { - if (auto definingOp = tensor.getDefiningOp()) { - std::cout << "tensor operation name: " << definingOp->getName().getStringRef().str() << std::endl; - } else { - std::cout << "tensor does not have a defining operation." << std::endl; - } - - +void getSplitedTensors(SmallVector& outputs, Value tensor, int64_t target_dim, PatternRewriter &rewriter) { auto Type = tensor.getType().cast(); + auto loc = tensor.getLoc(); int64_t rank = Type.getRank(); - if (!Type || Type.getRank() != SUPPORTED_RANK) { + llvm::outs() << "split rank: " << rank << "\n"; + if (!Type || Type.getRank() > SUPPORTED_RANK) { return; } - - int64_t M = Type.getDimSize(0); - int64_t N = Type.getDimSize(1); - std::cout << "M: " << M << ", N: " << N << std::endl; - bool has_tail = target_dim == 1 ? N % NUM_OF_NUMA != 0 : M % NUM_OF_NUMA != 0; - int64_t split_length = target_dim == 1 ? (N + NUM_OF_NUMA - 1) / NUM_OF_NUMA : (M + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + llvm::outs() << "split shape: ["; + for (int64_t dim : Type.getShape()) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + llvm::outs() << "target_dim: " << target_dim << "\n"; + bool has_tail = Type.getDimSize(target_dim) % NUM_OF_NUMA != 0; + int64_t split_length = (Type.getDimSize(target_dim) + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + SmallVector shape(Type.getShape().begin(), Type.getShape().end()); + shape[target_dim] = split_length; // Split the weight tensor into NUM_OF_NUMA parts - auto splitEvenType = target_dim == 1 - ? RankedTensorType::get({M, split_length}, Type.getElementType()) - : RankedTensorType::get({split_length, N}, Type.getElementType()); + auto splitEvenType = RankedTensorType::get(shape, Type.getElementType()); auto splitTailType = splitEvenType; - if (has_tail) splitTailType = target_dim == 1 - ? RankedTensorType::get({M, int64_t(N % split_length)}, Type.getElementType()) - : RankedTensorType::get({int64_t(M % split_length), N}, Type.getElementType()); + if (has_tail) { + shape[target_dim] = Type.getDimSize(target_dim) % split_length; + splitTailType = RankedTensorType::get(shape, Type.getElementType()); + } + llvm::outs() << "start to extract slice\n"; for (auto split_idx : llvm::seq(0, NUM_OF_NUMA)) { SmallVector sizes; SmallVector offsets; @@ -87,12 +86,51 @@ void getSplitedTensors(SmallVector& outputs, Location& loc, Value tensor, Value res = rewriter.create( loc, split_idx == (NUM_OF_NUMA-1) ? splitTailType : splitEvenType, tensor, offsets, sizes, strides)->getResult(0); auto res_type = res.getType().cast(); - std::cout << split_idx << ", res_type M: " << res_type.getDimSize(0) << ", N: " << res_type.getDimSize(1) << std::endl; + llvm::outs() << "splited shape: ["; + for (int64_t dim : res_type.getShape()) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; outputs.push_back(res); std::cout << outputs.size() << std::endl; } } +void splitBroadcast(SmallVector& outputs, linalg::BroadcastOp broadcastOp, int64_t target_dim, PatternRewriter &rewriter) { + auto loc = broadcastOp->getLoc(); + SmallVector broadcastInputs; + auto in = broadcastOp.getInput(); + if (in.getType().getShape().size() > SUPPORTED_RANK) { + llvm::outs() << "cannot split broadcast on current size.\n"; + return; + } + auto out = broadcastOp.getInit(); + auto outType = out.getType().dyn_cast(); + auto shape = outType.getShape(); + if (shape.size() != SUPPORTED_RANK || target_dim != 1) { + llvm::outs() << "cannot split broadcast on current size or current target dim \n"; + return; + } + llvm::outs() << "Tensor shape: ["; + for (int64_t dim : shape) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + llvm::outs() << "duplicate broadcast inputs\n"; + getSplitedTensors(broadcastInputs, in, /*target_dim*/in.getType().getShape().size()-1, rewriter); + if (auto emptyOp = dyn_cast(out.getDefiningOp())) { + int64_t split_length = (shape[1] + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + int64_t split_tail = shape[1] % NUM_OF_NUMA != 0 ? shape[1] % split_length : split_length; + for (auto split_idx : llvm::seq(0, NUM_OF_NUMA)) { + Value empty = rewriter.create( + loc, ArrayRef{shape[0], (split_idx == (NUM_OF_NUMA - 1)) ? split_tail : split_length}, outType.getElementType()); + Value res = rewriter.create(loc, broadcastInputs[split_idx], empty, broadcastOp.getDimensions()).getResults()[0]; + outputs.push_back(res); + std::cout << outputs.size() << std::endl; + } + } +} + void SplitMMonN(SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { /*Split on N axis*/ std::cout << "split on N" << std::endl; @@ -100,7 +138,7 @@ void SplitMMonN(SmallVector& outputs, SmallVector& inputs, TensorT int64_t N = inputs[1].getType().cast().getDimSize(1); int64_t K = inputs[0].getType().cast().getDimSize(1); SmallVector splited_weights; - getSplitedTensors(splited_weights, loc, inputs[1], /*target_dim*/1, rewriter); + getSplitedTensors(splited_weights, inputs[1], /*target_dim*/1, rewriter); if (splited_weights.size() != NUM_OF_NUMA) return; for (Value weight : splited_weights) { @@ -126,10 +164,10 @@ void SplitMMonK(SmallVector& outputs, SmallVector& inputs, TensorT int64_t N = inputs[1].getType().cast().getDimSize(1); int64_t K = inputs[0].getType().cast().getDimSize(1); SmallVector splited_data, splited_weights; - getSplitedTensors(splited_data, loc, inputs[0], /*target_dim*/1, rewriter); + getSplitedTensors(splited_data, inputs[0], /*target_dim*/1, rewriter); std::cout << "splited_data size: " << splited_data.size() << std::endl; if (splited_data.size() != NUM_OF_NUMA) return; - getSplitedTensors(splited_weights, loc, inputs[1], /*target_dim*/0, rewriter); + getSplitedTensors(splited_weights, inputs[1], /*target_dim*/0, rewriter); std::cout << "splited_weights size: " << splited_weights.size() << std::endl; if (splited_weights.size() != NUM_OF_NUMA) return; @@ -171,10 +209,9 @@ void getUnOps(Operation *op, SmallVectorImpl &postOps) { } template -void duplicateBinary(SmallVector& outputs,std::vector>& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(resultTy.getElementType())); +void duplicateBinary(SmallVector& outputs,std::vector>& inputs, TensorType& resultTy, PatternRewriter &rewriter) { for (int i = 0; i < NUM_OF_NUMA; ++i) { + auto loc = inputs[i][0].getLoc(); TensorType type = inputs[i][0].getType().cast(); Value Empty = rewriter.create( loc, type.getShape(), type.getElementType()); @@ -197,13 +234,13 @@ void deleteOperation(Operation *op) { // If the operand is an operation and is either emptyOp or fillOp if (auto definingOp = operand.getDefiningOp()) { - if (isa(definingOp) || isa(definingOp)) { - llvm::outs() << "is empty \n"; - // Recursively delete the operand operation if it has only one use - if (definingOp->hasOneUse()) { - deleteOperation(definingOp); - } + // if (isa(definingOp) || isa(definingOp)) { + // llvm::outs() << "is empty \n"; + // // Recursively delete the operand operation if it has only one use + if (definingOp->hasOneUse()) { + deleteOperation(definingOp); } + // } } } @@ -215,6 +252,20 @@ void deleteOperation(Operation *op) { op->erase(); } +void deleteOperands(Operation *op) { + for (auto operand : op->getOperands()) { + if (!operand) continue; + if (operand.use_empty()) {continue;} // Skip if operand has no uses + if (auto definingOp = operand.getDefiningOp()) { + if (definingOp->hasOneUse()) { + definingOp->dropAllUses(); + definingOp->dropAllReferences(); + definingOp->erase(); + } + } + } +} + Value addN(Value& initTensor, SmallVector& ins, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { llvm::outs() << "start addN \n"; // Create indexing maps (for input tensors and output tensor) @@ -263,7 +314,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, int64_t K = input_tensors[0].getType().cast().getDimSize(1); std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; - int64_t target_dim = N / K >= 2 ? 1 : 0; + int64_t target_dim = N / K >= 2 ? 1 : 1; SmallVector splites_res; if (target_dim == 1) { SplitMMonN(splites_res, input_tensors, resultTy, loc, rewriter); @@ -279,11 +330,6 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, std::vector> Inputs; for (auto input : opInputs) { - if (auto definingOp = input.getDefiningOp()) { - std::cout << "Input operation name: " << definingOp->getName().getStringRef().str() << std::endl; - } else { - std::cout << "Input does not have a defining operation." << std::endl; - } if (input == lastInput) { std::cout << "enter mm output" << std::endl; for (size_t i = 0; i < NUM_OF_NUMA; ++i) { @@ -291,12 +337,53 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, innerVector.push_back(Outputs[0]); Inputs.push_back(innerVector); Outputs.erase(Outputs.begin()); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() <<" \n"; + } + } else if (auto definingOp = input.getDefiningOp()) { + llvm::outs() << "is definingOp\n"; + std::cout << "Input operation name: " << definingOp->getName().getStringRef().str() << std::endl; + if (auto fillOp = dyn_cast(definingOp)) { + llvm::outs() << "is fill \n"; + SmallVector splited_inputs; + getSplitedTensors(splited_inputs, input, /*target_dim*/1, rewriter); + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() <<" \n"; + i++; + } + llvm::outs() << "split input done \n"; + } else if (auto broadcastOp = dyn_cast(definingOp)){ + llvm::outs() << "is broadcast \n"; + SmallVector splited_inputs; + splitBroadcast(splited_inputs, broadcastOp, /*target_dim*/1, rewriter); llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + i++; + } + deleteOperation(broadcastOp); + llvm::outs() << "split input done \n"; + } else if (auto constantOp = dyn_cast(definingOp)){ + llvm::outs() << "is constant \n"; + auto newConstantOp = rewriter.create( + constantOp.getLoc(), constantOp.getType(), constantOp.getValue()); + SmallVector splited_inputs; + getSplitedTensors(splited_inputs, newConstantOp, /*target_dim*/1, rewriter); + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() <<" \n"; + i++; + } + deleteOperation(constantOp); + llvm::outs() << "split input done \n"; } } else { llvm::outs() << "doesnot match anything \n"; SmallVector splited_inputs; - getSplitedTensors(splited_inputs, loc, input, /*target_dim*/1, rewriter); + getSplitedTensors(splited_inputs, input, /*target_dim*/1, rewriter); llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; int i = 0; for (const auto &splited_input : splited_inputs) { @@ -307,17 +394,17 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, } } if (auto postOpType = llvm::dyn_cast(postOp)) - duplicateBinary(Outputs, Inputs, resultTy, loc, rewriter); + duplicateBinary(Outputs, Inputs, resultTy, rewriter); else if (auto postOpType = llvm::dyn_cast(postOp)) - duplicateBinary(Outputs, Inputs, resultTy, loc, rewriter); + duplicateBinary(Outputs, Inputs, resultTy, rewriter); else if (auto postOpType = llvm::dyn_cast(postOp)) - duplicateBinary(Outputs, Inputs, resultTy, loc, rewriter); + duplicateBinary(Outputs, Inputs, resultTy, rewriter); llvm::outs() << "post op creation and deletion done \n"; lastInput = postOp->getResult(0); } // Concatenate the two halves back together on N axis auto newop = rewriter.create( - loc, target_dim, Outputs); + Outputs.back().getLoc(), target_dim, Outputs); llvm::outs() << "created concat \n"; auto replaced_op = postOps.size() ? postOps.back() : op; if (postOps.size() > 1) { @@ -326,6 +413,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, for (auto &deleteOp : postOps) deleteOperation(deleteOp); } + deleteOperands(replaced_op); rewriter.replaceOp(replaced_op, newop); } else { SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter); diff --git a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir index 25f3bf685..7e9542933 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir @@ -1,27 +1,24 @@ // RUN: gc-opt %s --split-compute-intensive-patterns | FileCheck %s func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> { -%cst = arith.constant 0.000000e+00 : bf16 -%0 = tensor.empty() : tensor<128x64xbf16> -%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> -%2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> -%3 = tensor.empty() : tensor<128x64xbf16> -%broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] -%4 = tensor.empty() : tensor<128x64xbf16> -%5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> -%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> -%6 = tensor.empty() : tensor<128x64xbf16> -%7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> -%cst_1 = arith.constant 0.000000e+00 : bf16 -%8 = tensor.empty() : tensor<128x256xbf16> -%9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> -%10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> -%11 = tensor.empty() : tensor<128x256xbf16> -%broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] -%12 = tensor.empty() : tensor<128x256xbf16> -%13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> -%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> -%14 = tensor.empty() : tensor<128x256xbf16> -%15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> -return %15 : tensor<128x256xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %3 = tensor.empty() : tensor<128x64xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + %6 = tensor.empty() : tensor<128x64xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %11 = tensor.empty() : tensor<128x256xbf16> + %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] + %12 = tensor.empty() : tensor<128x256xbf16> + %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %13 : tensor<128x256xbf16> } From 62fc8bffe548737655d5b4212b19beb2bdea8051 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Tue, 25 Jun 2024 14:59:49 +0800 Subject: [PATCH 03/10] fix color and recursive logic --- .../SplitComputeIntensivePatterns.cpp | 35 +++++++++++++++---- .../split-compute-intensive-patterns.mlir | 25 ++++++++++++- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp index f5688e78b..441f61eb2 100644 --- a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -35,7 +35,7 @@ namespace gc { #include "gc/Transforms/Passes.h.inc" } // namespace gc -size_t NUM_OF_NUMA = 2; +size_t NUM_OF_NUMA = 3; size_t SUPPORTED_RANK = 2; void printValueType(Value value) { @@ -149,11 +149,14 @@ void SplitMMonN(SmallVector& outputs, SmallVector& inputs, TensorT loc, ArrayRef {M, weight.getType().cast().getDimSize(1)}, resultTy.getElementType()); Value tensor = rewriter.create(loc, zero, empty).getResult(0); - outputs.push_back(rewriter.create( + auto newMM = rewriter.create( /*location=*/loc, /*resultTensorTypes=*/tensor.getType().cast(), /*inputs=*/ValueRange{inputs[0], weight}, - /*outputs=*/tensor)->getResult(0)); + /*outputs=*/tensor); + mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); + newMM->setAttr("splited", boolAttr); + outputs.push_back(newMM->getResult(0)); } } @@ -179,11 +182,15 @@ void SplitMMonK(SmallVector& outputs, SmallVector& inputs, TensorT loc, resultTy.getShape(), resultTy.getElementType()); Value tensor = rewriter.create(loc, zero, empty).getResult(0); - outputs.push_back(rewriter.create( + auto newMM = rewriter.create( /*location=*/loc, /*resultTensorTypes=*/tensor.getType().cast(), /*inputs=*/ValueRange{data, weight}, - /*outputs=*/tensor)->getResult(0)); + /*outputs=*/tensor); + mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); + newMM->setAttr("splited", boolAttr); + outputs.push_back(newMM->getResult(0)); + outputs.push_back(newMM->getResult(0)); } } @@ -203,8 +210,10 @@ bool isSupportedPostOp(Operation *op) { void getUnOps(Operation *op, SmallVectorImpl &postOps) { for (auto user : op->getUsers()) { if (isSupportedPostOp(user)) postOps.push_back(user); - // Recursively search for unary ops + if (isa(user)) return; + // Recursively search for unary ops, unless it's a matmul op getUnOps(user, postOps); + // } } } @@ -296,7 +305,12 @@ Value addN(Value& initTensor, SmallVector& ins, TensorType& resultTy, Loc LogicalResult splitSingleMM(linalg::MatmulOp& op, PatternRewriter &rewriter) { - SmallVector postOps; + // rewriter.updateRootInPlace(op, [&]() { + // mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); + // op->setAttr("splited", boolAttr); + // }); + + SmallVector postOps = {}; getUnOps(op, postOps); auto loc = op->getLoc(); auto resultTy = dyn_cast(op->getResultTypes().front()); @@ -321,6 +335,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, if (splites_res.size() != NUM_OF_NUMA) return failure(); SmallVector Outputs = splites_res; auto lastInput = op->getResult(0); + llvm::outs() << "postOps num: " << postOps.size() << "\n"; for (auto postOp : postOps) { llvm::outs() << "Operation name: " << postOp->getName().getStringRef() << "\n"; auto opInputs = postOp->getOperands().drop_back(); @@ -401,6 +416,8 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, duplicateBinary(Outputs, Inputs, resultTy, rewriter); llvm::outs() << "post op creation and deletion done \n"; lastInput = postOp->getResult(0); + if(auto lastop = lastInput.getDefiningOp()) + std::cout << "lastInput operation name: " << lastop->getName().getStringRef().str() << std::endl; } // Concatenate the two halves back together on N axis auto newop = rewriter.create( @@ -415,6 +432,8 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, } deleteOperands(replaced_op); rewriter.replaceOp(replaced_op, newop); + postOps = {}; + llvm::outs() << "after duplicate, postOps num: " << postOps.size() << "\n"; } else { SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter); if (splites_res.size() != NUM_OF_NUMA) return failure(); @@ -430,6 +449,8 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, // Replace the original operation with the new linalg.map operation rewriter.replaceOp(op, newop); } + llvm::outs() << "exit duplicate mm.\n"; + llvm::outs() << "==================================================\n"; return success(); } diff --git a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir index 7e9542933..09cb6e04e 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir @@ -1,4 +1,24 @@ // RUN: gc-opt %s --split-compute-intensive-patterns | FileCheck %s +func.func @basic_mlp(%in: tensor<128x512xbf16>, + %weight: tensor<512x256xbf16>, + %offset: tensor<128x256xbf16>, + %scale: tensor<128x256xbf16>, + %weight2: tensor<256x1024xbf16>) -> tensor<128x1024xbf16> { + %0 = tensor.empty() : tensor<128x256xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %2 = linalg.matmul ins(%in, %weight : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %3 = tensor.empty() : tensor<128x256xbf16> + %4 = linalg.add ins(%2, %offset : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %5 = tensor.empty() : tensor<128x256xbf16> + %6 = linalg.mul ins(%4, %scale : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%5 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %9 = tensor.empty() : tensor<128x256xbf16> + %10 = linalg.max ins(%6, %1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %11 = tensor.empty() : tensor<128x1024xbf16> + %12 = linalg.fill ins(%cst : bf16) outs(%11 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + %13 = linalg.matmul ins(%10, %weight2 : tensor<128x256xbf16>, tensor<256x1024xbf16>) outs(%12 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> + return %13 : tensor<128x1024xbf16> +} func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> { %cst = arith.constant 0.000000e+00 : bf16 @@ -20,5 +40,8 @@ func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: t %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] %12 = tensor.empty() : tensor<128x256xbf16> %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - return %13 : tensor<128x256xbf16> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %14 = tensor.empty() : tensor<128x256xbf16> + %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %15 : tensor<128x256xbf16> } From 569b710e2dcf87ffe96c71883a91d8caa4c3a540 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Tue, 25 Jun 2024 16:05:33 +0800 Subject: [PATCH 04/10] disable transpose --- .../SplitComputeIntensivePatterns.cpp | 44 ++++++++++++++++--- .../split-compute-intensive-patterns.mlir | 19 ++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp index 441f61eb2..37293ef92 100644 --- a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -200,17 +200,18 @@ bool isSupportedPostOp(Operation *op) { return false; // Get the inputs and outputs of the linalg operation - bool ismax = isa(op); - bool isadd = isa(op); - bool ismul = isa(op); - return ismax || isadd || ismul; + bool isMax = isa(op); + bool isAdd = isa(op); + bool isMul = isa(op); + // bool isTranspose = isa(op); + return isMax || isAdd || isMul; } // Helper function to get all post ops following the given operation void getUnOps(Operation *op, SmallVectorImpl &postOps) { for (auto user : op->getUsers()) { if (isSupportedPostOp(user)) postOps.push_back(user); - if (isa(user)) return; + if (isa(user)) return; // Recursively search for unary ops, unless it's a matmul op getUnOps(user, postOps); // } @@ -231,6 +232,33 @@ void duplicateBinary(SmallVector& outputs,std::vector> } } +void duplicateTranspose(SmallVector& outputs,std::vector>& inputs, linalg::TransposeOp transposeOp, TensorType& resultTy, PatternRewriter &rewriter) { + ArrayRef permutation = transposeOp.getPermutation(); + if (permutation.size() != SUPPORTED_RANK) {llvm::outs() << "unsupported rank\n"; return;} + for (int i = 0; i < NUM_OF_NUMA; ++i) { + auto loc = inputs[i][0].getLoc(); + TensorType type = inputs[i][0].getType().cast(); + const auto &inputShape = type.getShape(); + SmallVector transShape{inputShape[permutation[0]], inputShape[permutation[1]]}; + auto transTy = type.clone(transShape); + llvm::outs() << "TransTy shape: ["; + for (int64_t dim : transTy.getShape()) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(transTy.getElementType())); + Value empty = rewriter.create( + loc, transTy.getShape(), transTy.getElementType()); + Value tensor = + rewriter.create(loc, zero, empty).getResult(0); + auto tmpOp = rewriter.create(loc, inputs[i][0], tensor, permutation); + for (auto result : tmpOp->getResults()) { + outputs.push_back(result); + } + } +} + void deleteOperation(Operation *op) { // Step 1: Ensure the operation exists if (!op) @@ -263,10 +291,12 @@ void deleteOperation(Operation *op) { void deleteOperands(Operation *op) { for (auto operand : op->getOperands()) { + // llvm::outs() << "operands: " << operand << "\n"; if (!operand) continue; if (operand.use_empty()) {continue;} // Skip if operand has no uses if (auto definingOp = operand.getDefiningOp()) { if (definingOp->hasOneUse()) { + deleteOperands(definingOp); definingOp->dropAllUses(); definingOp->dropAllReferences(); definingOp->erase(); @@ -414,6 +444,10 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, duplicateBinary(Outputs, Inputs, resultTy, rewriter); else if (auto postOpType = llvm::dyn_cast(postOp)) duplicateBinary(Outputs, Inputs, resultTy, rewriter); + // else if (auto transOp = llvm::dyn_cast(postOp)) { + // duplicateTranspose(Outputs, Inputs, transOp, resultTy, rewriter); + // target_dim ^= 0x1; + // } llvm::outs() << "post op creation and deletion done \n"; lastInput = postOp->getResult(0); if(auto lastop = lastInput.getDefiningOp()) diff --git a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir index 09cb6e04e..d68eade46 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir @@ -45,3 +45,22 @@ func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: t %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> return %15 : tensor<128x256xbf16> } + +func.func @mlp_transpose_a_b(%arg0: tensor<512x128xbf16>, %arg1: tensor<256x512xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<256x128xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<256x128xbf16>) -> tensor<256x128xbf16> + %2 = linalg.matmul ins(%arg1, %arg0 : tensor<256x512xbf16>, tensor<512x128xbf16>) outs(%1 : tensor<256x128xbf16>) -> tensor<256x128xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %3 = tensor.empty() : tensor<128x256xbf16> + %4 = linalg.fill ins(%cst_0 : bf16) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %transposed = linalg.transpose ins(%2 : tensor<256x128xbf16>) outs(%4 : tensor<128x256xbf16>) permutation = [1, 0] + %5 = tensor.empty() : tensor<128x256xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%5 : tensor<128x256xbf16>) dimensions = [0] + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.add ins(%transposed, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.max ins(%7, %cst_1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %9 : tensor<128x256xbf16> +} From e0da29ef5c7544dd154f5b7f09b48d9bfdef811c Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Wed, 26 Jun 2024 15:53:19 +0800 Subject: [PATCH 05/10] add mm_trans --- .../SplitComputeIntensivePatterns.cpp | 145 ++++------- .../split-compute-intensive-patterns.mlir | 230 +++++++++++++----- 2 files changed, 221 insertions(+), 154 deletions(-) diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp index 37293ef92..519528782 100644 --- a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -131,36 +131,42 @@ void splitBroadcast(SmallVector& outputs, linalg::BroadcastOp broadcastOp } } -void SplitMMonN(SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { +void SplitMMonN(Operation* op, SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, int64_t target_dim, Location& loc, PatternRewriter &rewriter) { /*Split on N axis*/ std::cout << "split on N" << std::endl; int64_t M = inputs[0].getType().cast().getDimSize(0); - int64_t N = inputs[1].getType().cast().getDimSize(1); + int64_t N = inputs[1].getType().cast().getDimSize(target_dim); int64_t K = inputs[0].getType().cast().getDimSize(1); SmallVector splited_weights; - getSplitedTensors(splited_weights, inputs[1], /*target_dim*/1, rewriter); + getSplitedTensors(splited_weights, inputs[1], target_dim, rewriter); if (splited_weights.size() != NUM_OF_NUMA) return; for (Value weight : splited_weights) { Value zero = rewriter.create( loc, rewriter.getZeroAttr(resultTy.getElementType())); - std::cout << "weight.getType().cast().getDimSize(1): " << weight.getType().cast().getDimSize(1) << std::endl; + std::cout << "weight.getType().cast().getDimSize(1): " << weight.getType().cast().getDimSize(target_dim) << std::endl; Value empty = rewriter.create( - loc, ArrayRef {M, weight.getType().cast().getDimSize(1)}, resultTy.getElementType()); + loc, ArrayRef {M, weight.getType().cast().getDimSize(target_dim)}, resultTy.getElementType()); Value tensor = rewriter.create(loc, zero, empty).getResult(0); - auto newMM = rewriter.create( - /*location=*/loc, - /*resultTensorTypes=*/tensor.getType().cast(), - /*inputs=*/ValueRange{inputs[0], weight}, - /*outputs=*/tensor); + auto newMM = isa(op) ? + rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor) : + rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor); mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); newMM->setAttr("splited", boolAttr); outputs.push_back(newMM->getResult(0)); } } -void SplitMMonK(SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { +void SplitMMonK(Operation* op, SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { /*Split on K axis*/ std::cout << "split on K" << std::endl; int64_t M = inputs[0].getType().cast().getDimSize(0); @@ -211,7 +217,7 @@ bool isSupportedPostOp(Operation *op) { void getUnOps(Operation *op, SmallVectorImpl &postOps) { for (auto user : op->getUsers()) { if (isSupportedPostOp(user)) postOps.push_back(user); - if (isa(user)) return; + if (isa(user)) return; // Recursively search for unary ops, unless it's a matmul op getUnOps(user, postOps); // } @@ -333,18 +339,13 @@ Value addN(Value& initTensor, SmallVector& ins, TensorType& resultTy, Loc return genericOp.getResults().front();; } -LogicalResult splitSingleMM(linalg::MatmulOp& op, +LogicalResult splitSingleMM(Operation* op, PatternRewriter &rewriter) { - // rewriter.updateRootInPlace(op, [&]() { - // mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); - // op->setAttr("splited", boolAttr); - // }); - SmallVector postOps = {}; getUnOps(op, postOps); auto loc = op->getLoc(); auto resultTy = dyn_cast(op->getResultTypes().front()); - auto input_operands = op.getInputs(); + auto input_operands = op->getOperands().drop_back(); SmallVector input_tensors; for (Value operand : input_operands) { if (!operand.getType().isa()) { @@ -352,16 +353,17 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, } input_tensors.push_back(operand); } - + bool istransB = isa(op); + llvm::outs() << "is trans B\n"; int64_t M = input_tensors[0].getType().cast().getDimSize(0); - int64_t N = input_tensors[1].getType().cast().getDimSize(1); + int64_t N = input_tensors[1].getType().cast().getDimSize(istransB ? 0 : 1); int64_t K = input_tensors[0].getType().cast().getDimSize(1); std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; int64_t target_dim = N / K >= 2 ? 1 : 1; SmallVector splites_res; if (target_dim == 1) { - SplitMMonN(splites_res, input_tensors, resultTy, loc, rewriter); + SplitMMonN(op, splites_res, input_tensors, resultTy, target_dim ^ istransB, loc, rewriter); if (splites_res.size() != NUM_OF_NUMA) return failure(); SmallVector Outputs = splites_res; auto lastInput = op->getResult(0); @@ -390,7 +392,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, if (auto fillOp = dyn_cast(definingOp)) { llvm::outs() << "is fill \n"; SmallVector splited_inputs; - getSplitedTensors(splited_inputs, input, /*target_dim*/1, rewriter); + getSplitedTensors(splited_inputs, input, target_dim, rewriter); int i = 0; for (const auto &splited_input : splited_inputs) { Inputs[i].push_back(splited_input); @@ -401,7 +403,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, } else if (auto broadcastOp = dyn_cast(definingOp)){ llvm::outs() << "is broadcast \n"; SmallVector splited_inputs; - splitBroadcast(splited_inputs, broadcastOp, /*target_dim*/1, rewriter); + splitBroadcast(splited_inputs, broadcastOp, target_dim, rewriter); llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; int i = 0; for (const auto &splited_input : splited_inputs) { @@ -415,7 +417,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, auto newConstantOp = rewriter.create( constantOp.getLoc(), constantOp.getType(), constantOp.getValue()); SmallVector splited_inputs; - getSplitedTensors(splited_inputs, newConstantOp, /*target_dim*/1, rewriter); + getSplitedTensors(splited_inputs, newConstantOp, target_dim, rewriter); int i = 0; for (const auto &splited_input : splited_inputs) { Inputs[i].push_back(splited_input); @@ -428,7 +430,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, } else { llvm::outs() << "doesnot match anything \n"; SmallVector splited_inputs; - getSplitedTensors(splited_inputs, input, /*target_dim*/1, rewriter); + getSplitedTensors(splited_inputs, input, target_dim, rewriter); llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; int i = 0; for (const auto &splited_input : splited_inputs) { @@ -469,7 +471,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, postOps = {}; llvm::outs() << "after duplicate, postOps num: " << postOps.size() << "\n"; } else { - SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter); + SplitMMonK(op, splites_res, input_tensors, resultTy, loc, rewriter); if (splites_res.size() != NUM_OF_NUMA) return failure(); // Add the two halves back together // Create linalg.map operation @@ -488,61 +490,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op, return success(); } -LogicalResult splitSingleMMwithUnary(linalg::MatmulOp& op, - PatternRewriter &rewriter) { - auto loc = op->getLoc(); - auto resultTy = dyn_cast(op->getResultTypes().front()); - auto input_operands = op.getInputs(); - SmallVector input_tensors; - for (Value operand : input_operands) { - if (!operand.getType().isa()) { - continue; - } - input_tensors.push_back(operand); - } - - int64_t M = input_tensors[0].getType().cast().getDimSize(0); - int64_t N = input_tensors[1].getType().cast().getDimSize(1); - int64_t K = input_tensors[0].getType().cast().getDimSize(1); - std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; - - int64_t target_dim = N / K >= 2 ? 1 : 0; - SmallVector splites_res; - if (target_dim == 1) { - SplitMMonN(splites_res, input_tensors, resultTy, loc, rewriter); - if (splites_res.size() != NUM_OF_NUMA) return failure(); - - - // Concatenate the two halves back together on N axis - auto newop = rewriter.create( - loc, target_dim, splites_res); - rewriter.replaceOp(op, newop); - } else { - SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter); - if (splites_res.size() != NUM_OF_NUMA) return failure(); - // Add the two halves back together - // Create linalg.map operation - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(resultTy.getElementType())); - Value empty = rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType()); - Value initTensor = - rewriter.create(loc, zero, empty).getResult(0); - auto newop = rewriter.create( - loc, resultTy, splites_res, ValueRange{initTensor}); - - // Replace the original operation with the new linalg.map operation - rewriter.replaceOp(op, newop); - } - return success(); -} - -LogicalResult splitMLP(linalg::MatmulOp& op, - PatternRewriter &rewriter) { - return success(); -} - -class SplitComputeIntensivePatternsRewriter +class SplitMatmulRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -551,21 +499,21 @@ class SplitComputeIntensivePatternsRewriter // Check if the operation has already been processed if (op->hasAttr("splited")) return failure(); - // Ensure the operation is followed by relu and another matmul. - // auto nextOp = op->getNextNode(); - // while (isa(nextOp) || isa(nextOp)) nextOp = nextOp->getNextNode(); - // std::cout << !isa(nextOp) << std::endl; - // need to break when encounters computational op - // if (!nextOp || !isa(nextOp)) return splitSingleMM(op, rewriter); - // auto reluOp = cast(nextOp); - // auto nextNextOp = reluOp->getNextNode(); - // while (isa(nextNextOp) || isa(nextNextOp)) nextNextOp = nextNextOp->getNextNode(); - // // need to break when encounters binary op - // if (!nextNextOp || !isa(nextNextOp)) - // return splitSingleMMwithUnary(op, rewriter); - // auto nextMatmulOp = cast(nextNextOp); - // return splitMLP(op, rewriter); + } +}; + +class SplitMatmulTransposeBRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::MatmulTransposeBOp op, + PatternRewriter &rewriter) const final { + // Check if the operation has already been processed + llvm::outs() << "get into mm transpose b\n"; + if (op->hasAttr("splited")) + return failure(); + return splitSingleMM(op, rewriter); } }; @@ -577,11 +525,12 @@ class SplitComputeIntensivePatterns SplitComputeIntensivePatterns>::SplitComputeIntensivePatternsBase; void runOnOperation() final { RewritePatternSet patterns(&getContext()); - patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); SmallVector ops; getOperation()->walk([&](Operation *op) { - if (isa(op)) + if (isa(op)) ops.push_back(op); }); GreedyRewriteConfig config; diff --git a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir index d68eade46..5595b166e 100644 --- a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir +++ b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir @@ -1,66 +1,184 @@ // RUN: gc-opt %s --split-compute-intensive-patterns | FileCheck %s -func.func @basic_mlp(%in: tensor<128x512xbf16>, - %weight: tensor<512x256xbf16>, - %offset: tensor<128x256xbf16>, - %scale: tensor<128x256xbf16>, - %weight2: tensor<256x1024xbf16>) -> tensor<128x1024xbf16> { - %0 = tensor.empty() : tensor<128x256xbf16> + +func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %3 = tensor.empty() : tensor<128x64xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + %6 = tensor.empty() : tensor<128x64xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %11 = tensor.empty() : tensor<128x256xbf16> + %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] + %12 = tensor.empty() : tensor<128x256xbf16> + %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %14 = tensor.empty() : tensor<128x256xbf16> + %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %15 : tensor<128x256xbf16> +} + +func.func @mlp_transpose_a(%arg0: tensor<512x128xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %2 = linalg.matmul ins(%in, %weight : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<512x128xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> %3 = tensor.empty() : tensor<128x256xbf16> - %4 = linalg.add ins(%2, %offset : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %5 = tensor.empty() : tensor<128x256xbf16> - %6 = linalg.mul ins(%4, %scale : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%5 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %9 = tensor.empty() : tensor<128x256xbf16> - %10 = linalg.max ins(%6, %1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %11 = tensor.empty() : tensor<128x1024xbf16> - %12 = linalg.fill ins(%cst : bf16) outs(%11 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> - %13 = linalg.matmul ins(%10, %weight2 : tensor<128x256xbf16>, tensor<256x1024xbf16>) outs(%12 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16> - return %13 : tensor<128x1024xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%3 : tensor<128x256xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x256xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%4 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %7 : tensor<128x256xbf16> } -func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x64xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %3 = tensor.empty() : tensor<128x64xbf16> - %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] - %4 = tensor.empty() : tensor<128x64xbf16> - %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> - %6 = tensor.empty() : tensor<128x64xbf16> - %7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %cst_1 = arith.constant 0.000000e+00 : bf16 - %8 = tensor.empty() : tensor<128x256xbf16> - %9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %11 = tensor.empty() : tensor<128x256xbf16> - %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] - %12 = tensor.empty() : tensor<128x256xbf16> - %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> - %14 = tensor.empty() : tensor<128x256xbf16> - %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - return %15 : tensor<128x256xbf16> +func.func @mlp_transpose_b(%arg0: tensor<128x512xbf16>, %arg1: tensor<256x512xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<256x512xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %3 = tensor.empty() : tensor<128x256xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%3 : tensor<128x256xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x256xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%4 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %7 : tensor<128x256xbf16> } func.func @mlp_transpose_a_b(%arg0: tensor<512x128xbf16>, %arg1: tensor<256x512xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<256x128xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<256x128xbf16>) -> tensor<256x128xbf16> - %2 = linalg.matmul ins(%arg1, %arg0 : tensor<256x512xbf16>, tensor<512x128xbf16>) outs(%1 : tensor<256x128xbf16>) -> tensor<256x128xbf16> - %cst_0 = arith.constant 0.000000e+00 : bf16 - %3 = tensor.empty() : tensor<128x256xbf16> - %4 = linalg.fill ins(%cst_0 : bf16) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %transposed = linalg.transpose ins(%2 : tensor<256x128xbf16>) outs(%4 : tensor<128x256xbf16>) permutation = [1, 0] - %5 = tensor.empty() : tensor<128x256xbf16> - %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%5 : tensor<128x256xbf16>) dimensions = [0] - %6 = tensor.empty() : tensor<128x256xbf16> - %7 = linalg.add ins(%transposed, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> - %8 = tensor.empty() : tensor<128x256xbf16> - %9 = linalg.max ins(%7, %cst_1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - return %9 : tensor<128x256xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<256x128xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<256x128xbf16>) -> tensor<256x128xbf16> + %2 = linalg.matmul ins(%arg1, %arg0 : tensor<256x512xbf16>, tensor<512x128xbf16>) outs(%1 : tensor<256x128xbf16>) -> tensor<256x128xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %3 = tensor.empty() : tensor<128x256xbf16> + %4 = linalg.fill ins(%cst_0 : bf16) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %transposed = linalg.transpose ins(%2 : tensor<256x128xbf16>) outs(%4 : tensor<128x256xbf16>) permutation = [1, 0] + %5 = tensor.empty() : tensor<128x256xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%5 : tensor<128x256xbf16>) dimensions = [0] + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.add ins(%transposed, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.max ins(%7, %cst_1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %9 : tensor<128x256xbf16> +} + +func.func @llama2_mlp(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008x4096xbf16>, %arg6: tensor<11008x4096xbf16>, %arg7: tensor<4096x11008xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> (tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul_transpose_b ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %3 = tensor.empty() : tensor<1x32x4096xbf16> + %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %5 = tensor.empty() : tensor<1x32x4096xf32> + %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %7 = tensor.empty() : tensor<1x32x4096xf32> + %8 = linalg.powf ins(%6, %cst_0 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_1 = arith.constant 0.000000e+00 : f32 + %9 = tensor.empty() : tensor<1x32xf32> + %10 = linalg.fill ins(%cst_1 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %64 = arith.addf %in, %init : f32 + linalg.yield %64 : f32 + } + %cst_2 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %11 = tensor.empty() : tensor<1x32xf32> + %12 = linalg.div ins(%reduced, %cst_2 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_3 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %13 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] + %14 = tensor.empty() : tensor<1x32x1xf32> + %15 = linalg.add ins(%expanded_3, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_4 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %16 = tensor.empty() : tensor<1x32x1xf32> + %17 = linalg.powf ins(%15, %cst_4 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_5 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %18 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_6 = linalg.broadcast ins(%collapsed_5 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] + %19 = tensor.empty() : tensor<1x32x4096xf32> + %20 = linalg.mul ins(%6, %broadcasted_6 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %21 = tensor.empty() : tensor<1x32x4096xbf16> + %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %23 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_7 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %24 = tensor.empty() : tensor<1x32x4096xbf16> + %25 = linalg.mul ins(%broadcasted_7, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %collapsed_8 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_9 = arith.constant 0.000000e+00 : bf16 + %26 = tensor.empty() : tensor<32x11008xbf16> + %27 = linalg.fill ins(%cst_9 : bf16) outs(%26 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %28 = linalg.matmul_transpose_b ins(%collapsed_8, %arg5 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%27 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_10 = tensor.expand_shape %28 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %29 = tensor.empty() : tensor<1x32x11008xbf16> + %30 = linalgx.sigmoid ins(%expanded_10 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %31 = tensor.empty() : tensor<1x32x11008xbf16> + %32 = linalg.mul ins(%30, %expanded_10 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%31 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_11 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_12 = arith.constant 0.000000e+00 : bf16 + %33 = tensor.empty() : tensor<32x11008xbf16> + %34 = linalg.fill ins(%cst_12 : bf16) outs(%33 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %35 = linalg.matmul_transpose_b ins(%collapsed_11, %arg6 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%34 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_13 = tensor.expand_shape %35 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %36 = tensor.empty() : tensor<1x32x11008xbf16> + %37 = linalg.mul ins(%32, %expanded_13 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%36 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_14 = tensor.collapse_shape %37 [[0, 1], [2]] : tensor<1x32x11008xbf16> into tensor<32x11008xbf16> + %cst_15 = arith.constant 0.000000e+00 : bf16 + %38 = tensor.empty() : tensor<32x4096xbf16> + %39 = linalg.fill ins(%cst_15 : bf16) outs(%38 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %40 = linalg.matmul_transpose_b ins(%collapsed_14, %arg7 : tensor<32x11008xbf16>, tensor<4096x11008xbf16>) outs(%39 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded_16 = tensor.expand_shape %40 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %41 = tensor.empty() : tensor<1x32x4096xbf16> + %42 = linalg.add ins(%4, %expanded_16 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%41 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %43 = tensor.empty() : tensor<1x32x4096xf32> + %44 = linalg.copy ins(%42 : tensor<1x32x4096xbf16>) outs(%43 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_17 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %45 = tensor.empty() : tensor<1x32x4096xf32> + %46 = linalg.powf ins(%44, %cst_17 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%45 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_18 = arith.constant 0.000000e+00 : f32 + %47 = tensor.empty() : tensor<1x32xf32> + %48 = linalg.fill ins(%cst_18 : f32) outs(%47 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced_19 = linalg.reduce ins(%46 : tensor<1x32x4096xf32>) outs(%48 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %64 = arith.addf %in, %init : f32 + linalg.yield %64 : f32 + } + %cst_20 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %49 = tensor.empty() : tensor<1x32xf32> + %50 = linalg.div ins(%reduced_19, %cst_20 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%49 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_21 = tensor.expand_shape %50 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %51 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted_22 = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%51 : tensor<1x32x1xf32>) dimensions = [0, 1] + %52 = tensor.empty() : tensor<1x32x1xf32> + %53 = linalg.add ins(%expanded_21, %broadcasted_22 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%52 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_23 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %54 = tensor.empty() : tensor<1x32x1xf32> + %55 = linalg.powf ins(%53, %cst_23 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%54 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_24 = tensor.collapse_shape %55 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %56 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_25 = linalg.broadcast ins(%collapsed_24 : tensor<1x32xf32>) outs(%56 : tensor<1x32x4096xf32>) dimensions = [2] + %57 = tensor.empty() : tensor<1x32x4096xf32> + %58 = linalg.mul ins(%44, %broadcasted_25 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%57 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %59 = tensor.empty() : tensor<1x32x4096xbf16> + %60 = linalg.copy ins(%58 : tensor<1x32x4096xf32>) outs(%59 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %61 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_26 = linalg.broadcast ins(%arg9 : tensor<4096xbf16>) outs(%61 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %62 = tensor.empty() : tensor<1x32x4096xbf16> + %63 = linalg.mul ins(%broadcasted_26, %60 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%62 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + return %63, %42 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16> } From 89b7e28424b2c81a8da559a349f843a3a5e538c9 Mon Sep 17 00:00:00 2001 From: Zhang Yan Date: Mon, 1 Jul 2024 13:22:32 +0800 Subject: [PATCH 06/10] enable split compute intensive --- lib/gc/Transforms/Pipeline.cpp | 1 + .../SplitComputeIntensivePatterns.cpp | 2 +- tools/workloads/test.mlir | 27 ++++++++++++++++--- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 56f89e3af..0c6849b72 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -38,6 +38,7 @@ void populateFrontendPasses(mlir::PassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack void populateTensorPasses(mlir::PassManager &pm) { + pm.addNestedPass(createSplitComputeIntensivePatterns()); // todo: padding propagation pass // todo: layout propagation pass // todo: tensor constant propagation pass diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp index 519528782..290d64994 100644 --- a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -360,7 +360,7 @@ LogicalResult splitSingleMM(Operation* op, int64_t K = input_tensors[0].getType().cast().getDimSize(1); std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; - int64_t target_dim = N / K >= 2 ? 1 : 1; + int64_t target_dim = N / K >= 2 ? 0 : 0; SmallVector splites_res; if (target_dim == 1) { SplitMMonN(op, splites_res, input_tensors, resultTy, target_dim ^ istransB, loc, rewriter); diff --git a/tools/workloads/test.mlir b/tools/workloads/test.mlir index f1a7eb53b..99feeffc3 100644 --- a/tools/workloads/test.mlir +++ b/tools/workloads/test.mlir @@ -1,4 +1,25 @@ -func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { - %0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16> - return %0 : tensor<128x256xbf16> +func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %3 = tensor.empty() : tensor<128x64xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + %6 = tensor.empty() : tensor<128x64xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %11 = tensor.empty() : tensor<128x256xbf16> + %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] + %12 = tensor.empty() : tensor<128x256xbf16> + %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %14 = tensor.empty() : tensor<128x256xbf16> + %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %15 : tensor<128x256xbf16> } \ No newline at end of file From 3de1679db3513406b36163c95e7b4d4757c25098 Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Mon, 1 Jul 2024 20:07:12 -0700 Subject: [PATCH 07/10] use ctypes to call libnuma --- tools/utils.py | 41 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/tools/utils.py b/tools/utils.py index 4369f4663..ab8ee708e 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -12,6 +12,16 @@ from gc_mlir.dialects import arith, func, memref, onednn_graph from op_config import * +# Load libnuma +libnuma = ctypes.CDLL("libnuma.so.1") + +# Define numa_alloc_onnode function +libnuma.numa_alloc_onnode.restype = ctypes.c_void_p +libnuma.numa_alloc_onnode.argtypes = [ctypes.c_size_t, ctypes.c_int] + +# Define numa_free function +libnuma.numa_free.argtypes = [ctypes.c_void_p, ctypes.c_size_t] + MLIR_TYPE_TO_NUMPY_TYPE = { "bf16": ml_dtypes.bfloat16, "f32": np.float32, @@ -116,11 +126,32 @@ def mlir_type(s, ctx): return type_mapping[s] -def make_tensor(tensor_type): - return np.zeros( - tensor_type.shape, MLIR_TYPE_TO_NUMPY_TYPE[str(tensor_type.element_type)] - ) - +def make_tensor(tensor_type, numa_node = 1): + # return np.zeros( + # tensor_type.shape, MLIR_TYPE_TO_NUMPY_TYPE[str(tensor_type.element_type)] + # ) + shape = tensor_type.shape + element_type = MLIR_TYPE_TO_NUMPY_TYPE[str(tensor_type.element_type)] + dtype = np.dtype(element_type) + # Calculate the total size of the tensor in bytes + tensor_size = np.prod(shape) * dtype.itemsize + + # Allocate memory on the specified NUMA node + buffer_addr = libnuma.numa_alloc_onnode(tensor_size, numa_node) + if not buffer_addr: + raise MemoryError(f"Failed to allocate memory on NUMA node {numa_node}") + + # Cast buffer_addr to the correct pointer type + buffer_pointer = ctypes.cast(buffer_addr, ctypes.POINTER(ctypes.c_float)) + # Create numpy array pointing to allocated memory + tensor = np.ctypeslib.as_array(buffer_pointer, shape=shape) + + # Check if the actual buffer size is sufficient + actual_buffer_size = ctypes.sizeof(ctypes.c_char) * tensor.nbytes + if actual_buffer_size < tensor_size: + raise ValueError(f"Buffer size {actual_buffer_size} is smaller than required size {tensor_size}") + + return tensor def get_kernel_func_from_module( module: ir.Module, func_name: str = "main_entry" From 105a5e7dc0eebc6ecbb471e1a44163ff08371b19 Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Mon, 1 Jul 2024 23:01:08 -0700 Subject: [PATCH 08/10] try bench --- lib/gc/Transforms/Pipeline.cpp | 2 +- .../SplitComputeIntensivePatterns.cpp | 313 +++++++++++------- tools/drivers.py | 16 +- tools/workloads/single_mm.mlir | 7 + tools/workloads/splited_mm.mlir | 12 + tools/workloads/test.mlir | 25 -- 6 files changed, 231 insertions(+), 144 deletions(-) create mode 100644 tools/workloads/single_mm.mlir create mode 100644 tools/workloads/splited_mm.mlir delete mode 100644 tools/workloads/test.mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 0c6849b72..fd5e342c0 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -38,7 +38,7 @@ void populateFrontendPasses(mlir::PassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack void populateTensorPasses(mlir::PassManager &pm) { - pm.addNestedPass(createSplitComputeIntensivePatterns()); + // pm.addNestedPass(createSplitComputeIntensivePatterns()); // todo: padding propagation pass // todo: layout propagation pass // todo: tensor constant propagation pass diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp index 290d64994..9b8addb7d 100644 --- a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -14,16 +14,16 @@ * limitations under the License. *******************************************************************************/ -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Location.h" #include "gc/Transforms/Passes.h" @@ -35,7 +35,7 @@ namespace gc { #include "gc/Transforms/Passes.h.inc" } // namespace gc -size_t NUM_OF_NUMA = 3; +size_t NUM_OF_NUMA = 2; size_t SUPPORTED_RANK = 2; void printValueType(Value value) { @@ -49,7 +49,8 @@ void printValueType(Value value) { llvm::outs() << "\n"; } -void getSplitedTensors(SmallVector& outputs, Value tensor, int64_t target_dim, PatternRewriter &rewriter) { +void getSplitedTensors(SmallVector &outputs, Value tensor, + int64_t target_dim, PatternRewriter &rewriter) { auto Type = tensor.getType().cast(); auto loc = tensor.getLoc(); int64_t rank = Type.getRank(); @@ -64,7 +65,8 @@ void getSplitedTensors(SmallVector& outputs, Value tensor, int64_t target llvm::outs() << "]\n"; llvm::outs() << "target_dim: " << target_dim << "\n"; bool has_tail = Type.getDimSize(target_dim) % NUM_OF_NUMA != 0; - int64_t split_length = (Type.getDimSize(target_dim) + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + int64_t split_length = + (Type.getDimSize(target_dim) + NUM_OF_NUMA - 1) / NUM_OF_NUMA; SmallVector shape(Type.getShape().begin(), Type.getShape().end()); shape[target_dim] = split_length; // Split the weight tensor into NUM_OF_NUMA parts @@ -80,11 +82,21 @@ void getSplitedTensors(SmallVector& outputs, Value tensor, int64_t target SmallVector offsets; SmallVector strides(rank, rewriter.getIndexAttr(1)); for (auto i : llvm::seq(0, rank)) { - sizes.push_back(rewriter.getIndexAttr((split_idx == (NUM_OF_NUMA-1)) ? splitTailType.getShape()[i] : splitEvenType.getShape()[i])); - offsets.push_back(rewriter.getIndexAttr((split_idx == 0 || i != target_dim) ? 0 : splitEvenType.getShape()[i] * split_idx)); + sizes.push_back(rewriter.getIndexAttr((split_idx == (NUM_OF_NUMA - 1)) + ? splitTailType.getShape()[i] + : splitEvenType.getShape()[i])); + offsets.push_back( + rewriter.getIndexAttr((split_idx == 0 || i != target_dim) + ? 0 + : splitEvenType.getShape()[i] * split_idx)); } - Value res = rewriter.create( - loc, split_idx == (NUM_OF_NUMA-1) ? splitTailType : splitEvenType, tensor, offsets, sizes, strides)->getResult(0); + Value res = + rewriter + .create( + loc, + split_idx == (NUM_OF_NUMA - 1) ? splitTailType : splitEvenType, + tensor, offsets, sizes, strides) + ->getResult(0); auto res_type = res.getType().cast(); llvm::outs() << "splited shape: ["; for (int64_t dim : res_type.getShape()) { @@ -96,19 +108,22 @@ void getSplitedTensors(SmallVector& outputs, Value tensor, int64_t target } } -void splitBroadcast(SmallVector& outputs, linalg::BroadcastOp broadcastOp, int64_t target_dim, PatternRewriter &rewriter) { +void splitBroadcast(SmallVector &outputs, + linalg::BroadcastOp broadcastOp, int64_t target_dim, + PatternRewriter &rewriter) { auto loc = broadcastOp->getLoc(); SmallVector broadcastInputs; auto in = broadcastOp.getInput(); if (in.getType().getShape().size() > SUPPORTED_RANK) { - llvm::outs() << "cannot split broadcast on current size.\n"; + llvm::outs() << "cannot split broadcast on current size.\n"; return; } auto out = broadcastOp.getInit(); auto outType = out.getType().dyn_cast(); auto shape = outType.getShape(); if (shape.size() != SUPPORTED_RANK || target_dim != 1) { - llvm::outs() << "cannot split broadcast on current size or current target dim \n"; + llvm::outs() + << "cannot split broadcast on current size or current target dim \n"; return; } llvm::outs() << "Tensor shape: ["; @@ -117,86 +132,112 @@ void splitBroadcast(SmallVector& outputs, linalg::BroadcastOp broadcastOp } llvm::outs() << "]\n"; llvm::outs() << "duplicate broadcast inputs\n"; - getSplitedTensors(broadcastInputs, in, /*target_dim*/in.getType().getShape().size()-1, rewriter); + getSplitedTensors(broadcastInputs, in, + /*target_dim*/ in.getType().getShape().size() - 1, + rewriter); if (auto emptyOp = dyn_cast(out.getDefiningOp())) { int64_t split_length = (shape[1] + NUM_OF_NUMA - 1) / NUM_OF_NUMA; - int64_t split_tail = shape[1] % NUM_OF_NUMA != 0 ? shape[1] % split_length : split_length; + int64_t split_tail = + shape[1] % NUM_OF_NUMA != 0 ? shape[1] % split_length : split_length; for (auto split_idx : llvm::seq(0, NUM_OF_NUMA)) { Value empty = rewriter.create( - loc, ArrayRef{shape[0], (split_idx == (NUM_OF_NUMA - 1)) ? split_tail : split_length}, outType.getElementType()); - Value res = rewriter.create(loc, broadcastInputs[split_idx], empty, broadcastOp.getDimensions()).getResults()[0]; + loc, + ArrayRef{shape[0], (split_idx == (NUM_OF_NUMA - 1)) + ? split_tail + : split_length}, + outType.getElementType()); + Value res = + rewriter + .create(loc, broadcastInputs[split_idx], + empty, broadcastOp.getDimensions()) + .getResults()[0]; outputs.push_back(res); std::cout << outputs.size() << std::endl; } } } -void SplitMMonN(Operation* op, SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, int64_t target_dim, Location& loc, PatternRewriter &rewriter) { +void SplitMMonN(Operation *op, SmallVector &outputs, + SmallVector &inputs, TensorType &resultTy, + int64_t target_dim, Location &loc, PatternRewriter &rewriter) { /*Split on N axis*/ std::cout << "split on N" << std::endl; int64_t M = inputs[0].getType().cast().getDimSize(0); - int64_t N = inputs[1].getType().cast().getDimSize(target_dim); + int64_t N = + inputs[1].getType().cast().getDimSize(target_dim); int64_t K = inputs[0].getType().cast().getDimSize(1); SmallVector splited_weights; getSplitedTensors(splited_weights, inputs[1], target_dim, rewriter); - if (splited_weights.size() != NUM_OF_NUMA) return; + if (splited_weights.size() != NUM_OF_NUMA) + return; for (Value weight : splited_weights) { Value zero = rewriter.create( - loc, rewriter.getZeroAttr(resultTy.getElementType())); - std::cout << "weight.getType().cast().getDimSize(1): " << weight.getType().cast().getDimSize(target_dim) << std::endl; + loc, rewriter.getZeroAttr(resultTy.getElementType())); + std::cout << "weight.getType().cast().getDimSize(1): " + << weight.getType().cast().getDimSize( + target_dim) + << std::endl; Value empty = rewriter.create( - loc, ArrayRef {M, weight.getType().cast().getDimSize(target_dim)}, resultTy.getElementType()); + loc, + ArrayRef{ + M, + weight.getType().cast().getDimSize(target_dim)}, + resultTy.getElementType()); Value tensor = rewriter.create(loc, zero, empty).getResult(0); - auto newMM = isa(op) ? - rewriter.create( - /*location=*/loc, - /*resultTensorTypes=*/tensor.getType().cast(), - /*inputs=*/ValueRange{inputs[0], weight}, - /*outputs=*/tensor) : - rewriter.create( - /*location=*/loc, - /*resultTensorTypes=*/tensor.getType().cast(), - /*inputs=*/ValueRange{inputs[0], weight}, - /*outputs=*/tensor); + auto newMM = isa(op) + ? rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/ + tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor) + : rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/ + tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor); mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); newMM->setAttr("splited", boolAttr); outputs.push_back(newMM->getResult(0)); } } -void SplitMMonK(Operation* op, SmallVector& outputs, SmallVector& inputs, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { +void SplitMMonK(Operation *op, SmallVector &outputs, + SmallVector &inputs, TensorType &resultTy, Location &loc, + PatternRewriter &rewriter) { /*Split on K axis*/ std::cout << "split on K" << std::endl; int64_t M = inputs[0].getType().cast().getDimSize(0); int64_t N = inputs[1].getType().cast().getDimSize(1); int64_t K = inputs[0].getType().cast().getDimSize(1); SmallVector splited_data, splited_weights; - getSplitedTensors(splited_data, inputs[0], /*target_dim*/1, rewriter); + getSplitedTensors(splited_data, inputs[0], /*target_dim*/ 1, rewriter); std::cout << "splited_data size: " << splited_data.size() << std::endl; - if (splited_data.size() != NUM_OF_NUMA) return; - getSplitedTensors(splited_weights, inputs[1], /*target_dim*/0, rewriter); + if (splited_data.size() != NUM_OF_NUMA) + return; + getSplitedTensors(splited_weights, inputs[1], /*target_dim*/ 0, rewriter); std::cout << "splited_weights size: " << splited_weights.size() << std::endl; - if (splited_weights.size() != NUM_OF_NUMA) return; + if (splited_weights.size() != NUM_OF_NUMA) + return; - for (auto [data, weight] : - llvm::zip_equal(splited_data, splited_weights)) { + for (auto [data, weight] : llvm::zip_equal(splited_data, splited_weights)) { Value zero = rewriter.create( - loc, rewriter.getZeroAttr(resultTy.getElementType())); - Value empty = rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType()); + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); Value tensor = rewriter.create(loc, zero, empty).getResult(0); auto newMM = rewriter.create( - /*location=*/loc, - /*resultTensorTypes=*/tensor.getType().cast(), - /*inputs=*/ValueRange{data, weight}, - /*outputs=*/tensor); + /*location=*/loc, + /*resultTensorTypes=*/tensor.getType().cast(), + /*inputs=*/ValueRange{data, weight}, + /*outputs=*/tensor); mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); newMM->setAttr("splited", boolAttr); outputs.push_back(newMM->getResult(0)); - outputs.push_back(newMM->getResult(0)); } } @@ -216,36 +257,47 @@ bool isSupportedPostOp(Operation *op) { // Helper function to get all post ops following the given operation void getUnOps(Operation *op, SmallVectorImpl &postOps) { for (auto user : op->getUsers()) { - if (isSupportedPostOp(user)) postOps.push_back(user); - if (isa(user)) return; - // Recursively search for unary ops, unless it's a matmul op + if (isSupportedPostOp(user)) + postOps.push_back(user); + if (isa(user)) + return; + // Recursively search for unary ops, unless it's a matmul op getUnOps(user, postOps); // } } } template -void duplicateBinary(SmallVector& outputs,std::vector>& inputs, TensorType& resultTy, PatternRewriter &rewriter) { +void duplicateBinary(SmallVector &outputs, + std::vector> &inputs, + TensorType &resultTy, PatternRewriter &rewriter) { for (int i = 0; i < NUM_OF_NUMA; ++i) { auto loc = inputs[i][0].getLoc(); TensorType type = inputs[i][0].getType().cast(); - Value Empty = rewriter.create( - loc, type.getShape(), type.getElementType()); - auto tmpOp = rewriter.create(loc, inputs[i], ValueRange {Empty}); + Value Empty = rewriter.create(loc, type.getShape(), + type.getElementType()); + auto tmpOp = rewriter.create(loc, inputs[i], ValueRange{Empty}); for (auto result : tmpOp->getResults()) { outputs.push_back(result); } } } -void duplicateTranspose(SmallVector& outputs,std::vector>& inputs, linalg::TransposeOp transposeOp, TensorType& resultTy, PatternRewriter &rewriter) { +void duplicateTranspose(SmallVector &outputs, + std::vector> &inputs, + linalg::TransposeOp transposeOp, TensorType &resultTy, + PatternRewriter &rewriter) { ArrayRef permutation = transposeOp.getPermutation(); - if (permutation.size() != SUPPORTED_RANK) {llvm::outs() << "unsupported rank\n"; return;} + if (permutation.size() != SUPPORTED_RANK) { + llvm::outs() << "unsupported rank\n"; + return; + } for (int i = 0; i < NUM_OF_NUMA; ++i) { auto loc = inputs[i][0].getLoc(); TensorType type = inputs[i][0].getType().cast(); const auto &inputShape = type.getShape(); - SmallVector transShape{inputShape[permutation[0]], inputShape[permutation[1]]}; + SmallVector transShape{inputShape[permutation[0]], + inputShape[permutation[1]]}; auto transTy = type.clone(transShape); llvm::outs() << "TransTy shape: ["; for (int64_t dim : transTy.getShape()) { @@ -253,12 +305,13 @@ void duplicateTranspose(SmallVector& outputs,std::vector( - loc, rewriter.getZeroAttr(transTy.getElementType())); - Value empty = rewriter.create( - loc, transTy.getShape(), transTy.getElementType()); + loc, rewriter.getZeroAttr(transTy.getElementType())); + Value empty = rewriter.create(loc, transTy.getShape(), + transTy.getElementType()); Value tensor = rewriter.create(loc, zero, empty).getResult(0); - auto tmpOp = rewriter.create(loc, inputs[i][0], tensor, permutation); + auto tmpOp = rewriter.create(loc, inputs[i][0], tensor, + permutation); for (auto result : tmpOp->getResults()) { outputs.push_back(result); } @@ -272,12 +325,15 @@ void deleteOperation(Operation *op) { // Step 2: Check each operand of the operation for (auto operand : op->getOperands()) { - if (!operand) continue; - if (operand.use_empty()) continue; // Skip if operand has no uses + if (!operand) + continue; + if (operand.use_empty()) + continue; // Skip if operand has no uses // If the operand is an operation and is either emptyOp or fillOp if (auto definingOp = operand.getDefiningOp()) { - // if (isa(definingOp) || isa(definingOp)) { + // if (isa(definingOp) || + // isa(definingOp)) { // llvm::outs() << "is empty \n"; // // Recursively delete the operand operation if it has only one use if (definingOp->hasOneUse()) { @@ -298,8 +354,11 @@ void deleteOperation(Operation *op) { void deleteOperands(Operation *op) { for (auto operand : op->getOperands()) { // llvm::outs() << "operands: " << operand << "\n"; - if (!operand) continue; - if (operand.use_empty()) {continue;} // Skip if operand has no uses + if (!operand) + continue; + if (operand.use_empty()) { + continue; + } // Skip if operand has no uses if (auto definingOp = operand.getDefiningOp()) { if (definingOp->hasOneUse()) { deleteOperands(definingOp); @@ -311,36 +370,42 @@ void deleteOperands(Operation *op) { } } -Value addN(Value& initTensor, SmallVector& ins, TensorType& resultTy, Location& loc, PatternRewriter &rewriter) { +Value addN(Value &initTensor, SmallVector &ins, TensorType &resultTy, + Location &loc, PatternRewriter &rewriter) { llvm::outs() << "start addN \n"; // Create indexing maps (for input tensors and output tensor) int num_of_args = int(ins.size()) + 1; MLIRContext *context = rewriter.getContext(); - SmallVector indexingMaps(num_of_args, - AffineMap::getMultiDimIdentityMap(resultTy.getRank(), context)); + SmallVector indexingMaps( + num_of_args, + AffineMap::getMultiDimIdentityMap(resultTy.getRank(), context)); llvm::outs() << "created affinemap \n"; // Create iterator types (parallel for all dimensions) // ArrayRef iteratorTypes(resultTy.getRank(), "parallel"); - SmallVector iteratorTypes(resultTy.getRank(), utils::IteratorType::parallel); + SmallVector iteratorTypes(resultTy.getRank(), + utils::IteratorType::parallel); llvm::outs() << "created IteratorType \n"; // Create the linalg.generic op auto genericOp = rewriter.create( - loc, resultTy, ValueRange{ins}, ValueRange{initTensor}, - indexingMaps, iteratorTypes, + loc, resultTy, ValueRange{ins}, ValueRange{initTensor}, indexingMaps, + iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Define the body of the linalg.generic operation (elementwise addition) - Value sum = nestedBuilder.create(nestedLoc, args[0], args[1]); + // Define the body of the linalg.generic operation (elementwise + // addition) + Value sum = + nestedBuilder.create(nestedLoc, args[0], args[1]); for (auto i = 2; i < num_of_args - 1; ++i) - sum = nestedBuilder.create(nestedLoc, sum, args[i]); // Add more if more inputs + sum = nestedBuilder.create( + nestedLoc, sum, args[i]); // Add more if more inputs nestedBuilder.create(nestedLoc, sum); }); // Mark the output as the result of the function (for demonstration purposes) - return genericOp.getResults().front();; + return genericOp.getResults().front(); + ; } -LogicalResult splitSingleMM(Operation* op, - PatternRewriter &rewriter) { +LogicalResult splitSingleMM(Operation *op, PatternRewriter &rewriter) { SmallVector postOps = {}; getUnOps(op, postOps); auto loc = op->getLoc(); @@ -356,20 +421,24 @@ LogicalResult splitSingleMM(Operation* op, bool istransB = isa(op); llvm::outs() << "is trans B\n"; int64_t M = input_tensors[0].getType().cast().getDimSize(0); - int64_t N = input_tensors[1].getType().cast().getDimSize(istransB ? 0 : 1); + int64_t N = input_tensors[1].getType().cast().getDimSize( + istransB ? 0 : 1); int64_t K = input_tensors[0].getType().cast().getDimSize(1); std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; int64_t target_dim = N / K >= 2 ? 0 : 0; SmallVector splites_res; if (target_dim == 1) { - SplitMMonN(op, splites_res, input_tensors, resultTy, target_dim ^ istransB, loc, rewriter); - if (splites_res.size() != NUM_OF_NUMA) return failure(); + SplitMMonN(op, splites_res, input_tensors, resultTy, target_dim ^ istransB, + loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) + return failure(); SmallVector Outputs = splites_res; auto lastInput = op->getResult(0); llvm::outs() << "postOps num: " << postOps.size() << "\n"; for (auto postOp : postOps) { - llvm::outs() << "Operation name: " << postOp->getName().getStringRef() << "\n"; + llvm::outs() << "Operation name: " << postOp->getName().getStringRef() + << "\n"; auto opInputs = postOp->getOperands().drop_back(); llvm::outs() << "inputs: " << opInputs.size() << "\n"; auto opOutputs = postOp->getResults(); @@ -384,45 +453,53 @@ LogicalResult splitSingleMM(Operation* op, innerVector.push_back(Outputs[0]); Inputs.push_back(innerVector); Outputs.erase(Outputs.begin()); - llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() <<" \n"; + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() + << " \n"; } } else if (auto definingOp = input.getDefiningOp()) { llvm::outs() << "is definingOp\n"; - std::cout << "Input operation name: " << definingOp->getName().getStringRef().str() << std::endl; + std::cout << "Input operation name: " + << definingOp->getName().getStringRef().str() << std::endl; if (auto fillOp = dyn_cast(definingOp)) { llvm::outs() << "is fill \n"; SmallVector splited_inputs; getSplitedTensors(splited_inputs, input, target_dim, rewriter); int i = 0; for (const auto &splited_input : splited_inputs) { - Inputs[i].push_back(splited_input); - llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() <<" \n"; - i++; + Inputs[i].push_back(splited_input); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() + << " \n"; + i++; } llvm::outs() << "split input done \n"; - } else if (auto broadcastOp = dyn_cast(definingOp)){ + } else if (auto broadcastOp = + dyn_cast(definingOp)) { llvm::outs() << "is broadcast \n"; SmallVector splited_inputs; splitBroadcast(splited_inputs, broadcastOp, target_dim, rewriter); - llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; + llvm::outs() << "inputs[0].size: " << Inputs[0].size() << " \n"; int i = 0; for (const auto &splited_input : splited_inputs) { - Inputs[i].push_back(splited_input); - i++; + Inputs[i].push_back(splited_input); + i++; } deleteOperation(broadcastOp); llvm::outs() << "split input done \n"; - } else if (auto constantOp = dyn_cast(definingOp)){ + } else if (auto constantOp = + dyn_cast(definingOp)) { llvm::outs() << "is constant \n"; auto newConstantOp = rewriter.create( - constantOp.getLoc(), constantOp.getType(), constantOp.getValue()); + constantOp.getLoc(), constantOp.getType(), + constantOp.getValue()); SmallVector splited_inputs; - getSplitedTensors(splited_inputs, newConstantOp, target_dim, rewriter); + getSplitedTensors(splited_inputs, newConstantOp, target_dim, + rewriter); int i = 0; for (const auto &splited_input : splited_inputs) { - Inputs[i].push_back(splited_input); - llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() <<" \n"; - i++; + Inputs[i].push_back(splited_input); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() + << " \n"; + i++; } deleteOperation(constantOp); llvm::outs() << "split input done \n"; @@ -431,11 +508,11 @@ LogicalResult splitSingleMM(Operation* op, llvm::outs() << "doesnot match anything \n"; SmallVector splited_inputs; getSplitedTensors(splited_inputs, input, target_dim, rewriter); - llvm::outs() << "inputs[0].size: " << Inputs[0].size() <<" \n"; + llvm::outs() << "inputs[0].size: " << Inputs[0].size() << " \n"; int i = 0; for (const auto &splited_input : splited_inputs) { - Inputs[i].push_back(splited_input); - i++; + Inputs[i].push_back(splited_input); + i++; } llvm::outs() << "split input done \n"; } @@ -452,12 +529,13 @@ LogicalResult splitSingleMM(Operation* op, // } llvm::outs() << "post op creation and deletion done \n"; lastInput = postOp->getResult(0); - if(auto lastop = lastInput.getDefiningOp()) - std::cout << "lastInput operation name: " << lastop->getName().getStringRef().str() << std::endl; + if (auto lastop = lastInput.getDefiningOp()) + std::cout << "lastInput operation name: " + << lastop->getName().getStringRef().str() << std::endl; } // Concatenate the two halves back together on N axis - auto newop = rewriter.create( - Outputs.back().getLoc(), target_dim, Outputs); + auto newop = rewriter.create(Outputs.back().getLoc(), + target_dim, Outputs); llvm::outs() << "created concat \n"; auto replaced_op = postOps.size() ? postOps.back() : op; if (postOps.size() > 1) { @@ -472,13 +550,16 @@ LogicalResult splitSingleMM(Operation* op, llvm::outs() << "after duplicate, postOps num: " << postOps.size() << "\n"; } else { SplitMMonK(op, splites_res, input_tensors, resultTy, loc, rewriter); - if (splites_res.size() != NUM_OF_NUMA) return failure(); + if (splites_res.size() != NUM_OF_NUMA) { + llvm::outs() << "not getting the expected splited outputs\n"; + return failure(); + } // Add the two halves back together // Create linalg.map operation Value zero = rewriter.create( - loc, rewriter.getZeroAttr(resultTy.getElementType())); - Value empty = rewriter.create( - loc, resultTy.getShape(), resultTy.getElementType()); + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); Value initTensor = rewriter.create(loc, zero, empty).getResult(0); auto newop = addN(initTensor, splites_res, resultTy, loc, rewriter); @@ -490,8 +571,7 @@ LogicalResult splitSingleMM(Operation* op, return success(); } -class SplitMatmulRewriter - : public OpRewritePattern { +class SplitMatmulRewriter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(linalg::MatmulOp op, @@ -519,8 +599,9 @@ class SplitMatmulTransposeBRewriter namespace gc { class SplitComputeIntensivePatterns - : public impl::SplitComputeIntensivePatternsBase { -public: + : public impl::SplitComputeIntensivePatternsBase< + SplitComputeIntensivePatterns> { +public: using impl::SplitComputeIntensivePatternsBase< SplitComputeIntensivePatterns>::SplitComputeIntensivePatternsBase; void runOnOperation() final { @@ -534,11 +615,11 @@ class SplitComputeIntensivePatterns ops.push_back(op); }); GreedyRewriteConfig config; - config.strictMode = GreedyRewriteStrictness::ExistingOps; + config.strictMode = GreedyRewriteStrictness::ExistingOps; bool erased; std::cout << "ops.size(): " << ops.size() << std::endl; - if (failed(applyOpPatternsAndFold(ops, patternSet, - config, /*changed=*/nullptr, &erased))) + if (failed(applyOpPatternsAndFold(ops, patternSet, config, + /*changed=*/nullptr, &erased))) signalPassFailure(); return; } diff --git a/tools/drivers.py b/tools/drivers.py index 35a26023f..15ca09401 100644 --- a/tools/drivers.py +++ b/tools/drivers.py @@ -83,11 +83,23 @@ def init_module(self, ctx: ir.Context) -> ir.Module: def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) np_args = [] + idx = 0 for arg in bench_func.arguments: - np_args.append(make_tensor(arg.type)) + print(idx, ":", arg) + if idx in [1, 3]: + np_args.append(make_tensor(arg.type, 1)) + else: + np_args.append(make_tensor(arg.type, 0)) + # idx += 1 + idx = 0 if not disable_results_to_params: for res in bench_func.type.results: - np_args.append(make_tensor(res)) + print(idx, ":", res) + if idx in [1]: + np_args.append(make_tensor(res, 1)) + else: + np_args.append(make_tensor(res, 0)) + # idx += 1 # todo : data filling for i in range(len(np_args)): np.ndarray.fill(np_args[i], 1) diff --git a/tools/workloads/single_mm.mlir b/tools/workloads/single_mm.mlir new file mode 100644 index 000000000..705b2ea76 --- /dev/null +++ b/tools/workloads/single_mm.mlir @@ -0,0 +1,7 @@ +func.func @main_entry(%arg0: tensor<128x512xf32>, %arg1: tensor<512x64xf32>) -> tensor<128x64xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x64xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x64xf32>) -> tensor<128x64xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xf32>, tensor<512x64xf32>) outs(%1 : tensor<128x64xf32>) -> tensor<128x64xf32> + return %2 : tensor<128x64xf32> +} \ No newline at end of file diff --git a/tools/workloads/splited_mm.mlir b/tools/workloads/splited_mm.mlir new file mode 100644 index 000000000..07c1bb059 --- /dev/null +++ b/tools/workloads/splited_mm.mlir @@ -0,0 +1,12 @@ +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @main_entry(%arg0: tensor<128x256xf32>, %arg0_1: tensor<128x256xf32>, %arg1: tensor<256x64xf32>, %arg1_1: tensor<256x64xf32>) -> (tensor<128x64xf32>, tensor<128x64xf32>) attributes {llvm.emit_c_interface} { + %cst_3 = arith.constant 0.000000e+00 : f32 + %2 = tensor.empty() : tensor<128x64xf32> + %3 = linalg.fill ins(%cst_3 : f32) outs(%2 : tensor<128x64xf32>) -> tensor<128x64xf32> + %4 = linalg.matmul {splited = true} ins(%arg0, %arg1 : tensor<128x256xf32>, tensor<256x64xf32>) outs(%3 : tensor<128x64xf32>) -> tensor<128x64xf32> + %cst_4 = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : tensor<128x64xf32> + %6 = linalg.fill ins(%cst_4 : f32) outs(%5 : tensor<128x64xf32>) -> tensor<128x64xf32> + %7 = linalg.matmul {splited = true} ins(%arg0_1, %arg1_1 : tensor<128x256xf32>, tensor<256x64xf32>) outs(%6 : tensor<128x64xf32>) -> tensor<128x64xf32> + return %4, %7 : tensor<128x64xf32>, tensor<128x64xf32> +} \ No newline at end of file diff --git a/tools/workloads/test.mlir b/tools/workloads/test.mlir deleted file mode 100644 index 99feeffc3..000000000 --- a/tools/workloads/test.mlir +++ /dev/null @@ -1,25 +0,0 @@ -func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x64xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %3 = tensor.empty() : tensor<128x64xbf16> - %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] - %4 = tensor.empty() : tensor<128x64xbf16> - %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> - %6 = tensor.empty() : tensor<128x64xbf16> - %7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> - %cst_1 = arith.constant 0.000000e+00 : bf16 - %8 = tensor.empty() : tensor<128x256xbf16> - %9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %11 = tensor.empty() : tensor<128x256xbf16> - %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] - %12 = tensor.empty() : tensor<128x256xbf16> - %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> - %14 = tensor.empty() : tensor<128x256xbf16> - %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - return %15 : tensor<128x256xbf16> -} \ No newline at end of file From 9c7671ea09af2478e3acd404866e4d4a3bd2f47f Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Tue, 2 Jul 2024 19:48:23 -0700 Subject: [PATCH 09/10] enable pybind build --- scripts/compile.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/compile.sh b/scripts/compile.sh index 3d97a3926..35b77e816 100755 --- a/scripts/compile.sh +++ b/scripts/compile.sh @@ -74,7 +74,7 @@ build_llvm() { -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=true \ -DLLVM_ENABLE_PROJECTS="mlir" -DLLVM_TARGETS_TO_BUILD="X86" \ -DLLVM_INSTALL_UTILS=true -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DLLVM_INSTALL_GTEST=ON -DLLVM_BUILD_LLVM_DYLIB=$dylib -DLLVM_LINK_LLVM_DYLIB=$dylib + -DLLVM_INSTALL_GTEST=ON -DLLVM_BUILD_LLVM_DYLIB=$dylib -DLLVM_LINK_LLVM_DYLIB=$dylib -DMLIR_ENABLE_BINDINGS_PYTHON=ON -DPython3_EXECUTABLE=$(which python3) cmake --build build MLIR_DIR="$PWD/build/lib/cmake/mlir" From 978bcd4ebaad657ac14868f6b8e9e9d6eedf1a38 Mon Sep 17 00:00:00 2001 From: ZhangYan Date: Tue, 2 Jul 2024 19:49:09 -0700 Subject: [PATCH 10/10] update llvm --- cmake/llvm-version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/llvm-version.txt b/cmake/llvm-version.txt index bd9ee1a10..0e5ced519 100644 --- a/cmake/llvm-version.txt +++ b/cmake/llvm-version.txt @@ -1 +1 @@ -7042fcc6389c6c103d501b6f39988eafed0d9b5b +891ec2af45c02718c65f539cb6dad1758f079e73