-
Notifications
You must be signed in to change notification settings - Fork 70
Refactor concretization traversal #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
csrc/dynamic_transform.cpp
Outdated
|
|
||
| void mutate(TensorView* tv) final; | ||
|
|
||
| void mutate(TensorDomain* td) final; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The override of mutate(TensorDomain*) merely updates contiguity to reflect any introduced Broadcast IDs. This is a general situation that could probably be added to OptOutMutator::mutate(TensorDomain*) instead to further simplify the concretization code.
csrc/dynamic_transform.cpp
Outdated
| if (auto def = id->definition()) { | ||
| // Determine concrete IterType based on promotion of inputs to def | ||
| IterType iter_type = IterType::Symbolic; | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(def->inputs())) { | ||
| auto updated_id = maybeMutated(inp_id)->as<IterDomain>(); | ||
| iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); | ||
| } | ||
| TORCH_INTERNAL_ASSERT( | ||
| iter_type != IterType::Symbolic, | ||
| "Failed to concretize an output IterType for expression: ", | ||
| def->toString()); | ||
| auto concretized_id = IterDomainBuilder(id).iter_type(iter_type).build(); | ||
| registerConcretization(id, concretized_id); | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch replaces the root->rfactor propagation that was removed from mutate(TensorView*).
csrc/dynamic_transform.cpp
Outdated
| // IterDomains without definitions might be root domains for the output of a | ||
| // TensorView expression. If so, we should propagate their concretization in | ||
| // the producer to consumer direction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This branch replaces propagateFromProducerToConsumer.
| // This was previously "working" by concretizing the size-1 pad to | ||
| // Iteration, even though it should be Broadcast. When set properly to | ||
| // Broadcast, it fails with an error in ConcretizedBroadcastDomains. | ||
| //{{3, 5}, {0, -4}, true}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This case deserves its own issue, which I will add. When there is a broadcast domain introduced by concretizing a resize we hit an error since we can't concretize the broadcast. On main, this case actually concretizes the pad as Iteration even if extent is 1. Instead, we should probably translate a Resize that results in size 1 as a select+broadcast or a full(pad_value)+broadcast; however that is a complicated change since we need to operate on the TensorView containing the Resized ID, meaning we would change concretization info to track TV ops (cat, pad, slice) instead of ID op Resize. For now I have disabled this case. Once I file an issue I'll point to it here in the comment.
| if (stmt->isA<Val>()) { | ||
| mutate(stmt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of only mutating Vals, we now mutate Exprs as well, which replaces the Expr in place if any inputs or outputs have changed. Note that outputs of Exprs are mutated after their definition has been mutated, so we should be careful updating a Val that has a definition. But of course we should be careful in that case in the existing code too.
Also register extents in concretizeReshape
| std::vector<Val*> non_tds_tvs; | ||
| std::vector<Expr*> all_exprs; | ||
| std::vector<Val*> tvs_and_tds; | ||
| for (auto stmt : StmtSort::getStmts(info_->fusion(), true)) { | ||
| if (stmt->isExpr()) { | ||
| all_exprs.push_back(stmt->asExpr()); | ||
| } else { | ||
| auto val = stmt->asVal(); | ||
| if (val->isA<TensorView>() || val->isA<TensorDomain>()) { | ||
| tvs_and_tds.push_back(val); | ||
| } else { | ||
| non_tds_tvs.push_back(val); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was previously a single loop over all_stmts is now three separate loops over these subsets.
| if (updated_id->isBroadcast()) { | ||
| contig.at(i) = std::nullopt; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was previously done in dynamic_transform.cpp but I think it makes sense to always recompute contig if mutating Symbolic to Broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this only covers the case where the original ID was Symbolic and checks that it was marked contig then. If we mutated from an original ID with type Broadcast or Iteration we might want to fill in a different contiguity here instead.
|
Closing in favor of #449, which was fixed without needing such an invasive refactor. |
Concretization requires two fundamental operations:
IterTypes and extent expressions forIterDomains might change during step 1 but downstream expressions have already been defined using the original symbolicIterDomains. This propagation step ensures that we have no dangling symbolicIterDomains and that extent expressions are properly replaced so that everything is consistent and appears as it would have if theFusionhad been static at definition.The downstream propagation in the second operation needs to pass through
IterDomainexpressions like those found inpadorreshapeops. But it also needs to cross from producer TVs to consumers. Those jumps are more complicated to handle inmutate(IterDomain*), since we cannot use the standarddefinition/usesmachinery to find tensor expressions linkingIterDomains in this way. Instead, currently we usemutate(TensorView* tv)and callpropagateProducerToConsumer()at the beginning of that method. That method finds exact mapped producer IDs usingtv->definition()andPairwiseRootDomainMapand propagates their concretization information to the corresponding root IDs intv. The root IDs are then updated and propagation alongIterDomainexpressions is done between root and rfactor oftv.The problem comes when we'd like to propagate information in the root->rfactor expressions. Since we do not modify the Fusion during traversal, these expressions are fixed. However, when we call
propagateProducerToConsumer()we might register a concretization of some rootIterDomains. These concretizations should ideally propagate in all of their uses. However,propagateProducerToConsumer()is called inmutate(TensorView* tv)which is called after all oftv's dependencies have been mutated. In particular,tv->domain()and allIterDomains in it must be mutated before this point. SubsequentlyStmtSort::getExprsis called on the unmutated root and rfactor, requiring manual mutation of intermediate exprs. That is why in the current code we need to perform another manual traversal from root to rfactor insidemutate(tv), which is the cause of some unneeded complexity since thoseIterDomains might be mutated multiple times.This PR addresses this by splitting the traversal into three loops, each of which is done in topological order:
Vals which are neitherTensorDomains orTensorViews and callmutate(val)on each. We exclude TDs and TVs here sinceOptOutMutator::mutateactually modifies the Fusion when these are called. This loop does not modify theFusionat all, but registers mutations for mostVals.Exprs. This actually will remove any expression with registered mutations of its inputs, outputs, or attributes, and replaces it with a new one linking those new vals.TensorDomains andTensorViews.TensorDomains are registered for mutation if theirIterDomains were registered for mutation, at which point a newTensorDomainis created. It is important that theFusionhas properly linked root to rfactor IDs at this point, which is done in the second loop.TensorViews then have theirdomain()(which is mutable) set to the newTensorDomain.In order to perform the producer to consumer jump across
TensorViewexpressions, we first extract mappings from consumer IDs to sets of producer IDs before the traversal begins. Those sets are then looked up inmutate(IterDomain*)in the first loop.Note on
IterVisitor's topological orderingConcretization, like most other traversals in nvfuser, uses
IterVisitorto obtain a topologically ordered set of statements in theFusion. This class guarantees that the statements will be in proper topological order with respect to theFusiongraph. This graph has directed edges from inputVals toExprs, fromExprs to outputVals, from attributeStatements toExprs, and from memberStatements toVals (such asTensorDomain(domain) ->TensorView,Val(start, stop, extent) ->IterDomain, etc.).At first glance it seems that this is sufficient, but it does not represent the relation that the rfactor
IterDomains of producerTensorViews are dependencies of their corresponding consumer rootIterDomains. These relations between "aligned"IterDomains are not present in the graph since they could easily become inconsistent when replacingTensorDomains orTensorViews. However, their absence means that there may be valid topological orderings that visit consumer domains before producer domains if neither has any other dependencies.The current implementation of
IterVisitor::traverseBetweendoes in fact maintain the ordering we'd like, sinceconsumer->definition()is processed beforeconsumer->domain(). A comment is added to the definition ofIterVisitorto briefly explain this.