diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 79e7bc4cf8fa..71a994da149a 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -48,7 +48,7 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { IdMappingMode::LOOP}; for (auto mode : mapping_types) { - nodes_[mode] = DisjointSets(); + disjoint_ids_[mode] = DisjointSets(); } build(fusion); @@ -68,19 +68,25 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { } } -const DisjointSets& IterDomainGraph::getNodes( +const DisjointSets& IterDomainGraph::getDisjointIdsSet( IdMappingMode mode) const { - auto node_set_it = nodes_.find(mode); + auto disjoint_ids_it = disjoint_ids_.find(mode); TORCH_INTERNAL_ASSERT( - node_set_it != nodes_.end(), "Mapping mode ", mode, " not supported."); - return node_set_it->second; + disjoint_ids_it != disjoint_ids_.end(), + "Mapping mode ", + mode, + " not supported."); + return disjoint_ids_it->second; } -DisjointSets& IterDomainGraph::nodes(IdMappingMode mode) { - auto node_set_it = nodes_.find(mode); +DisjointSets& IterDomainGraph::disjointIdsSet(IdMappingMode mode) { + auto disjoint_ids_it = disjoint_ids_.find(mode); TORCH_INTERNAL_ASSERT( - node_set_it != nodes_.end(), "Mapping mode ", mode, " not supported."); - return node_set_it->second; + disjoint_ids_it != disjoint_ids_.end(), + "Mapping mode ", + mode, + " not supported."); + return disjoint_ids_it->second; } bool IterDomainGraph::exprsMap( @@ -133,7 +139,7 @@ bool IterDomainGraph::exprsMap( zipped_ids.begin(), zipped_ids.end(), [&](std::pair id_pair) { - return !getNodes(mode).permissiveAreMapped( + return !getDisjointIdsSet(mode).permissiveAreMapped( id_pair.first, id_pair.second); })) { return false; @@ -212,7 +218,7 @@ void IterDomainGraph::mapThroughExpr( "\nand\n", second->toString()); for (auto out_i : c10::irange(first_ids.size())) { - mapNodes(first_ids[out_i], second_ids[out_i], mode); + mapIds(first_ids[out_i], second_ids[out_i], mode); } } @@ -258,7 +264,7 @@ c10::optional> detectMappablePair( if (id1 == id2) { continue; } - if (id_graph.getNodes(mode).disjointSetMap().at(id1)->has(id2)) { + if (id_graph.getDisjointIdsSet(mode).disjointSetMap().at(id1)->has(id2)) { return std::make_pair(id1, id2); } } @@ -319,15 +325,15 @@ findFirstSelfMapping(Fusion* fusion, const IterDomainGraph& id_graph) { } // namespace -// TODO: Should we avoid marking leaf nodes at this point? +// TODO: Should we avoid marking leaf Ids at this point? void IterDomainGraph::initializeId( IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id) { - nodes(IdMappingMode::PERMISSIVE).initializeSet(id); - nodes(IdMappingMode::EXACT).initializeSet(id); + disjointIdsSet(IdMappingMode::PERMISSIVE).initializeSet(id); + disjointIdsSet(IdMappingMode::EXACT).initializeSet(id); if (is_leaf_id) { - nodes(IdMappingMode::LOOP).initializeSet(id); + disjointIdsSet(IdMappingMode::LOOP).initializeSet(id); } consumers_[id] = {}; producers_[id] = {}; @@ -338,7 +344,7 @@ void IterDomainGraph::initializeId( } void IterDomainGraph::initialIdProcessing(Fusion* fusion) { - // Initialize a node for every iteration domain and mark view like iteration + // Initialize entries for every iteration domain and mark view like iteration // domains and leaf iteration domains. for (auto tv : ir_utils::allTvs(fusion)) { const auto& domain = tv->domain()->domain(); @@ -421,8 +427,8 @@ void IterDomainGraph::mapMultiOutput(Expr* expr) { } auto id0 = *disjoint_set->begin(); for (auto id1 : disjoint_set->vector()) { - mapNodes(id0, id1, IdMappingMode::PERMISSIVE); - mapNodes(id0, id1, IdMappingMode::EXACT); + mapIds(id0, id1, IdMappingMode::PERMISSIVE); + mapIds(id0, id1, IdMappingMode::EXACT); } } @@ -431,7 +437,7 @@ void IterDomainGraph::mapMultiOutput(Expr* expr) { auto disjoint_set = c2f_disjoint_sets.getDisjointSetOf(f_id); auto id0 = *(disjoint_set.begin()); for (auto id1 : disjoint_set) { - mapNodes(id0, id1, IdMappingMode::LOOP); + mapIds(id0, id1, IdMappingMode::LOOP); } } } @@ -484,7 +490,7 @@ void IterDomainGraph::mapExact(Expr* expr) { for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) { auto p_id = exact_c2p_map.at(c_id); - mapNodes(c_id, p_id, IdMappingMode::EXACT); + mapIds(c_id, p_id, IdMappingMode::EXACT); // TODO: consumers/producers should be on a per map basis, mapping should // include unique expr between the disjoint sets @@ -494,8 +500,8 @@ void IterDomainGraph::mapExact(Expr* expr) { // Add the swizzle inputs to the same // disjoint set as well if either c_id // or p_id is swizzle output. - mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), p_id); - mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), c_id); + mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::EXACT), p_id); + mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::EXACT), c_id); } } } @@ -526,12 +532,12 @@ void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) { auto& vec = dset->vector(); for (auto i : c10::irange(vec.size())) { auto id1 = vec[i]; - mapNodes(id1, vec[0], IdMappingMode::PERMISSIVE); + mapIds(id1, vec[0], IdMappingMode::PERMISSIVE); // Add the swizzle inputs to the same // disjoint set as well if either c_id // or p_id is swizzle output. - mapMaybeSwizzleOp(nodes(IdMappingMode::PERMISSIVE), id1); + mapMaybeSwizzleOp(disjointIdsSet(IdMappingMode::PERMISSIVE), id1); // Loop/producer/consumer for (auto j : c10::irange(i + 1, vec.size())) { @@ -541,7 +547,7 @@ void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) { producers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) && idIsALeafDomain(id2, c_tv)) { - mapNodes(id1, id2, IdMappingMode::LOOP); + mapIds(id1, id2, IdMappingMode::LOOP); } } if (c_ids.count(id1) && p_ids.count(id2)) { @@ -549,7 +555,7 @@ void IterDomainGraph::mapPermissiveAndLoop(Expr* expr) { consumers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) && idIsALeafDomain(id1, c_tv)) { - mapNodes(id1, id2, IdMappingMode::LOOP); + mapIds(id1, id2, IdMappingMode::LOOP); } } } @@ -669,8 +675,8 @@ void IterDomainGraph::mapRFactorExprs(Fusion* fusion) { // Only need to be concerned here with mapping across rfactor iter // domains, so isolate out those. - auto all_exact_map_ids = - nodes(IdMappingMode::EXACT).getDisjointSetOf(first_rfactor_id); + auto all_exact_map_ids = disjointIdsSet(IdMappingMode::EXACT) + .getDisjointSetOf(first_rfactor_id); std::vector exact_map_rf_ids; std::copy_if( all_exact_map_ids.vector().begin(), @@ -711,9 +717,10 @@ void IterDomainGraph::mapRFactorExprs(Fusion* fusion) { void IterDomainGraph::buildAlmostExactMap() { // Build almost exact map by forwarding through broadcast axes - nodes(IdMappingMode::ALMOSTEXACT) = nodes(IdMappingMode::EXACT); + disjointIdsSet(IdMappingMode::ALMOSTEXACT) = + disjointIdsSet(IdMappingMode::EXACT); std::unordered_set visited; - auto all_elements = nodes(IdMappingMode::EXACT).getAllElements(); + auto all_elements = disjointIdsSet(IdMappingMode::EXACT).getAllElements(); for (auto entry : all_elements.vector()) { if (entry->definition() == nullptr) { continue; @@ -724,18 +731,18 @@ void IterDomainGraph::buildAlmostExactMap() { } if (auto merge = dynamic_cast(def)) { if (merge->inner()->extent()->isOneInt()) { - mapNodes(merge->outer(), merge->out(), IdMappingMode::ALMOSTEXACT); + mapIds(merge->outer(), merge->out(), IdMappingMode::ALMOSTEXACT); } if (merge->outer()->extent()->isOneInt()) { - mapNodes(merge->inner(), merge->out(), IdMappingMode::ALMOSTEXACT); + mapIds(merge->inner(), merge->out(), IdMappingMode::ALMOSTEXACT); } } else if (auto split = dynamic_cast(def)) { if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && split->stopOffset()->isZeroInt()) { if (split->innerSplit()) { - mapNodes(split->in(), split->outer(), IdMappingMode::ALMOSTEXACT); + mapIds(split->in(), split->outer(), IdMappingMode::ALMOSTEXACT); } else { - mapNodes(split->in(), split->inner(), IdMappingMode::ALMOSTEXACT); + mapIds(split->in(), split->inner(), IdMappingMode::ALMOSTEXACT); } } } @@ -787,7 +794,7 @@ void ComputeAtMap::build(Fusion* fusion) { void ComputeAtMap::validateAndPropagatePType() { for (const auto& loop_disjoint_set : - id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::LOOP).disjointSets()) { ParallelType common_ptype = ParallelType::Serial; for (auto id : loop_disjoint_set->vector()) { auto id_ptype = id->getParallelType(); @@ -813,12 +820,12 @@ void ComputeAtMap::allocateIndexVariables() { // all lowered kir::ForLoop will correspond to one of the disjoint sets // and we only need one index variable for each set. for (const auto& loop_disjoint_set : - id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::LOOP).disjointSets()) { ParallelType ptype; // first allocate thread and grid parallel indices: // The validation pass will check that the parallel bindings within the - // loop nodes are consistent so all the loops within this disjoint set - // will be realized implicitly using parallel index variables. + // loop disjoint IDs set are consistent so all the loops within this + // disjoint set will be realized implicitly using parallel index variables. if (std::any_of( loop_disjoint_set->vector().begin(), loop_disjoint_set->vector().end(), @@ -881,12 +888,12 @@ Val* ComputeAtMap::getIndexVariable( IterDomain* id, DoubleBufferLoopStage double_buffer_loop_stage) const { TORCH_INTERNAL_ASSERT( - id_graph_.getNodes(IdMappingMode::LOOP).mappingExists(id), + id_graph_.getDisjointIdsSet(IdMappingMode::LOOP).mappingExists(id), "Index Variable: no index variable allocated as ", id->toString(), " is not registered in loop map"); const auto* loop_set = - &(id_graph_.getNodes(IdMappingMode::LOOP).getDisjointSetOf(id)); + &(id_graph_.getDisjointIdsSet(IdMappingMode::LOOP).getDisjointSetOf(id)); // Check if this loop was modified by double buffer pass. bool is_double_buffer_iterdomain = @@ -1154,7 +1161,7 @@ void ComputeAtMap::buildConcreteIds() { // deterministic but which ID gets selected her depends on the traversal // order generating the set (compute at map build). for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::EXACT).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1165,7 +1172,7 @@ void ComputeAtMap::buildConcreteIds() { // The following two algorithms seem quite wasteful. Should find a more // efficient way to compute concrete IDs. for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::PERMISSIVE).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::PERMISSIVE).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1176,7 +1183,7 @@ void ComputeAtMap::buildConcreteIds() { // Same as exact computation for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::ALMOSTEXACT).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::ALMOSTEXACT).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1186,7 +1193,7 @@ void ComputeAtMap::buildConcreteIds() { } for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::LOOP).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1244,7 +1251,7 @@ bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) { void ComputeAtMap::buildUniqueExactExprMaps() { // Start by building definitions for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::EXACT).disjointSets()) { std::vector definitions; // N^2 in number of unique transformations, this might be better to do @@ -1290,7 +1297,7 @@ void ComputeAtMap::buildUniqueExactExprMaps() { // Use definitions to build uses for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::EXACT).disjointSets()) { // Make sure uses is always initialized even there are no uses. if (unique_exact_uses_.find(disjoint_set_shared_ptr) == unique_exact_uses_.end()) { @@ -1362,7 +1369,7 @@ IterDomain* ComputeAtMap::getConcreteMappedID( namespace { -std::string idGraphNodesToString( +std::string idGraphDisjointIdSetToString( const ComputeAtMap& ca_map, IdMappingMode mode) { std::stringstream ss; @@ -1410,12 +1417,14 @@ std::string idGraphNodesToString( std::string ComputeAtMap::toString() const { std::stringstream ss; ss << "Compute at map { \n"; - ss << "Exact map:\n" << idGraphNodesToString(*this, IdMappingMode::EXACT); + ss << "Exact map:\n" + << idGraphDisjointIdSetToString(*this, IdMappingMode::EXACT); ss << "Almost Exact map:\n" - << idGraphNodesToString(*this, IdMappingMode::ALMOSTEXACT); - ss << "Loop map:\n" << idGraphNodesToString(*this, IdMappingMode::LOOP); + << idGraphDisjointIdSetToString(*this, IdMappingMode::ALMOSTEXACT); + ss << "Loop map:\n" + << idGraphDisjointIdSetToString(*this, IdMappingMode::LOOP); ss << "Permissive map:\n" - << idGraphNodesToString(*this, IdMappingMode::PERMISSIVE); + << idGraphDisjointIdSetToString(*this, IdMappingMode::PERMISSIVE); ss << "Consumer maps:\n"; for (auto key : getSortedKeys(id_graph_.consumers(), Statement::lessThan)) { auto consumers = id_graph_.consumers().at(key); @@ -1465,7 +1474,7 @@ const std::shared_ptr>& ComputeAtMap:: const DisjointSets& ComputeAtMap::getIdSets( IdMappingMode mode) const { - return id_graph_.getNodes(mode); + return id_graph_.getDisjointIdsSet(mode); } bool ComputeAtMap::idExistsInMap(IterDomain* id) const { @@ -1648,7 +1657,7 @@ void IterDomainGraph::updateComputeWith(TensorView* compute_with_tv) { consumer_tv->domain()->domain().begin(), consumer_tv->domain()->domain().end(), [&](auto consumer_id) { - return getNodes(IdMappingMode::PERMISSIVE) + return getDisjointIdsSet(IdMappingMode::PERMISSIVE) .disjointSetMap() .at(id) ->has(consumer_id); @@ -1662,7 +1671,7 @@ void IterDomainGraph::updateComputeWith(TensorView* compute_with_tv) { IterDomain* consumer_id = *it; - mapNodes(id, consumer_id, IdMappingMode::LOOP); + mapIds(id, consumer_id, IdMappingMode::LOOP); } } @@ -1676,7 +1685,7 @@ void ComputeAtMap::updateComputeWith(TensorView* compute_with_tv) { // Update the LOOP concrete IDs for (const auto& disjoint_set_shared_ptr : - id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { + id_graph_.getDisjointIdsSet(IdMappingMode::LOOP).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index 6c8051d6993c..dd5173fb72c0 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -64,7 +64,8 @@ class TORCH_CUDA_CU_API IterDomainGraph { IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false); // Returns the disjoint set according to one of the mapping mode types. - const DisjointSets& getNodes(IdMappingMode mode) const; + const DisjointSets& getDisjointIdsSet(IdMappingMode mode) const; + // Consumers and producers is not symmetric like the other sets const std::unordered_map>& consumers() const { @@ -94,7 +95,7 @@ class TORCH_CUDA_CU_API IterDomainGraph { return self_mapping_info_.has_value(); } - // Update the LOOP nodes with resolved computeWith + // Update the LOOP ID disjoint sets with resolved computeWith void updateComputeWith(TensorView* compute_with_tv); private: @@ -114,14 +115,15 @@ class TORCH_CUDA_CU_API IterDomainGraph { // be replayed the same as eachother, so mapping them is very straightforward. void mapMultiOutput(Expr* expr); - // Fills nodes_[IdMappingMode::EXACT] for relationships between inputs and - // first output of expr + // Fills disjoint_ids_[IdMappingMode::EXACT] for relationships between inputs + // and first output of expr void mapExact(Expr* expr); - // Fills nodes_[IdMappingMode::PERMISSIVE] for relationships between inputs - // and first output of expr + // Fills disjoint_ids_[IdMappingMode::PERMISSIVE] for relationships between + // inputs and first output of expr // - // Currently also fills nodes_[IdMappingMode::LOOP], consumer_, and producer_ + // Currently also fills disjoint_ids_[IdMappingMode::LOOP], consumer_, and + // producer_ void mapPermissiveAndLoop(Expr* expr); // Propagates forward then backward through all view like rfactor @@ -139,12 +141,12 @@ class TORCH_CUDA_CU_API IterDomainGraph { // ======= END Iteration domain build process in order called ======= - // Non-const internal only version of getNodes. - DisjointSets& nodes(IdMappingMode mode); + // Non-const internal only version of getDisjointIdsSet. + DisjointSets& disjointIdsSet(IdMappingMode mode); // Simple alias - void mapNodes(IterDomain* id0, IterDomain* id1, IdMappingMode mode) { - nodes(mode).mapEntries(id0, id1); + void mapIds(IterDomain* id0, IterDomain* id1, IdMappingMode mode) { + disjointIdsSet(mode).mapEntries(id0, id1); } // Checks if expr's are considered "the same" where sameness inputs and @@ -166,7 +168,7 @@ class TORCH_CUDA_CU_API IterDomainGraph { // Using an array here might be nice, but it seems hard to use an enum as an // array key // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum - std::unordered_map> nodes_; + std::unordered_map> disjoint_ids_; // Consumers and producers is not symmetric like the other sets // TODO: Generalize to mapping type. Mappings between producer TV ids and @@ -228,7 +230,7 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Simple alias to IdGraph mappings. bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const { - return idGraph().getNodes(mode).strictAreMapped(id0, id1); + return idGraph().getDisjointIdsSet(mode).strictAreMapped(id0, id1); } //! Returns an iter domain that is the maximum expanded size of all iter //! domains the one provided maps to. Useful for opening loops to the correct diff --git a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp index 2a8c724ff32b..991ecb72daef 100644 --- a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp @@ -91,7 +91,7 @@ std::unordered_set getAllDivisibleSplits( auto original_view_split = entry.second; const auto& exact_mapped_ids = ca_map->idGraph() - .getNodes(IdMappingMode::EXACT) + .getDisjointIdsSet(IdMappingMode::EXACT) .getDisjointSetOf(concrete_id) .vector(); for (auto other_id : exact_mapped_ids) { diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 9b231d835ede..2ab45c289da1 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -1274,7 +1274,7 @@ bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector& ids) { GpuLower::current() ->caMap() ->idGraph() - .getNodes(IdMappingMode::PERMISSIVE) + .getDisjointIdsSet(IdMappingMode::PERMISSIVE) .permissiveAreMapped(id, val->as()); }); } diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index b6bde3193950..8fdbeea6a2c2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -157,7 +157,8 @@ void HaloInfo::setRootAxisInfo( HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr ca_map) // Make a copy of the permissive map for extent comparators - : permissive_map_(ca_map->idGraph().getNodes(IdMappingMode::PERMISSIVE)) { + : permissive_map_( + ca_map->idGraph().getDisjointIdsSet(IdMappingMode::PERMISSIVE)) { const auto vals = fusion->usedMathVals(); auto tvs = ir_utils::filterByType(vals); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 10efc4f63605..b8a2234ddfcf 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -489,7 +489,9 @@ bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map) { // Mark those as an active use of the rfactor, if two are detected, return // true. for (const auto& disjoint_set_shared_ptr : - ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) { + ca_map.idGraph() + .getDisjointIdsSet(IdMappingMode::EXACT) + .disjointSets()) { // Make sure there's at least one rfactor domain in the set, otherwise we // don't need to check anything from this set. if (!std::any_of( diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp index 14d30f9b300d..21523880e6a5 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -51,7 +51,7 @@ class DomainMap : public pointwise_utils::DomainMap { IterDomain* mapped_id = nullptr; for (auto i : c10::irange(root_dom.size())) { if (ca_map_.idGraph() - .getNodes(IdMappingMode::EXACT) + .getDisjointIdsSet(IdMappingMode::EXACT) .permissiveAreMapped(root_dom[i], root_dim)) { mapped_id = root_dom[i]; break; diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index ffaf31ac5126..32735070a37f 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -2100,7 +2100,7 @@ void BoundedDirectionalTransformPropagator::bothWays( DisjointSets disjointViewSets(Fusion* fusion) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(fusion); - auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT); + auto disjoint_view_ids = id_graph.getDisjointIdsSet(IdMappingMode::EXACT); // If iter domains are involved in any transformation from root domains to // rfactor domains they should be considered "contaminated". @@ -2240,7 +2240,9 @@ void propagateViewTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { std::unordered_set terminating_rfactor_dims; for (const auto& disjoint_set_shared_ptr : - ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) { + ca_map.idGraph() + .getDisjointIdsSet(IdMappingMode::EXACT) + .disjointSets()) { if (std::none_of( disjoint_set_shared_ptr->vector().begin(), disjoint_set_shared_ptr->vector().end(), diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp index f21cf9f7c136..b74099715af1 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp @@ -150,7 +150,7 @@ Val* commonOrConstExtent( std::shared_ptr ca_map, IterDomain* id) { auto disjoint_set = ca_map->idGraph() - .getNodes(IdMappingMode::ALMOSTEXACT) + .getDisjointIdsSet(IdMappingMode::ALMOSTEXACT) .getDisjointSetOf(id); for (auto entry : disjoint_set) { if (entry->extent()->isConstScalar()) { diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp index 545ca501a7e2..3ebe83bd8e05 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp @@ -1210,21 +1210,21 @@ TEST_F(NVFuserTest, FusionViewIdGraph_CUDA) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(&fusion); - auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT); + auto disjoint_view_ids = id_graph.getDisjointIdsSet(IdMappingMode::EXACT); - TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT) + TORCH_CHECK(id_graph.getDisjointIdsSet(IdMappingMode::EXACT) .strictAreMapped(tv2->axis(1), tv4->axis(1))); - TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT) + TORCH_CHECK(id_graph.getDisjointIdsSet(IdMappingMode::EXACT) .strictAreMapped(tv2->axis(2), tv4->axis(2))); TORCH_CHECK( - id_graph.getNodes(IdMappingMode::EXACT) + id_graph.getDisjointIdsSet(IdMappingMode::EXACT) .strictAreMapped(tv2->getRootDomain()[1], tv12->getRootDomain()[1])); TORCH_CHECK( - id_graph.getNodes(IdMappingMode::EXACT) + id_graph.getDisjointIdsSet(IdMappingMode::EXACT) .strictAreMapped(tv2->getRootDomain()[2], tv12->getRootDomain()[2])); TORCH_CHECK( - id_graph.getNodes(IdMappingMode::EXACT) + id_graph.getDisjointIdsSet(IdMappingMode::EXACT) .strictAreMapped(tv2->getRootDomain()[3], tv12->getRootDomain()[3])); }