diff --git a/cmake/llvm-version.txt b/cmake/llvm-version.txt index bd9ee1a10..0e5ced519 100644 --- a/cmake/llvm-version.txt +++ b/cmake/llvm-version.txt @@ -1 +1 @@ -7042fcc6389c6c103d501b6f39988eafed0d9b5b +891ec2af45c02718c65f539cb6dad1758f079e73 diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index aaea602b6..b733baad5 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -46,4 +46,16 @@ def GCCPUPipeline: Pass<"gc-cpu-pipeline"> { "vector::VectorDialect"]; } +def SplitComputeIntensivePatterns : Pass<"split-compute-intensive-patterns"> { + let summary = "Split matmul patterns"; + let description = [{ + Split matmul patterns' weights into several parts, number of which aligns + with the number of target machine's numa node. + }]; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + "mlir::tensor::TensorDialect", + "mlir::arith::ArithDialect"]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 7be337566..cbee61a60 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_library(GCPasses OneDNNGraphToLinalg.cpp Pipeline.cpp TileNamed.cpp + SplitComputeIntensivePatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include diff --git a/lib/gc/Transforms/OneDNNGraphToLinalg.cpp b/lib/gc/Transforms/OneDNNGraphToLinalg.cpp index c472dbe87..d31ebf633 100644 --- a/lib/gc/Transforms/OneDNNGraphToLinalg.cpp +++ b/lib/gc/Transforms/OneDNNGraphToLinalg.cpp @@ -26,6 +26,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/raw_ostream.h" +#include using namespace mlir::onednn_graph; namespace mlir { @@ -492,6 +494,8 @@ struct MatMulOpLowering : public OpRewritePattern { /*outputs=*/outBias); } + // Passing mutmal configs to linalg.matmul + newOp->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newOp); return success(); } diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 6e5151e9e..fd5e342c0 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -38,6 +38,7 @@ void populateFrontendPasses(mlir::PassManager &pm) { // scf + arith + math + vector + tensor + linalg.brgemm + tensor.pack/unpack void populateTensorPasses(mlir::PassManager &pm) { + // pm.addNestedPass(createSplitComputeIntensivePatterns()); // todo: padding propagation pass // todo: layout propagation pass // todo: tensor constant propagation pass @@ -75,6 +76,7 @@ void populateBufferizationPasses(mlir::PassManager &pm) { bufferization::LayoutMapOption::IdentityLayoutMap); pm.addPass(bufferization::createOneShotBufferizePass(options)); pm.addPass(createCSEPass()); + bufferization::BufferResultsToOutParamsOpts opt{}; opt.hoistStaticAllocs = true; pm.addPass(bufferization::createBufferResultsToOutParamsPass(opt)); diff --git a/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp new file mode 100644 index 000000000..9b8addb7d --- /dev/null +++ b/lib/gc/Transforms/SplitComputeIntensivePatterns.cpp @@ -0,0 +1,629 @@ +/******************************************************************************* + * Copyright 2024 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" + +#include "gc/Transforms/Passes.h" + +#include + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_SPLITCOMPUTEINTENSIVEPATTERNS +#include "gc/Transforms/Passes.h.inc" +} // namespace gc + +size_t NUM_OF_NUMA = 2; +size_t SUPPORTED_RANK = 2; + +void printValueType(Value value) { + if (!value) { + llvm::outs() << "Invalid value\n"; + return; + } + + Type type = value.getType(); + type.print(llvm::outs()); + llvm::outs() << "\n"; +} + +void getSplitedTensors(SmallVector &outputs, Value tensor, + int64_t target_dim, PatternRewriter &rewriter) { + auto Type = tensor.getType().cast(); + auto loc = tensor.getLoc(); + int64_t rank = Type.getRank(); + llvm::outs() << "split rank: " << rank << "\n"; + if (!Type || Type.getRank() > SUPPORTED_RANK) { + return; + } + llvm::outs() << "split shape: ["; + for (int64_t dim : Type.getShape()) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + llvm::outs() << "target_dim: " << target_dim << "\n"; + bool has_tail = Type.getDimSize(target_dim) % NUM_OF_NUMA != 0; + int64_t split_length = + (Type.getDimSize(target_dim) + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + SmallVector shape(Type.getShape().begin(), Type.getShape().end()); + shape[target_dim] = split_length; + // Split the weight tensor into NUM_OF_NUMA parts + auto splitEvenType = RankedTensorType::get(shape, Type.getElementType()); + auto splitTailType = splitEvenType; + if (has_tail) { + shape[target_dim] = Type.getDimSize(target_dim) % split_length; + splitTailType = RankedTensorType::get(shape, Type.getElementType()); + } + llvm::outs() << "start to extract slice\n"; + for (auto split_idx : llvm::seq(0, NUM_OF_NUMA)) { + SmallVector sizes; + SmallVector offsets; + SmallVector strides(rank, rewriter.getIndexAttr(1)); + for (auto i : llvm::seq(0, rank)) { + sizes.push_back(rewriter.getIndexAttr((split_idx == (NUM_OF_NUMA - 1)) + ? splitTailType.getShape()[i] + : splitEvenType.getShape()[i])); + offsets.push_back( + rewriter.getIndexAttr((split_idx == 0 || i != target_dim) + ? 0 + : splitEvenType.getShape()[i] * split_idx)); + } + Value res = + rewriter + .create( + loc, + split_idx == (NUM_OF_NUMA - 1) ? splitTailType : splitEvenType, + tensor, offsets, sizes, strides) + ->getResult(0); + auto res_type = res.getType().cast(); + llvm::outs() << "splited shape: ["; + for (int64_t dim : res_type.getShape()) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + outputs.push_back(res); + std::cout << outputs.size() << std::endl; + } +} + +void splitBroadcast(SmallVector &outputs, + linalg::BroadcastOp broadcastOp, int64_t target_dim, + PatternRewriter &rewriter) { + auto loc = broadcastOp->getLoc(); + SmallVector broadcastInputs; + auto in = broadcastOp.getInput(); + if (in.getType().getShape().size() > SUPPORTED_RANK) { + llvm::outs() << "cannot split broadcast on current size.\n"; + return; + } + auto out = broadcastOp.getInit(); + auto outType = out.getType().dyn_cast(); + auto shape = outType.getShape(); + if (shape.size() != SUPPORTED_RANK || target_dim != 1) { + llvm::outs() + << "cannot split broadcast on current size or current target dim \n"; + return; + } + llvm::outs() << "Tensor shape: ["; + for (int64_t dim : shape) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + llvm::outs() << "duplicate broadcast inputs\n"; + getSplitedTensors(broadcastInputs, in, + /*target_dim*/ in.getType().getShape().size() - 1, + rewriter); + if (auto emptyOp = dyn_cast(out.getDefiningOp())) { + int64_t split_length = (shape[1] + NUM_OF_NUMA - 1) / NUM_OF_NUMA; + int64_t split_tail = + shape[1] % NUM_OF_NUMA != 0 ? shape[1] % split_length : split_length; + for (auto split_idx : llvm::seq(0, NUM_OF_NUMA)) { + Value empty = rewriter.create( + loc, + ArrayRef{shape[0], (split_idx == (NUM_OF_NUMA - 1)) + ? split_tail + : split_length}, + outType.getElementType()); + Value res = + rewriter + .create(loc, broadcastInputs[split_idx], + empty, broadcastOp.getDimensions()) + .getResults()[0]; + outputs.push_back(res); + std::cout << outputs.size() << std::endl; + } + } +} + +void SplitMMonN(Operation *op, SmallVector &outputs, + SmallVector &inputs, TensorType &resultTy, + int64_t target_dim, Location &loc, PatternRewriter &rewriter) { + /*Split on N axis*/ + std::cout << "split on N" << std::endl; + int64_t M = inputs[0].getType().cast().getDimSize(0); + int64_t N = + inputs[1].getType().cast().getDimSize(target_dim); + int64_t K = inputs[0].getType().cast().getDimSize(1); + SmallVector splited_weights; + getSplitedTensors(splited_weights, inputs[1], target_dim, rewriter); + if (splited_weights.size() != NUM_OF_NUMA) + return; + + for (Value weight : splited_weights) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + std::cout << "weight.getType().cast().getDimSize(1): " + << weight.getType().cast().getDimSize( + target_dim) + << std::endl; + Value empty = rewriter.create( + loc, + ArrayRef{ + M, + weight.getType().cast().getDimSize(target_dim)}, + resultTy.getElementType()); + Value tensor = + rewriter.create(loc, zero, empty).getResult(0); + auto newMM = isa(op) + ? rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/ + tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor) + : rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/ + tensor.getType().cast(), + /*inputs=*/ValueRange{inputs[0], weight}, + /*outputs=*/tensor); + mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); + newMM->setAttr("splited", boolAttr); + outputs.push_back(newMM->getResult(0)); + } +} + +void SplitMMonK(Operation *op, SmallVector &outputs, + SmallVector &inputs, TensorType &resultTy, Location &loc, + PatternRewriter &rewriter) { + /*Split on K axis*/ + std::cout << "split on K" << std::endl; + int64_t M = inputs[0].getType().cast().getDimSize(0); + int64_t N = inputs[1].getType().cast().getDimSize(1); + int64_t K = inputs[0].getType().cast().getDimSize(1); + SmallVector splited_data, splited_weights; + getSplitedTensors(splited_data, inputs[0], /*target_dim*/ 1, rewriter); + std::cout << "splited_data size: " << splited_data.size() << std::endl; + if (splited_data.size() != NUM_OF_NUMA) + return; + getSplitedTensors(splited_weights, inputs[1], /*target_dim*/ 0, rewriter); + std::cout << "splited_weights size: " << splited_weights.size() << std::endl; + if (splited_weights.size() != NUM_OF_NUMA) + return; + + for (auto [data, weight] : llvm::zip_equal(splited_data, splited_weights)) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); + Value tensor = + rewriter.create(loc, zero, empty).getResult(0); + auto newMM = rewriter.create( + /*location=*/loc, + /*resultTensorTypes=*/tensor.getType().cast(), + /*inputs=*/ValueRange{data, weight}, + /*outputs=*/tensor); + mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true); + newMM->setAttr("splited", boolAttr); + outputs.push_back(newMM->getResult(0)); + } +} + +bool isSupportedPostOp(Operation *op) { + // Check if the operation is a linalg operation + if (!isa(op)) + return false; + + // Get the inputs and outputs of the linalg operation + bool isMax = isa(op); + bool isAdd = isa(op); + bool isMul = isa(op); + // bool isTranspose = isa(op); + return isMax || isAdd || isMul; +} + +// Helper function to get all post ops following the given operation +void getUnOps(Operation *op, SmallVectorImpl &postOps) { + for (auto user : op->getUsers()) { + if (isSupportedPostOp(user)) + postOps.push_back(user); + if (isa(user)) + return; + // Recursively search for unary ops, unless it's a matmul op + getUnOps(user, postOps); + // } + } +} + +template +void duplicateBinary(SmallVector &outputs, + std::vector> &inputs, + TensorType &resultTy, PatternRewriter &rewriter) { + for (int i = 0; i < NUM_OF_NUMA; ++i) { + auto loc = inputs[i][0].getLoc(); + TensorType type = inputs[i][0].getType().cast(); + Value Empty = rewriter.create(loc, type.getShape(), + type.getElementType()); + auto tmpOp = rewriter.create(loc, inputs[i], ValueRange{Empty}); + for (auto result : tmpOp->getResults()) { + outputs.push_back(result); + } + } +} + +void duplicateTranspose(SmallVector &outputs, + std::vector> &inputs, + linalg::TransposeOp transposeOp, TensorType &resultTy, + PatternRewriter &rewriter) { + ArrayRef permutation = transposeOp.getPermutation(); + if (permutation.size() != SUPPORTED_RANK) { + llvm::outs() << "unsupported rank\n"; + return; + } + for (int i = 0; i < NUM_OF_NUMA; ++i) { + auto loc = inputs[i][0].getLoc(); + TensorType type = inputs[i][0].getType().cast(); + const auto &inputShape = type.getShape(); + SmallVector transShape{inputShape[permutation[0]], + inputShape[permutation[1]]}; + auto transTy = type.clone(transShape); + llvm::outs() << "TransTy shape: ["; + for (int64_t dim : transTy.getShape()) { + llvm::outs() << dim << " "; + } + llvm::outs() << "]\n"; + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(transTy.getElementType())); + Value empty = rewriter.create(loc, transTy.getShape(), + transTy.getElementType()); + Value tensor = + rewriter.create(loc, zero, empty).getResult(0); + auto tmpOp = rewriter.create(loc, inputs[i][0], tensor, + permutation); + for (auto result : tmpOp->getResults()) { + outputs.push_back(result); + } + } +} + +void deleteOperation(Operation *op) { + // Step 1: Ensure the operation exists + if (!op) + return; + + // Step 2: Check each operand of the operation + for (auto operand : op->getOperands()) { + if (!operand) + continue; + if (operand.use_empty()) + continue; // Skip if operand has no uses + + // If the operand is an operation and is either emptyOp or fillOp + if (auto definingOp = operand.getDefiningOp()) { + // if (isa(definingOp) || + // isa(definingOp)) { + // llvm::outs() << "is empty \n"; + // // Recursively delete the operand operation if it has only one use + if (definingOp->hasOneUse()) { + deleteOperation(definingOp); + } + // } + } + } + + // Step 3: Disconnect the operation from its operands and users + op->dropAllUses(); + op->dropAllReferences(); + + // Step 4: Erase the operation from its parent block + op->erase(); +} + +void deleteOperands(Operation *op) { + for (auto operand : op->getOperands()) { + // llvm::outs() << "operands: " << operand << "\n"; + if (!operand) + continue; + if (operand.use_empty()) { + continue; + } // Skip if operand has no uses + if (auto definingOp = operand.getDefiningOp()) { + if (definingOp->hasOneUse()) { + deleteOperands(definingOp); + definingOp->dropAllUses(); + definingOp->dropAllReferences(); + definingOp->erase(); + } + } + } +} + +Value addN(Value &initTensor, SmallVector &ins, TensorType &resultTy, + Location &loc, PatternRewriter &rewriter) { + llvm::outs() << "start addN \n"; + // Create indexing maps (for input tensors and output tensor) + int num_of_args = int(ins.size()) + 1; + MLIRContext *context = rewriter.getContext(); + SmallVector indexingMaps( + num_of_args, + AffineMap::getMultiDimIdentityMap(resultTy.getRank(), context)); + llvm::outs() << "created affinemap \n"; + // Create iterator types (parallel for all dimensions) + // ArrayRef iteratorTypes(resultTy.getRank(), "parallel"); + SmallVector iteratorTypes(resultTy.getRank(), + utils::IteratorType::parallel); + llvm::outs() << "created IteratorType \n"; + // Create the linalg.generic op + auto genericOp = rewriter.create( + loc, resultTy, ValueRange{ins}, ValueRange{initTensor}, indexingMaps, + iteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + // Define the body of the linalg.generic operation (elementwise + // addition) + Value sum = + nestedBuilder.create(nestedLoc, args[0], args[1]); + for (auto i = 2; i < num_of_args - 1; ++i) + sum = nestedBuilder.create( + nestedLoc, sum, args[i]); // Add more if more inputs + nestedBuilder.create(nestedLoc, sum); + }); + + // Mark the output as the result of the function (for demonstration purposes) + return genericOp.getResults().front(); + ; +} + +LogicalResult splitSingleMM(Operation *op, PatternRewriter &rewriter) { + SmallVector postOps = {}; + getUnOps(op, postOps); + auto loc = op->getLoc(); + auto resultTy = dyn_cast(op->getResultTypes().front()); + auto input_operands = op->getOperands().drop_back(); + SmallVector input_tensors; + for (Value operand : input_operands) { + if (!operand.getType().isa()) { + continue; + } + input_tensors.push_back(operand); + } + bool istransB = isa(op); + llvm::outs() << "is trans B\n"; + int64_t M = input_tensors[0].getType().cast().getDimSize(0); + int64_t N = input_tensors[1].getType().cast().getDimSize( + istransB ? 0 : 1); + int64_t K = input_tensors[0].getType().cast().getDimSize(1); + std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; + + int64_t target_dim = N / K >= 2 ? 0 : 0; + SmallVector splites_res; + if (target_dim == 1) { + SplitMMonN(op, splites_res, input_tensors, resultTy, target_dim ^ istransB, + loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) + return failure(); + SmallVector Outputs = splites_res; + auto lastInput = op->getResult(0); + llvm::outs() << "postOps num: " << postOps.size() << "\n"; + for (auto postOp : postOps) { + llvm::outs() << "Operation name: " << postOp->getName().getStringRef() + << "\n"; + auto opInputs = postOp->getOperands().drop_back(); + llvm::outs() << "inputs: " << opInputs.size() << "\n"; + auto opOutputs = postOp->getResults(); + llvm::outs() << "outputs: " << opOutputs.size() << "\n"; + + std::vector> Inputs; + for (auto input : opInputs) { + if (input == lastInput) { + std::cout << "enter mm output" << std::endl; + for (size_t i = 0; i < NUM_OF_NUMA; ++i) { + SmallVector innerVector; + innerVector.push_back(Outputs[0]); + Inputs.push_back(innerVector); + Outputs.erase(Outputs.begin()); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() + << " \n"; + } + } else if (auto definingOp = input.getDefiningOp()) { + llvm::outs() << "is definingOp\n"; + std::cout << "Input operation name: " + << definingOp->getName().getStringRef().str() << std::endl; + if (auto fillOp = dyn_cast(definingOp)) { + llvm::outs() << "is fill \n"; + SmallVector splited_inputs; + getSplitedTensors(splited_inputs, input, target_dim, rewriter); + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() + << " \n"; + i++; + } + llvm::outs() << "split input done \n"; + } else if (auto broadcastOp = + dyn_cast(definingOp)) { + llvm::outs() << "is broadcast \n"; + SmallVector splited_inputs; + splitBroadcast(splited_inputs, broadcastOp, target_dim, rewriter); + llvm::outs() << "inputs[0].size: " << Inputs[0].size() << " \n"; + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + i++; + } + deleteOperation(broadcastOp); + llvm::outs() << "split input done \n"; + } else if (auto constantOp = + dyn_cast(definingOp)) { + llvm::outs() << "is constant \n"; + auto newConstantOp = rewriter.create( + constantOp.getLoc(), constantOp.getType(), + constantOp.getValue()); + SmallVector splited_inputs; + getSplitedTensors(splited_inputs, newConstantOp, target_dim, + rewriter); + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + llvm::outs() << "inputs[" << i << "].size: " << Inputs[i].size() + << " \n"; + i++; + } + deleteOperation(constantOp); + llvm::outs() << "split input done \n"; + } + } else { + llvm::outs() << "doesnot match anything \n"; + SmallVector splited_inputs; + getSplitedTensors(splited_inputs, input, target_dim, rewriter); + llvm::outs() << "inputs[0].size: " << Inputs[0].size() << " \n"; + int i = 0; + for (const auto &splited_input : splited_inputs) { + Inputs[i].push_back(splited_input); + i++; + } + llvm::outs() << "split input done \n"; + } + } + if (auto postOpType = llvm::dyn_cast(postOp)) + duplicateBinary(Outputs, Inputs, resultTy, rewriter); + else if (auto postOpType = llvm::dyn_cast(postOp)) + duplicateBinary(Outputs, Inputs, resultTy, rewriter); + else if (auto postOpType = llvm::dyn_cast(postOp)) + duplicateBinary(Outputs, Inputs, resultTy, rewriter); + // else if (auto transOp = llvm::dyn_cast(postOp)) { + // duplicateTranspose(Outputs, Inputs, transOp, resultTy, rewriter); + // target_dim ^= 0x1; + // } + llvm::outs() << "post op creation and deletion done \n"; + lastInput = postOp->getResult(0); + if (auto lastop = lastInput.getDefiningOp()) + std::cout << "lastInput operation name: " + << lastop->getName().getStringRef().str() << std::endl; + } + // Concatenate the two halves back together on N axis + auto newop = rewriter.create(Outputs.back().getLoc(), + target_dim, Outputs); + llvm::outs() << "created concat \n"; + auto replaced_op = postOps.size() ? postOps.back() : op; + if (postOps.size() > 1) { + postOps.pop_back(); + deleteOperation(op); + for (auto &deleteOp : postOps) + deleteOperation(deleteOp); + } + deleteOperands(replaced_op); + rewriter.replaceOp(replaced_op, newop); + postOps = {}; + llvm::outs() << "after duplicate, postOps num: " << postOps.size() << "\n"; + } else { + SplitMMonK(op, splites_res, input_tensors, resultTy, loc, rewriter); + if (splites_res.size() != NUM_OF_NUMA) { + llvm::outs() << "not getting the expected splited outputs\n"; + return failure(); + } + // Add the two halves back together + // Create linalg.map operation + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(resultTy.getElementType())); + Value empty = rewriter.create(loc, resultTy.getShape(), + resultTy.getElementType()); + Value initTensor = + rewriter.create(loc, zero, empty).getResult(0); + auto newop = addN(initTensor, splites_res, resultTy, loc, rewriter); + // Replace the original operation with the new linalg.map operation + rewriter.replaceOp(op, newop); + } + llvm::outs() << "exit duplicate mm.\n"; + llvm::outs() << "==================================================\n"; + return success(); +} + +class SplitMatmulRewriter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::MatmulOp op, + PatternRewriter &rewriter) const final { + // Check if the operation has already been processed + if (op->hasAttr("splited")) + return failure(); + return splitSingleMM(op, rewriter); + } +}; + +class SplitMatmulTransposeBRewriter + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(linalg::MatmulTransposeBOp op, + PatternRewriter &rewriter) const final { + // Check if the operation has already been processed + llvm::outs() << "get into mm transpose b\n"; + if (op->hasAttr("splited")) + return failure(); + return splitSingleMM(op, rewriter); + } +}; + +namespace gc { +class SplitComputeIntensivePatterns + : public impl::SplitComputeIntensivePatternsBase< + SplitComputeIntensivePatterns> { +public: + using impl::SplitComputeIntensivePatternsBase< + SplitComputeIntensivePatterns>::SplitComputeIntensivePatternsBase; + void runOnOperation() final { + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + FrozenRewritePatternSet patternSet(std::move(patterns)); + SmallVector ops; + getOperation()->walk([&](Operation *op) { + if (isa(op)) + ops.push_back(op); + }); + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + bool erased; + std::cout << "ops.size(): " << ops.size() << std::endl; + if (failed(applyOpPatternsAndFold(ops, patternSet, config, + /*changed=*/nullptr, &erased))) + signalPassFailure(); + return; + } +}; + +} // namespace gc +} // namespace mlir diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 2aef3f17a..df613d3b4 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -34,6 +34,7 @@ declare_mlir_python_sources(GcPythonSources.Common ADD_TO_PARENT GcPythonSources SOURCES __init__.py + graph_compiler.py dialects/__init__.py # init hooks _mlir_libs/_site_initialize_0.py @@ -83,6 +84,8 @@ add_mlir_python_common_capi_library(GcPythonCAPI GcPythonSources MLIRPythonExtension.RegisterEverything MLIRPythonSources.Core + MLIRPythonSources.Dialects.linalg + MLIRPythonSources.ExecutionEngine ) ################################################################################ @@ -96,6 +99,7 @@ add_mlir_python_modules(GcPythonModules GcPythonSources MLIRPythonExtension.RegisterEverything MLIRPythonSources + MLIRPythonSources.ExecutionEngine COMMON_CAPI_LINK_LIBS GcPythonCAPI ) \ No newline at end of file diff --git a/python/gc_mlir/graph_compiler.py b/python/gc_mlir/graph_compiler.py new file mode 100644 index 000000000..0873f2b61 --- /dev/null +++ b/python/gc_mlir/graph_compiler.py @@ -0,0 +1,48 @@ +# ===-- graph_compiler.py - DESC ------------------------------*- Python -*-===# +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# ===-----------------------------------------------------------------------===# + +from gc_mlir import execution_engine +from gc_mlir import ir +from gc_mlir import passmanager +from typing import Sequence + +__all__ = [ + "GraphCompiler", +] + + +class GraphCompiler: + def __init__( + self, + pipeline: str = "any(gc-cpu-pipeline)", + shared_libs: Sequence[str] = [], + opt_level: int = 3, + ): + self.shared_libs = shared_libs + self.pipeline = pipeline + self.opt_level = opt_level + + def __call__(self, module: ir.Module, ir_printing: bool = False): + self.compile(module, ir_printing) + + def compile(self, module: ir.Module, ir_printing: bool = False): + pm = passmanager.PassManager.parse(self.pipeline) + if ir_printing: + pm.enable_ir_printing() + pm.run(module.operation) + + def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: + return execution_engine.ExecutionEngine( + module, opt_level=self.opt_level, shared_libs=self.shared_libs + ) + + def compile_and_jit( + self, module: ir.Module, ir_printing: bool = False + ) -> execution_engine.ExecutionEngine: + self.compile(module, ir_printing) + return self.jit(module) diff --git a/scripts/compile.sh b/scripts/compile.sh index 3d97a3926..35b77e816 100755 --- a/scripts/compile.sh +++ b/scripts/compile.sh @@ -74,7 +74,7 @@ build_llvm() { -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=true \ -DLLVM_ENABLE_PROJECTS="mlir" -DLLVM_TARGETS_TO_BUILD="X86" \ -DLLVM_INSTALL_UTILS=true -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ - -DLLVM_INSTALL_GTEST=ON -DLLVM_BUILD_LLVM_DYLIB=$dylib -DLLVM_LINK_LLVM_DYLIB=$dylib + -DLLVM_INSTALL_GTEST=ON -DLLVM_BUILD_LLVM_DYLIB=$dylib -DLLVM_LINK_LLVM_DYLIB=$dylib -DMLIR_ENABLE_BINDINGS_PYTHON=ON -DPython3_EXECUTABLE=$(which python3) cmake --build build MLIR_DIR="$PWD/build/lib/cmake/mlir" diff --git a/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir new file mode 100644 index 000000000..5595b166e --- /dev/null +++ b/test/mlir/test/gc/Dialect/Linlagx/split-compute-intensive-patterns.mlir @@ -0,0 +1,184 @@ +// RUN: gc-opt %s --split-compute-intensive-patterns | FileCheck %s + +func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x64xbf16>) outs(%1 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %3 = tensor.empty() : tensor<128x64xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xbf16> + %6 = tensor.empty() : tensor<128x64xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%6 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %cst_1 = arith.constant 0.000000e+00 : bf16 + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.fill ins(%cst_1 : bf16) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %10 = linalg.matmul ins(%7, %arg3 : tensor<128x64xbf16>, tensor<64x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %11 = tensor.empty() : tensor<128x256xbf16> + %broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0] + %12 = tensor.empty() : tensor<128x256xbf16> + %13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %14 = tensor.empty() : tensor<128x256xbf16> + %15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %15 : tensor<128x256xbf16> +} + +func.func @mlp_transpose_a(%arg0: tensor<512x128xbf16>, %arg1: tensor<512x256xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<512x128xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %3 = tensor.empty() : tensor<128x256xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%3 : tensor<128x256xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x256xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%4 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %7 : tensor<128x256xbf16> +} + +func.func @mlp_transpose_b(%arg0: tensor<128x512xbf16>, %arg1: tensor<256x512xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x256xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<256x512xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %3 = tensor.empty() : tensor<128x256xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%3 : tensor<128x256xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x256xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%4 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %7 : tensor<128x256xbf16> +} + +func.func @mlp_transpose_a_b(%arg0: tensor<512x128xbf16>, %arg1: tensor<256x512xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<256x128xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<256x128xbf16>) -> tensor<256x128xbf16> + %2 = linalg.matmul ins(%arg1, %arg0 : tensor<256x512xbf16>, tensor<512x128xbf16>) outs(%1 : tensor<256x128xbf16>) -> tensor<256x128xbf16> + %cst_0 = arith.constant 0.000000e+00 : bf16 + %3 = tensor.empty() : tensor<128x256xbf16> + %4 = linalg.fill ins(%cst_0 : bf16) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %transposed = linalg.transpose ins(%2 : tensor<256x128xbf16>) outs(%4 : tensor<128x256xbf16>) permutation = [1, 0] + %5 = tensor.empty() : tensor<128x256xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<256xbf16>) outs(%5 : tensor<128x256xbf16>) dimensions = [0] + %6 = tensor.empty() : tensor<128x256xbf16> + %7 = linalg.add ins(%transposed, %broadcasted : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%6 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + %cst_1 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16> + %8 = tensor.empty() : tensor<128x256xbf16> + %9 = linalg.max ins(%7, %cst_1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%8 : tensor<128x256xbf16>) -> tensor<128x256xbf16> + return %9 : tensor<128x256xbf16> +} + +func.func @llama2_mlp(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008x4096xbf16>, %arg6: tensor<11008x4096xbf16>, %arg7: tensor<4096x11008xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> (tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul_transpose_b ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %3 = tensor.empty() : tensor<1x32x4096xbf16> + %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %5 = tensor.empty() : tensor<1x32x4096xf32> + %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_0 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %7 = tensor.empty() : tensor<1x32x4096xf32> + %8 = linalg.powf ins(%6, %cst_0 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_1 = arith.constant 0.000000e+00 : f32 + %9 = tensor.empty() : tensor<1x32xf32> + %10 = linalg.fill ins(%cst_1 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %64 = arith.addf %in, %init : f32 + linalg.yield %64 : f32 + } + %cst_2 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %11 = tensor.empty() : tensor<1x32xf32> + %12 = linalg.div ins(%reduced, %cst_2 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_3 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %13 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] + %14 = tensor.empty() : tensor<1x32x1xf32> + %15 = linalg.add ins(%expanded_3, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_4 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %16 = tensor.empty() : tensor<1x32x1xf32> + %17 = linalg.powf ins(%15, %cst_4 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_5 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %18 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_6 = linalg.broadcast ins(%collapsed_5 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] + %19 = tensor.empty() : tensor<1x32x4096xf32> + %20 = linalg.mul ins(%6, %broadcasted_6 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %21 = tensor.empty() : tensor<1x32x4096xbf16> + %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %23 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_7 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %24 = tensor.empty() : tensor<1x32x4096xbf16> + %25 = linalg.mul ins(%broadcasted_7, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %collapsed_8 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_9 = arith.constant 0.000000e+00 : bf16 + %26 = tensor.empty() : tensor<32x11008xbf16> + %27 = linalg.fill ins(%cst_9 : bf16) outs(%26 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %28 = linalg.matmul_transpose_b ins(%collapsed_8, %arg5 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%27 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_10 = tensor.expand_shape %28 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %29 = tensor.empty() : tensor<1x32x11008xbf16> + %30 = linalgx.sigmoid ins(%expanded_10 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %31 = tensor.empty() : tensor<1x32x11008xbf16> + %32 = linalg.mul ins(%30, %expanded_10 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%31 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_11 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_12 = arith.constant 0.000000e+00 : bf16 + %33 = tensor.empty() : tensor<32x11008xbf16> + %34 = linalg.fill ins(%cst_12 : bf16) outs(%33 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %35 = linalg.matmul_transpose_b ins(%collapsed_11, %arg6 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%34 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_13 = tensor.expand_shape %35 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %36 = tensor.empty() : tensor<1x32x11008xbf16> + %37 = linalg.mul ins(%32, %expanded_13 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%36 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_14 = tensor.collapse_shape %37 [[0, 1], [2]] : tensor<1x32x11008xbf16> into tensor<32x11008xbf16> + %cst_15 = arith.constant 0.000000e+00 : bf16 + %38 = tensor.empty() : tensor<32x4096xbf16> + %39 = linalg.fill ins(%cst_15 : bf16) outs(%38 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %40 = linalg.matmul_transpose_b ins(%collapsed_14, %arg7 : tensor<32x11008xbf16>, tensor<4096x11008xbf16>) outs(%39 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded_16 = tensor.expand_shape %40 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %41 = tensor.empty() : tensor<1x32x4096xbf16> + %42 = linalg.add ins(%4, %expanded_16 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%41 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %43 = tensor.empty() : tensor<1x32x4096xf32> + %44 = linalg.copy ins(%42 : tensor<1x32x4096xbf16>) outs(%43 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_17 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %45 = tensor.empty() : tensor<1x32x4096xf32> + %46 = linalg.powf ins(%44, %cst_17 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%45 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_18 = arith.constant 0.000000e+00 : f32 + %47 = tensor.empty() : tensor<1x32xf32> + %48 = linalg.fill ins(%cst_18 : f32) outs(%47 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced_19 = linalg.reduce ins(%46 : tensor<1x32x4096xf32>) outs(%48 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %64 = arith.addf %in, %init : f32 + linalg.yield %64 : f32 + } + %cst_20 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %49 = tensor.empty() : tensor<1x32xf32> + %50 = linalg.div ins(%reduced_19, %cst_20 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%49 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_21 = tensor.expand_shape %50 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %51 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted_22 = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%51 : tensor<1x32x1xf32>) dimensions = [0, 1] + %52 = tensor.empty() : tensor<1x32x1xf32> + %53 = linalg.add ins(%expanded_21, %broadcasted_22 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%52 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_23 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %54 = tensor.empty() : tensor<1x32x1xf32> + %55 = linalg.powf ins(%53, %cst_23 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%54 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_24 = tensor.collapse_shape %55 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %56 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_25 = linalg.broadcast ins(%collapsed_24 : tensor<1x32xf32>) outs(%56 : tensor<1x32x4096xf32>) dimensions = [2] + %57 = tensor.empty() : tensor<1x32x4096xf32> + %58 = linalg.mul ins(%44, %broadcasted_25 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%57 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %59 = tensor.empty() : tensor<1x32x4096xbf16> + %60 = linalg.copy ins(%58 : tensor<1x32x4096xf32>) outs(%59 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %61 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_26 = linalg.broadcast ins(%arg9 : tensor<4096xbf16>) outs(%61 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %62 = tensor.empty() : tensor<1x32x4096xbf16> + %63 = linalg.mul ins(%broadcasted_26, %60 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%62 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + return %63, %42 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16> +} diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 000000000..20d763c6b --- /dev/null +++ b/tools/README.md @@ -0,0 +1,14 @@ +# Python Tools +## Pre-requisites +* Enable python binding +* Install `/tools/requirements.txt` +* Set env +** PYTHONPATH=${BUILD_DIR}/python_packages/gc_mlir_core +** LD_PRELOAD=path/to/libiomp5.so +** MLIR_C_RUNNER_UTILS=${LLVM_INSTALL_DIR}/lib/libmlir_c_runner_utils.so +** MLIR_RUNNER_UTILS=${LLVM_INSTALL_DIR}/lib/libmlir_runner_utils.so + + +##Bench +##Tuning +TODO \ No newline at end of file diff --git a/tools/bench.py b/tools/bench.py new file mode 100644 index 000000000..7c3b06102 --- /dev/null +++ b/tools/bench.py @@ -0,0 +1,110 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import ctypes +import random +import timeit +from time import sleep +from typing import Sequence + +import numpy as np +from gc_mlir import ir, runtime +from gc_mlir.dialects import arith, func, memref +from gc_mlir.graph_compiler import GraphCompiler +from utils import ( + emit_benchmark_wrapped_main_func, + emit_nano_time, + get_kernel_func_from_module, +) + + +def py_timeit_bench( + ir_module: ir.Module, + entry_name: str, + pipeline: str, + mlir_args: list, + shared_libs: Sequence, + ir_printing=False, + repeat_time=100, + warm_up=20, +) -> float: + + compiler = GraphCompiler( + pipeline, + shared_libs, + ) + compile_begin = timeit.default_timer() + engine = compiler.compile_and_jit(ir_module, ir_printing=ir_printing) + compile_cost = (timeit.default_timer() - compile_begin) * 1000 + + func = engine.lookup(entry_name) + packed_args = (ctypes.c_void_p * len(mlir_args))() + for argNum in range(len(mlir_args)): + packed_args[argNum] = ctypes.cast(mlir_args[argNum], ctypes.c_void_p) + + def run_bench(func, arg): + func(arg) + + timeit.timeit(lambda: run_bench(func, packed_args), number=warm_up) + total_time = timeit.timeit(lambda: run_bench(func, packed_args), number=repeat_time) + execute_cost = total_time * 1000 / repeat_time + return (execute_cost, compile_cost) + + +def mlir_wrapper_bench( + ir_module: ir.Module, + entry_name: str, + pipeline: str, + mlir_args: list, + shared_libs: Sequence, + ir_printing=False, + repeat_time=100, + warm_up=20, +) -> float: + kernel_func = get_kernel_func_from_module(ir_module, entry_name) + + wrapper_module = ir_module + with ir.InsertionPoint(wrapper_module.body): + emit_benchmark_wrapped_main_func(kernel_func, emit_nano_time()) + compiler = GraphCompiler( + pipeline, + shared_libs, + ) + compile_begin = timeit.default_timer() + engine = compiler.compile_and_jit(wrapper_module, ir_printing=ir_printing) + compile_cost = (timeit.default_timer() - compile_begin) * 1000 + + np_timers_ns = np.array([0], dtype=np.int64) + time_arg = ctypes.pointer( + ctypes.pointer(runtime.get_ranked_memref_descriptor(np_timers_ns)) + ) + total_time = 0 + ns_to_ms_scale = 1e-6 + def run(engine_invoke, bench_func_name, *mlir_args): + engine_invoke(bench_func_name, *mlir_args) + + for i in range(repeat_time + warm_up): + run(engine.invoke, "wrapped_main", time_arg, *mlir_args) + if i >= warm_up: + total_time += int(np_timers_ns[0]) * ns_to_ms_scale + execute_cost = total_time / repeat_time + return (execute_cost, compile_cost) + + +# for test +def fake_bench() -> float: + return float(random.randint(1, 100)) diff --git a/tools/config_filter.py b/tools/config_filter.py new file mode 100644 index 000000000..cf804ac14 --- /dev/null +++ b/tools/config_filter.py @@ -0,0 +1,82 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import math +from abc import ABC, abstractmethod +from typing import List + +import mmh3 + + +class ConfigFilter(ABC): + + @abstractmethod + def already_met(self, v: List[int]) -> bool: + pass + + @abstractmethod + def add(self, v: List[int]): + pass + + @abstractmethod + def save(self): + pass + + def load(self, data): + pass + + +class BloomFilter(ConfigFilter): + def __init__(self, num_samples: int, err_rate: float): + self.num_bits = int(-(num_samples * math.log(err_rate)) / (math.log(2) ** 2)) + self.num_hashes = int((self.num_bits / num_samples) * math.log(2)) + self.bit_array = [0] * self.num_bits + + def already_met(self, v): + for i in range(int(self.num_hashes)): + hash_v = mmh3.hash(v, i) % self.num_bits + if self.bit_array[hash_v] == 0: + return False + return True + + def add(self, v): + for i in range(int(self.num_hashes)): + hash_v = mmh3.hash(v, i) % self.num_bits + self.bit_array[hash_v] = 1 + + def save(self): + return self.bit_array + + def load(self, data): + self.bit_array == data + + +class HashSetFilter(ConfigFilter): + def __init__(self): + self.data = set() + + def add(self, v): + self.data.add(tuple(v)) + + def already_met(self, v: List[int]) -> bool: + return tuple(v) in self.data + + def save(self): + return self.data + + def load(self, data): + self.data = data diff --git a/tools/drivers.py b/tools/drivers.py new file mode 100644 index 000000000..15ca09401 --- /dev/null +++ b/tools/drivers.py @@ -0,0 +1,221 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import argparse +from abc import ABC, abstractmethod +from typing import List + +import numpy as np +from gc_mlir import ir +from gc_mlir.dialects import func, onednn_graph +from utils import ( + get_default_passes, + get_kernel_func_from_module, + make_tensor, + mlir_type, + to_bool_vector, + to_int_vector, +) + + +class Driver(ABC): + """Abstract class for driver.""" + + @staticmethod + @abstractmethod + def add_args(parser: argparse.ArgumentParser): + pass + + @abstractmethod + def handle_args(self, args: argparse.Namespace): + pass + + def __init__(self, ctx: ir.Context, args: argparse.Namespace): + self.main_entry = "main_entry" + self.handle_args(args) + self.ir_module = self.init_module(ctx) + + @abstractmethod + def init_module(self, ctx: ir.Context) -> ir.Module: + pass + + @abstractmethod + def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: + pass + + def get_passes(self) -> str: + return get_default_passes() + + +class LoadMLIR(Driver): + @staticmethod + def add_args(parser: argparse.ArgumentParser): + parser.add_argument("--path", type=str, required=True) + parser.add_argument("--entry", type=str, default="main_entry") + + def handle_args(self, args: argparse.Namespace): + self.path = args.path + self.main_entry = args.entry + + def _get_mlir(self): + with open(self.path, "r") as file: + content = file.read() + return content + + def init_module(self, ctx: ir.Context) -> ir.Module: + module = ir.Module.parse(self._get_mlir(), ctx) + return module + + def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: + bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) + np_args = [] + idx = 0 + for arg in bench_func.arguments: + print(idx, ":", arg) + if idx in [1, 3]: + np_args.append(make_tensor(arg.type, 1)) + else: + np_args.append(make_tensor(arg.type, 0)) + # idx += 1 + idx = 0 + if not disable_results_to_params: + for res in bench_func.type.results: + print(idx, ":", res) + if idx in [1]: + np_args.append(make_tensor(res, 1)) + else: + np_args.append(make_tensor(res, 0)) + # idx += 1 + # todo : data filling + for i in range(len(np_args)): + np.ndarray.fill(np_args[i], 1) + return np_args + +class MLP(Driver): + @staticmethod + def add_args(parser: argparse.ArgumentParser): + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--hidden_size_list", type=str, default="") + parser.add_argument("--has_bias", type=str, default="") + parser.add_argument("--has_ln", type=str, default="") + parser.add_argument( + "--act_type", type=str, choices=["noop", "relu", "sigmoid"], default="noop" + ) + parser.add_argument( + "--dtype", + type=str, + choices=[ + "f32", + "bf16", + ], + default="f32", + ) + + def handle_args(self, args: argparse.Namespace): + self.batch_size = args.batch_size + assert self.batch_size > 0, "batch size should be greater than 0" + + self.hidden_size_list = to_int_vector(args.hidden_size_list) + layers = len(self.hidden_size_list) - 1 + assert layers >= 1, "hidden_size_list should have at least 2 elements" + + self.has_bias = ( + [False] * layers + if "has_bias" not in args.__dict__ + else to_bool_vector(args.has_bias) + ) + + assert ( + len(self.has_bias) == layers + ), "has_bias should have the same length as hidden_size_list" + + # TODO + self.has_ln = to_bool_vector(args.has_ln) + self.act_type = args.act_type + self.dtype = args.dtype + + def init_module(self, ctx: ir.Context) -> ir.Module: + with ctx, ir.Location.unknown(): + layers = len(self.hidden_size_list) - 1 + module = ir.Module.create() + dtype = mlir_type(self.dtype, ctx) + src = ir.RankedTensorType.get( + [self.batch_size, self.hidden_size_list[0]], dtype + ) + weights = [] + bias = [] + for i in range(layers): + weights.append( + ir.RankedTensorType.get( + [ + self.hidden_size_list[i], + self.hidden_size_list[i + 1], + ], + dtype, + ) + ) + if self.has_bias[i]: + bias.append( + ir.RankedTensorType.get([self.hidden_size_list[i + 1]], dtype) + ) + result = ir.RankedTensorType.get( + [ + self.batch_size, + self.hidden_size_list[-1], + ], + dtype, + ) + with ir.InsertionPoint(module.body): + f = func.FuncOp( + name=self.main_entry, + type=ir.FunctionType.get( + inputs=[src] + weights + bias, results=[result] + ), + ) + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): + data = f.entry_block.arguments[0] + bias_idx = len(weights) + 1 + for i in range(layers): + weight = f.entry_block.arguments[i + 1] + if self.has_bias[i]: + bias = f.entry_block.arguments[bias_idx] + bias_idx += 1 + else: + bias = None + data = onednn_graph.MatMulOp( + data, + weight, + bias=bias, + transpose_a=False, + transpose_b=False, + ).result + func.ReturnOp([data]) + return module + + def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: + bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) + np_args = [] + for arg in bench_func.arguments: + np_args.append(make_tensor(arg.type)) + if not disable_results_to_params: + for res in bench_func.type.results: + np_args.append(make_tensor(res)) + # todo : data filling + for i in range(len(np_args)): + np.ndarray.fill(np_args[i], 1) + return np_args diff --git a/tools/enhanced_np_to_memref.py b/tools/enhanced_np_to_memref.py new file mode 100644 index 000000000..43a7e134d --- /dev/null +++ b/tools/enhanced_np_to_memref.py @@ -0,0 +1,185 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. + +import numpy as np +import ctypes + +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes provides some optional low precision data-types for NumPy. + ml_dtypes = None + + +class C128(ctypes.Structure): + """A ctype representation for MLIR's Double Complex.""" + + _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] + + +class C64(ctypes.Structure): + """A ctype representation for MLIR's Float Complex.""" + + _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] + + +class F16(ctypes.Structure): + """A ctype representation for MLIR's Float16.""" + + _fields_ = [("f16", ctypes.c_int16)] + + +class BF16(ctypes.Structure): + """A ctype representation for MLIR's BFloat16.""" + + _fields_ = [("bf16", ctypes.c_int16)] + + +# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype +def as_ctype(dtp): + """Converts dtype to ctype.""" + if dtp == np.dtype(np.complex128): + return C128 + if dtp == np.dtype(np.complex64): + return C64 + if dtp == np.dtype(np.float16): + return F16 + if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: + return BF16 + return np.ctypeslib.as_ctypes_type(dtp) + + +def to_numpy(array): + """Converts ctypes array back to numpy dtype array.""" + if array.dtype == C128: + return array.view("complex128") + if array.dtype == C64: + return array.view("complex64") + if array.dtype == F16: + return array.view("float16") + assert not ( + array.dtype == BF16 and ml_dtypes is None + ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == BF16: + return array.view("bfloat16") + return array + + +def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given rank/dtype, where rank>0.""" + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor + + +def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given dtype, where rank=0.""" + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + + return MemRefDescriptor + + +class UnrankedMemRefDescriptor(ctypes.Structure): + """Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + + +def get_ranked_memref_descriptor(nparray): + """Returns a ranked memref descriptor for the given numpy array.""" + ctp = as_ctype(nparray.dtype) + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + return x + + x = make_nd_memref_descriptor(nparray.ndim, ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x + + +def get_unranked_memref_descriptor(nparray): + """Returns a generic/unranked memref descriptor for the given numpy array.""" + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d + + +def move_aligned_ptr_by_offset(aligned_ptr, offset): + """Moves the supplied ctypes pointer ahead by `offset` elements.""" + aligned_addr = ctypes.addressof(aligned_ptr.contents) + elem_size = ctypes.sizeof(aligned_ptr.contents) + shift = offset * elem_size + content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) + return content_ptr + + +def unranked_memref_to_numpy(unranked_memref, np_dtype): + """Converts unranked memrefs to numpy arrays.""" + ctp = as_ctype(np_dtype) + descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset) + np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) + + +def ranked_memref_to_numpy(ranked_memref): + """Converts ranked memrefs to numpy arrays.""" + content_ptr = move_aligned_ptr_by_offset( + ranked_memref[0].aligned, ranked_memref[0].offset + ) + np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) \ No newline at end of file diff --git a/tools/example/simple_test.py b/tools/example/simple_test.py new file mode 100644 index 000000000..35d3c41e5 --- /dev/null +++ b/tools/example/simple_test.py @@ -0,0 +1,87 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import os +import sys + +import numpy as np +from gc_mlir import ir +from gc_mlir.graph_compiler import GraphCompiler +from numpy.testing import assert_allclose + +project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if project_dir not in sys.path: + sys.path.insert(0, project_dir) + +import ml_dtypes +import torch +from bench import py_timeit_bench +from enhanced_np_to_memref import ranked_memref_to_numpy +from utils import get_mlir_args + +if __name__ == "__main__": + with ir.Context() as ctx: + module = ir.Module.parse( + """ + module { + func.func @main_entry(%arg0:tensor<10x10xbf16>, %arg1:tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} { + %0 = onednn_graph.matmul %arg0, %arg1: (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16> + return %0:tensor<10x10xbf16> + } + } + """ + ) + torch_arg0 = torch.full((10, 10), 1.0, dtype=torch.bfloat16) + torch_arg1 = torch.full((10, 10), 1.0, dtype=torch.bfloat16) + # torch_arg0 = torch.randn((10, 10), dtype=torch.bfloat16) + # torch_arg1 = torch.randn((10, 10), dtype=torch.bfloat16) + ref_res = torch.matmul(torch_arg0, torch_arg1) + + np_arg0 = torch_arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + np_arg1 = torch_arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) + gc_res = np.zeros((10, 10), dtype=ml_dtypes.bfloat16) + + entry = "main_entry" + mlir_args = get_mlir_args(module, entry, [np_arg0, np_arg1, gc_res]) + passes = "any(gc-cpu-pipeline)" + shared_libs = [ + os.environ["MLIR_C_RUNNER_UTILS"], + os.environ["MLIR_RUNNER_UTILS"], + ] + + # bench + # _, cost = py_timeit_bench( + # module, + # "main_entry", + # passes, + # mlir_args, + # shared_libs, + # ) + # print("cost=", cost) + + # just run + compiler = GraphCompiler(passes, shared_libs) + engine = compiler.compile_and_jit(module) + engine.invoke(entry, *mlir_args) + + print(gc_res) + assert_allclose( + gc_res.astype(np.float32), + ref_res.to(torch.float32).numpy(), + rtol=1e-5, + atol=0, + ) diff --git a/tools/main.py b/tools/main.py new file mode 100644 index 000000000..3d58c795a --- /dev/null +++ b/tools/main.py @@ -0,0 +1,165 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import argparse +import json +import os +from timeit import repeat, timeit + +import numpy as np +from bench import * +from drivers import * +from gc_mlir import ir +from gc_mlir import runtime as rt +from gc_mlir.passmanager import * +from utils import get_mlir_args +from tuner import * + + +def get_driver_clz(diver_str: str): + clz = {"mlp": MLP, "load_mlir": LoadMLIR}[diver_str] + return clz + + +def add_driver_args(parser): + driver = parser.parse_known_args()[0].driver + get_driver_clz(driver).add_args(parser) + + +def do_bench(args): + with ir.Context() as ctx, ir.Location.unknown(): + driver_clz = get_driver_clz(args.driver) + driver = driver_clz(ctx, args) + if args.print_ir: + ctx.enable_multithreading(False) + np_args = driver.prepare_np_args(args.disable_results_to_params) + + # TODO need data filling + # for test, fill all data with 1 + for i in range(len(np_args)): + np.ndarray.fill(np_args[i], 1) + + mlir_args = get_mlir_args( + driver.ir_module, driver.main_entry, np_args, args.disable_results_to_params + ) + + print("===========bench func name: ", driver.main_entry, "===========") + print(driver.ir_module) + bench_alg = py_timeit_bench if args.bench_alg == "py" else mlir_wrapper_bench + execute_cost, compile_cost = bench_alg( + driver.ir_module, + driver.main_entry, + driver.get_passes(), + mlir_args, + [os.environ["MLIR_C_RUNNER_UTILS"], os.environ["MLIR_RUNNER_UTILS"]], + args.print_ir, + args.repeat, + args.warm_up, + ) + print("===========bench result===========") + json_res = json.dumps( + { + "args": vars(args), + "compile_cost": compile_cost, + "execute_cost": execute_cost, + }, + indent=4, + ) + print(json_res) + +def do_tune(args): + with ir.Context() as ctx, ir.Location.unknown(): + ctx.allow_unregistered_dialects = True + driver_clz = get_driver_clz(args.driver) + driver = driver_clz(ctx, args) + if args.print_ir: + ctx.enable_multithreading(False) + # todo (data filling) + np_args = driver.prepare_np_args(args.disable_results_to_params) + # TODO need data filling + # for test, fill all data with 1 + for i in range(len(np_args)): + np.ndarray.fill(np_args[i], 1) + + mlir_args = get_mlir_args( + driver.ir_module, driver.main_entry, np_args, args.disable_results_to_params + ) + + bench_alg = py_timeit_bench if args.bench_alg == "py" else mlir_wrapper_bench + tuner_bench = lambda ir_moudle: bench_alg( + ir_moudle, + driver.main_entry, + driver.get_passes(), + mlir_args, + [os.environ["MLIR_C_RUNNER_UTILS"], os.environ["MLIR_RUNNER_UTILS"]], + args.print_ir, + repeat_time=1, + warm_up=1, + ) + + space = TuningSpace(driver.ir_module) + if args.search_alg == "grid": + tuner = GridTuner( + tuner_bench, + space, + args.tuning_batch, + args.early_stop, + args.checkpoint_path, + ) + else: + tuner = GATuner( + tuner_bench, + space, + args.tuning_batch, + args.early_stop, + args.checkpoint_path, + ) + tuner.run(args.tuning_times, args.timeout) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--type", type=str, choices=["bench", "tune"], default="bench") + parser.add_argument( + "--driver", type=str, choices=["load_mlir", "mlp"], required=True + ) + parser.add_argument( + "--disable_results_to_params", action="store_true", default=False + ) + add_driver_args(parser) + parser.add_argument( + "--bench_alg", type=str, choices=["py", "wrapper"], default="py" + ) + parser.add_argument("-p", "--print_ir", action="store_true") + + if parser.parse_known_args()[0].type == "bench": + parser.add_argument("--warm_up", type=int, default=20) + parser.add_argument("--repeat", type=int, default=100) + args = parser.parse_args() + do_bench(args) + else: + parser.add_argument( + "--search_alg", type=str, choices=["grid", "ga"], default="ga" + ) + parser.add_argument("--tuning_batch", type=int, default=50) + parser.add_argument("--early_stop", type=int, default=-1) + parser.add_argument("--tuning_times", type=int, default=100) + parser.add_argument("--timeout", type=int, default=-1) + parser.add_argument("--space_percent", type=float, default=1.0) + parser.add_argument("--checkpoint_path", type=str, default="") + args = parser.parse_args() + do_tune(args) diff --git a/tools/op_config.py b/tools/op_config.py new file mode 100644 index 000000000..53835ee14 --- /dev/null +++ b/tools/op_config.py @@ -0,0 +1,116 @@ +import json +import os + +from gc_mlir.dialects import onednn_graph +from gc_mlir.dialects._ods_common import OpView +from gc_mlir.extras import types as T +from gc_mlir.ir import IntegerAttr + + +class Config: + def __init__(self): + self.field_candidates = {} + self.field_constraints = {} + self.init_candidates() + self.init_constraints() + + def init_candidates(self): + pass + + def init_constraints(self): + pass + + def attach_to_ir(self, op: OpView): + pass + +class MatMulConfig(Config): + def __init__( + self, + M_threads: int = 1, + K_threads: int = 1, + N_threads: int = 1, + M_block: int = 1, + K_block: int = 1, + N_block: int = 1, + innermostM_block: int = 1, + innermostK_block: int = 1, + innermostN_block: int = 1, + ): + super().__init__() + self.M_threads = M_threads + self.K_threads = K_threads + self.N_threads = N_threads + self.M_block = M_block + self.K_block = K_block + self.N_block = N_block + self.innermostM_block = innermostM_block + self.innermostK_block = innermostK_block + self.innermostN_block = innermostN_block + + def __init__(self, op: OpView): + super().__init__() + assert isinstance(op, onednn_graph.MatMulOp) + # you can set the default value by matmul_op + # cpu_counts = os.cpu_count() + # self.input_a_shape = op.input_a.type.shape + # self.input_b_shape = op.input_b.type.shape + # self.input_a_dtype = op.input_a.type.element_type + # print(self.input_a_shape, self.input_a_dtype) + self.M_threads = 1 + self.K_threads = 1 + self.N_threads = 1 + self.M_block = 1 + self.K_block = 1 + self.N_block = 1 + self.innermostM_block = 1 + self.innermostK_block = 1 + self.innermostN_block = 1 + + def init_candidates(self): + # you can set the candidates by info form matmul op + self.field_candidates["M_block"] = [16, 32] + self.field_candidates["K_block"] = [16, 32, 64] + self.field_candidates["N_block"] = [16] + + def init_constraints(self): + # example: using lambda to add constraints + # self.field_constraints["K_block"] = ( + # lambda MatMulConfig, K_block: MatMulConfig.M_block <= K_block + # ) + self.field_constraints["M_block"] = None + self.field_constraints["K_block"] = None + self.field_constraints["N_block"] = None + + def attach_to_ir(self, op: OpView): + assert isinstance(op, onednn_graph.MatMulOp) + attr_to_field = { + "Mthreads": self.M_threads, + "Kthreads": self.K_threads, + "Nthreads": self.N_threads, + "MBlock": self.M_block, + "KBlock": self.K_block, + "NBlock": self.N_block, + "innermostMBlock": self.innermostM_block, + "innermostKBlock": self.innermostK_block, + "innermostNBlock": self.innermostN_block, + } + for name, value in attr_to_field.items(): + op.attributes[name] = IntegerAttr.get(T.i32(), value) + def __repr__(self) -> str: + return self.__str__() + + def __str__(self) -> str: + obj_dict = { + "MatMulConfig": { + "M_threads": self.M_threads, + "K_threads": self.K_threads, + "N_threads": self.N_threads, + "M_block": self.M_block, + "K_block": self.K_block, + "N_block": self.N_block, + "innermostM_block": self.innermostM_block, + "innermostK_block": self.innermostK_block, + "innermostN_block": self.innermostN_block, + } + } + return json.dumps(obj_dict, indent=4) diff --git a/tools/requirements.txt b/tools/requirements.txt new file mode 100644 index 000000000..3a942605e --- /dev/null +++ b/tools/requirements.txt @@ -0,0 +1 @@ +ml_dtypes \ No newline at end of file diff --git a/tools/tuner.py b/tools/tuner.py new file mode 100644 index 000000000..39e4f2c5c --- /dev/null +++ b/tools/tuner.py @@ -0,0 +1,492 @@ +from copy import deepcopy +import os +import sys +import random +from functools import reduce +import time +from config_filter import * +from op_config import * +from gc_mlir import ir +import utils +import json +from abc import ABC, abstractmethod +from typing import List + +need_print = False + + +class TuningSpace: + def __init__(self, ir_module: ir.Module): + self.initial_ir = ir_module + self.graph_config = utils.gen_configs_from_ir(ir_module) + self.space_size = 1 + self.flatten_candidates = [] + self.flatten_field_name = [] + self.flatten_constraints = [] + self.ind_candidate_to_config = {} + candidate_ind = 0 + for config_ind, config in enumerate(self.graph_config): + for field_name, candidates in config.field_candidates.items(): + self.space_size = self.space_size * len(candidates) + self.flatten_candidates.append(candidates) + self.flatten_field_name.append(field_name) + self.flatten_constraints.append(config.field_constraints[field_name]) + self.ind_candidate_to_config[candidate_ind] = config_ind + candidate_ind += 1 + + def make_config_from_indexes(self, indexes: List[int]): + graph_config = deepcopy(self.graph_config) + for cid, candidate in enumerate(self.flatten_candidates): + val = candidate[indexes[cid]] + config = graph_config[self.ind_candidate_to_config[cid]] + field_name = self.flatten_field_name[cid] + setattr(config, field_name, val) + return graph_config + + def get_cur_config(self, candidate_ind): + return self.graph_config[self.ind_candidate_to_config[candidate_ind]] + + def verify_config(self, candidate_idx, val) -> bool: + config = self.get_cur_config(candidate_idx) + field_name = self.flatten_field_name[candidate_idx] + constraint = self.flatten_constraints[candidate_idx] + val = self.flatten_candidates[candidate_idx][val] + setattr(config, field_name, val) + if constraint: + return constraint(config, val) + return True + + def filter_next_candidates(self, candidate_idx, val) -> List[int]: + field_name = self.flatten_field_name[candidate_idx] + config = self.get_cur_config(candidate_idx) + setattr( + config, + field_name, + self.flatten_candidates[candidate_idx][val], + ) + if (candidate_idx + 1) >= len(self.flatten_candidates): + return [] + constraint = self.flatten_constraints[candidate_idx + 1] + if constraint: + next_candidates = self.flatten_candidates[candidate_idx + 1] + return [ + index + for index, value in enumerate(next_candidates) + if constraint(config, value) + ] + else: + return list(range(len(self.flatten_candidates[candidate_idx + 1]))) + + +class Tuner(ABC): + def __init__( + self, + executor, + tunning_space: TuningSpace, + batch_size=50, + early_stop=-1, + checkpoint="", + ): + self.executor = executor + self.batch_size = batch_size + self.early_stop = early_stop + self.best_cost = sys.float_info.max + self.best = [] + self.iter = 0 + self.last_update_iter = 0 + self.skipped_num = 0 + self.tunning_space = tunning_space + self.checkpoint = checkpoint + if self.checkpoint: + os.makedirs(os.path.dirname(self.checkpoint), exist_ok=True) + + def tuner_update(self, config_indices_batch: List[List[int]], costs: List[float]): + if min(costs) < self.best_cost: + self.best_cost = min(costs) + self.best = config_indices_batch[costs.index(min(costs))] + if self.checkpoint: + self.save_status() + + @abstractmethod + def get_next_config_indices_batch(self) -> List[List[int]]: + pass + + @abstractmethod + def load_status(self): + pass + + @abstractmethod + def save_status(self): + pass + + def tuner_finish(self, tuning_time): + print("Tuning ends in", tuning_time, "s") + best_config = self.tunning_space.make_config_from_indexes(self.best) + print("Best cost:", self.best_cost, "ms") + print("Best config:", best_config) + utils.attach_configs_to_ir(self.tunning_space.initial_ir, best_config), + print( + "mlir:\n", + self.tunning_space.initial_ir, + ) + + def run(self, max_iter: int, timeout: int = -1): + if self.early_stop > 0 and self.iter - self.last_update_iter > self.early_stop: + # in case of resuming from a saved state and it has already + # early-stopped + print("Early stop now") + return + start_time = time.time() + spaces_size = self.tunning_space.space_size + while self.iter < max_iter and self.iter < spaces_size: + config_indices_batch = self.get_next_config_indices_batch() + if not config_indices_batch: + print("Tuner returns empty batch, early stop now") + break + if len(config_indices_batch) > min( + max_iter - self.iter, spaces_size - self.iter + ): + config_indices_batch = config_indices_batch[ + : min(max_iter - self.iter, spaces_size - self.iter) + ] + + old_iter = self.iter + self.iter += len(config_indices_batch) + if need_print: + print("config_indices_batch:", config_indices_batch) + perf_result = [] + for config_indexes in config_indices_batch: + real_config = self.tunning_space.make_config_from_indexes( + config_indexes + ) + # todo : ir.Module can not support deepcopy + new_ir = ir.Module.parse( + str(self.tunning_space.initial_ir), + self.tunning_space.initial_ir.context, + ) + utils.attach_configs_to_ir(new_ir, real_config) + _, cost = self.executor(new_ir) + perf_result.append(cost) + + print( + "[", + self.iter, + "/", + max_iter, + "] skipped:", + self.skipped_num, + "best:", + self.best_cost, + "ms", + ) + old_best = self.best_cost + self.tuner_update(config_indices_batch, perf_result) + if self.best_cost != old_best: + self.last_update_iter = old_iter + else: + if ( + self.early_stop > 0 + and old_iter - self.last_update_iter > self.early_stop + ): + print("Early stop now") + break + if timeout >= 0 and time.time() - start_time > timeout: + print("Tuning timeout...") + break + self.tuner_finish(time.time() - start_time) + + +class GridTuner(Tuner): + def __init__( + self, + executor, + tunning_space: TuningSpace, + batch_size, + early_stop, + checkpoint="", + ): + super().__init__(executor, tunning_space, batch_size, early_stop, checkpoint) + self.current_idx = 0 + self.cumulative_size = [1] * len(self.tunning_space.flatten_candidates) + self.cumulative_size[-1] = 1 + for i in range(len(self.cumulative_size) - 2, -1, -1): + self.cumulative_size[i] = self.cumulative_size[i + 1] * len( + self.tunning_space.flatten_candidates[i + 1] + ) + if self.checkpoint: + self.load_status() + + def get_next_config_indices_batch(self) -> list: + config_indices_batch = [] + while len(config_indices_batch) < self.batch_size: + if self.current_idx >= self.tunning_space.space_size: + break + config_ids = [-1] * len(self.tunning_space.flatten_candidates) + remain = self.current_idx + valid_config_idx = True + for j in range(len(config_ids)): + config_ids[j] = remain // self.cumulative_size[j] + valid_config_idx = self.tunning_space.verify_config(j, config_ids[j]) + if not valid_config_idx: + break + remain = remain % self.cumulative_size[j] + self.current_idx = self.current_idx + 1 + if valid_config_idx: + config_indices_batch.append(config_ids) + if need_print: + print(self.tunning_space.make_config_from_indexes(config_ids)) + else: + self.skipped_num += 1 + print("bad config, skip") + return config_indices_batch + + def save_status(self): + save_dict = { + "iter": self.iter, + "last_update_iter": self.last_update_iter, + "best": self.best, + "best_cost": self.best_cost, + "current_idx": self.current_idx, + "skipped_num": self.skipped_num, + } + with open(self.checkpoint, "w") as file: + json.dump(save_dict, file, indent=4) + + def load_status(self): + print("continue tuning from checkpoint...") + with open( + self.checkpoint, + "r", + ) as file: + try: + data = json.load(file) + assert set( + [ + "iter", + "last_update_iter", + "best", + "best_cost", + "current_idx", + "skipped_num", + ] + ) == set(data.keys()) + self.iter = data["iter"] + self.last_update_iter = data["last_update_iter"] + self.best = data["best"] + self.best_cost = data["best_cost"] + self.current_idx = data["current_idx"] + self.skipped_num = data["skipped_num"] + except Exception as e: + print("load checkpoint failed", e) + + +class GATuner(Tuner): + def __init__( + self, + executor, + tuning_space, + pop_size=100, + early_stop=-1, + checkpoint="", + elite_num: int = 9, + mutation_prob: float = 0.1, + random_seed: int = 0, + expected_tune_num: int = 0, + ): + super().__init__(executor, tuning_space, pop_size, early_stop, checkpoint) + self.elite_num = min(elite_num, pop_size) + self.mutation_prob = mutation_prob + self.pop_size = pop_size + self.cur_mutation_prob = mutation_prob + self.prev_result = [] + self.elites = [] + if expected_tune_num == 0: + self.filter = HashSetFilter() + else: + self.filter = BloomFilter(expected_tune_num) + + self.candidate_indices = [[]] * len(self.tunning_space.flatten_candidates) + self.candidate_indices[0] = list( + range(len(self.tunning_space.flatten_candidates[0])) + ) + + def save_status(self): + save_dict = { + "iter": self.iter, + "last_update_iter": self.last_update_iter, + "best": self.best, + "best_cost": self.best_cost, + "skipped_num": self.skipped_num, + "cur_mutation_prob": self.cur_mutation_prob, + "prev_result": self.prev_result, + "elites": self.elites, + "tuned": self.filter.save(), + } + return super().save_status() + + def load_status(self): + return super().load_status() + + def set_field(self, gene, idx, val): + gene[idx] = val + self.update_candidate_indices(idx, val) + + def update_candidate_indices(self, idx, val): + next_candidates = self.tunning_space.filter_next_candidates(idx, val) + if idx + 1 < len(self.candidate_indices): + self.candidate_indices[idx + 1] = next_candidates + + @staticmethod + def update_mutation_prob(prob, lower_bound, move_up): + if move_up: + prob = min(prob * 1.01, 0.5) + else: + prob = max(prob * 0.98, lower_bound) + return prob + + @staticmethod + def random_choice(prob_range) -> int: + random_val = random.randint(0, sys.maxsize) / sys.maxsize + for i in range(len(prob_range)): + if random_val <= prob_range[i]: + return i + return -1 + + def push_to_tune(self, to_tune, gene) -> bool: + if self.filter.already_met(gene): + self.cur_mutation_prob = GATuner.update_mutation_prob( + self.cur_mutation_prob, self.mutation_prob, True + ) + return False + if gene in to_tune: + self.cur_mutation_prob = GATuner.update_mutation_prob( + self.cur_mutation_prob, self.mutation_prob, True + ) + return False + + to_tune.append(gene) + self.cur_mutation_prob = GATuner.update_mutation_prob( + self.cur_mutation_prob, self.mutation_prob, False + ) + return True + + def get_next_config_indices_batch(self) -> list: + prob_range = [0.0] * len(self.prev_result) + total_score = 0 + for i in range(len(self.prev_result)): + total_score += self.prev_result[i][1] + prob_range[i] = total_score + prob_range = [x / total_score for x in prob_range] + to_tune = [] + for i in range(self.pop_size): + self.get_next_config(prob_range, to_tune) + + if need_print: + print("to_tune", to_tune) + for i in range(len(to_tune)): + print(self.tunning_space.make_config_from_indexes(to_tune[i])) + + if len(to_tune) < self.pop_size: + print( + f"GA Cannot generate enough unmet genes in this batch (batch_size={self.pop_size})" + ) + return to_tune + + def get_next_config(self, prob_range, to_tune): + max_tries = 20 + try_cnt = 0 + while try_cnt < max_tries: + try_cnt += 1 + if not self.elites: + gene = [-1] * len(self.tunning_space.flatten_candidates) + need_repo = True + redo_cnt = 0 + while redo_cnt < 50 and need_repo: + need_repo = False + for j in range(len(gene)): + # try to randomly pick one candidate + data, success = GATuner.random_item_from( + self.candidate_indices[j] + ) + if not success: + need_repo = True + break + else: + self.set_field(gene, j, data) + redo_cnt += 1 + if need_repo: + print("Cannot create a valid random gene") + if self.push_to_tune(to_tune, gene): + return + else: + assert len(self.prev_result) > 0 + # print("len(prob_range) = ", len(prob_range)) + if len(prob_range) == 1: + return + gene_size = len(self.tunning_space.flatten_candidates) + first_gene = GATuner.random_choice(prob_range) + second_gene = GATuner.random_choice(prob_range) + while second_gene == first_gene: + second_gene = GATuner.random_choice(prob_range) + + joint_point = random.randint(0, gene_size) + + new_gene = [-1] * gene_size + need_redo = False + for j in range(gene_size): + candidates = self.candidate_indices[j] + if not candidates: + need_redo = True + continue + if ( + random.randint(0, sys.maxsize) / sys.maxsize + ) < self.cur_mutation_prob: + self.set_field( + new_gene, j, GATuner.random_item_from(candidates)[0] + ) + else: + # inherit from parents + left_gene = self.prev_result[first_gene][0][j] + right_gene = self.prev_result[second_gene][0][j] + if j < joint_point: + prefered_gene = left_gene + unprefered_gene = right_gene + else: + prefered_gene = right_gene + unprefered_gene = left_gene + + if prefered_gene in candidates: + self.set_field(new_gene, j, prefered_gene) + elif unprefered_gene in candidates: + self.set_field(new_gene, j, unprefered_gene) + else: + self.set_field( + new_gene, j, GATuner.random_item_from(candidates)[0] + ) + if need_redo: + print("need_redo") + continue + + if self.push_to_tune(to_tune, new_gene): + return + + def tuner_update( + self, config_indices_batch: List[List[int]], perf_result: List[float] + ): + super().tuner_update(config_indices_batch, perf_result) + self.prev_result.clear() + for i in range(len(config_indices_batch)): + self.filter.add(config_indices_batch[i]) + self.prev_result.append((config_indices_batch[i], 1 / perf_result[i])) + + for elite in self.elites: + self.prev_result.append(elite) + self.elites = sorted(self.prev_result, key=lambda x: x[1], reverse=True)[ + : self.elite_num + ] + + @staticmethod + def random_item_from(v: List[int]): + if not v: + return 0, False + return v[random.randint(0, len(v) - 1)], True diff --git a/tools/utils.py b/tools/utils.py new file mode 100644 index 000000000..ab8ee708e --- /dev/null +++ b/tools/utils.py @@ -0,0 +1,237 @@ +import ctypes +from typing import List + +import ml_dtypes +import numpy as np +from enhanced_np_to_memref import ( + BF16, + get_ranked_memref_descriptor, + make_nd_memref_descriptor, +) +from gc_mlir import ir +from gc_mlir.dialects import arith, func, memref, onednn_graph +from op_config import * + +# Load libnuma +libnuma = ctypes.CDLL("libnuma.so.1") + +# Define numa_alloc_onnode function +libnuma.numa_alloc_onnode.restype = ctypes.c_void_p +libnuma.numa_alloc_onnode.argtypes = [ctypes.c_size_t, ctypes.c_int] + +# Define numa_free function +libnuma.numa_free.argtypes = [ctypes.c_void_p, ctypes.c_size_t] + +MLIR_TYPE_TO_NUMPY_TYPE = { + "bf16": ml_dtypes.bfloat16, + "f32": np.float32, + "f64": np.float64, + "i8": np.int8, + "i32": np.int32, + "i64": np.int64, +} + +MLIR_TYPE_TO_C_TYPE = { + "f32": ctypes.c_float, + "f64": ctypes.c_double, + "i32": ctypes.c_int, + "i8": ctypes.c_byte, + "bf16": BF16, +} + + +def emit_nano_time() -> func.FuncOp: + nanoTime = func.FuncOp( + "nanoTime", ([], [ir.IntegerType.get_signless(64)]), visibility="private" + ) + nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + return nanoTime + + +def emit_benchmark_wrapped_main_func( + kernel_func: func.FuncOp, timer_func: func.FuncOp +) -> func.FuncOp: + memref_of_i64_type = ir.MemRefType.get([1], ir.IntegerType.get_signless(64)) + wrapped_func_name = "wrapped_main" + assert wrapped_func_name != str( + kernel_func.name + ), "wrapped function name should be different from kernel function name" + wrapped_func = func.FuncOp( + wrapped_func_name, + ([memref_of_i64_type] + kernel_func.arguments.types, kernel_func.type.results), + visibility="public", + ) + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(wrapped_func.add_entry_block()): + timer_buffer = wrapped_func.arguments[0] + start = func.CallOp(timer_func, []) + call_op = func.CallOp( + kernel_func, + list(wrapped_func.arguments[1:]), + ) + end = func.CallOp(timer_func, []) + time_taken = arith.SubIOp(end, start) + zero = arith.ConstantOp.create_index(0) + memref.StoreOp(time_taken, timer_buffer, [zero]) + func.ReturnOp(call_op.results) + return wrapped_func + + + + +def np_args_to_mlir_args(np_args: List[np.ndarray]) -> List: + mlir_args = [] + for arg in np_args: + mlir_args.append( + ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg))) + ) + return mlir_args + + +def get_mlir_args( + module: ir.Module, + entry: str, + np_args: List[np.ndarray], + disable_results_to_params=False, +): + f = get_kernel_func_from_module(module, entry) + compiled_func_args = [] + if disable_results_to_params: + assert len(np_args) == len(f.arguments), "input args mismatch" + for res in f.type.results: + compiled_func_args.append( + ctypes.pointer( + ctypes.pointer( + make_nd_memref_descriptor( + len(res.shape), MLIR_TYPE_TO_C_TYPE[str(res.element_type)] + )() + ) + ) + ) + for arg in np_args: + compiled_func_args.append( + ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg))) + ) + return compiled_func_args + + +def mlir_type(s, ctx): + type_mapping = { + "f32": ir.F32Type.get(ctx), + "f64": ir.F64Type.get(ctx), + "bf16": ir.BF16Type.get(ctx), + "i32": ir.IntegerType.get_signed(32), + "i8": ir.IntegerType.get_signed(8), + } + return type_mapping[s] + + +def make_tensor(tensor_type, numa_node = 1): + # return np.zeros( + # tensor_type.shape, MLIR_TYPE_TO_NUMPY_TYPE[str(tensor_type.element_type)] + # ) + shape = tensor_type.shape + element_type = MLIR_TYPE_TO_NUMPY_TYPE[str(tensor_type.element_type)] + dtype = np.dtype(element_type) + # Calculate the total size of the tensor in bytes + tensor_size = np.prod(shape) * dtype.itemsize + + # Allocate memory on the specified NUMA node + buffer_addr = libnuma.numa_alloc_onnode(tensor_size, numa_node) + if not buffer_addr: + raise MemoryError(f"Failed to allocate memory on NUMA node {numa_node}") + + # Cast buffer_addr to the correct pointer type + buffer_pointer = ctypes.cast(buffer_addr, ctypes.POINTER(ctypes.c_float)) + # Create numpy array pointing to allocated memory + tensor = np.ctypeslib.as_array(buffer_pointer, shape=shape) + + # Check if the actual buffer size is sufficient + actual_buffer_size = ctypes.sizeof(ctypes.c_char) * tensor.nbytes + if actual_buffer_size < tensor_size: + raise ValueError(f"Buffer size {actual_buffer_size} is smaller than required size {tensor_size}") + + return tensor + +def get_kernel_func_from_module( + module: ir.Module, func_name: str = "main_entry" +) -> func.FuncOp: + assert ( + len(module.operation.regions) == 1 + ), "Expected kernel module to have only one region" + assert ( + len(module.operation.regions[0].blocks) == 1 + ), "Expected kernel module to have only one block" + for f in module.operation.regions[0].blocks[0].operations: + if type(f) is func.FuncOp and str(f.name).strip('"') == func_name: + return f + raise ValueError("can not find the entry function") + + +def get_default_passes(): + passes = """ + any(gc-cpu-pipeline) + """ + return passes + + +def to_int_vector(s: str) -> List[int]: + if not s or len(s) == 0: + return [] + return [int(i) for i in s.strip().split("x")] + + +def to_bool_vector(s: str) -> List[bool]: + if not s or len(s) == 0: + return [] + return [bool(i) for i in s.strip().split("x")] + + +def load_mlir_from_path(path: str) -> str: + with open(path, "r") as file: + content = file.read() + return content + + +def walk_operations(op: ir.Operation, callback=None): + for region in op.regions: + for block in region: + for child_op in block: + if callback: + callback(child_op) + walk_operations(child_op, callback) + + +def get_all_tunable_ops(op: ir.Operation): + tunable_ops = [] + for region in op.regions: + for block in region: + for child_op in block: + if ( + "skipTuner" in child_op.attributes + and child_op.attributes["skipTuner"] + ): + continue + if child_op.name == "onednn_graph.matmul": + tunable_ops.append(child_op) + tunable_ops = tunable_ops + get_all_tunable_ops(child_op) + return tunable_ops + + +def gen_configs_from_ir(ir_module: ir.Module): + tunable_ops = get_all_tunable_ops(ir_module.operation) + configs = [] + for op in tunable_ops: + if op.name == "onednn_graph.matmul": + configs.append(MatMulConfig(op)) + return configs + + +def attach_configs_to_ir(ir_module: ir.Module, configs: List[Config]): + tunable_ops = get_all_tunable_ops(ir_module.operation) + assert len(tunable_ops) == len( + configs + ), "tunable ops and configs should have the same length" + for i, op in enumerate(tunable_ops): + if op.name == "onednn_graph.matmul": + configs[i].attach_to_ir(op) diff --git a/tools/workloads/single_mm.mlir b/tools/workloads/single_mm.mlir new file mode 100644 index 000000000..705b2ea76 --- /dev/null +++ b/tools/workloads/single_mm.mlir @@ -0,0 +1,7 @@ +func.func @main_entry(%arg0: tensor<128x512xf32>, %arg1: tensor<512x64xf32>) -> tensor<128x64xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x64xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x64xf32>) -> tensor<128x64xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xf32>, tensor<512x64xf32>) outs(%1 : tensor<128x64xf32>) -> tensor<128x64xf32> + return %2 : tensor<128x64xf32> +} \ No newline at end of file diff --git a/tools/workloads/splited_mm.mlir b/tools/workloads/splited_mm.mlir new file mode 100644 index 000000000..07c1bb059 --- /dev/null +++ b/tools/workloads/splited_mm.mlir @@ -0,0 +1,12 @@ +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @main_entry(%arg0: tensor<128x256xf32>, %arg0_1: tensor<128x256xf32>, %arg1: tensor<256x64xf32>, %arg1_1: tensor<256x64xf32>) -> (tensor<128x64xf32>, tensor<128x64xf32>) attributes {llvm.emit_c_interface} { + %cst_3 = arith.constant 0.000000e+00 : f32 + %2 = tensor.empty() : tensor<128x64xf32> + %3 = linalg.fill ins(%cst_3 : f32) outs(%2 : tensor<128x64xf32>) -> tensor<128x64xf32> + %4 = linalg.matmul {splited = true} ins(%arg0, %arg1 : tensor<128x256xf32>, tensor<256x64xf32>) outs(%3 : tensor<128x64xf32>) -> tensor<128x64xf32> + %cst_4 = arith.constant 0.000000e+00 : f32 + %5 = tensor.empty() : tensor<128x64xf32> + %6 = linalg.fill ins(%cst_4 : f32) outs(%5 : tensor<128x64xf32>) -> tensor<128x64xf32> + %7 = linalg.matmul {splited = true} ins(%arg0_1, %arg1_1 : tensor<128x256xf32>, tensor<256x64xf32>) outs(%6 : tensor<128x64xf32>) -> tensor<128x64xf32> + return %4, %7 : tensor<128x64xf32>, tensor<128x64xf32> +} \ No newline at end of file