-
Notifications
You must be signed in to change notification settings - Fork 69
Set extents of empty tensors to zeroVal() during concretization #449
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
55b0dfd
759d803
e0f4eb1
5b56be7
7d01339
c16c362
137ad20
b15929f
2a2eef7
3270f15
aee5a2f
d41decc
96e105a
3892578
0e75ba1
e38d027
6099940
478ed4a
5ce87fe
13a5b57
c56bd86
145f4de
91101b9
9a4f857
c56f163
739ae06
fc4a484
9f28007
c80a93b
08c4edc
7e770ac
5c0a9e0
b864aed
b52df95
a664464
c10f56b
3d4f6ae
97cf441
9bdb140
b45638d
3026cc6
8b8524e
0faf0ef
c5dfe02
9811fa8
c09691c
9a0ba86
c34caa7
f42a48b
57dfd15
ef53abe
ba7e9dd
8483d14
1d7165e
7fb9428
cd209c2
2175ed8
bdbb31c
42392d3
1e31b10
7f9b588
7593ea4
aa5d75c
ae8b5b5
da70cc1
899ee21
69356eb
4dcff1e
7817163
aac04e8
a1c9238
16a9418
91e3c04
756e533
14468a6
6f2db3b
8eda179
a55586a
3c10b7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,18 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( | |
| cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op)); | ||
| } | ||
| } | ||
| cloned_info.maybe_zero_extents_set_.reserve(maybe_zero_extents_set_.size()); | ||
| for (const auto v : maybe_zero_extents_set_) { | ||
| if (v) { | ||
| cloned_info.maybe_zero_extents_set_.insert(ir_cloner.clone(v)); | ||
| } | ||
| } | ||
| cloned_info.maybe_zero_extents_.reserve(maybe_zero_extents_.size()); | ||
| for (const auto v : maybe_zero_extents_) { | ||
| if (v) { | ||
| cloned_info.maybe_zero_extents_.push_back(ir_cloner.clone(v)); | ||
| } | ||
| } | ||
| cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); | ||
| for (const auto v : root_dynamic_vals_) { | ||
| if (v) { | ||
|
|
@@ -58,6 +70,10 @@ std::string DynamicTransformInitialInfo::toString() const { | |
| for (const auto& op : dynamic_resized_ids_) { | ||
| ss << indent << indent << op->toString() << "\n"; | ||
| } | ||
| ss << indent << "Dynamic extent Vals:\n"; | ||
| for (const auto& v : maybe_zero_extents_) { | ||
| ss << indent << indent << v->toString() << "\n"; | ||
| } | ||
| ss << indent << "Root dynamic Vals:\n"; | ||
| for (const auto& v : root_dynamic_vals_) { | ||
| ss << indent << indent << v->toString() << "\n"; | ||
|
|
@@ -77,6 +93,8 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { | |
| traverseTo(fusion, fusion->getTerminatingOutputs(), false, false); | ||
|
|
||
| finalizeDynamicVals(); | ||
|
|
||
| finalizeMaybeEmptyExtents(); | ||
| } | ||
|
|
||
| const auto& getInfo() const { | ||
|
|
@@ -95,22 +113,24 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { | |
| info_.dynamic_reshaped_tvs_.push_back(out_tv); | ||
|
|
||
| // Input and output extent expressions both affect concretization | ||
| const auto& inp_dom = | ||
| TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); | ||
| for (const auto id : inp_dom) { | ||
| for (const auto& id : | ||
| TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain())) { | ||
| leaf_dynamic_vals_.push_back(id->extent()); | ||
| } | ||
| const auto& out_dom = out_tv->getMaybeRFactorDomain(); | ||
| for (const auto id : out_dom) { | ||
| for (const auto& id : out_tv->getMaybeRFactorDomain()) { | ||
| leaf_dynamic_vals_.push_back(id->extent()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| //! Detect dynamic IterDomain transforms when handling TensorViews | ||
| //! Detect possibly empty TensorViews and dynamic IterDomain transforms | ||
| void handle(TensorView* tv) override { | ||
| const auto& rfd = tv->getMaybeRFactorDomain(); | ||
| for (auto id : rfd) { | ||
| if (!id->extent()->isConstScalar() || id->extent()->evaluateInt() == 0) { | ||
| info_.maybe_zero_extents_set_.insert(id->extent()); | ||
| leaf_dynamic_vals_.push_back(id->extent()); | ||
| } | ||
| if (!id->definition() || id->getIterType() != IterType::Symbolic) { | ||
| continue; | ||
| } | ||
|
|
@@ -141,6 +161,15 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { | |
| } | ||
| } | ||
|
|
||
| //! Convert maybe_zero_extents_set_ to a vector so we can index it reliably | ||
| void finalizeMaybeEmptyExtents() { | ||
| info_.maybe_zero_extents_ = std::vector<Val*>( | ||
| info_.maybe_zero_extents_set_.begin(), | ||
| info_.maybe_zero_extents_set_.end()); | ||
| // Clear the corresponding set to free memory and speed up cloning | ||
| info_.maybe_zero_extents_set_.clear(); | ||
| } | ||
|
|
||
| private: | ||
| DynamicTransformInitialInfo info_; | ||
|
|
||
|
|
@@ -154,6 +183,36 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { | |
| std::vector<Val*> leaf_dynamic_vals_; | ||
| }; | ||
|
|
||
| DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( | ||
| const DynamicTransformInitialInfo* initial_info, | ||
| ExpressionEvaluator* expr_eval) | ||
| : initial_info_(initial_info) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| !fusion()->isA<kir::Kernel>(), | ||
| "Invalid container. Kernel container not allowed.\n"); | ||
|
|
||
| // Make sure all exactly mapped IDs have the same value in the | ||
| // evaluator when any one of the IDs has a known value | ||
| expr_eval->propagateBoundValuesThroughExactMaps(initial_info_->fusion()); | ||
|
|
||
| analyzeReshapes(expr_eval); | ||
|
|
||
| analyzeResizes(expr_eval); | ||
|
|
||
| auto maybe_zero_extents = initial_info_->getMaybeZeroExtents(); | ||
| for (auto i : c10::irange(maybe_zero_extents.size())) { | ||
| auto ext = maybe_zero_extents.at(i); | ||
| auto ext_opt = expr_eval->evaluate(ext); | ||
| TORCH_INTERNAL_ASSERT( | ||
| ext_opt.hasValue(), | ||
| "Could not evaluate dynamic extent: ", | ||
| ext->toString()); | ||
| if (ext_opt == 0) { | ||
| empty_extents_.push_back(i); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void DynamicTransformConcretizationInfo::analyzeReshapes( | ||
| ExpressionEvaluator* expr_eval) { | ||
| const auto& reshape_tvs = initial_info_->getDynamicReshapedTensorViews(); | ||
|
|
@@ -281,7 +340,8 @@ bool DynamicTransformConcretizationInfo::operator==( | |
| } | ||
|
|
||
| if (reshape_transforms_.size() != other.reshape_transforms_.size() || | ||
| resize_itertypes_.size() != other.resize_itertypes_.size()) { | ||
| resize_itertypes_.size() != other.resize_itertypes_.size() || | ||
| empty_extents_.size() != other.empty_extents_.size()) { | ||
| return false; | ||
| } | ||
|
|
||
|
|
@@ -301,13 +361,26 @@ bool DynamicTransformConcretizationInfo::operator==( | |
| } | ||
| } | ||
|
|
||
| for (const auto i : c10::irange(empty_extents_.size())) { | ||
| const auto& ee = empty_extents_.at(i); | ||
| const auto& other_ee = other.empty_extents_.at(i); | ||
| if (ee != other_ee) { | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| std::string DynamicTransformConcretizationInfo::toString() const { | ||
| std::stringstream ss; | ||
| ss << "DynamicTransformConcretizationInfo\n"; | ||
| std::string indent = " "; | ||
| ss << indent << "Empty tensor extents:\n"; | ||
| for (const auto& i : empty_extents_) { | ||
| auto ext = initial_info_->getMaybeZeroExtents().at(i); | ||
| ss << indent << indent << ext->toString() << " is zero\n"; | ||
| } | ||
| ss << indent << "Reshape:\n"; | ||
| for (const auto& [tv_index, analyze_result] : reshape_transforms_) { | ||
| auto tv = initial_info_->getDynamicReshapedTensorViews().at(tv_index); | ||
|
|
@@ -333,6 +406,7 @@ class DynamicTransformConcretizer : public OptOutMutator { | |
| TORCH_INTERNAL_ASSERT( | ||
| fusion == info->fusion(), | ||
| "Invalid DynamicTransformInitialInfo. The associated Fusion is different from the given Fusion"); | ||
| FusionGuard fg(fusion); | ||
| concretize(); | ||
| } | ||
|
|
||
|
|
@@ -343,6 +417,8 @@ class DynamicTransformConcretizer : public OptOutMutator { | |
|
|
||
| void concretizeResize(); | ||
|
|
||
| void concretizeEmptyExtents(); | ||
|
|
||
| //! Use this instead of calling registerMutation directly, since it will also | ||
| //! check that the concretized value is a valid input to all of its uses. | ||
| void registerConcretization(Val* old_val, Val* new_val) { | ||
|
|
@@ -370,18 +446,40 @@ class DynamicTransformConcretizer : public OptOutMutator { | |
| }; | ||
|
|
||
| void DynamicTransformConcretizer::concretize() { | ||
| // First, concretize all dynamic reshape ops | ||
| // Concretize all dynamic reshape ops | ||
| concretizeReshape(); | ||
|
|
||
| // Set output IterTypes for dynamic resize ops | ||
| concretizeResize(); | ||
|
|
||
| // Registers replacement of all empty extents with zeroVal() | ||
| concretizeEmptyExtents(); | ||
|
|
||
| // Finally, propagate concretized domains | ||
| auto all_stmts = StmtSort::getStmts(info_->fusion(), true); | ||
| for (auto stmt : all_stmts) { | ||
| if (stmt->isA<Val>()) { | ||
| mutate(stmt); | ||
| auto all_stmts = StmtSort::getStmts(info_->fusion()); | ||
| for (auto tv : ir_utils::filterByType<TensorView>(all_stmts)) { | ||
| mutate(tv); | ||
| } | ||
| } | ||
|
|
||
| void DynamicTransformConcretizer::concretizeEmptyExtents() { | ||
| auto fusion = FusionGuard::getCurFusion(); | ||
| for (const auto& ext_index : info_->getEmptyExtents()) { | ||
| auto ext = info_->initialInfo()->getMaybeZeroExtents().at(ext_index); | ||
| auto zero = fusion->zeroVal(ext->getDataType().value()); | ||
| auto uses = ext->uses(); | ||
| for (auto use : uses) { | ||
| ir_utils::replaceValInExpr(use, ext, zero); | ||
| } | ||
| // Register the concretization of this scalar, which allows us to replace it | ||
| // whenever it is used as an extent member of an IterDomain. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly. When we replace in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add this to the code comment as well?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| // | ||
| // When we ext in all uses above, it affects downstream expressions. For | ||
| // example we might replace i0 with 0 in (i0 + i1) + i2 to form (0 + i1) + | ||
| // i2. However, i0 itself might be used as the extent, start, or stop values | ||
| // in an IterDomain, so we register the concretization here so that we can | ||
| // replace these values whenever we encounter them. | ||
| registerConcretization(ext, zero); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -400,7 +498,8 @@ void DynamicTransformConcretizer::concretizeReshape() { | |
| checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); | ||
|
|
||
| // Replace the old tensor with the new concretized tensor | ||
| for (auto use_of_old_tv : incomplete_out_tv->uses()) { | ||
| auto uses = incomplete_out_tv->uses(); | ||
| for (auto use_of_old_tv : uses) { | ||
| ir_utils::replaceValInExpr( | ||
| use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); | ||
| } | ||
|
|
@@ -446,8 +545,10 @@ void DynamicTransformConcretizer::checkConcretizedUses( | |
| // concretized. Since symbolic IDs may be propagated down to | ||
| // consumers, those domains need to be concretized accordingly. | ||
| void DynamicTransformConcretizer::mutate(TensorView* tv) { | ||
| if (!tv->domain()->hasSymbolicAxis()) { | ||
| return; | ||
| for (auto root_id : tv->getRootDomain()) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks a bit confusing. A root ID may be mutated here, but there's also
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At line 659 in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, what happens if the same root ID is mutated both at line 546 and within line 551? Doesn't the latter just overwrite the mutation at line 546?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, are you referring to line 659 before this PR?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will overwrite the mutation, but since we using
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see it now. Maybe it'd be helpful to add that to the code comment too. |
||
| // This will register root_id for mutation if its extent, start, or | ||
| // stop_offset is registered for mutation | ||
| OptOutMutator::mutate(root_id); | ||
| } | ||
|
|
||
| // First, try to concretize the root domain as there may be symbolic | ||
|
|
@@ -522,12 +623,14 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { | |
| continue; | ||
| } | ||
| auto concretized_out_id = | ||
| IterDomainBuilder(out_id).iter_type(iter_type).build(); | ||
| IterDomainBuilder(maybeMutated(out_id)->as<IterDomain>()) | ||
| .iter_type(iter_type) | ||
| .build(); | ||
| registerConcretization(out_id, concretized_out_id); | ||
| } | ||
|
|
||
| // The expr itself needs to be mutated as well in case the outputs are | ||
| // mutated, which can be done by the mutate method | ||
| // expr must be mutated in order to set it as the definition for the | ||
| // concretized outputs. | ||
| OptOutMutator::mutate(expr); | ||
| } | ||
| } | ||
|
|
@@ -656,7 +759,9 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( | |
| consumer->toString()); | ||
|
|
||
| auto concretized_id = | ||
| IterDomainBuilder(root_id).iter_type(*id_type).build(); | ||
| IterDomainBuilder(maybeMutated(root_id)->as<IterDomain>()) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change because the extent of
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes exactly, when we call |
||
| .iter_type(*id_type) | ||
| .build(); | ||
|
|
||
| registerConcretization(root_id, concretized_id); | ||
| is_concretized = true; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note we no longer traverse into members.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't remember why it did
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it to
truein #258 so that the traversal would handleIterDomains, but I didn't have a clear enough picture of how that should work at that time.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we explicitly are only traversing the
TensorViewgraph, and we only handleIterDomainmembers manually.