-
Notifications
You must be signed in to change notification settings - Fork 70
Refactor concretization traversal #576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
a58ce3e
88f5d5e
bdec706
e3f5a35
4de65f9
8dab796
d278b02
9b78b17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -357,16 +357,20 @@ class DynamicTransformConcretizer : public OptOutMutator { | |
|
|
||
| using OptOutMutator::mutate; | ||
|
|
||
| void mutate(TensorView* tv) final; | ||
|
|
||
| void mutate(TensorDomain* td) final; | ||
|
|
||
| //! Concretizes the root domain of a symbolic consumer tensor from | ||
| //! its producer domains. Returns true if any root ID is concretized. | ||
| bool propagateFromProducerToConsumer(TensorView* consumer); | ||
| void mutate(IterDomain* id) final; | ||
|
|
||
| private: | ||
| const DynamicTransformConcretizationInfo* info_; | ||
|
|
||
| //! This map is used during concretization to identify, for a given IterDomain | ||
| //! the set of all IterDomains which are "aligned" with it in some TensorView | ||
| //! expression. This enables us to write mutate(IterDomain*) and propagate | ||
| //! information from producer IterDomains to consumers, which is otherwise not | ||
| //! represented in the graph since we do not connect IterDomains between | ||
| //! TensorViews with expressions. | ||
| std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> | ||
| id_producers_; | ||
| }; | ||
|
|
||
| void DynamicTransformConcretizer::concretize() { | ||
|
|
@@ -376,13 +380,37 @@ void DynamicTransformConcretizer::concretize() { | |
| // Set output IterTypes for dynamic resize ops | ||
| concretizeResize(); | ||
|
|
||
| // Finally, propagate concretized domains | ||
| // The methods above do not traverse the graph. Instead they fill in | ||
| // root->rfactor expressions by replacing the dynamic reshaped TV with a | ||
| // static reshaped one, and by registering concretization of dynamic Resized | ||
| // IterDomains. From this point forward, we will not modify any TensorView | ||
| // expressions. This restriction makes it safe for us to to traverse forward | ||
| // through the graph and mutate IterDomains and TensorDomains in order to | ||
| // properly propagate IterTypes and concretized extent expressions, without | ||
| // breaking the topological ordering of these expressions. | ||
| // | ||
| // When propagating IterTypes across expressions, we need to know the producer | ||
| // IterDomains corresponding to a consumer ID. This mapping helps facilitate | ||
| // this and is used later in mutate(IterDomain*). | ||
| auto all_stmts = StmtSort::getStmts(info_->fusion(), true); | ||
| for (auto stmt : all_stmts) { | ||
| if (stmt->isA<Val>()) { | ||
| mutate(stmt); | ||
|
Comment on lines
-382
to
-383
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of only mutating |
||
| for (auto expr : ir_utils::filterByType<Expr>(all_stmts)) { | ||
| for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { | ||
| for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { | ||
| PairwiseRootDomainMap root_map(producer, consumer); | ||
| for (auto [cid, pid] : root_map.mapConsumerToProducer( | ||
| consumer->domain(), producer->domain())) { | ||
| // Initialize set of producer IDs, if we haven't already | ||
| auto& producers = id_producers_.emplace(cid, 0).first->second; | ||
| producers.insert(pid); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Finally, propagate concretized domains with forward traversal | ||
| for (auto stmt : all_stmts) { | ||
| mutate(stmt); | ||
| } | ||
| } | ||
|
|
||
| void DynamicTransformConcretizer::concretizeReshape() { | ||
|
|
@@ -441,103 +469,6 @@ void DynamicTransformConcretizer::checkConcretizedUses( | |
| } | ||
| } | ||
|
|
||
| // Concretizes inherited symbolic domains. Note that when this is | ||
| // called, it is assumed that all dynamic ops themselves are | ||
| // concretized. Since symbolic IDs may be propagated down to | ||
| // consumers, those domains need to be concretized accordingly. | ||
| void DynamicTransformConcretizer::mutate(TensorView* tv) { | ||
| if (!tv->domain()->hasSymbolicAxis()) { | ||
| return; | ||
| } | ||
|
|
||
| // First, try to concretize the root domain as there may be symbolic | ||
| // axes inherited from the producers | ||
| propagateFromProducerToConsumer(tv); | ||
|
|
||
| // If no root domain is altered by producer, we don't need to propagate back | ||
| // up to rfactor. We could return early, but instead we go ahead and check the | ||
| // root to rfactor transforms to be sure we have concretized any intermediate | ||
| // IterDomains. | ||
|
|
||
| // At this point, there should be no expr beyond rfactor root | ||
| TORCH_INTERNAL_ASSERT( | ||
| tv->getLeafDomain() == tv->getMaybeRFactorDomain(), | ||
| "Invalid tensor: ", | ||
| tv->toString()); | ||
|
|
||
| // If it has an rfactor root domain, the IterTypes of the rfactor | ||
| // IDs may need to be updated as well. Traverse the rfactor exprs | ||
| // and mutate the IterTypes of output IDs if symbolic. | ||
| if (tv->hasRFactor()) { | ||
| // Note that it is assumed that theres's no further expression | ||
| // beyond the rfactor domain as asserted above | ||
| auto all_id_exprs = StmtSort::getExprsBetween( | ||
| tv->fusion(), | ||
| {tv->getRootDomain().begin(), tv->getRootDomain().end()}, | ||
| {tv->getMaybeRFactorDomain().begin(), | ||
| tv->getMaybeRFactorDomain().end()}); | ||
| for (auto expr : all_id_exprs) { | ||
| // Assume outputs of IterDomain exprs are always IterDomains. If | ||
| // the assumption is invalidated, the logic here would need to | ||
| // be updated. Assert the assumption to immediately detect such | ||
| // a case if happened. | ||
| for (auto out_val : expr->outputs()) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| out_val->isA<IterDomain>(), | ||
| "Unexpected output: ", | ||
| out_val->toString(), | ||
| ". IterDomain was expected."); | ||
| } | ||
|
|
||
| // NOTE: We do not return early if all outputs are concrete as there may | ||
| // still be concrete inputs. For example, a Symbolic IterDomain might be | ||
| // padded with constant pad widths (1, 1), in which case although we do | ||
| // not know the exact extent of the output, we know it is at least as | ||
| // large as the sum of the pad widths, 2. In such cases, the output | ||
| // IterDomain is concrete at definition, since if the extent is >1 we know | ||
| // the IterType is Iteration. In these cases, we must continue to | ||
| // concretize intermediate expressions between the root and R-factor | ||
| // domain. See test DynamicTransform5_CUDA which demonstrates this | ||
| // behavior. | ||
| // NOTE: We also do not assume that if one output ID is symbolic, that | ||
| // they all must be. See test FusionSliceForNanoGPT3_CUDA for an example | ||
| // that does a static split by a factor of 16 of a symbolic input domain. | ||
| // The static split in that case results in a concrete IterDomain with | ||
| // extent 16 along with a symbolic one (extent ceilDiv(n / 16)). | ||
|
|
||
| // Determine the output IterType | ||
| IterType iter_type = IterType::Symbolic; | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(expr->inputs())) { | ||
| auto updated_id = maybeMutated(inp_id)->as<IterDomain>(); | ||
| iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); | ||
| } | ||
| TORCH_INTERNAL_ASSERT( | ||
| iter_type != IterType::Symbolic, | ||
| "Failed to concretize an output IterType for expression: ", | ||
| expr->toString()); | ||
|
|
||
| // Update the IterType of each output | ||
| for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) { | ||
| if (!out_id->isSymbolic()) { | ||
| continue; | ||
| } | ||
| auto concretized_out_id = | ||
| IterDomainBuilder(out_id).iter_type(iter_type).build(); | ||
| registerConcretization(out_id, concretized_out_id); | ||
| } | ||
|
|
||
| // The expr itself needs to be mutated as well in case the outputs are | ||
| // mutated, which can be done by the mutate method | ||
| OptOutMutator::mutate(expr); | ||
| } | ||
| } | ||
|
|
||
| // Root and rfactor domains are updated. First mutate the | ||
| // TensorDomain and then TensorView | ||
| mutate(tv->domain()); | ||
| OptOutMutator::mutate(tv); | ||
| } | ||
|
|
||
| // Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but | ||
| // the contiguity vector may need to be updated as well as symbolic | ||
| // domains may be mutated to broadcast domains, which means contiguity | ||
|
|
@@ -594,75 +525,59 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) { | |
| registerConcretization(td, mutated_val); | ||
| } | ||
|
|
||
| bool DynamicTransformConcretizer::propagateFromProducerToConsumer( | ||
| TensorView* consumer) { | ||
| if (consumer->definition() == nullptr || | ||
| !consumer->domain()->hasSymbolicAxis()) { | ||
| return false; | ||
| void DynamicTransformConcretizer::mutate(IterDomain* id) { | ||
| // id might have already been mutated if its definition was updated | ||
| id = maybeMutated(id)->as<IterDomain>(); | ||
| if (!id->isSymbolic()) { | ||
| return; | ||
| } | ||
|
|
||
| const auto& root_domain = consumer->getRootDomain(); | ||
|
|
||
| auto def = consumer->definition(); | ||
|
|
||
| bool is_concretized = false; | ||
|
|
||
| for (const auto i : c10::irange(root_domain.size())) { | ||
| auto root_id = root_domain.at(i); | ||
| if (root_id->getIterType() != IterType::Symbolic) { | ||
| continue; | ||
| if (auto def = id->definition()) { | ||
| // Determine concrete IterType based on promotion of inputs to def | ||
| IterType iter_type = IterType::Symbolic; | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(def->inputs())) { | ||
| auto updated_id = maybeMutated(inp_id)->as<IterDomain>(); | ||
| iter_type = ops::promoteIterType(iter_type, updated_id->getIterType()); | ||
| } | ||
| TORCH_INTERNAL_ASSERT( | ||
| iter_type != IterType::Symbolic, | ||
| "Failed to concretize an output IterType for expression: ", | ||
| def->toString()); | ||
| auto concretized_id = IterDomainBuilder(id).iter_type(iter_type).build(); | ||
| registerConcretization(id, concretized_id); | ||
| } else { | ||
|
||
| // IterDomains without definitions might be root domains for the output of a | ||
| // TensorView expression. If so, we should propagate their concretization in | ||
| // the producer to consumer direction. | ||
|
||
|
|
||
| auto producers_it = id_producers_.find(id); | ||
| if (producers_it == id_producers_.end()) { | ||
| // id was not a consumer root ID in any TV expression | ||
| return; | ||
| } | ||
|
|
||
| // Figure out the right IterType of this consumer root ID from its | ||
| // corresponding producer IDs | ||
|
|
||
| std::optional<IterType> id_type; | ||
|
|
||
| for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) { | ||
| PairwiseRootDomainMap root_map(producer, consumer); | ||
| auto c2p = root_map.mapConsumerToProducer( | ||
| consumer->domain(), producer->domain()); | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| c2p.find(root_id) != c2p.end(), | ||
| "No input ID found to map with output ID: ", | ||
| root_id->toString()); | ||
|
|
||
| auto input_id = c2p.at(root_id); | ||
| TORCH_INTERNAL_ASSERT( | ||
| input_id->getIterType() != IterType::Symbolic, | ||
| "Producer ID not concretized: ", | ||
| input_id->toString()); | ||
|
|
||
| for (auto producer_id : producers_it->second) { | ||
| producer_id = maybeMutated(producer_id)->as<IterDomain>(); | ||
| if (id_type.has_value()) { | ||
| id_type = ops::promoteIterType(*id_type, input_id->getIterType()); | ||
| id_type = ops::promoteIterType(*id_type, producer_id->getIterType()); | ||
| } else { | ||
| id_type = input_id->getIterType(); | ||
| id_type = producer_id->getIterType(); | ||
| } | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| id_type.has_value(), | ||
| "Did not find id_type for consumer root domain ", | ||
| root_id->toString(), | ||
| ". Perhaps consumer def has no inputs. Consumer definition = ", | ||
| def->toString()); | ||
| id->toString(), | ||
| ". Perhaps consumer def has no inputs."); | ||
|
|
||
| TORCH_INTERNAL_ASSERT( | ||
| id_type != IterType::Symbolic, | ||
| "Failed to concretize ", | ||
| root_id->toString(), | ||
| " of ", | ||
| consumer->toString()); | ||
| id_type != IterType::Symbolic, "Failed to concretize ", id->toString()); | ||
|
|
||
| auto concretized_id = | ||
| IterDomainBuilder(root_id).iter_type(*id_type).build(); | ||
| auto concretized_id = IterDomainBuilder(id).iter_type(*id_type).build(); | ||
|
|
||
| registerConcretization(root_id, concretized_id); | ||
| is_concretized = true; | ||
| registerConcretization(id, concretized_id); | ||
| } | ||
|
|
||
| return is_concretized; | ||
| } | ||
|
|
||
| DynamicTransformInitialInfo DynamicTransform::getInitialInfo(Fusion* fusion) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -944,7 +944,10 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { | |
| //{{3, 5}, {-3, -2}, false}, // output is zero-dimensional | ||
|
|
||
| // Output has size 1 so is set to broadcast. | ||
| {{3, 5}, {0, -4}, true}, | ||
| // This was previously "working" by concretizing the size-1 pad to | ||
| // Iteration, even though it should be Broadcast. When set properly to | ||
| // Broadcast, it fails with an error in ConcretizedBroadcastDomains. | ||
| //{{3, 5}, {0, -4}, true}, | ||
|
Comment on lines
+947
to
+950
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This case deserves its own issue, which I will add. When there is a broadcast domain introduced by concretizing a resize we hit an error since we can't concretize the broadcast. On |
||
|
|
||
| // Test full negative shifts, so output doesn't overlap input | ||
| {{3, 5}, {-5, 2}, false}, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The override of
mutate(TensorDomain*)merely updates contiguity to reflect any introducedBroadcastIDs. This is a general situation that could probably be added toOptOutMutator::mutate(TensorDomain*)instead to further simplify the concretization code.