Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
55b0dfd
Introduce info_.has_possible_empty_tensor_
jacobhinkle Jun 4, 2023
759d803
Remove fusion arg from concretizeFusion
jacobhinkle Jun 4, 2023
e0f4eb1
Sketch of empty branch finding in conc info
jacobhinkle Jun 4, 2023
5b56be7
Cleanup and fix assumptions in a couple tests
jacobhinkle Jun 5, 2023
7d01339
noReductions, replaceOutput, add tests
jacobhinkle Jun 5, 2023
c16c362
Clean up clang-tidy
jacobhinkle Jun 5, 2023
137ad20
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 6, 2023
b15929f
Fix FusionMagicSchedulerInstanceNormalizationBackward_CUDA
jacobhinkle Jun 6, 2023
2a2eef7
Bind to expanded extent if needed, and only if non-const
jacobhinkle Jun 6, 2023
3270f15
Add all symbolic tensor extents to leaf_dynamic_vals_
jacobhinkle Jun 6, 2023
aee5a2f
Print pre-concretization fusion for fusion_ir_concretized
jacobhinkle Jun 6, 2023
d41decc
Remove empty reductions
jacobhinkle Jun 8, 2023
96e105a
Clean up EmptyBranchFinder
jacobhinkle Jun 8, 2023
3892578
Simplify removeEmptyBranches
jacobhinkle Jun 8, 2023
0e75ba1
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 8, 2023
e38d027
Evaluate extents instead of shallow getInt
jacobhinkle Jun 8, 2023
6099940
Add FusionResizeMultiSliceEmpty_CUDA test
jacobhinkle Jun 8, 2023
478ed4a
Add FusionReduceZeroElementTensor_CUDA
jacobhinkle Jun 8, 2023
5ce87fe
Sweep reduction dims in reduce zero elt test
jacobhinkle Jun 8, 2023
13a5b57
Remove debug print
jacobhinkle Jun 8, 2023
c56bd86
Add failing tests
jacobhinkle Jun 8, 2023
145f4de
Add Fusion::replaceInput
jacobhinkle Jun 8, 2023
91101b9
Silence clang-tidy
jacobhinkle Jun 8, 2023
9a4f857
Fix test by switching from BackwardVisitor to standalone function
jacobhinkle Jun 8, 2023
c56f163
Fix length check in conc info operator==
jacobhinkle Jun 8, 2023
739ae06
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 9, 2023
fc4a484
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 13, 2023
9f28007
Look up TVs by name() instead of holding ptrs in conc_info
jacobhinkle Jun 14, 2023
c80a93b
Change assumption in PadShmoo due to empty concretization change
jacobhinkle Jun 14, 2023
08c4edc
Print name to TV mapping and dyn extent vals
jacobhinkle Jun 14, 2023
7e770ac
Change placement of FusionGuard in getKernelRuntimeFor
jacobhinkle Jun 14, 2023
5c0a9e0
Handle PadOp
jacobhinkle Jun 14, 2023
b864aed
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 14, 2023
b52df95
Remove stray debugging printMath
jacobhinkle Jun 14, 2023
a664464
Replace cats with empty inputs.
jacobhinkle Jun 21, 2023
c10f56b
Add test with three catted tensors, only one empty
jacobhinkle Jun 21, 2023
3d4f6ae
Merge branch 'main' into remove_empty_branches
jacobhinkle Jun 21, 2023
97cf441
Bind tv1 extents in DynamicTransform1_CUDA
jacobhinkle Jun 21, 2023
9bdb140
Only bind tv1 extents in the one case. See comment
jacobhinkle Jun 21, 2023
b45638d
Minor cleanup
jacobhinkle Jun 21, 2023
3026cc6
Update comment on initial info handling of TVs
jacobhinkle Jun 23, 2023
8b8524e
Rename dynamic_extent_vals to maybe_zero_extents_
jacobhinkle Jun 23, 2023
0faf0ef
Place findEmptyTensors in anonymous namespace
jacobhinkle Jun 23, 2023
c5dfe02
Improve comments and recurse in maybeReplaced
jacobhinkle Jun 23, 2023
9811fa8
Fix typo and add example to comment
jacobhinkle Jun 23, 2023
c09691c
Fix stuff I broke when trying to write a recursive lambda.
jacobhinkle Jun 23, 2023
9a0ba86
Update comment in removeEmptyBranches()
jacobhinkle Jun 23, 2023
c34caa7
Update comments, don't replace inputs
jacobhinkle Jun 23, 2023
f42a48b
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 23, 2023
57dfd15
Merge branch 'main' into remove_empty_branches
jacobhinkle Jun 23, 2023
ef53abe
Use quiet_NaN instead of 0.0 / 0.0
jacobhinkle Jun 26, 2023
ba7e9dd
Rename getDynamicExtentVals -> getMaybeZeroExtents
jacobhinkle Jun 26, 2023
8483d14
Skip replacing reductions with empty outputs
jacobhinkle Jun 26, 2023
1d7165e
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jun 26, 2023
7fb9428
Refactor to more clearly show dispatch of replaceEmptyUse
jacobhinkle Jun 26, 2023
cd209c2
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jul 7, 2023
2175ed8
Remove TV replacement code. Only mutate extents.
jacobhinkle Jul 7, 2023
bdbb31c
Skip check for symbolic axis in mutate(TensorView)
jacobhinkle Jul 7, 2023
42392d3
Specify iter_type in cat() during RemoveEmptyPass
jacobhinkle Jul 7, 2023
1e31b10
Merge branch 'main' into remove_empty_branches
jacobhinkle Jul 7, 2023
7f9b588
Fix up this PR based on lessons from #576
jacobhinkle Jul 12, 2023
7593ea4
Comment out non-working test case in DynamicPadShmoo
jacobhinkle Jul 12, 2023
aa5d75c
Clean up stale code
jacobhinkle Jul 12, 2023
ae8b5b5
Remove unused Fusion::replaceInput
jacobhinkle Jul 12, 2023
da70cc1
Minor comment cleanup
jacobhinkle Jul 12, 2023
899ee21
Update comment in FusionResizeMultiSliceEmpty_CUDA
jacobhinkle Jul 12, 2023
69356eb
Verify that tv2 is replaced by full in #365 repro test
jacobhinkle Jul 12, 2023
4dcff1e
Fix segfault in segmentation
jacobhinkle Jul 12, 2023
7817163
Merge branch 'main' into remove_empty_branches
jacobhinkle Jul 12, 2023
aac04e8
Fix clang-tidy warning about .empty()
jacobhinkle Jul 12, 2023
a1c9238
Update test to accomodate recent Scalar refactor
jacobhinkle Jul 12, 2023
16a9418
Initialize maybe_zero_extents_ more efficiently
jacobhinkle Jul 13, 2023
91e3c04
Minor type and linter fix
jacobhinkle Jul 13, 2023
756e533
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jul 13, 2023
14468a6
Expand comment about why we concretize ext->zero
jacobhinkle Jul 13, 2023
6f2db3b
Update comment about cat in EmptyTensorRemover
jacobhinkle Jul 13, 2023
8eda179
Merge branch 'main' into remove_empty_branches
jacobhinkle Jul 13, 2023
a55586a
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jul 14, 2023
3c10b7a
Merge remote-tracking branch 'origin/main' into remove_empty_branches
jacobhinkle Jul 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 124 additions & 19 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone(
cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op));
}
}
cloned_info.maybe_zero_extents_set_.reserve(maybe_zero_extents_set_.size());
for (const auto v : maybe_zero_extents_set_) {
if (v) {
cloned_info.maybe_zero_extents_set_.insert(ir_cloner.clone(v));
}
}
cloned_info.maybe_zero_extents_.reserve(maybe_zero_extents_.size());
for (const auto v : maybe_zero_extents_) {
if (v) {
cloned_info.maybe_zero_extents_.push_back(ir_cloner.clone(v));
}
}
cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size());
for (const auto v : root_dynamic_vals_) {
if (v) {
Expand All @@ -58,6 +70,10 @@ std::string DynamicTransformInitialInfo::toString() const {
for (const auto& op : dynamic_resized_ids_) {
ss << indent << indent << op->toString() << "\n";
}
ss << indent << "Dynamic extent Vals:\n";
for (const auto& v : maybe_zero_extents_) {
ss << indent << indent << v->toString() << "\n";
}
ss << indent << "Root dynamic Vals:\n";
for (const auto& v : root_dynamic_vals_) {
ss << indent << indent << v->toString() << "\n";
Expand All @@ -77,6 +93,8 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
traverseTo(fusion, fusion->getTerminatingOutputs(), false, false);

finalizeDynamicVals();

finalizeMaybeEmptyExtents();
}

const auto& getInfo() const {
Expand All @@ -95,22 +113,24 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
info_.dynamic_reshaped_tvs_.push_back(out_tv);

// Input and output extent expressions both affect concretization
const auto& inp_dom =
TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain());
for (const auto id : inp_dom) {
for (const auto& id :
TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain())) {
leaf_dynamic_vals_.push_back(id->extent());
}
const auto& out_dom = out_tv->getMaybeRFactorDomain();
for (const auto id : out_dom) {
for (const auto& id : out_tv->getMaybeRFactorDomain()) {
leaf_dynamic_vals_.push_back(id->extent());
}
}
}

//! Detect dynamic IterDomain transforms when handling TensorViews
//! Detect possibly empty TensorViews and dynamic IterDomain transforms
void handle(TensorView* tv) override {
const auto& rfd = tv->getMaybeRFactorDomain();
for (auto id : rfd) {
if (!id->extent()->isConstScalar() || id->extent()->evaluateInt() == 0) {
info_.maybe_zero_extents_set_.insert(id->extent());
leaf_dynamic_vals_.push_back(id->extent());
}
if (!id->definition() || id->getIterType() != IterType::Symbolic) {
continue;
}
Expand Down Expand Up @@ -141,6 +161,15 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
}
}

//! Convert maybe_zero_extents_set_ to a vector so we can index it reliably
void finalizeMaybeEmptyExtents() {
info_.maybe_zero_extents_ = std::vector<Val*>(
info_.maybe_zero_extents_set_.begin(),
info_.maybe_zero_extents_set_.end());
// Clear the corresponding set to free memory and speed up cloning
info_.maybe_zero_extents_set_.clear();
}

private:
DynamicTransformInitialInfo info_;

Expand All @@ -154,6 +183,36 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor {
std::vector<Val*> leaf_dynamic_vals_;
};

DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo(
const DynamicTransformInitialInfo* initial_info,
ExpressionEvaluator* expr_eval)
: initial_info_(initial_info) {
TORCH_INTERNAL_ASSERT(
!fusion()->isA<kir::Kernel>(),
"Invalid container. Kernel container not allowed.\n");

// Make sure all exactly mapped IDs have the same value in the
// evaluator when any one of the IDs has a known value
expr_eval->propagateBoundValuesThroughExactMaps(initial_info_->fusion());

analyzeReshapes(expr_eval);

analyzeResizes(expr_eval);

auto maybe_zero_extents = initial_info_->getMaybeZeroExtents();
for (auto i : c10::irange(maybe_zero_extents.size())) {
auto ext = maybe_zero_extents.at(i);
auto ext_opt = expr_eval->evaluate(ext);
TORCH_INTERNAL_ASSERT(
ext_opt.hasValue(),
"Could not evaluate dynamic extent: ",
ext->toString());
if (ext_opt == 0) {
empty_extents_.push_back(i);
}
}
}

void DynamicTransformConcretizationInfo::analyzeReshapes(
ExpressionEvaluator* expr_eval) {
const auto& reshape_tvs = initial_info_->getDynamicReshapedTensorViews();
Expand Down Expand Up @@ -281,7 +340,8 @@ bool DynamicTransformConcretizationInfo::operator==(
}

if (reshape_transforms_.size() != other.reshape_transforms_.size() ||
resize_itertypes_.size() != other.resize_itertypes_.size()) {
resize_itertypes_.size() != other.resize_itertypes_.size() ||
empty_extents_.size() != other.empty_extents_.size()) {
return false;
}

Expand All @@ -301,13 +361,26 @@ bool DynamicTransformConcretizationInfo::operator==(
}
}

for (const auto i : c10::irange(empty_extents_.size())) {
const auto& ee = empty_extents_.at(i);
const auto& other_ee = other.empty_extents_.at(i);
if (ee != other_ee) {
return false;
}
}

return true;
}

std::string DynamicTransformConcretizationInfo::toString() const {
std::stringstream ss;
ss << "DynamicTransformConcretizationInfo\n";
std::string indent = " ";
ss << indent << "Empty tensor extents:\n";
for (const auto& i : empty_extents_) {
auto ext = initial_info_->getMaybeZeroExtents().at(i);
ss << indent << indent << ext->toString() << " is zero\n";
}
ss << indent << "Reshape:\n";
for (const auto& [tv_index, analyze_result] : reshape_transforms_) {
auto tv = initial_info_->getDynamicReshapedTensorViews().at(tv_index);
Expand All @@ -333,6 +406,7 @@ class DynamicTransformConcretizer : public OptOutMutator {
TORCH_INTERNAL_ASSERT(
fusion == info->fusion(),
"Invalid DynamicTransformInitialInfo. The associated Fusion is different from the given Fusion");
FusionGuard fg(fusion);
concretize();
}

Expand All @@ -343,6 +417,8 @@ class DynamicTransformConcretizer : public OptOutMutator {

void concretizeResize();

void concretizeEmptyExtents();

//! Use this instead of calling registerMutation directly, since it will also
//! check that the concretized value is a valid input to all of its uses.
void registerConcretization(Val* old_val, Val* new_val) {
Expand Down Expand Up @@ -370,18 +446,40 @@ class DynamicTransformConcretizer : public OptOutMutator {
};

void DynamicTransformConcretizer::concretize() {
// First, concretize all dynamic reshape ops
// Concretize all dynamic reshape ops
concretizeReshape();

// Set output IterTypes for dynamic resize ops
concretizeResize();

// Registers replacement of all empty extents with zeroVal()
concretizeEmptyExtents();

// Finally, propagate concretized domains
auto all_stmts = StmtSort::getStmts(info_->fusion(), true);
for (auto stmt : all_stmts) {
if (stmt->isA<Val>()) {
mutate(stmt);
auto all_stmts = StmtSort::getStmts(info_->fusion());
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note we no longer traverse into members.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't remember why it did

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to true in #258 so that the traversal would handle IterDomains, but I didn't have a clear enough picture of how that should work at that time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we explicitly are only traversing the TensorView graph, and we only handle IterDomain members manually.

for (auto tv : ir_utils::filterByType<TensorView>(all_stmts)) {
mutate(tv);
}
}

void DynamicTransformConcretizer::concretizeEmptyExtents() {
auto fusion = FusionGuard::getCurFusion();
for (const auto& ext_index : info_->getEmptyExtents()) {
auto ext = info_->initialInfo()->getMaybeZeroExtents().at(ext_index);
auto zero = fusion->zeroVal(ext->getDataType().value());
auto uses = ext->uses();
for (auto use : uses) {
ir_utils::replaceValInExpr(use, ext, zero);
}
// Register the concretization of this scalar, which allows us to replace it
// whenever it is used as an extent member of an IterDomain.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because ext is "used" by an IterDomain but it isn't part of uses(), correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly. When we replace in uses() we update things like i0 + 1, but iS2{i0} would remain unless registered for mutation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add this to the code comment as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

//
// When we ext in all uses above, it affects downstream expressions. For
// example we might replace i0 with 0 in (i0 + i1) + i2 to form (0 + i1) +
// i2. However, i0 itself might be used as the extent, start, or stop values
// in an IterDomain, so we register the concretization here so that we can
// replace these values whenever we encounter them.
registerConcretization(ext, zero);
}
}

Expand All @@ -400,7 +498,8 @@ void DynamicTransformConcretizer::concretizeReshape() {
checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv);

// Replace the old tensor with the new concretized tensor
for (auto use_of_old_tv : incomplete_out_tv->uses()) {
auto uses = incomplete_out_tv->uses();
for (auto use_of_old_tv : uses) {
ir_utils::replaceValInExpr(
use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv);
}
Expand Down Expand Up @@ -446,8 +545,10 @@ void DynamicTransformConcretizer::checkConcretizedUses(
// concretized. Since symbolic IDs may be propagated down to
// consumers, those domains need to be concretized accordingly.
void DynamicTransformConcretizer::mutate(TensorView* tv) {
if (!tv->domain()->hasSymbolicAxis()) {
return;
for (auto root_id : tv->getRootDomain()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit confusing. A root ID may be mutated here, but there's also propagateFromProducerToConsumer below, which may also mutate a root ID. Is there any concern of conflicts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At line 659 in propagateFromProducerToConsumer, we reflect the earlier mutation by basing the new mutated ID on maybeMutated(root_id). That way we compose both mutations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, what happens if the same root ID is mutated both at line 546 and within line 551? Doesn't the latter just overwrite the mutation at line 546?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, are you referring to line 659 before this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will overwrite the mutation, but since we using IterDomainBuilder(maybeMutated(root_id)->as<IterDomain>()) as the basis of the new mutation lets us update the IterType without changing other mutations like the extent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see it now. Maybe it'd be helpful to add that to the code comment too.

// This will register root_id for mutation if its extent, start, or
// stop_offset is registered for mutation
OptOutMutator::mutate(root_id);
}

// First, try to concretize the root domain as there may be symbolic
Expand Down Expand Up @@ -522,12 +623,14 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) {
continue;
}
auto concretized_out_id =
IterDomainBuilder(out_id).iter_type(iter_type).build();
IterDomainBuilder(maybeMutated(out_id)->as<IterDomain>())
.iter_type(iter_type)
.build();
registerConcretization(out_id, concretized_out_id);
}

// The expr itself needs to be mutated as well in case the outputs are
// mutated, which can be done by the mutate method
// expr must be mutated in order to set it as the definition for the
// concretized outputs.
OptOutMutator::mutate(expr);
}
}
Expand Down Expand Up @@ -656,7 +759,9 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
consumer->toString());

auto concretized_id =
IterDomainBuilder(root_id).iter_type(*id_type).build();
IterDomainBuilder(maybeMutated(root_id)->as<IterDomain>())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change because the extent of root_id may be changed to 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, when we call OptOutMutator::mutate(root_id) it will register it for mutation if any of its members are mutated, including the extent. Since we are going to register another mutation here, we don't want to lose those changes, so we base concretized_id on the mutated ID.

.iter_type(*id_type)
.build();

registerConcretization(root_id, concretized_id);
is_concretized = true;
Expand Down
55 changes: 38 additions & 17 deletions csrc/dynamic_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <ir/cloner.h>
#include <iter_visitor.h>
#include <transform_view.h>
#include <utils.h>

#include <functional>
#include <memory>
Expand All @@ -38,18 +39,35 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo {
return fusion_;
}

//! Return whether any dynamic transforms exist in the Fusion
bool hasDynamicTransforms() const {
return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty();
//! Return whether any dynamic transforms exist in the Fusion, or whether
//! there are any tensors which could potentially be empty (size-0 extent)
//! given some user input. In either of these cases, concretization may change
//! the structure of the Fusion.
bool isDynamic() const {
return hasPossibleEmptyTensor() || !dynamic_reshaped_tvs_.empty() ||
!dynamic_resized_ids_.empty();
}

//! Return whether there are any tensors with unknown extent in some
//! dimension, so that they might be empty
bool hasPossibleEmptyTensor() const {
return !maybe_zero_extents_.empty();
}

//! Return a set of scalars that are inputs or extents of input TensorViews
//! and that appear in inputs to dynamic expressions. Any Vals not in this
//! list do not affect concretization.
const std::unordered_set<Val*> getRootDynamicVals() const {
const std::unordered_set<Val*>& getRootDynamicVals() const {
return root_dynamic_vals_;
}

//! Return a set of scalars that appear as extents in TensorViews in the
//! Fusion. If any of these evaluate to zero, there is at least one empty
//! TensorView present.
const std::vector<Val*>& getMaybeZeroExtents() const {
return maybe_zero_extents_;
}

//! Return a vector of outputs of ViewOp expressions that have dynamic output
//! shapes
const std::vector<TensorView*>& getDynamicReshapedTensorViews() const {
Expand Down Expand Up @@ -93,6 +111,12 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo {

std::vector<IterDomain*> dynamic_resized_ids_;

// This is a minimal set of scalars to check for empty tensors. If any are
// zero, we should traverse to find empty tensors.
std::unordered_set<Val*> maybe_zero_extents_set_;
// The set above is populated then used to create this unique vector
std::vector<Val*> maybe_zero_extents_;

// Root Vals that determine concretization
std::unordered_set<Val*> root_dynamic_vals_;

Expand All @@ -105,19 +129,10 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
public:
DynamicTransformConcretizationInfo(
const DynamicTransformInitialInfo* initial_info,
ExpressionEvaluator* expr_eval)
: initial_info_(initial_info) {
TORCH_INTERNAL_ASSERT(
!fusion()->isA<kir::Kernel>(),
"Invalid container. Kernel container not allowed.\n");
ExpressionEvaluator* expr_eval);

// Make sure all exactly mapped IDs have the same value in the
// evaluator when any one of the IDs has a known value
expr_eval->propagateBoundValuesThroughExactMaps(initial_info->fusion());

analyzeReshapes(expr_eval);

analyzeResizes(expr_eval);
const std::vector<size_t>& getEmptyExtents() const {
return empty_extents_;
}

//! Return a vector of pairs holding the index of each reshaped TensorView in
Expand Down Expand Up @@ -185,10 +200,16 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo {
//! result of analyzeView
std::vector<std::pair<size_t, AnalyzeViewResult>> reshape_transforms_;

//! Holds a vector of indices into initial_info_.getMaybeZeroExtents() which
//! evaluate to 0
std::vector<size_t> empty_extents_;

//! Holds the index of the resized IterDomain (output of the Resize op) in the
//! vector returned by initial_info_->getDynamicResizedIterDomains() along
//! with its concretized IterType
std::vector<std::pair<size_t, IterType>> resize_itertypes_;

friend class DynamicTransformInfoBuilder;
};

class TORCH_CUDA_CU_API DynamicTransform {
Expand All @@ -201,7 +222,7 @@ class TORCH_CUDA_CU_API DynamicTransform {
//! Concretizes a given fusion. Note that the concretization is
//! in-place and the given fusion is modified.
static void concretizeFusion(
Fusion*,
Fusion* fusion,
const DynamicTransformConcretizationInfo* info);
};

Expand Down
Loading