From 8327f821fb56dc7e04e2f2fab0fa72857d1dc175 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 3 Dec 2022 11:25:03 -0500 Subject: [PATCH 1/4] Start reducing down IdGraph. --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 102 ++++++++++-------- torch/csrc/jit/codegen/cuda/compute_at_map.h | 27 ++--- .../codegen/cuda/lower_divisible_split.cpp | 8 +- torch/csrc/jit/codegen/cuda/lower_shift.cpp | 2 +- .../jit/codegen/cuda/scheduler/registry.cpp | 2 +- .../jit/codegen/cuda/scheduler/transpose.cpp | 5 +- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 4 +- .../cuda/scheduler/vectorize_helper.cpp | 4 +- .../jit/codegen/cuda/test/test_gpu_view.cpp | 23 ++-- 9 files changed, 100 insertions(+), 77 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index bde91fda24f8..fb368b01a27c 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -56,6 +56,35 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { } } +const DisjointSets& IterDomainGraph::getNodes( + IdMappingMode mode) const { + switch (mode) { + case IdMappingMode::EXACT: + return exact_nodes_; + case IdMappingMode::ALMOSTEXACT: + return almost_exact_nodes_; + case IdMappingMode::LOOP: + return loop_nodes_; + case IdMappingMode::PERMISSIVE: + return permissive_nodes_; + } + TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); +} + +DisjointSets& IterDomainGraph::nodes(IdMappingMode mode) { + switch (mode) { + case IdMappingMode::EXACT: + return exact_nodes_; + case IdMappingMode::ALMOSTEXACT: + return almost_exact_nodes_; + case IdMappingMode::LOOP: + return loop_nodes_; + case IdMappingMode::PERMISSIVE: + return permissive_nodes_; + } + TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); +} + //! Map corresponding inputs and outputs of swizzle op together //! on the given disjoint set, if the given id is an output //! of a swizzle operator. @@ -206,7 +235,8 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { second->toString()); for (auto out_i : c10::irange(first_ids.size())) { exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); - permissive_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); + nodes(IdMappingMode::PERMISSIVE) + .mapEntries(first_ids[out_i], second_ids[out_i]); } } @@ -252,20 +282,8 @@ c10::optional> detectMappablePair( if (id1 == id2) { continue; } - if (mode == IdMappingMode::EXACT) { - if (id_graph.exactNodes().disjointSetMap().at(id1)->has(id2)) { - return std::make_pair(id1, id2); - } - } else if (mode == IdMappingMode::PERMISSIVE) { - if (id_graph.permissiveNodes().disjointSetMap().at(id1)->has(id2)) { - return std::make_pair(id1, id2); - } - } else if (mode == IdMappingMode::LOOP) { - if (id_graph.loopNodes().disjointSetMap().at(id1)->has(id2)) { - return std::make_pair(id1, id2); - } - } else { - TORCH_INTERNAL_ASSERT(false, "Unrecognized IdMappingMode mode."); + if (id_graph.getNodes(mode).disjointSetMap().at(id1)->has(id2)) { + return std::make_pair(id1, id2); } } } @@ -411,7 +429,7 @@ void IterDomainGraph::build(Fusion* fusion) { } auto id0 = *disjoint_set->begin(); for (auto id1 : disjoint_set->vector()) { - permissive_nodes_.mapEntries(id0, id1); + nodes(IdMappingMode::PERMISSIVE).mapEntries(id0, id1); exact_nodes_.mapEntries(id0, id1); sibling_sets_.mapEntries(id0, id1); } @@ -478,12 +496,12 @@ void IterDomainGraph::build(Fusion* fusion) { auto& vec = dset->vector(); for (auto i : c10::irange(vec.size())) { auto id1 = vec[i]; - permissive_nodes_.mapEntries(id1, vec[0]); + nodes(IdMappingMode::PERMISSIVE).mapEntries(id1, vec[0]); // Add the swizzle inputs to the same // disjoint set as well if either c_id // or p_id is swizzle output. - mapMaybeSwizzleOp(permissive_nodes_, id1); + mapMaybeSwizzleOp(nodes(IdMappingMode::PERMISSIVE), id1); for (auto j : c10::irange(i + 1, vec.size())) { auto id2 = vec[j]; @@ -692,10 +710,10 @@ void IterDomainGraph::initializeId( IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id) { - permissive_nodes_.initializeSet(id); - exact_nodes_.initializeSet(id); + nodes(IdMappingMode::PERMISSIVE).initializeSet(id); + nodes(IdMappingMode::EXACT).initializeSet(id); if (is_leaf_id) { - loop_nodes_.initializeSet(id); + nodes(IdMappingMode::LOOP).initializeSet(id); } consumers_[id] = {}; producers_[id] = {}; @@ -720,7 +738,8 @@ void ComputeAtMap::build(Fusion* fusion) { } void ComputeAtMap::validateAndPropagatePType() { - for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { + for (const auto& loop_disjoint_set : + id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { ParallelType common_ptype = ParallelType::Serial; for (auto id : loop_disjoint_set->vector()) { auto id_ptype = id->getParallelType(); @@ -745,7 +764,8 @@ void ComputeAtMap::allocateIndexVariables() { // Run through all disjoint sets registered in loop map, // all lowered kir::ForLoop will correspond to one of the disjoint sets // and we only need one index variable for each set. - for (const auto& loop_disjoint_set : id_graph_.loopNodes().disjointSets()) { + for (const auto& loop_disjoint_set : + id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { ParallelType ptype; // first allocate thread and grid parallel indices: // The validation pass will check that the parallel bindings within the @@ -813,11 +833,12 @@ Val* ComputeAtMap::getIndexVariable( IterDomain* id, DoubleBufferLoopStage double_buffer_loop_stage) const { TORCH_INTERNAL_ASSERT( - id_graph_.loopNodes().mappingExists(id), + id_graph_.getNodes(IdMappingMode::LOOP).mappingExists(id), "Index Variable: no index variable allocated as ", id->toString(), " is not registered in loop map"); - const auto* loop_set = &(id_graph_.loopNodes().getDisjointSetOf(id)); + const auto* loop_set = + &(id_graph_.getNodes(IdMappingMode::LOOP).getDisjointSetOf(id)); // Check if this loop was modified by double buffer pass. bool is_double_buffer_iterdomain = @@ -1092,7 +1113,7 @@ void ComputeAtMap::buildConcreteIds() { // deterministic but which ID gets selected her depends on the traversal order // generating the set (compute at map build). for (const auto& disjoint_set_shared_ptr : - id_graph_.exactNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1103,7 +1124,7 @@ void ComputeAtMap::buildConcreteIds() { // The following two algorithms seem quite wasteful. Should find a more // efficient way to compute concrete IDs. for (const auto& disjoint_set_shared_ptr : - id_graph_.permissiveNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::PERMISSIVE).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1114,7 +1135,7 @@ void ComputeAtMap::buildConcreteIds() { // Same as exact computation for (const auto& disjoint_set_shared_ptr : - id_graph_.almostExactNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::ALMOSTEXACT).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1124,7 +1145,7 @@ void ComputeAtMap::buildConcreteIds() { } for (const auto& disjoint_set_shared_ptr : - id_graph_.loopNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); @@ -1182,7 +1203,7 @@ bool ComputeAtMap::areExactExprs(Expr* expr_1, Expr* expr_2) { void ComputeAtMap::buildUniqueExactExprMaps() { // Start by building definitions for (const auto& disjoint_set_shared_ptr : - id_graph_.exactNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { std::vector definitions; // N^2 in number of unique transformations, this might be better to do @@ -1228,7 +1249,7 @@ void ComputeAtMap::buildUniqueExactExprMaps() { // Use definitions to build uses for (const auto& disjoint_set_shared_ptr : - id_graph_.exactNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::EXACT).disjointSets()) { // Make sure uses is always initialized even there are no uses. if (unique_exact_uses_.find(disjoint_set_shared_ptr) == unique_exact_uses_.end()) { @@ -1405,17 +1426,7 @@ const std::shared_ptr>& ComputeAtMap:: const DisjointSets& ComputeAtMap::getIdSets( IdMappingMode mode) const { - switch (mode) { - case IdMappingMode::EXACT: - return id_graph_.exactNodes(); - case IdMappingMode::ALMOSTEXACT: - return id_graph_.almostExactNodes(); - case IdMappingMode::LOOP: - return id_graph_.loopNodes(); - case IdMappingMode::PERMISSIVE: - return id_graph_.permissiveNodes(); - } - TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); + return id_graph_.getNodes(mode); } bool ComputeAtMap::idExistsInMap(IterDomain* id) const { @@ -1598,7 +1609,10 @@ void IterDomainGraph::updateComputeWith(TensorView* compute_with_tv) { consumer_tv->domain()->domain().begin(), consumer_tv->domain()->domain().end(), [&](auto consumer_id) { - return permissiveNodes().disjointSetMap().at(id)->has(consumer_id); + return getNodes(IdMappingMode::PERMISSIVE) + .disjointSetMap() + .at(id) + ->has(consumer_id); }); TORCH_INTERNAL_ASSERT( it != consumer_tv->domain()->domain().end(), @@ -1623,7 +1637,7 @@ void ComputeAtMap::updateComputeWith(TensorView* compute_with_tv) { // Update the LOOP concrete IDs for (const auto& disjoint_set_shared_ptr : - id_graph_.loopNodes().disjointSets()) { + id_graph_.getNodes(IdMappingMode::LOOP).disjointSets()) { TORCH_INTERNAL_ASSERT( disjoint_set_shared_ptr->vector().size(), "Cannot compute concrete id of empty set."); diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index bfa40b422b21..b98e64cdd527 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -63,24 +63,14 @@ class TORCH_CUDA_CU_API IterDomainGraph { public: IterDomainGraph(Fusion* fusion, bool allow_self_mapping = false); - const DisjointSets& permissiveNodes() const { - return permissive_nodes_; - } - const DisjointSets& exactNodes() const { - return exact_nodes_; - } - const DisjointSets& almostExactNodes() const { - return almost_exact_nodes_; - } - const DisjointSets& loopNodes() const { - return loop_nodes_; - } - + // Returns the disjoint set according to one of the mapping mode types. + const DisjointSets& getNodes(IdMappingMode mode) const; // Consumers and producers is not symmetric like the other sets const std::unordered_map>& consumers() const { return consumers_; } + const std::unordered_map>& producers() const { return producers_; @@ -118,12 +108,23 @@ class TORCH_CUDA_CU_API IterDomainGraph { private: void build(Fusion* fusion); + // Non-const internal only version of getNodes. + DisjointSets& nodes(IdMappingMode mode); + void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); // Checks if exprsMap then if forward will map outputs else inputs in exact // and permissive map. void mapThroughExpr(Expr* first, Expr* second, bool forward); + // Using an array here might be nice, but it seems hard to use an enum as an + // array key + // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum + // + // Keeps a disjoint set entry for all IterDomain mapping mode types. + // TODO: + // std::unordered_map > nodes_; + DisjointSets permissive_nodes_; DisjointSets exact_nodes_; DisjointSets almost_exact_nodes_; diff --git a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp index cb3eaacd372c..d7bed02d0694 100644 --- a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp @@ -90,8 +90,10 @@ std::unordered_set getAllDivisibleSplits( auto concrete_id = entry.first; auto original_view_split = entry.second; - const auto& exact_mapped_ids = - ca_map->idGraph().exactNodes().getDisjointSetOf(concrete_id).vector(); + const auto& exact_mapped_ids = ca_map->idGraph() + .getNodes(IdMappingMode::EXACT) + .getDisjointSetOf(concrete_id) + .vector(); for (auto other_id : exact_mapped_ids) { if (other_id->definition() == nullptr) { continue; @@ -106,7 +108,7 @@ std::unordered_set getAllDivisibleSplits( original_view_split, other_id->definition(), false, - ca_map->idGraph().exactNodes())) { + ca_map->idGraph().getNodes(IdMappingMode::EXACT))) { all_divisible_splits.emplace(other_id->definition()->as()); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_shift.cpp b/torch/csrc/jit/codegen/cuda/lower_shift.cpp index 4b319ad59a12..b6bde3193950 100644 --- a/torch/csrc/jit/codegen/cuda/lower_shift.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_shift.cpp @@ -157,7 +157,7 @@ void HaloInfo::setRootAxisInfo( HaloInfo::HaloInfo(Fusion* fusion, std::shared_ptr ca_map) // Make a copy of the permissive map for extent comparators - : permissive_map_(ca_map->idGraph().permissiveNodes()) { + : permissive_map_(ca_map->idGraph().getNodes(IdMappingMode::PERMISSIVE)) { const auto vals = fusion->usedMathVals(); auto tvs = ir_utils::filterByType(vals); diff --git a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp index 7f3e7be31e0f..10efc4f63605 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp @@ -489,7 +489,7 @@ bool requiresForwardViewReplay(Fusion* fusion, ComputeAtMap& ca_map) { // Mark those as an active use of the rfactor, if two are detected, return // true. for (const auto& disjoint_set_shared_ptr : - ca_map.idGraph().exactNodes().disjointSets()) { + ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) { // Make sure there's at least one rfactor domain in the set, otherwise we // don't need to check anything from this set. if (!std::any_of( diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp index 8a8de7772d02..14d30f9b300d 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -50,8 +50,9 @@ class DomainMap : public pointwise_utils::DomainMap { const auto& root_dom = tv->getRootDomain(); IterDomain* mapped_id = nullptr; for (auto i : c10::irange(root_dom.size())) { - if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped( - root_dom[i], root_dim)) { + if (ca_map_.idGraph() + .getNodes(IdMappingMode::EXACT) + .permissiveAreMapped(root_dom[i], root_dim)) { mapped_id = root_dom[i]; break; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index 1c39810e29fe..ffaf31ac5126 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -2100,7 +2100,7 @@ void BoundedDirectionalTransformPropagator::bothWays( DisjointSets disjointViewSets(Fusion* fusion) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(fusion); - auto disjoint_view_ids = id_graph.exactNodes(); + auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT); // If iter domains are involved in any transformation from root domains to // rfactor domains they should be considered "contaminated". @@ -2240,7 +2240,7 @@ void propagateViewTransforms(Fusion* fusion, const ComputeAtMap& ca_map) { std::unordered_set terminating_rfactor_dims; for (const auto& disjoint_set_shared_ptr : - ca_map.idGraph().exactNodes().disjointSets()) { + ca_map.idGraph().getNodes(IdMappingMode::EXACT).disjointSets()) { if (std::none_of( disjoint_set_shared_ptr->vector().begin(), disjoint_set_shared_ptr->vector().end(), diff --git a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp index 33b9b7aac3ed..f21cf9f7c136 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp @@ -149,7 +149,9 @@ namespace { Val* commonOrConstExtent( std::shared_ptr ca_map, IterDomain* id) { - auto disjoint_set = ca_map->idGraph().almostExactNodes().getDisjointSetOf(id); + auto disjoint_set = ca_map->idGraph() + .getNodes(IdMappingMode::ALMOSTEXACT) + .getDisjointSetOf(id); for (auto entry : disjoint_set) { if (entry->extent()->isConstScalar()) { return entry->extent(); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp index 01276f744022..545ca501a7e2 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_view.cpp @@ -1210,19 +1210,22 @@ TEST_F(NVFuserTest, FusionViewIdGraph_CUDA) { // Start from the exact iter domain graph of the fusion IterDomainGraph id_graph(&fusion); - auto disjoint_view_ids = id_graph.exactNodes(); + auto disjoint_view_ids = id_graph.getNodes(IdMappingMode::EXACT); + TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT) + .strictAreMapped(tv2->axis(1), tv4->axis(1))); + TORCH_CHECK(id_graph.getNodes(IdMappingMode::EXACT) + .strictAreMapped(tv2->axis(2), tv4->axis(2))); + + TORCH_CHECK( + id_graph.getNodes(IdMappingMode::EXACT) + .strictAreMapped(tv2->getRootDomain()[1], tv12->getRootDomain()[1])); TORCH_CHECK( - id_graph.exactNodes().strictAreMapped(tv2->axis(1), tv4->axis(1))); + id_graph.getNodes(IdMappingMode::EXACT) + .strictAreMapped(tv2->getRootDomain()[2], tv12->getRootDomain()[2])); TORCH_CHECK( - id_graph.exactNodes().strictAreMapped(tv2->axis(2), tv4->axis(2))); - - TORCH_CHECK(id_graph.exactNodes().strictAreMapped( - tv2->getRootDomain()[1], tv12->getRootDomain()[1])); - TORCH_CHECK(id_graph.exactNodes().strictAreMapped( - tv2->getRootDomain()[2], tv12->getRootDomain()[2])); - TORCH_CHECK(id_graph.exactNodes().strictAreMapped( - tv2->getRootDomain()[3], tv12->getRootDomain()[3])); + id_graph.getNodes(IdMappingMode::EXACT) + .strictAreMapped(tv2->getRootDomain()[3], tv12->getRootDomain()[3])); } TEST_F(NVFuserTest, FusionViewVectorize_CUDA) { From 3b5a7cecbd60391bc60901a23e7d77f0fbd320b2 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 3 Dec 2022 12:15:52 -0500 Subject: [PATCH 2/4] Join the different sets into one structure based on MappingMode. --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 82 ++++++++++--------- torch/csrc/jit/codegen/cuda/compute_at_map.h | 13 +-- 2 files changed, 46 insertions(+), 49 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index fb368b01a27c..3cb9aa3cba68 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -39,6 +39,18 @@ bool idIsALeafDomain(IterDomain* id, TensorView* tv) { } // namespace IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { + // Initialize the required sets as if a permissive relationship is never + // found, then querying an empty permissive map will fail later. + std::vector mapping_types{ + IdMappingMode::EXACT, + IdMappingMode::ALMOSTEXACT, + IdMappingMode::PERMISSIVE, + IdMappingMode::LOOP}; + + for (auto mode : mapping_types) { + nodes_[mode] = DisjointSets(); + } + build(fusion); if (!allow_self_mapping) { @@ -58,31 +70,17 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion, bool allow_self_mapping) { const DisjointSets& IterDomainGraph::getNodes( IdMappingMode mode) const { - switch (mode) { - case IdMappingMode::EXACT: - return exact_nodes_; - case IdMappingMode::ALMOSTEXACT: - return almost_exact_nodes_; - case IdMappingMode::LOOP: - return loop_nodes_; - case IdMappingMode::PERMISSIVE: - return permissive_nodes_; - } - TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); + auto node_set_it = nodes_.find(mode); + TORCH_INTERNAL_ASSERT( + node_set_it != nodes_.end(), "Mapping mode ", mode, " not supported."); + return node_set_it->second; } DisjointSets& IterDomainGraph::nodes(IdMappingMode mode) { - switch (mode) { - case IdMappingMode::EXACT: - return exact_nodes_; - case IdMappingMode::ALMOSTEXACT: - return almost_exact_nodes_; - case IdMappingMode::LOOP: - return loop_nodes_; - case IdMappingMode::PERMISSIVE: - return permissive_nodes_; - } - TORCH_INTERNAL_ASSERT(false, "Error with mapping mode provided."); + auto node_set_it = nodes_.find(mode); + TORCH_INTERNAL_ASSERT( + node_set_it != nodes_.end(), "Mapping mode ", mode, " not supported."); + return node_set_it->second; } //! Map corresponding inputs and outputs of swizzle op together @@ -217,7 +215,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return; } - if (!exprsMap(first, second, forward, exact_nodes_)) { + if (!exprsMap(first, second, forward, nodes(IdMappingMode::EXACT))) { return; } @@ -234,7 +232,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { "\nand\n", second->toString()); for (auto out_i : c10::irange(first_ids.size())) { - exact_nodes_.mapEntries(first_ids[out_i], second_ids[out_i]); + nodes(IdMappingMode::EXACT).mapEntries(first_ids[out_i], second_ids[out_i]); nodes(IdMappingMode::PERMISSIVE) .mapEntries(first_ids[out_i], second_ids[out_i]); } @@ -430,7 +428,7 @@ void IterDomainGraph::build(Fusion* fusion) { auto id0 = *disjoint_set->begin(); for (auto id1 : disjoint_set->vector()) { nodes(IdMappingMode::PERMISSIVE).mapEntries(id0, id1); - exact_nodes_.mapEntries(id0, id1); + nodes(IdMappingMode::EXACT).mapEntries(id0, id1); sibling_sets_.mapEntries(id0, id1); } } @@ -440,7 +438,7 @@ void IterDomainGraph::build(Fusion* fusion) { auto disjoint_set = c2f_disjoint_sets.getDisjointSetOf(f_id); auto id0 = *(disjoint_set.begin()); for (auto id1 : disjoint_set) { - loop_nodes_.mapEntries(id0, id1); + nodes(IdMappingMode::LOOP).mapEntries(id0, id1); } } } @@ -474,15 +472,15 @@ void IterDomainGraph::build(Fusion* fusion) { for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) { auto p_id = exact_c2p_map.at(c_id); - exact_nodes_.mapEntries(c_id, p_id); + nodes(IdMappingMode::EXACT).mapEntries(c_id, p_id); consumers_.at(p_id).pushBack(c_id); producers_.at(c_id).pushBack(p_id); // Add the swizzle inputs to the same // disjoint set as well if either c_id // or p_id is swizzle output. - mapMaybeSwizzleOp(exact_nodes_, p_id); - mapMaybeSwizzleOp(exact_nodes_, c_id); + mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), p_id); + mapMaybeSwizzleOp(nodes(IdMappingMode::EXACT), c_id); } auto p_ids_vec = ir_utils::allIDsOf(p_tv); @@ -510,7 +508,7 @@ void IterDomainGraph::build(Fusion* fusion) { producers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) && idIsALeafDomain(id2, c_tv)) { - loop_nodes_.mapEntries(id1, id2); + nodes(IdMappingMode::LOOP).mapEntries(id1, id2); } } if (c_ids.count(id1) && p_ids.count(id2)) { @@ -518,7 +516,7 @@ void IterDomainGraph::build(Fusion* fusion) { consumers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) && idIsALeafDomain(id1, c_tv)) { - loop_nodes_.mapEntries(id1, id2); + nodes(IdMappingMode::LOOP).mapEntries(id1, id2); } } } @@ -637,7 +635,8 @@ void IterDomainGraph::build(Fusion* fusion) { // Only need to be concerned here with mapping across rfactor iter // domains, so isolate out those. - auto all_exact_map_ids = exact_nodes_.getDisjointSetOf(first_rfactor_id); + auto all_exact_map_ids = + nodes(IdMappingMode::EXACT).getDisjointSetOf(first_rfactor_id); std::vector exact_map_rf_ids; std::copy_if( all_exact_map_ids.vector().begin(), @@ -673,9 +672,9 @@ void IterDomainGraph::build(Fusion* fusion) { } // Build almost exact map by forwarding through broadcast axes - almost_exact_nodes_ = exact_nodes_; + nodes(IdMappingMode::ALMOSTEXACT) = nodes(IdMappingMode::EXACT); std::unordered_set visited; - auto all_elements = exact_nodes_.getAllElements(); + auto all_elements = nodes(IdMappingMode::EXACT).getAllElements(); for (auto entry : all_elements.vector()) { if (entry->definition() == nullptr) { continue; @@ -686,18 +685,22 @@ void IterDomainGraph::build(Fusion* fusion) { } if (auto merge = dynamic_cast(def)) { if (merge->inner()->extent()->isOneInt()) { - almost_exact_nodes_.mapEntries(merge->outer(), merge->out()); + nodes(IdMappingMode::ALMOSTEXACT) + .mapEntries(merge->outer(), merge->out()); } if (merge->outer()->extent()->isOneInt()) { - almost_exact_nodes_.mapEntries(merge->inner(), merge->out()); + nodes(IdMappingMode::ALMOSTEXACT) + .mapEntries(merge->inner(), merge->out()); } } else if (auto split = dynamic_cast(def)) { if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && split->stopOffset()->isZeroInt()) { if (split->innerSplit()) { - almost_exact_nodes_.mapEntries(split->in(), split->outer()); + nodes(IdMappingMode::ALMOSTEXACT) + .mapEntries(split->in(), split->outer()); } else { - almost_exact_nodes_.mapEntries(split->in(), split->inner()); + nodes(IdMappingMode::ALMOSTEXACT) + .mapEntries(split->in(), split->inner()); } } } @@ -734,7 +737,6 @@ ComputeAtMap::ComputeAtMap(Fusion* fusion) void ComputeAtMap::build(Fusion* fusion) { buildUniqueExactExprMaps(); buildConcreteIds(); - buildUniqueExactExprMaps(); } void ComputeAtMap::validateAndPropagatePType() { @@ -1623,7 +1625,7 @@ void IterDomainGraph::updateComputeWith(TensorView* compute_with_tv) { IterDomain* consumer_id = *it; - loop_nodes_.mapEntries(id, consumer_id); + nodes(IdMappingMode::LOOP).mapEntries(id, consumer_id); } } diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index b98e64cdd527..d94aa3de93b8 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -117,18 +117,12 @@ class TORCH_CUDA_CU_API IterDomainGraph { // and permissive map. void mapThroughExpr(Expr* first, Expr* second, bool forward); + // Keeps a disjoint set entry for all IterDomain mapping mode types. + // // Using an array here might be nice, but it seems hard to use an enum as an // array key // https://stackoverflow.com/questions/2102582/how-can-i-count-the-items-in-an-enum - // - // Keeps a disjoint set entry for all IterDomain mapping mode types. - // TODO: - // std::unordered_map > nodes_; - - DisjointSets permissive_nodes_; - DisjointSets exact_nodes_; - DisjointSets almost_exact_nodes_; - DisjointSets loop_nodes_; + std::unordered_map> nodes_; // Consumers and producers is not symmetric like the other sets std::unordered_map> @@ -142,6 +136,7 @@ class TORCH_CUDA_CU_API IterDomainGraph { std::unordered_set view_rfactor_ids_; + // Debug information to hold if a self mapping in a TensorView is found. c10::optional> self_mapping_info_ = c10::nullopt; }; From d8e5512485464d91d2a180976f79ffce59b33ddf Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 3 Dec 2022 12:37:26 -0500 Subject: [PATCH 3/4] Small alias. --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 33 ++++++++----------- torch/csrc/jit/codegen/cuda/compute_at_map.h | 5 +++ 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 3cb9aa3cba68..f92f0971b50b 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -232,9 +232,8 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { "\nand\n", second->toString()); for (auto out_i : c10::irange(first_ids.size())) { - nodes(IdMappingMode::EXACT).mapEntries(first_ids[out_i], second_ids[out_i]); - nodes(IdMappingMode::PERMISSIVE) - .mapEntries(first_ids[out_i], second_ids[out_i]); + mapNodes(first_ids[out_i], second_ids[out_i], IdMappingMode::EXACT); + mapNodes(first_ids[out_i], second_ids[out_i], IdMappingMode::PERMISSIVE); } } @@ -427,8 +426,8 @@ void IterDomainGraph::build(Fusion* fusion) { } auto id0 = *disjoint_set->begin(); for (auto id1 : disjoint_set->vector()) { - nodes(IdMappingMode::PERMISSIVE).mapEntries(id0, id1); - nodes(IdMappingMode::EXACT).mapEntries(id0, id1); + mapNodes(id0, id1, IdMappingMode::PERMISSIVE); + mapNodes(id0, id1, IdMappingMode::EXACT); sibling_sets_.mapEntries(id0, id1); } } @@ -438,7 +437,7 @@ void IterDomainGraph::build(Fusion* fusion) { auto disjoint_set = c2f_disjoint_sets.getDisjointSetOf(f_id); auto id0 = *(disjoint_set.begin()); for (auto id1 : disjoint_set) { - nodes(IdMappingMode::LOOP).mapEntries(id0, id1); + mapNodes(id0, id1, IdMappingMode::LOOP); } } } @@ -472,7 +471,7 @@ void IterDomainGraph::build(Fusion* fusion) { for (auto c_id : getSortedKeys(exact_c2p_map, Statement::lessThan)) { auto p_id = exact_c2p_map.at(c_id); - nodes(IdMappingMode::EXACT).mapEntries(c_id, p_id); + mapNodes(c_id, p_id, IdMappingMode::EXACT); consumers_.at(p_id).pushBack(c_id); producers_.at(c_id).pushBack(p_id); @@ -494,7 +493,7 @@ void IterDomainGraph::build(Fusion* fusion) { auto& vec = dset->vector(); for (auto i : c10::irange(vec.size())) { auto id1 = vec[i]; - nodes(IdMappingMode::PERMISSIVE).mapEntries(id1, vec[0]); + mapNodes(id1, vec[0], IdMappingMode::PERMISSIVE); // Add the swizzle inputs to the same // disjoint set as well if either c_id @@ -508,7 +507,7 @@ void IterDomainGraph::build(Fusion* fusion) { producers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id1, p_tv, c_tv) && idIsALeafDomain(id2, c_tv)) { - nodes(IdMappingMode::LOOP).mapEntries(id1, id2); + mapNodes(id1, id2, IdMappingMode::LOOP); } } if (c_ids.count(id1) && p_ids.count(id2)) { @@ -516,7 +515,7 @@ void IterDomainGraph::build(Fusion* fusion) { consumers_.at(id2).pushBack(id1); if (idIsAComputeAtLeafDomain(id2, p_tv, c_tv) && idIsALeafDomain(id1, c_tv)) { - nodes(IdMappingMode::LOOP).mapEntries(id1, id2); + mapNodes(id1, id2, IdMappingMode::LOOP); } } } @@ -685,22 +684,18 @@ void IterDomainGraph::build(Fusion* fusion) { } if (auto merge = dynamic_cast(def)) { if (merge->inner()->extent()->isOneInt()) { - nodes(IdMappingMode::ALMOSTEXACT) - .mapEntries(merge->outer(), merge->out()); + mapNodes(merge->outer(), merge->out(), IdMappingMode::ALMOSTEXACT); } if (merge->outer()->extent()->isOneInt()) { - nodes(IdMappingMode::ALMOSTEXACT) - .mapEntries(merge->inner(), merge->out()); + mapNodes(merge->inner(), merge->out(), IdMappingMode::ALMOSTEXACT); } } else if (auto split = dynamic_cast(def)) { if (split->factor()->isOneInt() && split->startOffset()->isZeroInt() && split->stopOffset()->isZeroInt()) { if (split->innerSplit()) { - nodes(IdMappingMode::ALMOSTEXACT) - .mapEntries(split->in(), split->outer()); + mapNodes(split->in(), split->outer(), IdMappingMode::ALMOSTEXACT); } else { - nodes(IdMappingMode::ALMOSTEXACT) - .mapEntries(split->in(), split->inner()); + mapNodes(split->in(), split->inner(), IdMappingMode::ALMOSTEXACT); } } } @@ -1625,7 +1620,7 @@ void IterDomainGraph::updateComputeWith(TensorView* compute_with_tv) { IterDomain* consumer_id = *it; - nodes(IdMappingMode::LOOP).mapEntries(id, consumer_id); + mapNodes(id, consumer_id, IdMappingMode::LOOP); } } diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index d94aa3de93b8..35719de5ec5f 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -111,6 +111,11 @@ class TORCH_CUDA_CU_API IterDomainGraph { // Non-const internal only version of getNodes. DisjointSets& nodes(IdMappingMode mode); + // Small alias + void mapNodes(IterDomain* id0, IterDomain* id1, IdMappingMode mode) { + nodes(mode).mapEntries(id0, id1); + } + void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); // Checks if exprsMap then if forward will map outputs else inputs in exact From d743fd277982f1dc6207b64c89978671201ea3fa Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sat, 3 Dec 2022 14:59:33 -0500 Subject: [PATCH 4/4] More minor refactoring. --- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 34 +++++------ torch/csrc/jit/codegen/cuda/compute_at_map.h | 57 +++++++++++-------- .../codegen/cuda/lower_divisible_split.cpp | 4 +- .../jit/codegen/cuda/lower_index_compute.cpp | 7 ++- 4 files changed, 54 insertions(+), 48 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index f92f0971b50b..095cbdd05568 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -112,7 +112,7 @@ bool IterDomainGraph::exprsMap( Expr* first, Expr* second, bool forward, - const DisjointSets& id_map) { + IdMappingMode mode) const { if (first == nullptr || second == nullptr) { return false; } @@ -158,7 +158,8 @@ bool IterDomainGraph::exprsMap( zipped_ids.begin(), zipped_ids.end(), [&](std::pair id_pair) { - return !id_map.strictAreMapped(id_pair.first, id_pair.second); + return !getNodes(mode).permissiveAreMapped( + id_pair.first, id_pair.second); })) { return false; } @@ -210,12 +211,16 @@ bool IterDomainGraph::exprsMap( // better, as today it will just check it's the same symbol or evaluated to // the same constant. However, we know all the extents of all the // IterDomain's that exact map with eachother are the same value. -void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { +void IterDomainGraph::mapThroughExpr( + Expr* first, + Expr* second, + bool forward, + IdMappingMode mode) { if (first == nullptr || second == nullptr) { return; } - if (!exprsMap(first, second, forward, nodes(IdMappingMode::EXACT))) { + if (!exprsMap(first, second, forward, mode)) { return; } @@ -232,8 +237,7 @@ void IterDomainGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { "\nand\n", second->toString()); for (auto out_i : c10::irange(first_ids.size())) { - mapNodes(first_ids[out_i], second_ids[out_i], IdMappingMode::EXACT); - mapNodes(first_ids[out_i], second_ids[out_i], IdMappingMode::PERMISSIVE); + mapNodes(first_ids[out_i], second_ids[out_i], mode); } } @@ -428,7 +432,6 @@ void IterDomainGraph::build(Fusion* fusion) { for (auto id1 : disjoint_set->vector()) { mapNodes(id0, id1, IdMappingMode::PERMISSIVE); mapNodes(id0, id1, IdMappingMode::EXACT); - sibling_sets_.mapEntries(id0, id1); } } @@ -665,7 +668,10 @@ void IterDomainGraph::build(Fusion* fusion) { continue; } - mapThroughExpr(first_expr, other_expr, prop_forward); + mapThroughExpr( + first_expr, other_expr, prop_forward, IdMappingMode::EXACT); + mapThroughExpr( + first_expr, other_expr, prop_forward, IdMappingMode::PERMISSIVE); } } } @@ -715,9 +721,6 @@ void IterDomainGraph::initializeId( } consumers_[id] = {}; producers_[id] = {}; - sibling_sets_.initializeSet(id); - - all_ids_.pushBack(id); if (is_view_rfactor_id) { view_rfactor_ids_.emplace(id); @@ -858,13 +861,6 @@ Val* ComputeAtMap::getIndexVariable( } } -bool ComputeAtMap::areMapped( - IterDomain* id0, - IterDomain* id1, - IdMappingMode mode) const { - return disjointSetOf(id0, mode)->has(id1); -} - IterDomain* ComputeAtMap::computeConcreteId( IterDomain* id, IdMappingMode mode) { @@ -1387,8 +1383,6 @@ std::string ComputeAtMap::toString() const { ss << " " << key->toString() << " :: " << producers.toString() << "\n"; } - ss << "Sibling map:\n" << id_graph_.siblings().toString() << "\n"; - ss << "} compute at map" << std::endl; return ss.str(); } diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.h b/torch/csrc/jit/codegen/cuda/compute_at_map.h index 35719de5ec5f..bb6abd0c21b8 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.h +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.h @@ -76,14 +76,7 @@ class TORCH_CUDA_CU_API IterDomainGraph { return producers_; } - const DisjointSets& siblings() const { - return sibling_sets_; - } - - const VectorOfUniqueEntries& allIds() const { - return all_ids_; - } - + // TODO: Seems a bit unfortunate that this isn't IterDomain local information. const std::unordered_set& viewRfactorIds() const { return view_rfactor_ids_; } @@ -92,12 +85,11 @@ class TORCH_CUDA_CU_API IterDomainGraph { // id_map have matching inputs (if forward), or outputs (if not forward). // Returning true means the expressions are "the same", in terms they modify // matching original extents, by the same amount. - static bool exprsMap( - Expr* first, - Expr* second, - bool forward, - const DisjointSets& id_map); + bool exprsMap(Expr* first, Expr* second, bool forward, IdMappingMode mode) + const; + // Returns if a self mapping was detected that would invalidate assumptions of + // the overall lowering system. bool hasSelfMapping() const { return self_mapping_info_.has_value(); } @@ -111,16 +103,26 @@ class TORCH_CUDA_CU_API IterDomainGraph { // Non-const internal only version of getNodes. DisjointSets& nodes(IdMappingMode mode); - // Small alias + // Simple alias void mapNodes(IterDomain* id0, IterDomain* id1, IdMappingMode mode) { nodes(mode).mapEntries(id0, id1); } void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); - // Checks if exprsMap then if forward will map outputs else inputs in exact - // and permissive map. - void mapThroughExpr(Expr* first, Expr* second, bool forward); + // Checks if expr's are considered "the same" where sameness inputs and + // outputs in the same position across expressions map with provided + // MappingMode. If the expressions are determined the same then + // if forward + // will map outputs + // else + // will map inputs + // in the provided mode + void mapThroughExpr( + Expr* first, + Expr* second, + bool forward, + IdMappingMode mode); // Keeps a disjoint set entry for all IterDomain mapping mode types. // @@ -130,15 +132,16 @@ class TORCH_CUDA_CU_API IterDomainGraph { std::unordered_map> nodes_; // Consumers and producers is not symmetric like the other sets + // TODO: Generalize to mapping type. Mappings between producer TV ids and + // consumer TV ids depend on the mapping type. std::unordered_map> consumers_; std::unordered_map> producers_; - DisjointSets sibling_sets_; - - VectorOfUniqueEntries all_ids_; - + // Hold a set of iter domains that are considered view rfactor ids. This + // identification is particularly important to understand if split operations + // are divisible or not. std::unordered_set view_rfactor_ids_; // Debug information to hold if a self mapping in a TensorView is found. @@ -160,6 +163,8 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Run through disjoint sets in the LOOP map, make sure there's only one //! non-serial parallel type in each disjoint set, set the parallel type of //! all IterDomains in the disjoint set to that PType. + //! + //! TODO: Should this be moved to parallel validation? void validateAndPropagatePType(); //! Run through disjoint sets in the LOOP map and allocate the index @@ -179,11 +184,15 @@ class TORCH_CUDA_CU_API ComputeAtMap { //! Under this condition, we can pre-allocate all required index //! variable integers before creating any kir::forloop, and this //! would help optimizing the generated integer math for indexing. + //! + //! TODO: Should this be moved to an indexing map structure outside of + //! ComputeAtMap that has a ComputeAtMap reference? void allocateIndexVariables(); - //! Returns if id0 and id1 are mapped to eachother with provided IdMappingMode - bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const; - + //! Simple alias to IdGraph mappings. + bool areMapped(IterDomain* id0, IterDomain* id1, IdMappingMode mode) const { + return idGraph().getNodes(mode).strictAreMapped(id0, id1); + } //! Returns an iter domain that is the maximum expanded size of all iter //! domains the one provided maps to. Useful for opening loops to the correct //! iteration size. Not guarenteed to return the same ID every call, but is diff --git a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp index d7bed02d0694..2a8c724ff32b 100644 --- a/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_divisible_split.cpp @@ -104,11 +104,11 @@ std::unordered_set getAllDivisibleSplits( continue; } - if (IterDomainGraph::exprsMap( + if (ca_map->idGraph().exprsMap( original_view_split, other_id->definition(), false, - ca_map->idGraph().getNodes(IdMappingMode::EXACT))) { + IdMappingMode::EXACT)) { all_divisible_splits.emplace(other_id->definition()->as()); } } diff --git a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp index 00376734de4c..9b231d835ede 100644 --- a/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_index_compute.cpp @@ -1271,8 +1271,11 @@ namespace { bool isPermissivelyMappedWithAny(IterDomain* id, const std::vector& ids) { return std::any_of(ids.begin(), ids.end(), [&](Val* val) { return val->isA() && - GpuLower::current()->caMap()->areMapped( - id, val->as(), IdMappingMode::PERMISSIVE); + GpuLower::current() + ->caMap() + ->idGraph() + .getNodes(IdMappingMode::PERMISSIVE) + .permissiveAreMapped(id, val->as()); }); }