From 3a15525461daa9dee683a0e77613d5bf8dd7dfb3 Mon Sep 17 00:00:00 2001 From: Hanchen Ye Date: Mon, 25 Mar 2024 23:45:09 -0500 Subject: [PATCH] Implement generate-dataflow-hierarchy pass --- include/scalehls/Dialect/HLS/IR/HLSOps.td | 12 +- .../scalehls/Dialect/HLS/Transforms/Passes.td | 7 +- lib/Dialect/HLS/IR/HLSOps.cpp | 4 +- lib/Dialect/HLS/Transforms/CMakeLists.txt | 1 + .../Transforms/GenerateDataflowHierarchy.cpp | 103 ++++++++++++++++++ .../HLS/Transforms/ScheduleDataflow.cpp | 64 +---------- python/scalehls/transforms.py | 20 +++- 7 files changed, 138 insertions(+), 73 deletions(-) create mode 100644 lib/Dialect/HLS/Transforms/GenerateDataflowHierarchy.cpp diff --git a/include/scalehls/Dialect/HLS/IR/HLSOps.td b/include/scalehls/Dialect/HLS/IR/HLSOps.td index 7939ee36..ce9e0872 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSOps.td @@ -386,11 +386,19 @@ def BufferOp : HLSOp<"buffer", [ def TaskOp : HLSOp<"task", [SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Represent a dataflow task"; - let arguments = (ins Variadic:$inits); + let arguments = (ins Variadic:$inits, OptionalAttr:$name); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); + let builders = [ + OpBuilder<(ins "mlir::ValueRange":$inits, "mlir::StringAttr":$name), + "build($_builder, $_state, inits, inits, name);">, + OpBuilder<(ins "mlir::ValueRange":$inits), + "build($_builder, $_state, inits, inits, nullptr);"> + ]; + let assemblyFormat = [{ - (`inits` $inits^)? `:` functional-type(operands, results) $body attr-dict + $name (`inits` $inits^)? $body attr-dict `:` + functional-type(operands, results) }]; let hasVerifier = 1; diff --git a/include/scalehls/Dialect/HLS/Transforms/Passes.td b/include/scalehls/Dialect/HLS/Transforms/Passes.td index 8f47a3ed..5dd6edc7 100644 --- a/include/scalehls/Dialect/HLS/Transforms/Passes.td +++ b/include/scalehls/Dialect/HLS/Transforms/Passes.td @@ -22,8 +22,13 @@ def PackITensorDMA : Pass<"scalehls-pack-itensor-dma", "func::FuncOp"> { let summary = "Pack/unpack itensor DMAs"; } +def GenerateDataflowHierarchy : + Pass<"scalehls-generate-dataflow-hierarchy", "func::FuncOp"> { + let summary = "Generate dataflow hierarchy"; +} + def ScheduleDataflow : Pass<"scalehls-schedule-dataflow", "func::FuncOp"> { - let summary = "Create a dataflow schedule"; + let summary = "Schedule dataflow"; } def GenerateRuntimeFunc : diff --git a/lib/Dialect/HLS/IR/HLSOps.cpp b/lib/Dialect/HLS/IR/HLSOps.cpp index d10f2b00..7330bc36 100644 --- a/lib/Dialect/HLS/IR/HLSOps.cpp +++ b/lib/Dialect/HLS/IR/HLSOps.cpp @@ -448,8 +448,8 @@ struct FoldTaskIterArgs : public OpRewritePattern { if (!canonicalize) return failure(); - TaskOp newtask = rewriter.create( - task.getLoc(), TypeRange(newIterArgs), newIterArgs); + TaskOp newtask = + rewriter.create(task.getLoc(), newIterArgs, task.getNameAttr()); newtask->setAttrs(task->getAttrs()); Block *newBlock = rewriter.createBlock( &newtask.getBody(), newtask.getBody().begin(), TypeRange(newIterArgs), diff --git a/lib/Dialect/HLS/Transforms/CMakeLists.txt b/lib/Dialect/HLS/Transforms/CMakeLists.txt index 8975c3ea..274e3aa1 100644 --- a/lib/Dialect/HLS/Transforms/CMakeLists.txt +++ b/lib/Dialect/HLS/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRScaleHLSHLSTransforms ScalarizeITensor.cpp LowerITensorToStream.cpp ScheduleDataflow.cpp + GenerateDataflowHierarchy.cpp ConvertDataflowToFunc.cpp TransformInterpreter.cpp ComprehensiveBufferize.cpp diff --git a/lib/Dialect/HLS/Transforms/GenerateDataflowHierarchy.cpp b/lib/Dialect/HLS/Transforms/GenerateDataflowHierarchy.cpp new file mode 100644 index 00000000..a4598178 --- /dev/null +++ b/lib/Dialect/HLS/Transforms/GenerateDataflowHierarchy.cpp @@ -0,0 +1,103 @@ +//===----------------------------------------------------------------------===// +// +// Copyright 2020-2021 The ScaleHLS Authors. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dominance.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "scalehls/Dialect/HLS/Transforms/Passes.h" +#include "scalehls/Utils/Utils.h" + +using namespace mlir; +using namespace scalehls; +using namespace hls; + +namespace mlir { +namespace scalehls { +namespace hls { +#define GEN_PASS_DEF_GENERATEDATAFLOWHIERARCHY +#include "scalehls/Dialect/HLS/Transforms/Passes.h.inc" +} // namespace hls +} // namespace scalehls +} // namespace mlir + +static hls::TaskOp wrapOpIntoTask(Operation *op, StringRef taskName, + SmallVectorImpl &destOperands, + OpBuilder &builder) { + builder.setInsertionPoint(op); + auto task = builder.create(op->getLoc(), destOperands, + builder.getStringAttr(taskName)); + op->replaceAllUsesWith(task.getResults()); + auto taskBlock = builder.createBlock( + &task.getBody(), task.getBody().end(), TypeRange(destOperands), + llvm::map_to_vector(destOperands, [&](Value v) { return v.getLoc(); })); + + builder.setInsertionPointToEnd(taskBlock); + auto yieldOp = builder.create(op->getLoc(), op->getResults()); + + op->moveBefore(yieldOp); + for (auto [destOperand, taskBlockArg] : + llvm::zip(destOperands, taskBlock->getArguments())) + destOperand.replaceUsesWithIf( + taskBlockArg, [&](OpOperand &use) { return use.getOwner() == op; }); + return task; +} + +static LogicalResult generateTasksInBlock(StringRef prefix, Block *block, + OpBuilder &builder) { + if (!isa(block->getParentOp())) + return block->getParentOp()->emitOpError("expected a FuncOp or a ForOp"); + + // Collect all ops that need to be wrapped into tasks. + SmallVector>> opsToWrap; + for (auto &op : *block) { + if (auto loop = dyn_cast(op)) + opsToWrap.push_back({loop, loop.getInitArgs()}); + else if (auto writeOp = dyn_cast(op)) + opsToWrap.push_back({writeOp, {writeOp.getDest()}}); + else if (auto destStyleOp = dyn_cast(op)) { + // Because tensor insertion-like ops will be eliminated in the tensor + // bufferization pass, we don't need to wrap them into tasks. + if (!isa(op)) + opsToWrap.push_back({destStyleOp, destStyleOp.getDpsInits()}); + } + } + + // Handle cases when there is no op to wrap or only one op to wrap. + if (opsToWrap.empty()) + return success(); + else if (llvm::hasSingleElement(opsToWrap)) { + if (auto loop = dyn_cast(opsToWrap.front().first)) + return generateTasksInBlock(prefix, loop.getBody(), builder); + else + return success(); + } + + // Generate tasks for all ops that need to be wrapped. + unsigned taskId = 0; + for (auto [op, destOperands] : opsToWrap) { + std::string taskName = prefix.str() + "_" + std::to_string(taskId++); + if (auto loop = dyn_cast(op)) + if (failed(generateTasksInBlock(taskName, loop.getBody(), builder))) + return failure(); + wrapOpIntoTask(op, taskName, destOperands, builder); + } + return success(); +} + +namespace { +struct GenerateDataflowHierarchy + : public hls::impl::GenerateDataflowHierarchyBase< + GenerateDataflowHierarchy> { + void runOnOperation() override { + auto func = getOperation(); + auto builder = OpBuilder(func); + if (failed(generateTasksInBlock(func.getName(), &func.front(), builder))) { + signalPassFailure(); + } + } +}; +} // namespace diff --git a/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp b/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp index 64a3b6e1..345d4140 100644 --- a/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp +++ b/lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp @@ -23,71 +23,9 @@ namespace hls { } // namespace scalehls } // namespace mlir -hls::TaskOp wrapOpIntoTask(Operation *op, StringRef taskName, - ValueRange destOperands, OpBuilder &builder) { - auto destTypes = TypeRange(destOperands); - - builder.setInsertionPoint(op); - auto task = builder.create(op->getLoc(), destTypes, destOperands); - op->replaceAllUsesWith(task.getResults()); - task->setAttr(taskName, builder.getUnitAttr()); - auto taskBlock = builder.createBlock( - &task.getBody(), task.getBody().end(), destTypes, - llvm::map_to_vector(destOperands, [&](Value v) { return v.getLoc(); })); - - builder.setInsertionPointToEnd(taskBlock); - auto yieldOp = builder.create(op->getLoc(), op->getResults()); - - op->moveBefore(yieldOp); - for (auto [destOperand, taskBlockArg] : - llvm::zip(destOperands, taskBlock->getArguments())) - destOperand.replaceUsesWithIf( - taskBlockArg, [&](OpOperand &use) { return use.getOwner() == op; }); - return task; -} - -static LogicalResult scheduleBlock(StringRef prefix, Block *block, - OpBuilder &builder) { - if (!isa(block->getParentOp())) - return block->getParentOp()->emitOpError("expected a FuncOp or a ForOp"); - - unsigned taskId = 0; - for (auto &op : llvm::make_early_inc_range(block->getOperations())) { - std::string taskName = prefix.str() + "_" + std::to_string(taskId); - - ValueRange destOperands; - - if (auto loop = dyn_cast(op)) { - if (failed(scheduleBlock(taskName, loop.getBody(), builder))) - return failure(); - destOperands = loop.getInitArgs(); - } else if (isa(op)) { - // TODO: For now, tensor insert ops are not scheduled into separate tasks - // as they will be handled in the bufferization passes. - continue; - } else if (auto destStyleOp = dyn_cast(op)) { - destOperands = destStyleOp.getDpsInits(); - } else if (auto writeOp = dyn_cast(op)) { - destOperands = writeOp.getDest(); - } else - continue; - - wrapOpIntoTask(&op, taskName, destOperands, builder); - taskId++; - } - return success(); -} - namespace { struct ScheduleDataflow : public hls::impl::ScheduleDataflowBase { - void runOnOperation() override { - auto func = getOperation(); - auto builder = OpBuilder(func); - if (failed(scheduleBlock(func.getName(), &func.front(), builder))) { - signalPassFailure(); - } - } + void runOnOperation() override { auto func = getOperation(); } }; } // namespace diff --git a/python/scalehls/transforms.py b/python/scalehls/transforms.py index 84519164..c00a048d 100644 --- a/python/scalehls/transforms.py +++ b/python/scalehls/transforms.py @@ -25,6 +25,9 @@ # ===----------------------------------------------------------------------=== # +k_id_attr_name = "__id__" + + def apply_transform_sequence( module: Module, sequence: transform.NamedSequenceOp, @@ -88,6 +91,13 @@ def apply_comprehensive_bufferize_passes(module: Module): pm.run(module.operation) +def apply_generate_dataflow_hierarchy(module: Module): + pm = PassManager.parse( + "builtin.module(func.func(scalehls-generate-dataflow-hierarchy)," + "cse, canonicalize)") + pm.run(module.operation) + + def apply_schedule_dataflow(module: Module): pm = PassManager.parse( "builtin.module(func.func(scalehls-schedule-dataflow)," @@ -488,7 +498,7 @@ def __init__(self, module: Module, top_name: str = "forward"): self.add_node(self.top, name=self.top.OPERATION_NAME, id=-1) for id, op in enumerate(self.top.entry_block): self.add_node(op, name=op.OPERATION_NAME, id=id) - op.attributes["id"] = i64_attr(id) + op.attributes[k_id_attr_name] = i64_attr(id) for operand in op.operands: parent = operand.owner.owner if isinstance( operand.owner, Block) else operand.owner @@ -623,7 +633,7 @@ def construct_transform_sequence(target: BlockArgument, """ for node, data in graph.nodes(data=True): node_handle = match(target, [data["name"]], { - "id": i64_attr(data["id"])}) + k_id_attr_name: i64_attr(data["id"])}) if isinstance(node, linalg.GenericOp): if "parallel_tile_sizes" not in data: @@ -642,7 +652,7 @@ def construct_transform_sequence(target: BlockArgument, data["unroll_sizes"], data["permutation"], len(node.inputs) > 0) - annotate(linalg_op_handle, "id", i64_param(data["id"])) + annotate(linalg_op_handle, k_id_attr_name, i64_param(data["id"])) if isinstance(node, tensor.ExpandShapeOp): if "source_tile_sizes" not in data: @@ -655,7 +665,7 @@ def construct_transform_sequence(target: BlockArgument, data["source_tile_sizes"], data["result_tile_sizes"]) annotate(convert_op.itensor_reassociate, - "id", i64_param(data["id"])) + k_id_attr_name, i64_param(data["id"])) if isinstance(node, tensor.CollapseShapeOp): if "source_tile_sizes" not in data: @@ -668,7 +678,7 @@ def construct_transform_sequence(target: BlockArgument, data["source_tile_sizes"], data["result_tile_sizes"]) annotate(convert_op.itensor_reassociate, - "id", i64_param(data["id"])) + k_id_attr_name, i64_param(data["id"])) return []