Skip to content

Commit

Permalink
[FIRRTL][InferWidths] Make solveExpr Iterative (#5305)
Browse files Browse the repository at this point in the history
The `solveExpr` implementation of `InferWidths` was recursive and causing stack
 overflow issues on some designs. This PR implements an iterative version
 of the algorithm.

---------

Co-authored-by: Schuyler Eldridge <[email protected]>
  • Loading branch information
prithayan and seldridge authored Jun 2, 2023
1 parent 1b426fe commit 5afd784
Showing 1 changed file with 130 additions and 81 deletions.
211 changes: 130 additions & 81 deletions lib/Dialect/FIRRTL/Transforms/InferWidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,95 +822,142 @@ computeBinary(ExprSolution lhs, ExprSolution rhs,
/// to memoize the result of expressions in case they were not involved in a
/// cycle (which may alter their value from the perspective of a variable).
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
unsigned indent = 1) {
// See if we have a memoized result we can return.
if (expr->solution) {
unsigned defaultWorklistSize) {

struct Frame {
Expr *expr;
unsigned indent;
};

// indent only used for debug logs.
unsigned indent = 1;
std::vector<Frame> worklist({{expr, indent}});
llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
// Reserving the vector size, to avoid frequent reallocs. The worklist can be
// quite large.
worklist.reserve(defaultWorklistSize);

while (!worklist.empty()) {
auto &frame = worklist.back();
auto setSolution = [&](ExprSolution solution) {
// Memoize the result.
if (solution.first && !solution.second)
frame.expr->solution = *solution.first;
solvedExprs[frame.expr] = solution;

// Produce some useful debug prints.
LLVM_DEBUG({
if (!isa<KnownExpr>(frame.expr)) {
if (solution.first)
llvm::dbgs().indent(frame.indent * 2)
<< "= Solved " << *frame.expr << " = " << *solution.first;
else
llvm::dbgs().indent(frame.indent * 2)
<< "= Skipped " << *frame.expr;
llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
<< ")\n";
}
});

worklist.pop_back();
};

// See if we have a memoized result we can return.
if (frame.expr->solution) {
LLVM_DEBUG({
if (!isa<KnownExpr>(frame.expr))
llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
<< *frame.expr->solution << "\n";
});
setSolution(ExprSolution{*frame.expr->solution, false});
continue;
}

// Otherwise compute the value of the expression.
LLVM_DEBUG({
if (!isa<KnownExpr>(expr))
llvm::dbgs().indent(indent * 2)
<< "- Cached " << *expr << " = " << *expr->solution << "\n";
if (!isa<KnownExpr>(frame.expr))
llvm::dbgs().indent(frame.indent * 2)
<< "- Solving " << *frame.expr << "\n";
});
return {*expr->solution, false};
}

// Otherwise compute the value of the expression.
LLVM_DEBUG({
if (!isa<KnownExpr>(expr))
llvm::dbgs().indent(indent * 2) << "- Solving " << *expr << "\n";
});
auto solution =
TypeSwitch<Expr *, ExprSolution>(expr)
.Case<KnownExpr>([&](auto *expr) {
return ExprSolution{*expr->solution, false};
})
.Case<VarExpr>([&](auto *expr) {
// Unconstrained variables produce no solution.
if (!expr->constraint)
return ExprSolution{std::nullopt, false};
// Return no solution for recursions in the variables. This is sane
// and will cause the expression to be ignored when computing the
// parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
if (!seenVars.insert(expr).second)
return ExprSolution{std::nullopt, true};
auto solution = solveExpr(expr->constraint, seenVars, indent + 1);
TypeSwitch<Expr *>(frame.expr)
.Case<KnownExpr>([&](auto *expr) {
setSolution(ExprSolution{*expr->solution, false});
})
.Case<VarExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->constraint)) {
auto solution = solvedExprs[expr->constraint];
if (expr->upperBound)
expr->upperBoundSolution =
solveExpr(expr->upperBound, seenVars, indent + 1).first;
expr->upperBoundSolution = solvedExprs[expr->upperBound].second;

seenVars.erase(expr);
// Constrain variables >= 0.
if (solution.first && *solution.first < 0)
solution.first = 0;
return solution;
})
.Case<IdExpr>([&](auto *expr) {
return solveExpr(expr->arg, seenVars, indent + 1);
})
.Case<PowExpr>([&](auto *expr) {
auto arg = solveExpr(expr->arg, seenVars, indent + 1);
return computeUnary(arg, [](int32_t arg) { return 1 << arg; });
})
.Case<AddExpr>([&](auto *expr) {
auto lhs = solveExpr(expr->lhs(), seenVars, indent + 1);
auto rhs = solveExpr(expr->rhs(), seenVars, indent + 1);
return computeBinary(
lhs, rhs, [](int32_t lhs, int32_t rhs) { return lhs + rhs; });
})
.Case<MaxExpr>([&](auto *expr) {
auto lhs = solveExpr(expr->lhs(), seenVars, indent + 1);
auto rhs = solveExpr(expr->rhs(), seenVars, indent + 1);
return computeBinary(lhs, rhs, [](int32_t lhs, int32_t rhs) {
return std::max(lhs, rhs);
});
})
.Case<MinExpr>([&](auto *expr) {
auto lhs = solveExpr(expr->lhs(), seenVars, indent + 1);
auto rhs = solveExpr(expr->rhs(), seenVars, indent + 1);
return computeBinary(lhs, rhs, [](int32_t lhs, int32_t rhs) {
return std::min(lhs, rhs);
});
})
.Default([](auto) {
return ExprSolution{std::nullopt, false};
});

// Memoize the result.
if (solution.first && !solution.second)
expr->solution = *solution.first;
return setSolution(solution);
}

// Produce some useful debug prints.
LLVM_DEBUG({
if (!isa<KnownExpr>(expr)) {
if (solution.first)
llvm::dbgs().indent(indent * 2)
<< "= Solved " << *expr << " = " << *solution.first;
else
llvm::dbgs().indent(indent * 2) << "= Skipped " << *expr;
llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
<< ")\n";
}
});
// Unconstrained variables produce no solution.
if (!expr->constraint)
return setSolution(ExprSolution{std::nullopt, false});
// Return no solution for recursions in the variables. This is sane
// and will cause the expression to be ignored when computing the
// parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
if (!seenVars.insert(expr).second)
return setSolution(ExprSolution{std::nullopt, true});

worklist.push_back({expr->constraint, indent + 1});
if (expr->upperBound)
worklist.push_back({expr->upperBound, indent + 1});
})
.Case<IdExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->arg))
return setSolution(solvedExprs[expr->arg]);
worklist.push_back({expr->arg, indent + 1});
})
.Case<PowExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->arg))
return setSolution(computeUnary(
solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));

