-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement generate-dataflow-hierarchy pass
- Loading branch information
Showing
7 changed files
with
138 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
lib/Dialect/HLS/Transforms/GenerateDataflowHierarchy.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// Copyright 2020-2021 The ScaleHLS Authors. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/IR/Dominance.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "scalehls/Dialect/HLS/Transforms/Passes.h" | ||
#include "scalehls/Utils/Utils.h" | ||
|
||
using namespace mlir; | ||
using namespace scalehls; | ||
using namespace hls; | ||
|
||
namespace mlir { | ||
namespace scalehls { | ||
namespace hls { | ||
#define GEN_PASS_DEF_GENERATEDATAFLOWHIERARCHY | ||
#include "scalehls/Dialect/HLS/Transforms/Passes.h.inc" | ||
} // namespace hls | ||
} // namespace scalehls | ||
} // namespace mlir | ||
|
||
static hls::TaskOp wrapOpIntoTask(Operation *op, StringRef taskName, | ||
SmallVectorImpl<Value> &destOperands, | ||
OpBuilder &builder) { | ||
builder.setInsertionPoint(op); | ||
auto task = builder.create<TaskOp>(op->getLoc(), destOperands, | ||
builder.getStringAttr(taskName)); | ||
op->replaceAllUsesWith(task.getResults()); | ||
auto taskBlock = builder.createBlock( | ||
&task.getBody(), task.getBody().end(), TypeRange(destOperands), | ||
llvm::map_to_vector(destOperands, [&](Value v) { return v.getLoc(); })); | ||
|
||
builder.setInsertionPointToEnd(taskBlock); | ||
auto yieldOp = builder.create<YieldOp>(op->getLoc(), op->getResults()); | ||
|
||
op->moveBefore(yieldOp); | ||
for (auto [destOperand, taskBlockArg] : | ||
llvm::zip(destOperands, taskBlock->getArguments())) | ||
destOperand.replaceUsesWithIf( | ||
taskBlockArg, [&](OpOperand &use) { return use.getOwner() == op; }); | ||
return task; | ||
} | ||
|
||
static LogicalResult generateTasksInBlock(StringRef prefix, Block *block, | ||
OpBuilder &builder) { | ||
if (!isa<func::FuncOp, scf::ForOp>(block->getParentOp())) | ||
return block->getParentOp()->emitOpError("expected a FuncOp or a ForOp"); | ||
|
||
// Collect all ops that need to be wrapped into tasks. | ||
SmallVector<std::pair<Operation *, SmallVector<Value>>> opsToWrap; | ||
for (auto &op : *block) { | ||
if (auto loop = dyn_cast<scf::ForOp>(op)) | ||
opsToWrap.push_back({loop, loop.getInitArgs()}); | ||
else if (auto writeOp = dyn_cast<ITensorWriteLikeOpInterface>(op)) | ||
opsToWrap.push_back({writeOp, {writeOp.getDest()}}); | ||
else if (auto destStyleOp = dyn_cast<DestinationStyleOpInterface>(op)) { | ||
// Because tensor insertion-like ops will be eliminated in the tensor | ||
// bufferization pass, we don't need to wrap them into tasks. | ||
if (!isa<tensor::InsertOp, tensor::InsertSliceOp, | ||
tensor::ParallelInsertSliceOp>(op)) | ||
opsToWrap.push_back({destStyleOp, destStyleOp.getDpsInits()}); | ||
} | ||
} | ||
|
||
// Handle cases when there is no op to wrap or only one op to wrap. | ||
if (opsToWrap.empty()) | ||
return success(); | ||
else if (llvm::hasSingleElement(opsToWrap)) { | ||
if (auto loop = dyn_cast<scf::ForOp>(opsToWrap.front().first)) | ||
return generateTasksInBlock(prefix, loop.getBody(), builder); | ||
else | ||
return success(); | ||
} | ||
|
||
// Generate tasks for all ops that need to be wrapped. | ||
unsigned taskId = 0; | ||
for (auto [op, destOperands] : opsToWrap) { | ||
std::string taskName = prefix.str() + "_" + std::to_string(taskId++); | ||
if (auto loop = dyn_cast<scf::ForOp>(op)) | ||
if (failed(generateTasksInBlock(taskName, loop.getBody(), builder))) | ||
return failure(); | ||
wrapOpIntoTask(op, taskName, destOperands, builder); | ||
} | ||
return success(); | ||
} | ||
|
||
namespace { | ||
struct GenerateDataflowHierarchy | ||
: public hls::impl::GenerateDataflowHierarchyBase< | ||
GenerateDataflowHierarchy> { | ||
void runOnOperation() override { | ||
auto func = getOperation(); | ||
auto builder = OpBuilder(func); | ||
if (failed(generateTasksInBlock(func.getName(), &func.front(), builder))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters