Skip to content

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 10, 2023

Concretization requires two fundamental operations:

  1. Concretize individual dynamic operations. For example, we need to set the root to rfactor transforms in a dynamic reshape.
  2. Propagate information downstream. This is because the IterTypes and extent expressions for IterDomains might change during step 1 but downstream expressions have already been defined using the original symbolic IterDomains. This propagation step ensures that we have no dangling symbolic IterDomains and that extent expressions are properly replaced so that everything is consistent and appears as it would have if the Fusion had been static at definition.

The downstream propagation in the second operation needs to pass through IterDomain expressions like those found in pad or reshape ops. But it also needs to cross from producer TVs to consumers. Those jumps are more complicated to handle in mutate(IterDomain*), since we cannot use the standard definition/uses machinery to find tensor expressions linking IterDomains in this way. Instead, currently we use mutate(TensorView* tv) and call propagateProducerToConsumer() at the beginning of that method. That method finds exact mapped producer IDs using tv->definition() and PairwiseRootDomainMap and propagates their concretization information to the corresponding root IDs in tv. The root IDs are then updated and propagation along IterDomain expressions is done between root and rfactor of tv.

The problem comes when we'd like to propagate information in the root->rfactor expressions. Since we do not modify the Fusion during traversal, these expressions are fixed. However, when we call propagateProducerToConsumer() we might register a concretization of some root IterDomains. These concretizations should ideally propagate in all of their uses. However, propagateProducerToConsumer() is called in mutate(TensorView* tv) which is called after all of tv's dependencies have been mutated. In particular, tv->domain() and all IterDomains in it must be mutated before this point. Subsequently StmtSort::getExprs is called on the unmutated root and rfactor, requiring manual mutation of intermediate exprs. That is why in the current code we need to perform another manual traversal from root to rfactor inside mutate(tv), which is the cause of some unneeded complexity since those IterDomains might be mutated multiple times.

This PR addresses this by splitting the traversal into three loops, each of which is done in topological order:

  1. First we loop over all Vals which are neither TensorDomains or TensorViews and call mutate(val) on each. We exclude TDs and TVs here since OptOutMutator::mutate actually modifies the Fusion when these are called. This loop does not modify the Fusion at all, but registers mutations for most Vals.
  2. The second loop is over all Exprs. This actually will remove any expression with registered mutations of its inputs, outputs, or attributes, and replaces it with a new one linking those new vals.
  3. The third loop is only over TensorDomains and TensorViews. TensorDomains are registered for mutation if their IterDomains were registered for mutation, at which point a new TensorDomain is created. It is important that the Fusion has properly linked root to rfactor IDs at this point, which is done in the second loop. TensorViews then have their domain() (which is mutable) set to the new TensorDomain.

In order to perform the producer to consumer jump across TensorView expressions, we first extract mappings from consumer IDs to sets of producer IDs before the traversal begins. Those sets are then looked up in mutate(IterDomain*) in the first loop.

Note on IterVisitor's topological ordering

Concretization, like most other traversals in nvfuser, uses IterVisitor to obtain a topologically ordered set of statements in the Fusion. This class guarantees that the statements will be in proper topological order with respect to the Fusion graph. This graph has directed edges from input Vals to Exprs, from Exprs to output Vals, from attribute Statements to Exprs, and from member Statements to Vals (such as TensorDomain (domain) -> TensorView, Val (start, stop, extent) -> IterDomain, etc.).

At first glance it seems that this is sufficient, but it does not represent the relation that the rfactor IterDomains of producer TensorViews are dependencies of their corresponding consumer root IterDomains. These relations between "aligned" IterDomains are not present in the graph since they could easily become inconsistent when replacing TensorDomains or TensorViews. However, their absence means that there may be valid topological orderings that visit consumer domains before producer domains if neither has any other dependencies.

The current implementation of IterVisitor::traverseBetween does in fact maintain the ordering we'd like, since consumer->definition() is processed before consumer->domain(). A comment is added to the definition of IterVisitor to briefly explain this.


void mutate(TensorView* tv) final;

