Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy fix constant splitter taking too long. #251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions shardy/dialect/sdy/transforms/import/constant_splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ namespace {

using func::FuncOp;

// Returns true if `value` has multiple uses that are not by `ShardingGroupOp`.
bool hasMultipleRealUses(Value value) {
// Returns true if `op` has multiple uses that are not by `ShardingGroupOp`.
bool hasMultipleRealUses(Operation* op) {
bool seenUse = false;
for (Operation* user : value.getUsers()) {
for (Operation* user : op->getUsers()) {
if (!isa<ShardingGroupOp>(user)) {
if (seenUse) {
return true;
Expand Down Expand Up @@ -156,25 +156,25 @@ struct ConstantSplitterPass
// Then we split constant sub-computations for each non-constant user.
llvm::DenseSet<Operation*> constantOps;
funcOp.walk([&](Operation* op) {
if (isa<ShardingGroupOp>(op)) {
if (!isConstantExpression(op, constantOps) || isa<ShardingGroupOp>(op)) {
return;
}
if (isConstantExpression(op, constantOps)) {
// `op` is a constant expression.
constantOps.insert(op);

// `op` is a constant expression.
constantOps.insert(op);

if (!hasMultipleRealUses(op)) {
// No need to split `op`.
return;
}

for (OpOperand& operand : op->getOpOperands()) {
// For each operand that is produced by a constant sub-computation
// (exists in `constantOps`) that has multiples uses, we recursively
// clone the sub-computation whose root is the defining op, and replace
// the operand with the cloned defining op. This will ensure that by the
// end of this walk, all constant sub-computations will have a single
// user.
if (auto defOpResult = dyn_cast<OpResult>(operand.get());
defOpResult && constantOps.contains(defOpResult.getOwner()) &&
hasMultipleRealUses(defOpResult)) {
operand.set(cloneSubComputation(defOpResult));
// For each use of `op`, we recursively clone the sub-computation whose
// root is `op`, and replace the use with the cloned defining op. This
// will ensure that by the end of this walk, all constant sub-computations
// will have a single user.
for (OpOperand& use : op->getUses()) {
if (!isa<ShardingGroupOp>(use.getOwner())) {
use.set(cloneSubComputation(cast<OpResult>(use.get())));
}
}
});
Expand Down
Loading