worklist.push_back({expr->arg, indent + 1});
})
.Case<AddExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->lhs()) &&
solvedExprs.contains(expr->rhs()))
return setSolution(computeBinary(
solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
[](int32_t lhs, int32_t rhs) { return lhs + rhs; }));

worklist.push_back({expr->lhs(), indent + 1});
worklist.push_back({expr->rhs(), indent + 1});
})
.Case<MaxExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->lhs()) &&
solvedExprs.contains(expr->rhs()))
return setSolution(computeBinary(
solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
[](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));

worklist.push_back({expr->lhs(), indent + 1});
worklist.push_back({expr->rhs(), indent + 1});
})
.Case<MinExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->lhs()) &&
solvedExprs.contains(expr->rhs()))
return setSolution(computeBinary(
solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
[](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));

worklist.push_back({expr->lhs(), indent + 1});
worklist.push_back({expr->rhs(), indent + 1});
})
.Default([&](auto) {
setSolution(ExprSolution{std::nullopt, false});
});
}

return solution;
return solvedExprs[expr];
}

/// Solve the constraint problem. This is a very simple implementation that
Expand Down Expand Up @@ -984,6 +1031,7 @@ LogicalResult ConstraintSolver::solve() {

// Iterate over the constraint variables and solve each.
LLVM_DEBUG(llvm::dbgs() << "\n===----- Solving constraints -----===\n\n");
unsigned defaultWorklistSize = exprs.size() / 2;
for (auto *expr : exprs) {
// Only work on variables.
auto *var = dyn_cast<VarExpr>(expr);
Expand All @@ -1002,9 +1050,10 @@ LogicalResult ConstraintSolver::solve() {
LLVM_DEBUG(llvm::dbgs()
<< "- Solving " << *var << " >= " << *var->constraint << "\n");
seenVars.insert(var);
auto solution = solveExpr(var->constraint, seenVars);
auto solution = solveExpr(var->constraint, seenVars, defaultWorklistSize);
if (var->upperBound)
var->upperBoundSolution = solveExpr(var->upperBound, seenVars).first;
var->upperBoundSolution =
solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
seenVars.clear();

// Constrain variables >= 0.
Expand Down

0 comments on commit 5afd784

Please sign in to comment.