void mutate(TensorDomain* td) final;
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The override of mutate(TensorDomain*) merely updates contiguity to reflect any introduced Broadcast IDs. This is a general situation that could probably be added to OptOutMutator::mutate(TensorDomain*) instead to further simplify the concretization code.

Comment on lines 534 to 547
if (auto def = id->definition()) {
// Determine concrete IterType based on promotion of inputs to def
IterType iter_type = IterType::Symbolic;
for (auto inp_id : ir_utils::filterByType<IterDomain>(def->inputs())) {
auto updated_id = maybeMutated(inp_id)->as<IterDomain>();
iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
}
TORCH_INTERNAL_ASSERT(
iter_type != IterType::Symbolic,
"Failed to concretize an output IterType for expression: ",
def->toString());
auto concretized_id = IterDomainBuilder(id).iter_type(iter_type).build();
registerConcretization(id, concretized_id);
} else {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch replaces the root->rfactor propagation that was removed from mutate(TensorView*).

Comment on lines 548 to 550
// IterDomains without definitions might be root domains for the output of a
// TensorView expression. If so, we should propagate their concretization in
// the producer to consumer direction.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch replaces propagateFromProducerToConsumer.

Comment on lines +947 to +950
// This was previously "working" by concretizing the size-1 pad to
// Iteration, even though it should be Broadcast. When set properly to
// Broadcast, it fails with an error in ConcretizedBroadcastDomains.
//{{3, 5}, {0, -4}, true},
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case deserves its own issue, which I will add. When there is a broadcast domain introduced by concretizing a resize we hit an error since we can't concretize the broadcast. On main, this case actually concretizes the pad as Iteration even if extent is 1. Instead, we should probably translate a Resize that results in size 1 as a select+broadcast or a full(pad_value)+broadcast; however that is a complicated change since we need to operate on the TensorView containing the Resized ID, meaning we would change concretization info to track TV ops (cat, pad, slice) instead of ID op Resize. For now I have disabled this case. Once I file an issue I'll point to it here in the comment.

Comment on lines -382 to -383
if (stmt->isA<Val>()) {
mutate(stmt);
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of only mutating Vals, we now mutate Exprs as well, which replaces the Expr in place if any inputs or outputs have changed. Note that outputs of Exprs are mutated after their definition has been mutated, so we should be careful updating a Val that has a definition. But of course we should be careful in that case in the existing code too.

@jacobhinkle jacobhinkle requested a review from naoyam July 10, 2023 17:03
@jacobhinkle jacobhinkle marked this pull request as ready for review July 10, 2023 17:04
@jacobhinkle jacobhinkle marked this pull request as draft July 12, 2023 12:44
@jacobhinkle jacobhinkle marked this pull request as ready for review July 12, 2023 14:00
Comment on lines +389 to +401
std::vector<Val*> non_tds_tvs;
std::vector<Expr*> all_exprs;
std::vector<Val*> tvs_and_tds;
for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) {
if (stmt->isExpr()) {
all_exprs.push_back(stmt->asExpr());
} else {
auto val = stmt->asVal();
if (val->isA<TensorView>() || val->isA<TensorDomain>()) {
tvs_and_tds.push_back(val);
} else {
non_tds_tvs.push_back(val);
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was previously a single loop over all_stmts is now three separate loops over these subsets.

@jacobhinkle jacobhinkle changed the title Refactor concretization to traverse in strict topo order Refactor concretization traversal Jul 12, 2023
if (updated_id->isBroadcast()) {
contig.at(i) = std::nullopt;
}
}
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was previously done in dynamic_transform.cpp but I think it makes sense to always recompute contig if mutating Symbolic to Broadcast.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this only covers the case where the original ID was Symbolic and checks that it was marked contig then. If we mutated from an original ID with type Broadcast or Iteration we might want to fill in a different contiguity here instead.

jacobhinkle added a commit that referenced this pull request Jul 12, 2023
@jacobhinkle jacobhinkle marked this pull request as draft July 13, 2023 11:47
@jacobhinkle
Copy link
Collaborator Author

Closing in favor of #449, which was fixed without needing such an invasive refactor.

@jacobhinkle jacobhinkle deleted the concretization_topo_order branch July 25, 2023 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants