Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 65 additions & 56 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions torch/csrc/jit/codegen/cuda/compute_at_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class TORCH_CUDA_CU_API IterDomainGraph {
IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false);

// Returns the disjoint set according to one of the mapping mode types.
const DisjointSets<IterDomain*>& getNodes(IdMappingMode mode) const;
const DisjointSets<IterDomain*>& getDisjointIdsSet(IdMappingMode mode) const;

// Consumers and producers is not symmetric like the other sets
const std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>&
consumers() const {
Expand Down Expand Up @@ -94,7 +95,7 @@ class TORCH_CUDA_CU_API IterDomainGraph {
return self_mapping_info_.has_value();
}

// Update the LOOP nodes with resolved computeWith
// Update the LOOP ID disjoint sets with resolved computeWith
void updateComputeWith(TensorView* compute_with_tv);

private:
Expand All @@ -114,14 +115,15 @@ class TORCH_CUDA_CU_API IterDomainGraph {
// be replayed the same as eachother, so mapping them is very straightforward.
void mapMultiOutput(Expr* expr);

// Fills nodes_[IdMappingMode::EXACT] for relationships between inputs and
// first output of expr
// Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs
// and first output of expr
void mapExact(Expr* expr);

// Fills nodes_[IdMappingMode::PERMISSIVE] for relationships between inputs
// and first output of expr
// Fills disjoint_ids_[IdMappingMode::PERMISSIVE] for relationships between
// inputs and first output of expr
//
// Currently also fills nodes_[IdMappingMode::LOOP], consumer_, and producer_
// Currently also fills disjoint_ids_[IdMappingMode::LOOP], consumer_, and
// producer_
void mapPermissiveAndLoop(Expr* expr);

// Propagates forward then backward through all view like rfactor
Expand All @@ -139,12 +141,12 @@ class TORCH_CUDA_CU_API IterDomainGraph {

// ======= END Iteration domain build process in order called =======

// Non-const internal only version of getNodes.
DisjointSets<IterDomain*>& nodes(IdMappingMode mode);
// Non-const internal only version of getDisjointIdsSet.
DisjointSets<IterDomain*>& disjointIdsSet(IdMappingMode mode);

// Simple alias
void mapNodes(IterDomain* id0, IterDomain* id1, IdMappingMode mode) {
nodes(mode).mapEntries(id0, id1);
void mapIds(IterDomain* id0, IterDomain* id1, IdMappingMode mode) {
disjointIdsSet(mode).mapEntries(id0, id1);
}

// Checks if expr's are considered "the same" where sameness inputs and
Expand All @@ -166,7 +168,7 @@ class TORCH_CUDA_CU_API IterDomainGraph {
// Using an array here might be nice, but it seems hard to use an enum as an
// array key
// https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum
std::unordered_map<IdMappingMode, DisjointSets<IterDomain*>> nodes_;
std::unordered_map<IdMappingMode, DisjointSets<IterDomain*>> disjoint_ids_;

// Consumers and producers is not symmetric like the other sets
// TODO: Generalize to mapping type. Mappings between producer TV ids and
Expand Down Expand Up @@ -228,7 +230,7 @@ class TORCH_CUDA_CU_API ComputeAtMap {

//! Simple alias to IdGraph mappings.
bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const {
return idGraph().getNodes(mode).strictAreMapped(id0, id1);
return idGraph().getDisjointIdsSet(mode).strictAreMapped(id0, id1);
}
//! Returns an iter domain that is the maximum expanded size of all iter
//! domains the one provided maps to. Useful for opening loops to the correct
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ std::unordered_set<Split*> getAllDivisibleSplits(
auto original_view_split = entry.second;

const auto& exact_mapped_ids = ca_map->idGraph()
.getNodes(IdMappingMode::EXACT)
.getDisjointIdsSet(IdMappingMode::EXACT)
.getDisjointSetOf(concrete_id)
.vector();
for (auto other_id : exact_mapped_ids) {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/lower_index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector<Val*>& ids) {
GpuLower::current()
->caMap()
->idGraph()
.getNodes(IdMappingMode::PERMISSIVE)
.getDisjointIdsSet(IdMappingMode::PERMISSIVE)
.permissiveAreMapped(id, val->as<IterDomain>());
});
}
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/lower_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ void HaloInfo::setRootAxisInfo(

HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr<const ComputeAtMap> ca_map)
// Make a copy of the permissive map for extent comparators
: permissive_map_(ca_map->idGraph().getNodes(IdMappingMode::PERMISSIVE)) {
: permissive_map_(
ca_map->idGraph().getDisjointIdsSet(IdMappingMode::PERMISSIVE)) {
const auto vals = fusion->usedMathVals();
auto tvs = ir_utils::filterByType<TensorView>(vals);

Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map) {
// Mark those as an active use of the rfactor, if two are detected, return
// true.
for (const auto& disjoint_set_shared_ptr :
ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) {
ca_map.idGraph()
.getDisjointIdsSet(IdMappingMode::EXACT)
.disjointSets()) {
// Make sure there's at least one rfactor domain in the set, otherwise we
// don't need to check anything from this set.
if (!std::any_of(
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class DomainMap : public pointwise_utils::DomainMap {
IterDomain* mapped_id = nullptr;
for (auto i : c10::irange(root_dom.size())) {
if (ca_map_.idGraph()
.getNodes(IdMappingMode::EXACT)
.getDisjointIdsSet(IdMappingMode::EXACT)
.permissiveAreMapped(root_dom[i], root_dim)) {
mapped_id = root_dom[i];
break;
Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2100,7 +2100,7 @@ void BoundedDirectionalTransformPropagator::bothWays(
DisjointSets<IterDomain*> disjointViewSets(Fusion* fusion) {
// Start from the exact iter domain graph of the fusion
IterDomainGraph id_graph(fusion);
auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT);
auto disjoint_view_ids = id_graph.getDisjointIdsSet(IdMappingMode::EXACT);

// If iter domains are involved in any transformation from root domains to
// rfactor domains they should be considered "contaminated".
Expand Down Expand Up @@ -2240,7 +2240,9 @@ void propagateViewTransforms(Fusion* fusion, const ComputeAtMap& ca_map) {

std::unordered_set<IterDomain*> terminating_rfactor_dims;
for (const auto& disjoint_set_shared_ptr :
ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) {
ca_map.idGraph()
.getDisjointIdsSet(IdMappingMode::EXACT)
.disjointSets()) {
if (std::none_of(
disjoint_set_shared_ptr->vector().begin(),
disjoint_set_shared_ptr->vector().end(),
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Val* commonOrConstExtent(
std::shared_ptr<const ComputeAtMap> ca_map,
IterDomain* id) {
auto disjoint_set = ca_map->idGraph()
.getNodes(IdMappingMode::ALMOSTEXACT)
.getDisjointIdsSet(IdMappingMode::ALMOSTEXACT)
.getDisjointSetOf(id);
for (auto entry : disjoint_set) {
if (entry->extent()->isConstScalar()) {
Expand Down
12 changes: 6 additions & 6 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,21 +1210,21 @@ TEST_F(NVFuserTest, FusionViewIdGraph_CUDA) {

// Start from the exact iter domain graph of the fusion
IterDomainGraph id_graph(&fusion);
auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT);
auto disjoint_view_ids = id_graph.getDisjointIdsSet(IdMappingMode::EXACT);

TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT)
TORCH_CHECK(id_graph.getDisjointIdsSet(IdMappingMode::EXACT)
.strictAreMapped(tv2->axis(1), tv4->axis(1)));
TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT)
TORCH_CHECK(id_graph.getDisjointIdsSet(IdMappingMode::EXACT)
.strictAreMapped(tv2->axis(2), tv4->axis(2)));

TORCH_CHECK(
id_graph.getNodes(IdMappingMode::EXACT)
id_graph.getDisjointIdsSet(IdMappingMode::EXACT)
.strictAreMapped(tv2->getRootDomain()[1], tv12->getRootDomain()[1]));
TORCH_CHECK(
id_graph.getNodes(IdMappingMode::EXACT)
id_graph.getDisjointIdsSet(IdMappingMode::EXACT)
.strictAreMapped(tv2->getRootDomain()[2], tv12->getRootDomain()[2]));
TORCH_CHECK(
id_graph.getNodes(IdMappingMode::EXACT)
id_graph.getDisjointIdsSet(IdMappingMode::EXACT)
.strictAreMapped(tv2->getRootDomain()[3], tv12->getRootDomain()[3]));
}

Expand Down