Skip to content

Commit

Permalink
Implement generate-dataflow-hierarchy pass
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Mar 26, 2024
1 parent e5e7e69 commit 3a15525
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 73 deletions.
12 changes: 10 additions & 2 deletions include/scalehls/Dialect/HLS/IR/HLSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -386,11 +386,19 @@ def BufferOp : HLSOp<"buffer", [
def TaskOp : HLSOp<"task", [SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Represent a dataflow task";

let arguments = (ins Variadic<AnyType>:$inits);
let arguments = (ins Variadic<AnyType>:$inits, OptionalAttr<StrAttr>:$name);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$body);
let builders = [
OpBuilder<(ins "mlir::ValueRange":$inits, "mlir::StringAttr":$name),
"build($_builder, $_state, inits, inits, name);">,
OpBuilder<(ins "mlir::ValueRange":$inits),
"build($_builder, $_state, inits, inits, nullptr);">
];

let assemblyFormat = [{
(`inits` $inits^)? `:` functional-type(operands, results) $body attr-dict
$name (`inits` $inits^)? $body attr-dict `:`
functional-type(operands, results)
}];

let hasVerifier = 1;
Expand Down
7 changes: 6 additions & 1 deletion include/scalehls/Dialect/HLS/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ def PackITensorDMA : Pass<"scalehls-pack-itensor-dma", "func::FuncOp"> {
let summary = "Pack/unpack itensor DMAs";
}

def GenerateDataflowHierarchy :
Pass<"scalehls-generate-dataflow-hierarchy", "func::FuncOp"> {
let summary = "Generate dataflow hierarchy";
}

def ScheduleDataflow : Pass<"scalehls-schedule-dataflow", "func::FuncOp"> {
let summary = "Create a dataflow schedule";
let summary = "Schedule dataflow";
}

def GenerateRuntimeFunc :
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/HLS/IR/HLSOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,8 @@ struct FoldTaskIterArgs : public OpRewritePattern<hls::TaskOp> {
if (!canonicalize)
return failure();

TaskOp newtask = rewriter.create<TaskOp>(
task.getLoc(), TypeRange(newIterArgs), newIterArgs);
TaskOp newtask =
rewriter.create<TaskOp>(task.getLoc(), newIterArgs, task.getNameAttr());
newtask->setAttrs(task->getAttrs());
Block *newBlock = rewriter.createBlock(
&newtask.getBody(), newtask.getBody().begin(), TypeRange(newIterArgs),
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/HLS/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRScaleHLSHLSTransforms
ScalarizeITensor.cpp
LowerITensorToStream.cpp
ScheduleDataflow.cpp
GenerateDataflowHierarchy.cpp
ConvertDataflowToFunc.cpp
TransformInterpreter.cpp
ComprehensiveBufferize.cpp
Expand Down
103 changes: 103 additions & 0 deletions lib/Dialect/HLS/Transforms/GenerateDataflowHierarchy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===----------------------------------------------------------------------===//
//
// Copyright 2020-2021 The ScaleHLS Authors.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "scalehls/Dialect/HLS/Transforms/Passes.h"
#include "scalehls/Utils/Utils.h"

using namespace mlir;
using namespace scalehls;
using namespace hls;

namespace mlir {
namespace scalehls {
namespace hls {
#define GEN_PASS_DEF_GENERATEDATAFLOWHIERARCHY
#include "scalehls/Dialect/HLS/Transforms/Passes.h.inc"
} // namespace hls
} // namespace scalehls
} // namespace mlir

static hls::TaskOp wrapOpIntoTask(Operation *op, StringRef taskName,
SmallVectorImpl<Value> &destOperands,
OpBuilder &builder) {
builder.setInsertionPoint(op);
auto task = builder.create<TaskOp>(op->getLoc(), destOperands,
builder.getStringAttr(taskName));
op->replaceAllUsesWith(task.getResults());
auto taskBlock = builder.createBlock(
&task.getBody(), task.getBody().end(), TypeRange(destOperands),
llvm::map_to_vector(destOperands, [&](Value v) { return v.getLoc(); }));

builder.setInsertionPointToEnd(taskBlock);
auto yieldOp = builder.create<YieldOp>(op->getLoc(), op->getResults());

op->moveBefore(yieldOp);
for (auto [destOperand, taskBlockArg] :
llvm::zip(destOperands, taskBlock->getArguments()))
destOperand.replaceUsesWithIf(
taskBlockArg, [&](OpOperand &use) { return use.getOwner() == op; });
return task;
}

static LogicalResult generateTasksInBlock(StringRef prefix, Block *block,
OpBuilder &builder) {
if (!isa<func::FuncOp, scf::ForOp>(block->getParentOp()))
return block->getParentOp()->emitOpError("expected a FuncOp or a ForOp");

// Collect all ops that need to be wrapped into tasks.
SmallVector<std::pair<Operation *, SmallVector<Value>>> opsToWrap;
for (auto &op : *block) {
if (auto loop = dyn_cast<scf::ForOp>(op))
opsToWrap.push_back({loop, loop.getInitArgs()});
else if (auto writeOp = dyn_cast<ITensorWriteLikeOpInterface>(op))
opsToWrap.push_back({writeOp, {writeOp.getDest()}});
else if (auto destStyleOp = dyn_cast<DestinationStyleOpInterface>(op)) {
// Because tensor insertion-like ops will be eliminated in the tensor
// bufferization pass, we don't need to wrap them into tasks.
if (!isa<tensor::InsertOp, tensor::InsertSliceOp,
tensor::ParallelInsertSliceOp>(op))
opsToWrap.push_back({destStyleOp, destStyleOp.getDpsInits()});
}
}

// Handle cases when there is no op to wrap or only one op to wrap.
if (opsToWrap.empty())
return success();
else if (llvm::hasSingleElement(opsToWrap)) {
if (auto loop = dyn_cast<scf::ForOp>(opsToWrap.front().first))
return generateTasksInBlock(prefix, loop.getBody(), builder);
else
return success();
}

// Generate tasks for all ops that need to be wrapped.
unsigned taskId = 0;
for (auto [op, destOperands] : opsToWrap) {
std::string taskName = prefix.str() + "_" + std::to_string(taskId++);
if (auto loop = dyn_cast<scf::ForOp>(op))
if (failed(generateTasksInBlock(taskName, loop.getBody(), builder)))
return failure();
wrapOpIntoTask(op, taskName, destOperands, builder);
}
return success();
}

namespace {
struct GenerateDataflowHierarchy
: public hls::impl::GenerateDataflowHierarchyBase<
GenerateDataflowHierarchy> {
void runOnOperation() override {
auto func = getOperation();
auto builder = OpBuilder(func);
if (failed(generateTasksInBlock(func.getName(), &func.front(), builder))) {
signalPassFailure();
}
}
};
} // namespace
64 changes: 1 addition & 63 deletions lib/Dialect/HLS/Transforms/ScheduleDataflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,71 +23,9 @@ namespace hls {
} // namespace scalehls
} // namespace mlir

hls::TaskOp wrapOpIntoTask(Operation *op, StringRef taskName,
ValueRange destOperands, OpBuilder &builder) {
auto destTypes = TypeRange(destOperands);

builder.setInsertionPoint(op);
auto task = builder.create<TaskOp>(op->getLoc(), destTypes, destOperands);
op->replaceAllUsesWith(task.getResults());
task->setAttr(taskName, builder.getUnitAttr());
auto taskBlock = builder.createBlock(
&task.getBody(), task.getBody().end(), destTypes,
llvm::map_to_vector(destOperands, [&](Value v) { return v.getLoc(); }));

builder.setInsertionPointToEnd(taskBlock);
auto yieldOp = builder.create<YieldOp>(op->getLoc(), op->getResults());

op->moveBefore(yieldOp);
for (auto [destOperand, taskBlockArg] :
llvm::zip(destOperands, taskBlock->getArguments()))
destOperand.replaceUsesWithIf(
taskBlockArg, [&](OpOperand &use) { return use.getOwner() == op; });
return task;
}

static LogicalResult scheduleBlock(StringRef prefix, Block *block,
OpBuilder &builder) {
if (!isa<func::FuncOp, scf::ForOp>(block->getParentOp()))
return block->getParentOp()->emitOpError("expected a FuncOp or a ForOp");

unsigned taskId = 0;
for (auto &op : llvm::make_early_inc_range(block->getOperations())) {
std::string taskName = prefix.str() + "_" + std::to_string(taskId);

ValueRange destOperands;

if (auto loop = dyn_cast<scf::ForOp>(op)) {
if (failed(scheduleBlock(taskName, loop.getBody(), builder)))
return failure();
destOperands = loop.getInitArgs();
} else if (isa<tensor::InsertOp, tensor::InsertSliceOp,
tensor::ParallelInsertSliceOp>(op)) {
// TODO: For now, tensor insert ops are not scheduled into separate tasks
// as they will be handled in the bufferization passes.
continue;
} else if (auto destStyleOp = dyn_cast<DestinationStyleOpInterface>(op)) {
destOperands = destStyleOp.getDpsInits();
} else if (auto writeOp = dyn_cast<ITensorWriteLikeOpInterface>(op)) {
destOperands = writeOp.getDest();
} else
continue;

wrapOpIntoTask(&op, taskName, destOperands, builder);
taskId++;
}
return success();
}

namespace {
struct ScheduleDataflow
: public hls::impl::ScheduleDataflowBase<ScheduleDataflow> {
void runOnOperation() override {
auto func = getOperation();
auto builder = OpBuilder(func);
if (failed(scheduleBlock(func.getName(), &func.front(), builder))) {
signalPassFailure();
}
}
void runOnOperation() override { auto func = getOperation(); }
};
} // namespace
20 changes: 15 additions & 5 deletions python/scalehls/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
# ===----------------------------------------------------------------------=== #


k_id_attr_name = "__id__"


def apply_transform_sequence(
module: Module,
sequence: transform.NamedSequenceOp,
Expand Down Expand Up @@ -88,6 +91,13 @@ def apply_comprehensive_bufferize_passes(module: Module):
pm.run(module.operation)


def apply_generate_dataflow_hierarchy(module: Module):
pm = PassManager.parse(
"builtin.module(func.func(scalehls-generate-dataflow-hierarchy),"
"cse, canonicalize)")
pm.run(module.operation)


def apply_schedule_dataflow(module: Module):
pm = PassManager.parse(
"builtin.module(func.func(scalehls-schedule-dataflow),"
Expand Down Expand Up @@ -488,7 +498,7 @@ def __init__(self, module: Module, top_name: str = "forward"):
self.add_node(self.top, name=self.top.OPERATION_NAME, id=-1)
for id, op in enumerate(self.top.entry_block):
self.add_node(op, name=op.OPERATION_NAME, id=id)
op.attributes["id"] = i64_attr(id)
op.attributes[k_id_attr_name] = i64_attr(id)
for operand in op.operands:
parent = operand.owner.owner if isinstance(
operand.owner, Block) else operand.owner
Expand Down Expand Up @@ -623,7 +633,7 @@ def construct_transform_sequence(target: BlockArgument,
"""
for node, data in graph.nodes(data=True):
node_handle = match(target, [data["name"]], {
"id": i64_attr(data["id"])})
k_id_attr_name: i64_attr(data["id"])})

if isinstance(node, linalg.GenericOp):
if "parallel_tile_sizes" not in data:
Expand All @@ -642,7 +652,7 @@ def construct_transform_sequence(target: BlockArgument,
data["unroll_sizes"],
data["permutation"],
len(node.inputs) > 0)
annotate(linalg_op_handle, "id", i64_param(data["id"]))
annotate(linalg_op_handle, k_id_attr_name, i64_param(data["id"]))

if isinstance(node, tensor.ExpandShapeOp):
if "source_tile_sizes" not in data:
Expand All @@ -655,7 +665,7 @@ def construct_transform_sequence(target: BlockArgument,
data["source_tile_sizes"],
data["result_tile_sizes"])
annotate(convert_op.itensor_reassociate,
"id", i64_param(data["id"]))
k_id_attr_name, i64_param(data["id"]))

if isinstance(node, tensor.CollapseShapeOp):
if "source_tile_sizes" not in data:
Expand All @@ -668,7 +678,7 @@ def construct_transform_sequence(target: BlockArgument,
data["source_tile_sizes"],
data["result_tile_sizes"])
annotate(convert_op.itensor_reassociate,
"id", i64_param(data["id"]))
k_id_attr_name, i64_param(data["id"]))
return []


Expand Down

0 comments on commit 3a15525

Please sign in to comment.