Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 74 additions & 159 deletions csrc/dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,16 +357,20 @@ class DynamicTransformConcretizer : public OptOutMutator {

using OptOutMutator::mutate;

void mutate(TensorView* tv) final;

void mutate(TensorDomain* td) final;
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The override of mutate(TensorDomain*) merely updates contiguity to reflect any introduced Broadcast IDs. This is a general situation that could probably be added to OptOutMutator::mutate(TensorDomain*) instead to further simplify the concretization code.


//! Concretizes the root domain of a symbolic consumer tensor from
//! its producer domains. Returns true if any root ID is concretized.
bool propagateFromProducerToConsumer(TensorView* consumer);
void mutate(IterDomain* id) final;

private:
const DynamicTransformConcretizationInfo* info_;

//! This map is used during concretization to identify, for a given IterDomain
//! the set of all IterDomains which are "aligned" with it in some TensorView
//! expression. This enables us to write mutate(IterDomain*) and propagate
//! information from producer IterDomains to consumers, which is otherwise not
//! represented in the graph since we do not connect IterDomains between
//! TensorViews with expressions.
std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
id_producers_;
};

void DynamicTransformConcretizer::concretize() {
Expand All @@ -376,13 +380,37 @@ void DynamicTransformConcretizer::concretize() {
// Set output IterTypes for dynamic resize ops
concretizeResize();

// Finally, propagate concretized domains
// The methods above do not traverse the graph. Instead they fill in
// root->rfactor expressions by replacing the dynamic reshaped TV with a
// static reshaped one, and by registering concretization of dynamic Resized
// IterDomains. From this point forward, we will not modify any TensorView
// expressions. This restriction makes it safe for us to to traverse forward
// through the graph and mutate IterDomains and TensorDomains in order to
// properly propagate IterTypes and concretized extent expressions, without
// breaking the topological ordering of these expressions.
//
// When propagating IterTypes across expressions, we need to know the producer
// IterDomains corresponding to a consumer ID. This mapping helps facilitate
// this and is used later in mutate(IterDomain*).
auto all_stmts = StmtSort::getStmts(info_->fusion(), true);
for (auto stmt : all_stmts) {
if (stmt->isA<Val>()) {
mutate(stmt);
Comment on lines -382 to -383
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of only mutating Vals, we now mutate Exprs as well, which replaces the Expr in place if any inputs or outputs have changed. Note that outputs of Exprs are mutated after their definition has been mutated, so we should be careful updating a Val that has a definition. But of course we should be careful in that case in the existing code too.

for (auto expr : ir_utils::filterByType<Expr>(all_stmts)) {
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
for (auto [cid, pid] : root_map.mapConsumerToProducer(
consumer->domain(), producer->domain())) {
// Initialize set of producer IDs, if we haven't already
auto& producers = id_producers_.emplace(cid, 0).first->second;
producers.insert(pid);
}
}
}
}

// Finally, propagate concretized domains with forward traversal
for (auto stmt : all_stmts) {
mutate(stmt);
}
}

void DynamicTransformConcretizer::concretizeReshape() {
Expand Down Expand Up @@ -441,103 +469,6 @@ void DynamicTransformConcretizer::checkConcretizedUses(
}
}

// Concretizes inherited symbolic domains. Note that when this is
// called, it is assumed that all dynamic ops themselves are
// concretized. Since symbolic IDs may be propagated down to
// consumers, those domains need to be concretized accordingly.
void DynamicTransformConcretizer::mutate(TensorView* tv) {
if (!tv->domain()->hasSymbolicAxis()) {
return;
}

// First, try to concretize the root domain as there may be symbolic
// axes inherited from the producers
propagateFromProducerToConsumer(tv);

// If no root domain is altered by producer, we don't need to propagate back
// up to rfactor. We could return early, but instead we go ahead and check the
// root to rfactor transforms to be sure we have concretized any intermediate
// IterDomains.

// At this point, there should be no expr beyond rfactor root
TORCH_INTERNAL_ASSERT(
tv->getLeafDomain() == tv->getMaybeRFactorDomain(),
"Invalid tensor: ",
tv->toString());

// If it has an rfactor root domain, the IterTypes of the rfactor
// IDs may need to be updated as well. Traverse the rfactor exprs
// and mutate the IterTypes of output IDs if symbolic.
if (tv->hasRFactor()) {
// Note that it is assumed that theres's no further expression
// beyond the rfactor domain as asserted above
auto all_id_exprs = StmtSort::getExprsBetween(
tv->fusion(),
{tv->getRootDomain().begin(), tv->getRootDomain().end()},
{tv->getMaybeRFactorDomain().begin(),
tv->getMaybeRFactorDomain().end()});
for (auto expr : all_id_exprs) {
// Assume outputs of IterDomain exprs are always IterDomains. If
// the assumption is invalidated, the logic here would need to
// be updated. Assert the assumption to immediately detect such
// a case if happened.
for (auto out_val : expr->outputs()) {
TORCH_INTERNAL_ASSERT(
out_val->isA<IterDomain>(),
"Unexpected output: ",
out_val->toString(),
". IterDomain was expected.");
}

// NOTE: We do not return early if all outputs are concrete as there may
// still be concrete inputs. For example, a Symbolic IterDomain might be
// padded with constant pad widths (1, 1), in which case although we do
// not know the exact extent of the output, we know it is at least as
// large as the sum of the pad widths, 2. In such cases, the output
// IterDomain is concrete at definition, since if the extent is >1 we know
// the IterType is Iteration. In these cases, we must continue to
// concretize intermediate expressions between the root and R-factor
// domain. See test DynamicTransform5_CUDA which demonstrates this
// behavior.
// NOTE: We also do not assume that if one output ID is symbolic, that
// they all must be. See test FusionSliceForNanoGPT3_CUDA for an example
// that does a static split by a factor of 16 of a symbolic input domain.
// The static split in that case results in a concrete IterDomain with
// extent 16 along with a symbolic one (extent ceilDiv(n / 16)).

// Determine the output IterType
IterType iter_type = IterType::Symbolic;
for (auto inp_id : ir_utils::filterByType<IterDomain>(expr->inputs())) {
auto updated_id = maybeMutated(inp_id)->as<IterDomain>();
iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
}
TORCH_INTERNAL_ASSERT(
iter_type != IterType::Symbolic,
"Failed to concretize an output IterType for expression: ",
expr->toString());

// Update the IterType of each output
for (auto out_id : ir_utils::filterByType<IterDomain>(expr->outputs())) {
if (!out_id->isSymbolic()) {
continue;
}
auto concretized_out_id =
IterDomainBuilder(out_id).iter_type(iter_type).build();
registerConcretization(out_id, concretized_out_id);
}

// The expr itself needs to be mutated as well in case the outputs are
// mutated, which can be done by the mutate method
OptOutMutator::mutate(expr);
}
}

// Root and rfactor domains are updated. First mutate the
// TensorDomain and then TensorView
mutate(tv->domain());
OptOutMutator::mutate(tv);
}

// Almost an exact copy of OptOutMutator::mutate(TensorDomain*), but
// the contiguity vector may need to be updated as well as symbolic
// domains may be mutated to broadcast domains, which means contiguity
Expand Down Expand Up @@ -594,75 +525,59 @@ void DynamicTransformConcretizer::mutate(TensorDomain* td) {
registerConcretization(td, mutated_val);
}

bool DynamicTransformConcretizer::propagateFromProducerToConsumer(
TensorView* consumer) {
if (consumer->definition() == nullptr ||
!consumer->domain()->hasSymbolicAxis()) {
return false;
void DynamicTransformConcretizer::mutate(IterDomain* id) {
// id might have already been mutated if its definition was updated
id = maybeMutated(id)->as<IterDomain>();
if (!id->isSymbolic()) {
return;
}

const auto& root_domain = consumer->getRootDomain();

auto def = consumer->definition();

bool is_concretized = false;

for (const auto i : c10::irange(root_domain.size())) {
auto root_id = root_domain.at(i);
if (root_id->getIterType() != IterType::Symbolic) {
continue;
if (auto def = id->definition()) {
// Determine concrete IterType based on promotion of inputs to def
IterType iter_type = IterType::Symbolic;
for (auto inp_id : ir_utils::filterByType<IterDomain>(def->inputs())) {
auto updated_id = maybeMutated(inp_id)->as<IterDomain>();
iter_type = ops::promoteIterType(iter_type, updated_id->getIterType());
}
TORCH_INTERNAL_ASSERT(
iter_type != IterType::Symbolic,
"Failed to concretize an output IterType for expression: ",
def->toString());
auto concretized_id = IterDomainBuilder(id).iter_type(iter_type).build();
registerConcretization(id, concretized_id);
} else {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch replaces the root->rfactor propagation that was removed from mutate(TensorView*).

// IterDomains without definitions might be root domains for the output of a
// TensorView expression. If so, we should propagate their concretization in
// the producer to consumer direction.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch replaces propagateFromProducerToConsumer.


auto producers_it = id_producers_.find(id);
if (producers_it == id_producers_.end()) {
// id was not a consumer root ID in any TV expression
return;
}

// Figure out the right IterType of this consumer root ID from its
// corresponding producer IDs

std::optional<IterType> id_type;

for (auto producer : ir_utils::filterByType<TensorView>(def->inputs())) {
PairwiseRootDomainMap root_map(producer, consumer);
auto c2p = root_map.mapConsumerToProducer(
consumer->domain(), producer->domain());

TORCH_INTERNAL_ASSERT(
c2p.find(root_id) != c2p.end(),
"No input ID found to map with output ID: ",
root_id->toString());

auto input_id = c2p.at(root_id);
TORCH_INTERNAL_ASSERT(
input_id->getIterType() != IterType::Symbolic,
"Producer ID not concretized: ",
input_id->toString());

for (auto producer_id : producers_it->second) {
producer_id = maybeMutated(producer_id)->as<IterDomain>();
if (id_type.has_value()) {
id_type = ops::promoteIterType(*id_type, input_id->getIterType());
id_type = ops::promoteIterType(*id_type, producer_id->getIterType());
} else {
id_type = input_id->getIterType();
id_type = producer_id->getIterType();
}
}

TORCH_INTERNAL_ASSERT(
id_type.has_value(),
"Did not find id_type for consumer root domain ",
root_id->toString(),
". Perhaps consumer def has no inputs. Consumer definition = ",
def->toString());
id->toString(),
". Perhaps consumer def has no inputs.");

TORCH_INTERNAL_ASSERT(
id_type != IterType::Symbolic,
"Failed to concretize ",
root_id->toString(),
" of ",
consumer->toString());
id_type != IterType::Symbolic, "Failed to concretize ", id->toString());

auto concretized_id =
IterDomainBuilder(root_id).iter_type(*id_type).build();
auto concretized_id = IterDomainBuilder(id).iter_type(*id_type).build();

registerConcretization(root_id, concretized_id);
is_concretized = true;
registerConcretization(id, concretized_id);
}

return is_concretized;
}

DynamicTransformInitialInfo DynamicTransform::getInitialInfo(Fusion* fusion) {
Expand Down
6 changes: 6 additions & 0 deletions csrc/iter_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class Fusion;
* the dag will be called with handle(Statement*) in topolgical order inputs of
* the fusion to outputs of the fusion.
*
* Note that for any Val whose definition is non-null, the following are
* processed in order: definition, attributes, members. In particular, this
* means that a TensorView's domain() is processed after its definition, meaning
* producer TVs and their IterDomains are all processed before those of
* consumers.
*
* TODO: We may want a BFS version of this code to extract ILP, not implemented
* yet.
*
Expand Down
5 changes: 4 additions & 1 deletion test/test_dynamic_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,10 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) {
//{{3, 5}, {-3, -2}, false}, // output is zero-dimensional

// Output has size 1 so is set to broadcast.
{{3, 5}, {0, -4}, true},
// This was previously "working" by concretizing the size-1 pad to
// Iteration, even though it should be Broadcast. When set properly to
// Broadcast, it fails with an error in ConcretizedBroadcastDomains.
//{{3, 5}, {0, -4}, true},
Comment on lines +947 to +950
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Jul 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case deserves its own issue, which I will add. When there is a broadcast domain introduced by concretizing a resize we hit an error since we can't concretize the broadcast. On main, this case actually concretizes the pad as Iteration even if extent is 1. Instead, we should probably translate a Resize that results in size 1 as a select+broadcast or a full(pad_value)+broadcast; however that is a complicated change since we need to operate on the TensorView containing the Resized ID, meaning we would change concretization info to track TV ops (cat, pad, slice) instead of ID op Resize. For now I have disabled this case. Once I file an issue I'll point to it here in the comment.


// Test full negative shifts, so output doesn't overlap input
{{3, 5}, {-5, 2}, false},
Expand Down