From 9465b73b4af27b12d565cf14bbd6fc8c0a54778b Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 9 Mar 2023 03:50:54 -0800 Subject: [PATCH 01/12] Include non-view rfactor IDs in CA map rfactor ID sets --- third_party/nvfuser/csrc/compute_at_map.cpp | 5 +- .../nvfuser/test/test_gpu_indexing.cpp | 47 +++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/third_party/nvfuser/csrc/compute_at_map.cpp b/third_party/nvfuser/csrc/compute_at_map.cpp index 6cdafc45d8f8..41bc8454649e 100644 --- a/third_party/nvfuser/csrc/compute_at_map.cpp +++ b/third_party/nvfuser/csrc/compute_at_map.cpp @@ -331,12 +331,13 @@ void IterDomainGraph::build(Fusion* fusion) { auto all_ids = ir_utils::allIDsOf(tv); // Check is this domain is a consumer of a view-like operation - bool view_like_domain = tv->domain()->hasViewLikeRFactor(); + // bool view_like_domain = tv->domain()->hasViewLikeRFactor(); for (auto id : all_ids) { // Check if this id is a view like rfactor id bool is_view_rfactor_id = false; - if (view_like_domain && id->isRFactorProduct()) { + // if (view_like_domain && id->isRFactorProduct()) { + if (id->isRFactorProduct()) { // If the tensor domain is a view like domain, and the iteration domain // is marked as an rfactor product and is in the rfactor domain, it's a // view like rfactor iteration domain diff --git a/third_party/nvfuser/test/test_gpu_indexing.cpp b/third_party/nvfuser/test/test_gpu_indexing.cpp index 2d52d255da50..7523faf638c9 100644 --- a/third_party/nvfuser/test/test_gpu_indexing.cpp +++ b/third_party/nvfuser/test/test_gpu_indexing.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -783,4 +784,50 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { &fusion, cg_outputs, aten_inputs, aten_outputs, __LINE__, __FILE__); } +// Repro of issue #2560 +TEST_F(NVFuserTest, FusionIndexing18_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv2, tv1); + auto tv4 = sum(tv3, {0, 1}); + fusion.addOutput(tv4); + + fusion.printMath(); + + tv4->merge(0); + tv4->split(0, 4); + auto tv5 = tv4->rFactor({1}); + + MaxRootDomainInfoSpanningTree tree(tv5); + TransformPropagator tp(tv5); + tree.traverse(&tp); + + inlineAllAt(tv4, 1, true); + + fusion.printMath(); + + fusion.printKernel(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(1); + at::Tensor t0 = at::randn({5}, options); + at::Tensor t1 = at::randn({5, 3}, options); + std::vector inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + auto ref = (t0.unsqueeze(-1) + t1).sum(); + + testValidate(fe.kernel(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); +} + } // namespace nvfuser From dc4b79628a0f1e034b6426d9a03054fa6928dfaf Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 9 Mar 2023 04:09:02 -0800 Subject: [PATCH 02/12] WIP: Fix #2559 --- .../nvfuser/csrc/scheduler/normalization.cpp | 2 ++ .../csrc/scheduler/normalization_utils.cpp | 32 +++++++++++++++++++ .../csrc/scheduler/normalization_utils.h | 5 +++ 3 files changed, 39 insertions(+) diff --git a/third_party/nvfuser/csrc/scheduler/normalization.cpp b/third_party/nvfuser/csrc/scheduler/normalization.cpp index c3b5994419f7..06644e6f3f0c 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization.cpp @@ -1209,6 +1209,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { cached_outputs, dummy_outputs); + normalization_scheduler_utils::fixUpInvalidPersistentBuffers(fusion); + if (rparams.compute_persistent_buffer_with_first_consumer) { TORCH_INTERNAL_ASSERT( rparams.persistent_kernel, diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 479bb4a6e573..d6eb7bc2b7a0 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -1,5 +1,10 @@ +#include +#include +#include +#include #include #include +#include #include #include @@ -494,5 +499,32 @@ std::optional getGridOuterNormalizationParams( return std::nullopt; } +void fixUpInvalidPersistentBuffers(Fusion* fusion) { + auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion); + ConcretizedBroadcastDomains concretize_info(fusion); + + for (auto persistent_buffer : persistent_buffer_info.persistent_buffers) { + std::cerr << "PB: " << persistent_buffer->toString() << std::endl; + + for (auto axis : persistent_buffer->domain()->domain()) { + if (!concretize_info.isConcretized(axis) || !axis->isThread()) { + continue; + } + // Found + std::cerr << "Concretized broadcast in persistent buffer: " + << axis->toString() << std::endl; + // Recompute + for (Expr* use : persistent_buffer->uses()) { + auto buffer_replicate = RecomputeTv::recompute(persistent_buffer); + ir_utils::replaceValInExpr(use, persistent_buffer, buffer_replicate); + std::cerr << "Replicated: " << buffer_replicate->toString() + << std::endl; + } + } + } + + inlineMost(); +} + } // namespace normalization_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.h b/third_party/nvfuser/csrc/scheduler/normalization_utils.h index ff353de41a34..fd13513453a0 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.h +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.h @@ -8,6 +8,9 @@ #include namespace nvfuser { + +class Fusion; + namespace normalization_scheduler_utils { //! Utility class to iterate candidates of launch configurations in a @@ -145,5 +148,7 @@ std::optional getGridOuterNormalizationParams( int64_t vectorize_factor, int64_t persistent_buffer_size); +void fixUpInvalidPersistentBuffers(Fusion* fusion); + } // namespace normalization_scheduler_utils } // namespace nvfuser From f98bbd60bb7369294cd9edb6e7514081562e4bd2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 11:51:00 -0800 Subject: [PATCH 03/12] benchmark fix --- third_party/nvfuser/benchmark/scale_bias_relu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/third_party/nvfuser/benchmark/scale_bias_relu.cpp b/third_party/nvfuser/benchmark/scale_bias_relu.cpp index 6a3dc7cd2e4d..f9ed37ccbba2 100644 --- a/third_party/nvfuser/benchmark/scale_bias_relu.cpp +++ b/third_party/nvfuser/benchmark/scale_bias_relu.cpp @@ -20,8 +20,7 @@ static void setupSBR(Fusion* fusion, DataType dtype) { std::vector bcast_shape(kNumberOfDims, 1); bcast_shape[bcast_shape.size() - 1] = -1; - std::vector bcast_contig(kNumberOfDims, false); - bcast_contig[bcast_contig.size() - 1] = true; + std::vector bcast_contig(1, true); auto x = makeContigTensor(kNumberOfDims, dtype); From 650ef8e3cf2e61fa3c4ba740c0e86564df784dd9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 16:41:05 -0800 Subject: [PATCH 04/12] cleanup --- third_party/nvfuser/csrc/compute_at_map.cpp | 45 +++++++------------ third_party/nvfuser/csrc/compute_at_map.h | 20 +++++---- .../nvfuser/csrc/lower_index_compute.cpp | 2 +- third_party/nvfuser/test/test_gpu3.cpp | 21 ++++++++- .../nvfuser/test/test_gpu_indexing.cpp | 6 --- 5 files changed, 48 insertions(+), 46 deletions(-) diff --git a/third_party/nvfuser/csrc/compute_at_map.cpp b/third_party/nvfuser/csrc/compute_at_map.cpp index 41bc8454649e..d1198e2db4c6 100644 --- a/third_party/nvfuser/csrc/compute_at_map.cpp +++ b/third_party/nvfuser/csrc/compute_at_map.cpp @@ -330,26 +330,16 @@ void IterDomainGraph::build(Fusion* fusion) { const auto& domain = tv->domain()->domain(); auto all_ids = ir_utils::allIDsOf(tv); - // Check is this domain is a consumer of a view-like operation - // bool view_like_domain = tv->domain()->hasViewLikeRFactor(); - for (auto id : all_ids) { - // Check if this id is a view like rfactor id - bool is_view_rfactor_id = false; - // if (view_like_domain && id->isRFactorProduct()) { - if (id->isRFactorProduct()) { - // If the tensor domain is a view like domain, and the iteration domain - // is marked as an rfactor product and is in the rfactor domain, it's a - // view like rfactor iteration domain - const auto& rfactor_domain = tv->domain()->getMaybeRFactorDomain(); - if (std::find(rfactor_domain.begin(), rfactor_domain.end(), id) != - rfactor_domain.end()) { - is_view_rfactor_id = true; - } - } + // Check if this id is an rfactor id in the rfactor domain + bool is_rfactor_domain_id = id->isRFactorProduct() && + std::find( + tv->getMaybeRFactorDomain().begin(), + tv->getMaybeRFactorDomain().end(), + id) != tv->getMaybeRFactorDomain().end(); bool is_leaf_id = std::find(domain.begin(), domain.end(), id) != domain.end(); - initializeId(id, is_view_rfactor_id, is_leaf_id); + initializeId(id, is_rfactor_domain_id, is_leaf_id); } } @@ -688,7 +678,7 @@ void IterDomainGraph::build(Fusion* fusion) { void IterDomainGraph::initializeId( IterDomain* id, - bool is_view_rfactor_id, + bool is_rfactor_id, bool is_leaf_id) { permissive_nodes_.initializeSet(id); exact_nodes_.initializeSet(id); @@ -701,8 +691,8 @@ void IterDomainGraph::initializeId( all_ids_.pushBack(id); - if (is_view_rfactor_id) { - view_rfactor_ids_.emplace(id); + if (is_rfactor_id) { + rfactor_ids_.emplace(id); } } @@ -995,7 +985,7 @@ IterDomain* ComputeAtMap::computeConcreteId( if (std::none_of( exact_set->vector().begin(), exact_set->vector().end(), - [&](IterDomain* id) { return isViewRfactor(id); })) { + [&](IterDomain* id) { return isRfactor(id); })) { continue; } VectorOfUniqueEntries>> @@ -1373,19 +1363,18 @@ std::string ComputeAtMap::toString() const { return ss.str(); } -bool ComputeAtMap::isViewRfactor(IterDomain* ref_id) const { - return id_graph_.viewRfactorIds().find(ref_id) != - id_graph_.viewRfactorIds().end(); +bool ComputeAtMap::isRfactor(IterDomain* ref_id) const { + return id_graph_.rfactorIds().find(ref_id) != id_graph_.rfactorIds().end(); } -std::vector ComputeAtMap::getViewRfactorDomainsOfIdGroup( +std::vector ComputeAtMap::getRfactorDomainsOfIdGroup( IterDomain* ref_id, IdMappingMode mode) const { auto disjoint_set = disjointSetOf(ref_id, mode); std::vector rfactor_ids; for (auto disjoint_id : disjoint_set->vector()) { - if (id_graph_.viewRfactorIds().find(disjoint_id) != - id_graph_.viewRfactorIds().end()) { + if (id_graph_.rfactorIds().find(disjoint_id) != + id_graph_.rfactorIds().end()) { rfactor_ids.push_back(disjoint_id); } } @@ -1454,7 +1443,7 @@ ComputeAtMap::getInputDisjointSetsOf(IterDomain* of_id, bool stop_at_rfactor) { std::any_of( currently_visiting->vector().begin(), currently_visiting->vector().end(), - [&](IterDomain* id) { return isViewRfactor(id); })) { + [&](IterDomain* id) { return isRfactor(id); })) { input_disjoint_sets.pushBack(currently_visiting); continue; } diff --git a/third_party/nvfuser/csrc/compute_at_map.h b/third_party/nvfuser/csrc/compute_at_map.h index a0c1dd72eefd..9d596ade2a3f 100644 --- a/third_party/nvfuser/csrc/compute_at_map.h +++ b/third_party/nvfuser/csrc/compute_at_map.h @@ -91,8 +91,8 @@ class TORCH_CUDA_CU_API IterDomainGraph { return all_ids_; } - const std::unordered_set& viewRfactorIds() const { - return view_rfactor_ids_; + const std::unordered_set& rfactorIds() const { + return rfactor_ids_; } // Returns if first and second are expressions through which the provided @@ -115,7 +115,7 @@ class TORCH_CUDA_CU_API IterDomainGraph { private: void build(Fusion* fusion); - void initializeId(IterDomain* id, bool is_view_rfactor_id, bool is_leaf_id); + void initializeId(IterDomain* id, bool is_rfactor_id, bool is_leaf_id); // Checks if exprsMap then if forward will map outputs else inputs in exact // and permissive map. @@ -136,7 +136,9 @@ class TORCH_CUDA_CU_API IterDomainGraph { VectorOfUniqueEntries all_ids_; - std::unordered_set view_rfactor_ids_; + // This used to only have non-reduction rfactor IDs. Changed to + // include reduction rfactor IDs as well at PR #2562 + std::unordered_set rfactor_ids_; c10::optional> self_mapping_info_ = c10::nullopt; @@ -214,13 +216,13 @@ class TORCH_CUDA_CU_API ComputeAtMap { // Prints mapping information, forwards to an internal IterDomainGraph std::string toString() const; - // Returns if the provided ID is a view like rfactor id - bool isViewRfactor(IterDomain* ref_id) const; + // Returns if the provided ID is an rfactor id + bool isRfactor(IterDomain* ref_id) const; // Returns all rfactor domains in rfactor_concrete_count_reset_domains_ that - // are in the disjoint set of the provided IterDomain. This will be every view - // like rfactor ID the provided ID "depends" on in the map. - std::vector getViewRfactorDomainsOfIdGroup( + // are in the disjoint set of the provided IterDomain. This will be every + // rfactor ID the provided ID "depends" on in the map. + std::vector getRfactorDomainsOfIdGroup( IterDomain* ref_id, IdMappingMode mode) const; diff --git a/third_party/nvfuser/csrc/lower_index_compute.cpp b/third_party/nvfuser/csrc/lower_index_compute.cpp index f2ca4b1381a3..e0f36273790f 100644 --- a/third_party/nvfuser/csrc/lower_index_compute.cpp +++ b/third_party/nvfuser/csrc/lower_index_compute.cpp @@ -1381,7 +1381,7 @@ IterDomain* getRfactorIDToTraverse( IterDomain* id, const std::vector& consumer_all_ids) { const auto& rfactor_ids = - GpuLower::current()->caMap()->getViewRfactorDomainsOfIdGroup( + GpuLower::current()->caMap()->getRfactorDomainsOfIdGroup( id, IdMappingMode::PERMISSIVE); if (rfactor_ids.empty()) { diff --git a/third_party/nvfuser/test/test_gpu3.cpp b/third_party/nvfuser/test/test_gpu3.cpp index 131e2d17435f..e506ec6b3add 100644 --- a/third_party/nvfuser/test/test_gpu3.cpp +++ b/third_party/nvfuser/test/test_gpu3.cpp @@ -6089,14 +6089,14 @@ TEST_F(NVFuserTest, FusionRepro2094_CUDA) { auto tv0 = TensorViewBuilder() .ndims(1) .shape(neg_one_vec) - .contiguity({true}) + .contiguity(true) .dtype(DataType::Float) .build(); fusion->addInput(tv0); auto tv1 = TensorViewBuilder() .ndims(1) .shape(neg_one_vec) - .contiguity({true}) + .contiguity(true) .dtype(DataType::Float) .build(); fusion->addInput(tv1); @@ -7879,6 +7879,23 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { } } +TEST_F(NVFuserTest, TMP) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + tv1->split(0, 4); + auto tv2 = tv1->rFactor({0}); + + std::cout << tv2->toString() << std::endl; + fusion.print(); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/third_party/nvfuser/test/test_gpu_indexing.cpp b/third_party/nvfuser/test/test_gpu_indexing.cpp index 7523faf638c9..64689f5f7925 100644 --- a/third_party/nvfuser/test/test_gpu_indexing.cpp +++ b/third_party/nvfuser/test/test_gpu_indexing.cpp @@ -799,8 +799,6 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { auto tv4 = sum(tv3, {0, 1}); fusion.addOutput(tv4); - fusion.printMath(); - tv4->merge(0); tv4->split(0, 4); auto tv5 = tv4->rFactor({1}); @@ -811,10 +809,6 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { inlineAllAt(tv4, 1, true); - fusion.printMath(); - - fusion.printKernel(); - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::manual_seed(1); at::Tensor t0 = at::randn({5}, options); From 30cc5b9ecc5c1b688bce6b0667e813b64f89c936 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 17:13:40 -0800 Subject: [PATCH 05/12] remove test added accidentally --- third_party/nvfuser/test/test_gpu3.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/third_party/nvfuser/test/test_gpu3.cpp b/third_party/nvfuser/test/test_gpu3.cpp index e506ec6b3add..6058378c8679 100644 --- a/third_party/nvfuser/test/test_gpu3.cpp +++ b/third_party/nvfuser/test/test_gpu3.cpp @@ -7879,23 +7879,6 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { } } -TEST_F(NVFuserTest, TMP) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(1); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - tv1->split(0, 4); - auto tv2 = tv1->rFactor({0}); - - std::cout << tv2->toString() << std::endl; - fusion.print(); -} - // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From 9b5e343cc87d9f330a91b8d8842685385268c943 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 18:22:27 -0800 Subject: [PATCH 06/12] Add a repro --- third_party/nvfuser/test/test_gpu3.cpp | 84 ++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/third_party/nvfuser/test/test_gpu3.cpp b/third_party/nvfuser/test/test_gpu3.cpp index 6058378c8679..6424f593055f 100644 --- a/third_party/nvfuser/test/test_gpu3.cpp +++ b/third_party/nvfuser/test/test_gpu3.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -7879,6 +7880,89 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { } } +// Repro of #2559 +TEST_F(NVFuserTest, FusionNonMatchingPersistentBuffer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1, DataType::Half); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv1); + + auto tv2 = castOp(DataType::Float, tv0); + auto tv3 = castOp(DataType::Float, tv1); + + auto tv4 = broadcast(tv2, {false, true}); + auto tv5 = add(tv4, tv3); + + auto tv6 = sum(tv5, {0, 1}); + auto tv7 = broadcast(tv6, {true, true}); + + auto tv8 = add(tv3, tv7); + auto tv9 = add(tv4, tv7); + auto tv10 = add(tv8, tv9); + + fusion.addOutput(tv10); + + fusion.printMath(); + fusion.printKernel(); + + const int tidx = 256; + const int pb = 8; + const int bidx = 2; + + tv6->merge(0); + // tidx + tv6->split(0, tidx); + // persistent buffer + tv6->split(0, pb); + // [BIDx, PB, TIDX] + + MaxRootDomainInfoSpanningTree tree(tv6); + TransformPropagator tp(tv6); + tree.traverse(&tp); + + auto tv6_rf = tv6->rFactor({1}); + + inlineMost(); + + tv6_rf->axis(0)->parallelize(ParallelType::BIDx); + tv6_rf->axis(2)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv6_rf); + + if (getenv("FIX")) { + normalization_scheduler_utils::fixUpInvalidPersistentBuffers(&fusion); + } + + fusion.printMath(); + fusion.printKernel(); + + std::vector shape({bidx, tidx * pb}); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({shape[0]}, options); + at::Tensor t1 = at::randn(shape, options); + std::vector inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + auto t2 = t0.to(at::kDouble); + auto t3 = t1.to(at::kDouble); + auto t4 = t2.unsqueeze(-1); + auto t5 = t4 + t3; + auto t6 = sum(t5); + auto t7 = t6.unsqueeze(-1).unsqueeze(-1); + auto t8 = t3 + t7; + auto t9 = t4 + t7; + auto t10 = t8 + t9; + + testValidate(fe.kernel(), cg_outputs, inputs, {t10}, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From 2516f6ad0cea4aed428b34634bb7679e2643cfd5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 19:12:36 -0800 Subject: [PATCH 07/12] cleanup --- .../nvfuser/csrc/scheduler/normalization.cpp | 3 +- .../csrc/scheduler/normalization_utils.cpp | 154 ++++++++++++++++-- .../csrc/scheduler/normalization_utils.h | 17 ++ .../csrc/scheduler/reduction_utils.cpp | 106 ------------ .../nvfuser/csrc/scheduler/reduction_utils.h | 6 - third_party/nvfuser/test/test_gpu2.cpp | 3 +- 6 files changed, 161 insertions(+), 128 deletions(-) diff --git a/third_party/nvfuser/csrc/scheduler/normalization.cpp b/third_party/nvfuser/csrc/scheduler/normalization.cpp index f611968caf0e..e326db5b659b 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization.cpp @@ -1130,7 +1130,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { std::vector dummy_outputs; if (rparams.project_persistent_buffers && ir_utils::getViewOps(fusion).empty()) { - dummy_outputs = reduction_scheduler_utils::projectPersistentBuffers(fusion); + dummy_outputs = + normalization_scheduler_utils::projectPersistentBuffers(fusion); } // Cache tensors before grabbing any references to reductions as cache_before diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index d6eb7bc2b7a0..6e443848b490 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -499,31 +500,156 @@ std::optional getGridOuterNormalizationParams( return std::nullopt; } +std::vector projectPersistentBuffers(Fusion* fusion) { + auto persistent_info = scheduler_utils::persistentBuffers(fusion); + std::vector dummy_outputs; + + // Convenience accessors + const auto& persistent_buffers = persistent_info.persistent_buffers; + const auto& persistent_resolution_points = + persistent_info.persistent_buffer_resolution_points; + const auto& projected_buffers = + persistent_info.projectable_persistent_buffers; + + TORCH_INTERNAL_ASSERT( + persistent_buffers.size() == persistent_resolution_points.size()); + + // Iterate through projected buffers, tracking which index it corresponds too + // since there's a resolution point entry for every buffer. + for (auto buffer_i : c10::irange(persistent_buffers.size())) { + auto buffer = persistent_buffers[buffer_i]; + if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) == + projected_buffers.end()) { + continue; + } + + auto resolution_points = persistent_resolution_points[buffer_i]; + + std::vector persistent_use_of_buffer; + + // Go through the resolution points one by one. Resolution points are points + // in which the reduction branch meets the residual branch. These are points + // where the persitent buffer may no longer be needed (one point could be + // after another, and the buffer would be needed until the last resolution + // points) + for (auto resolution_point : resolution_points) { + // Need to go through all paths from the persistent buffer to the + // resolution point + auto chains_to_resolution = + DependencyCheck::getAllDependencyChains(buffer, resolution_point); + for (auto chain : chains_to_resolution) { + auto tv_chain = ir_utils::filterByType(chain); + + // To move the persistent buffers to the inputs, we need to recompute + // the persistent buffer for all branches that don't go through a + // reduction. If there's a reduction on the current path between the + // persistent buffer and resolution, continue, there's no need to + // replicate this use. + if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { + return tv->hasReduction(); + })) { + continue; + } + + // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] + // is its first use. + auto use = chain[1]; + + // Only grab unique uses, a persistent buffer could be used multiple + // times in the same expression. + if (std::find( + persistent_use_of_buffer.begin(), + persistent_use_of_buffer.end(), + use) != persistent_use_of_buffer.end()) { + continue; + } + persistent_use_of_buffer.emplace_back(use); + } + + // For all uses that do not go towards the reduction operations in the + // persistent section of the graph, recompute the persistent buffer. + for (auto use : persistent_use_of_buffer) { + TORCH_INTERNAL_ASSERT(use->definition() != nullptr); + auto buffer_replicate = RecomputeTv::recompute(buffer); + // Create a shortcut buffer <--> buffer_replicate for propagation. + // Why is this needed? + // Consider that we have a fusion + // + // T0[I] + // T1[b b I] = broadcast(T0) + // T2[b b r] = reduction(T1) + // T3[b b b] = broadcast(T2) + // T4[b, b, I] = T1 + T3 + // T5[b, b, r] = reduction(T4) + // + // After projection, it becomes + // + // T0[I] + // T1[b b I] = broadcast(T0) + // T2[b b r] = reduction(T1) + // T3[b b b] = broadcast(T2) + // T6[b b I] = broadcast(T0) + // T4[b, b, I] = T6 + T3 + // T5[b, b, r] = reduction(T4) + // + // During schedule, we need to propagate from T2 to T5. However, in the + // resulting DAG, neither the propagation path T2->T3->T4->T5 nor + // T2->T1->T0->T6->T4->T5 works because they both have missing root + // domain. But adding `T7 = T1 + T6` creates a new propagation path + // `T2->T1->T7->T6->T4->T5` which has all root domain information. + // See FusionBroadcastPersistentReduction_CUDA for an example + dummy_outputs.emplace_back(add(buffer_replicate, buffer)); + ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); + } + } + } + return dummy_outputs; +} + void fixUpInvalidPersistentBuffers(Fusion* fusion) { auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion); ConcretizedBroadcastDomains concretize_info(fusion); - for (auto persistent_buffer : persistent_buffer_info.persistent_buffers) { - std::cerr << "PB: " << persistent_buffer->toString() << std::endl; + std::vector recomputed_tvs; - for (auto axis : persistent_buffer->domain()->domain()) { - if (!concretize_info.isConcretized(axis) || !axis->isThread()) { + for (auto persistent_buffer : persistent_buffer_info.persistent_buffers) { + // Check if this buffer needs to be recomputed + + bool need_recomputation = false; + for (const auto i : c10::irange( + persistent_buffer->getComputeAtPosition(), + persistent_buffer->domain()->domain().size())) { + auto axis = persistent_buffer->axis(i); + // concretize_info.isConcretized(axis) means the axis is a + // concreteized broadcast domain + if (!concretize_info.isConcretized(axis)) { continue; } - // Found - std::cerr << "Concretized broadcast in persistent buffer: " - << axis->toString() << std::endl; - // Recompute - for (Expr* use : persistent_buffer->uses()) { - auto buffer_replicate = RecomputeTv::recompute(persistent_buffer); - ir_utils::replaceValInExpr(use, persistent_buffer, buffer_replicate); - std::cerr << "Replicated: " << buffer_replicate->toString() - << std::endl; + + // If not parallelized, no dependency can exist + if (!axis->isThread()) { + continue; } + + need_recomputation = true; + break; + } + + if (!need_recomputation) { + continue; + } + + // Recompute + for (Expr* use : persistent_buffer->uses()) { + auto buffer_replicate = RecomputeTv::recompute(persistent_buffer); + ir_utils::replaceValInExpr(use, persistent_buffer, buffer_replicate); + std::cerr << "Replicated: " << buffer_replicate->toString() << std::endl; + recomputed_tvs.push_back(buffer_replicate); } } - inlineMost(); + // The new tensors are not yet inlined + inlineMost(recomputed_tvs); } } // namespace normalization_scheduler_utils diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.h b/third_party/nvfuser/csrc/scheduler/normalization_utils.h index fd13513453a0..de1183429f30 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.h +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.h @@ -148,6 +148,23 @@ std::optional getGridOuterNormalizationParams( int64_t vectorize_factor, int64_t persistent_buffer_size); +// Take all projectable persistent buffers, and move them to the inputs. This +// function create dummy outputs which should be used in later stages of the +// scheduling. +TORCH_CUDA_CU_API std::vector projectPersistentBuffers( + Fusion* fusion); + +//! Persistent buffers are effectively tensors that cannot be inlined +//! due to the reduction and broadcast pattern. Since we store +//! persistent buffers in registers, they have to be parallelized in +//! such a way that no data dependency exists between threads. Buffers +//! that do have dependencies cannot be persistent. This function +//! detects such buffers and make them non-persistent by +//! recomputation. +//! +//! Alternatively, such buffers could be stored on shared memory if +//! the dependecy only exists between threads in the same thread +//! block. Not considered yet. void fixUpInvalidPersistentBuffers(Fusion* fusion); } // namespace normalization_scheduler_utils diff --git a/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp b/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp index c16e500e2385..22d84e02a5a2 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp @@ -656,111 +656,5 @@ TensorView* sortAndRFactor(TensorView* reference_tv) { return ir_utils::rfactorHelper(reference_tv, rfactor_axes); } -std::vector projectPersistentBuffers(Fusion* fusion) { - auto persistent_info = scheduler_utils::persistentBuffers(fusion); - std::vector dummy_outputs; - - // Convenience accessors - const auto& persistent_buffers = persistent_info.persistent_buffers; - const auto& persistent_resolution_points = - persistent_info.persistent_buffer_resolution_points; - const auto& projected_buffers = - persistent_info.projectable_persistent_buffers; - - TORCH_INTERNAL_ASSERT( - persistent_buffers.size() == persistent_resolution_points.size()); - - // Iterate through projected buffers, tracking which index it corresponds too - // since there's a resolution point entry for every buffer. - for (auto buffer_i : c10::irange(persistent_buffers.size())) { - auto buffer = persistent_buffers[buffer_i]; - if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) == - projected_buffers.end()) { - continue; - } - - auto resolution_points = persistent_resolution_points[buffer_i]; - - std::vector persistent_use_of_buffer; - - // Go through the resolution points one by one. Resolution points are points - // in which the reduction branch meets the residual branch. These are points - // where the persitent buffer may no longer be needed (one point could be - // after another, and the buffer would be needed until the last resolution - // points) - for (auto resolution_point : resolution_points) { - // Need to go through all paths from the persistent buffer to the - // resolution point - auto chains_to_resolution = - DependencyCheck::getAllDependencyChains(buffer, resolution_point); - for (auto chain : chains_to_resolution) { - auto tv_chain = ir_utils::filterByType(chain); - - // To move the persistent buffers to the inputs, we need to recompute - // the persistent buffer for all branches that don't go through a - // reduction. If there's a reduction on the current path between the - // persistent buffer and resolution, continue, there's no need to - // replicate this use. - if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { - return tv->hasReduction(); - })) { - continue; - } - - // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] - // is its first use. - auto use = chain[1]; - - // Only grab unique uses, a persistent buffer could be used multiple - // times in the same expression. - if (std::find( - persistent_use_of_buffer.begin(), - persistent_use_of_buffer.end(), - use) != persistent_use_of_buffer.end()) { - continue; - } - persistent_use_of_buffer.emplace_back(use); - } - - // For all uses that do not go towards the reduction operations in the - // persistent section of the graph, recompute the persistent buffer. - for (auto use : persistent_use_of_buffer) { - TORCH_INTERNAL_ASSERT(use->definition() != nullptr); - auto buffer_replicate = RecomputeTv::recompute(buffer); - // Create a shortcut buffer <--> buffer_replicate for propagation. - // Why is this needed? - // Consider that we have a fusion - // - // T0[I] - // T1[b b I] = broadcast(T0) - // T2[b b r] = reduction(T1) - // T3[b b b] = broadcast(T2) - // T4[b, b, I] = T1 + T3 - // T5[b, b, r] = reduction(T4) - // - // After projection, it becomes - // - // T0[I] - // T1[b b I] = broadcast(T0) - // T2[b b r] = reduction(T1) - // T3[b b b] = broadcast(T2) - // T6[b b I] = broadcast(T0) - // T4[b, b, I] = T6 + T3 - // T5[b, b, r] = reduction(T4) - // - // During schedule, we need to propagate from T2 to T5. However, in the - // resulting DAG, neither the propagation path T2->T3->T4->T5 nor - // T2->T1->T0->T6->T4->T5 works because they both have missing root - // domain. But adding `T7 = T1 + T6` creates a new propagation path - // `T2->T1->T7->T6->T4->T5` which has all root domain information. - // See FusionBroadcastPersistentReduction_CUDA for an example - dummy_outputs.emplace_back(add(buffer_replicate, buffer)); - ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); - } - } - } - return dummy_outputs; -} - } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/scheduler/reduction_utils.h b/third_party/nvfuser/csrc/scheduler/reduction_utils.h index e05ada8a618b..adb6548a1974 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction_utils.h +++ b/third_party/nvfuser/csrc/scheduler/reduction_utils.h @@ -41,11 +41,5 @@ TORCH_CUDA_CU_API void multiReductionInliner( // Reduction inliner expects an rfactored domain. TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv); -// Take all projectable persistent buffers, and move them to the inputs. This -// function create dummy outputs which should be used in later stages of the -// scheduling. -TORCH_CUDA_CU_API std::vector projectPersistentBuffers( - Fusion* fusion); - } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/test/test_gpu2.cpp b/third_party/nvfuser/test/test_gpu2.cpp index e67b80f332d6..7352a7715cae 100644 --- a/third_party/nvfuser/test/test_gpu2.cpp +++ b/third_party/nvfuser/test/test_gpu2.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -9609,7 +9610,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferProjection_CUDA) { fusion.addOutput(tv9); - reduction_scheduler_utils::projectPersistentBuffers(&fusion); + normalization_scheduler_utils::projectPersistentBuffers(&fusion); auto tv5_producers = ir_utils::producerTvsOf(tv5); auto tv7_producers = ir_utils::producerTvsOf(tv7); From f70953d3c7ad7591c7b6ecb75aa4b4d80891d2df Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 19:26:47 -0800 Subject: [PATCH 08/12] cleanup --- .../csrc/scheduler/normalization_utils.cpp | 130 ++++++++++-------- 1 file changed, 76 insertions(+), 54 deletions(-) diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 6e443848b490..6fd22b68f34b 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -500,6 +500,56 @@ std::optional getGridOuterNormalizationParams( return std::nullopt; } +namespace { + +// Go through the resolution points one by one. Resolution points are points +// in which the reduction branch meets the residual branch. These are points +// where the persitent buffer may no longer be needed (one point could be +// after another, and the buffer would be needed until the last resolution +// points) +std::vector getPersistentUseOfBuffer( + TensorView* buffer, + TensorView* resolution_point) { + std::vector persistent_use_of_buffer; + + // Need to go through all paths from the persistent buffer to the + // resolution point + auto chains_to_resolution = + DependencyCheck::getAllDependencyChains(buffer, resolution_point); + for (auto chain : chains_to_resolution) { + auto tv_chain = ir_utils::filterByType(chain); + + // To move the persistent buffers to the inputs, we need to recompute + // the persistent buffer for all branches that don't go through a + // reduction. If there's a reduction on the current path between the + // persistent buffer and resolution, continue, there's no need to + // replicate this use. + if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { + return tv->hasReduction(); + })) { + continue; + } + + // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] + // is its first use. + auto use = chain[1]; + + // Only grab unique uses, a persistent buffer could be used multiple + // times in the same expression. + if (std::find( + persistent_use_of_buffer.begin(), + persistent_use_of_buffer.end(), + use) != persistent_use_of_buffer.end()) { + continue; + } + persistent_use_of_buffer.emplace_back(use->as()); + } + + return persistent_use_of_buffer; +} + +} // namespace + std::vector projectPersistentBuffers(Fusion* fusion) { auto persistent_info = scheduler_utils::persistentBuffers(fusion); std::vector dummy_outputs; @@ -523,48 +573,9 @@ std::vector projectPersistentBuffers(Fusion* fusion) { continue; } - auto resolution_points = persistent_resolution_points[buffer_i]; - - std::vector persistent_use_of_buffer; - - // Go through the resolution points one by one. Resolution points are points - // in which the reduction branch meets the residual branch. These are points - // where the persitent buffer may no longer be needed (one point could be - // after another, and the buffer would be needed until the last resolution - // points) - for (auto resolution_point : resolution_points) { - // Need to go through all paths from the persistent buffer to the - // resolution point - auto chains_to_resolution = - DependencyCheck::getAllDependencyChains(buffer, resolution_point); - for (auto chain : chains_to_resolution) { - auto tv_chain = ir_utils::filterByType(chain); - - // To move the persistent buffers to the inputs, we need to recompute - // the persistent buffer for all branches that don't go through a - // reduction. If there's a reduction on the current path between the - // persistent buffer and resolution, continue, there's no need to - // replicate this use. - if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { - return tv->hasReduction(); - })) { - continue; - } - - // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] - // is its first use. - auto use = chain[1]; - - // Only grab unique uses, a persistent buffer could be used multiple - // times in the same expression. - if (std::find( - persistent_use_of_buffer.begin(), - persistent_use_of_buffer.end(), - use) != persistent_use_of_buffer.end()) { - continue; - } - persistent_use_of_buffer.emplace_back(use); - } + for (auto resolution_point : persistent_resolution_points[buffer_i]) { + const auto persistent_use_of_buffer = + getPersistentUseOfBuffer(buffer, resolution_point); // For all uses that do not go towards the reduction operations in the // persistent section of the graph, recompute the persistent buffer. @@ -607,19 +618,26 @@ std::vector projectPersistentBuffers(Fusion* fusion) { } void fixUpInvalidPersistentBuffers(Fusion* fusion) { - auto persistent_buffer_info = scheduler_utils::persistentBuffers(fusion); + auto persistent_info = scheduler_utils::persistentBuffers(fusion); + const auto& persistent_buffers = persistent_info.persistent_buffers; + const auto& persistent_resolution_points = + persistent_info.persistent_buffer_resolution_points; + + // TODO: Compute this on demand ConcretizedBroadcastDomains concretize_info(fusion); std::vector recomputed_tvs; - for (auto persistent_buffer : persistent_buffer_info.persistent_buffers) { - // Check if this buffer needs to be recomputed + for (auto buffer_i : c10::irange(persistent_buffers.size())) { + auto buffer = persistent_buffers.at(buffer_i); + // Check if this buffer needs to be recomputed bool need_recomputation = false; + for (const auto i : c10::irange( - persistent_buffer->getComputeAtPosition(), - persistent_buffer->domain()->domain().size())) { - auto axis = persistent_buffer->axis(i); + buffer->getComputeAtPosition(), + buffer->domain()->domain().size())) { + auto axis = buffer->axis(i); // concretize_info.isConcretized(axis) means the axis is a // concreteized broadcast domain if (!concretize_info.isConcretized(axis)) { @@ -639,12 +657,16 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { continue; } - // Recompute - for (Expr* use : persistent_buffer->uses()) { - auto buffer_replicate = RecomputeTv::recompute(persistent_buffer); - ir_utils::replaceValInExpr(use, persistent_buffer, buffer_replicate); - std::cerr << "Replicated: " << buffer_replicate->toString() << std::endl; - recomputed_tvs.push_back(buffer_replicate); + for (auto resolution_point : persistent_resolution_points[buffer_i]) { + const auto persistent_use_of_buffer = + getPersistentUseOfBuffer(buffer, resolution_point); + for (auto use : persistent_use_of_buffer) { + TORCH_INTERNAL_ASSERT(use->definition() != nullptr); + auto buffer_replicate = RecomputeTv::recompute(buffer); + recomputed_tvs.push_back(buffer_replicate); + std::cerr << "Replicated: " << buffer_replicate->toString() + << std::endl; + } } } From 50116359f44492c6058419ba659d750feed38c98 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 10 Mar 2023 21:18:13 -0800 Subject: [PATCH 09/12] cleanup --- .../csrc/scheduler/normalization_utils.cpp | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 6fd22b68f34b..98a7a79c4152 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -10,6 +10,8 @@ #include +#include + namespace nvfuser { namespace normalization_scheduler_utils { @@ -623,11 +625,12 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { const auto& persistent_resolution_points = persistent_info.persistent_buffer_resolution_points; - // TODO: Compute this on demand - ConcretizedBroadcastDomains concretize_info(fusion); + std::unique_ptr concretize_info; std::vector recomputed_tvs; + bool recompute_done = false; + for (auto buffer_i : c10::irange(persistent_buffers.size())) { auto buffer = persistent_buffers.at(buffer_i); @@ -638,17 +641,22 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { buffer->getComputeAtPosition(), buffer->domain()->domain().size())) { auto axis = buffer->axis(i); - // concretize_info.isConcretized(axis) means the axis is a - // concreteized broadcast domain - if (!concretize_info.isConcretized(axis)) { - continue; - } // If not parallelized, no dependency can exist if (!axis->isThread()) { continue; } + if (!concretize_info) { + concretize_info = std::make_unique(fusion); + } + + // concretize_info.isConcretized(axis) means the axis is a + // concreteized broadcast domain + if (!concretize_info->isConcretized(axis)) { + continue; + } + need_recomputation = true; break; } @@ -663,15 +671,18 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { for (auto use : persistent_use_of_buffer) { TORCH_INTERNAL_ASSERT(use->definition() != nullptr); auto buffer_replicate = RecomputeTv::recompute(buffer); + ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); recomputed_tvs.push_back(buffer_replicate); std::cerr << "Replicated: " << buffer_replicate->toString() << std::endl; + recompute_done = true; } } } - // The new tensors are not yet inlined - inlineMost(recomputed_tvs); + if (recompute_done) { + inlineMost(); + } } } // namespace normalization_scheduler_utils From ff50f47ba76a05e1552e9222ed38c78206ce93f9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 11 Mar 2023 00:39:39 -0800 Subject: [PATCH 10/12] cleanup --- .../csrc/scheduler/normalization_utils.cpp | 43 +++++++++++-------- .../csrc/scheduler/normalization_utils.h | 4 +- third_party/nvfuser/test/test_gpu3.cpp | 10 +---- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 98a7a79c4152..3a26adf0fe3c 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -619,7 +619,7 @@ std::vector projectPersistentBuffers(Fusion* fusion) { return dummy_outputs; } -void fixUpInvalidPersistentBuffers(Fusion* fusion) { +std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion) { auto persistent_info = scheduler_utils::persistentBuffers(fusion); const auto& persistent_buffers = persistent_info.persistent_buffers; const auto& persistent_resolution_points = @@ -627,7 +627,7 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { std::unique_ptr concretize_info; - std::vector recomputed_tvs; + std::unordered_set recomputed_tvs; bool recompute_done = false; @@ -642,23 +642,28 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { buffer->domain()->domain().size())) { auto axis = buffer->axis(i); - // If not parallelized, no dependency can exist - if (!axis->isThread()) { - continue; - } - - if (!concretize_info) { - concretize_info = std::make_unique(fusion); - } - - // concretize_info.isConcretized(axis) means the axis is a - // concreteized broadcast domain - if (!concretize_info->isConcretized(axis)) { - continue; + // Unresolved data dependency could exist if: + // - Parallelized by tidx and stored on Local + // - Parallelized by bidx and stored on Local or Shared + if ((axis->isThreadDim() && + buffer->getMemoryType() == MemoryType::Local) || + (axis->isBlockDim() && + (buffer->getMemoryType() == MemoryType::Local || + buffer->getMemoryType() == MemoryType::Shared))) { + if (!concretize_info) { + concretize_info = + std::make_unique(fusion); + } + + // concretize_info.isConcretized(axis) means the axis is a + // concreteized broadcast domain + if (!concretize_info->isConcretized(axis)) { + continue; + } + + need_recomputation = true; + break; } - - need_recomputation = true; - break; } if (!need_recomputation) { @@ -672,7 +677,7 @@ void fixUpInvalidPersistentBuffers(Fusion* fusion) { TORCH_INTERNAL_ASSERT(use->definition() != nullptr); auto buffer_replicate = RecomputeTv::recompute(buffer); ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); - recomputed_tvs.push_back(buffer_replicate); + recomputed_tvs.insert(buffer); std::cerr << "Replicated: " << buffer_replicate->toString() << std::endl; recompute_done = true; diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.h b/third_party/nvfuser/csrc/scheduler/normalization_utils.h index de1183429f30..0393b25d0f7c 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.h +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.h @@ -160,12 +160,12 @@ TORCH_CUDA_CU_API std::vector projectPersistentBuffers( //! such a way that no data dependency exists between threads. Buffers //! that do have dependencies cannot be persistent. This function //! detects such buffers and make them non-persistent by -//! recomputation. +//! recomputation. Returns buffers that are made non-persistent. //! //! Alternatively, such buffers could be stored on shared memory if //! the dependecy only exists between threads in the same thread //! block. Not considered yet. -void fixUpInvalidPersistentBuffers(Fusion* fusion); +std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion); } // namespace normalization_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/test/test_gpu3.cpp b/third_party/nvfuser/test/test_gpu3.cpp index 6424f593055f..e4bc53c52275 100644 --- a/third_party/nvfuser/test/test_gpu3.cpp +++ b/third_party/nvfuser/test/test_gpu3.cpp @@ -7905,9 +7905,6 @@ TEST_F(NVFuserTest, FusionNonMatchingPersistentBuffer_CUDA) { fusion.addOutput(tv10); - fusion.printMath(); - fusion.printKernel(); - const int tidx = 256; const int pb = 8; const int bidx = 2; @@ -7931,12 +7928,7 @@ TEST_F(NVFuserTest, FusionNonMatchingPersistentBuffer_CUDA) { tv6_rf->axis(2)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(tv6_rf); - if (getenv("FIX")) { - normalization_scheduler_utils::fixUpInvalidPersistentBuffers(&fusion); - } - - fusion.printMath(); - fusion.printKernel(); + normalization_scheduler_utils::fixUpInvalidPersistentBuffers(&fusion); std::vector shape({bidx, tidx * pb}); From 35f0990bb9e916600b02437e05a3689c32081451 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sat, 11 Mar 2023 00:59:13 -0800 Subject: [PATCH 11/12] cleanup --- third_party/nvfuser/csrc/scheduler/normalization_utils.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 3a26adf0fe3c..58518cdcd78d 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -678,8 +678,6 @@ std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion) { auto buffer_replicate = RecomputeTv::recompute(buffer); ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); recomputed_tvs.insert(buffer); - std::cerr << "Replicated: " << buffer_replicate->toString() - << std::endl; recompute_done = true; } } @@ -688,6 +686,8 @@ std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion) { if (recompute_done) { inlineMost(); } + + return recomputed_tvs; } } // namespace normalization_scheduler_utils From 47f5b24d148445ec99c9825270eeb5aba4dd3f5f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 13 Mar 2023 18:04:40 -0700 Subject: [PATCH 12/12] update based on PR #2576 --- .../csrc/scheduler/normalization_utils.cpp | 159 +++++++++--------- 1 file changed, 81 insertions(+), 78 deletions(-) diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 58518cdcd78d..b9e3ff6f35b8 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -511,40 +511,42 @@ namespace { // points) std::vector getPersistentUseOfBuffer( TensorView* buffer, - TensorView* resolution_point) { + const std::vector& resolution_points) { std::vector persistent_use_of_buffer; - // Need to go through all paths from the persistent buffer to the - // resolution point - auto chains_to_resolution = - DependencyCheck::getAllDependencyChains(buffer, resolution_point); - for (auto chain : chains_to_resolution) { - auto tv_chain = ir_utils::filterByType(chain); - - // To move the persistent buffers to the inputs, we need to recompute - // the persistent buffer for all branches that don't go through a - // reduction. If there's a reduction on the current path between the - // persistent buffer and resolution, continue, there's no need to - // replicate this use. - if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { - return tv->hasReduction(); - })) { - continue; - } + for (auto resolution_point : resolution_points) { + // Need to go through all paths from the persistent buffer to the + // resolution point + auto chains_to_resolution = + DependencyCheck::getAllDependencyChains(buffer, resolution_point); + for (auto chain : chains_to_resolution) { + auto tv_chain = ir_utils::filterByType(chain); + + // To move the persistent buffers to the inputs, we need to recompute + // the persistent buffer for all branches that don't go through a + // reduction. If there's a reduction on the current path between the + // persistent buffer and resolution, continue, there's no need to + // replicate this use. + if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { + return tv->hasReduction(); + })) { + continue; + } - // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] - // is its first use. - auto use = chain[1]; + // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] + // is its first use. + auto use = chain[1]; - // Only grab unique uses, a persistent buffer could be used multiple - // times in the same expression. - if (std::find( - persistent_use_of_buffer.begin(), - persistent_use_of_buffer.end(), - use) != persistent_use_of_buffer.end()) { - continue; + // Only grab unique uses, a persistent buffer could be used multiple + // times in the same expression. + if (std::find( + persistent_use_of_buffer.begin(), + persistent_use_of_buffer.end(), + use) != persistent_use_of_buffer.end()) { + continue; + } + persistent_use_of_buffer.emplace_back(use->as()); } - persistent_use_of_buffer.emplace_back(use->as()); } return persistent_use_of_buffer; @@ -575,45 +577,45 @@ std::vector projectPersistentBuffers(Fusion* fusion) { continue; } - for (auto resolution_point : persistent_resolution_points[buffer_i]) { - const auto persistent_use_of_buffer = - getPersistentUseOfBuffer(buffer, resolution_point); - - // For all uses that do not go towards the reduction operations in the - // persistent section of the graph, recompute the persistent buffer. - for (auto use : persistent_use_of_buffer) { - TORCH_INTERNAL_ASSERT(use->definition() != nullptr); - auto buffer_replicate = RecomputeTv::recompute(buffer); - // Create a shortcut buffer <--> buffer_replicate for propagation. - // Why is this needed? - // Consider that we have a fusion - // - // T0[I] - // T1[b b I] = broadcast(T0) - // T2[b b r] = reduction(T1) - // T3[b b b] = broadcast(T2) - // T4[b, b, I] = T1 + T3 - // T5[b, b, r] = reduction(T4) - // - // After projection, it becomes - // - // T0[I] - // T1[b b I] = broadcast(T0) - // T2[b b r] = reduction(T1) - // T3[b b b] = broadcast(T2) - // T6[b b I] = broadcast(T0) - // T4[b, b, I] = T6 + T3 - // T5[b, b, r] = reduction(T4) - // - // During schedule, we need to propagate from T2 to T5. However, in the - // resulting DAG, neither the propagation path T2->T3->T4->T5 nor - // T2->T1->T0->T6->T4->T5 works because they both have missing root - // domain. But adding `T7 = T1 + T6` creates a new propagation path - // `T2->T1->T7->T6->T4->T5` which has all root domain information. - // See FusionBroadcastPersistentReduction_CUDA for an example - dummy_outputs.emplace_back(add(buffer_replicate, buffer)); - ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); - } + const auto& resolution_points = persistent_resolution_points.at(buffer_i); + + const auto persistent_use_of_buffer = + getPersistentUseOfBuffer(buffer, resolution_points); + + // For all uses that do not go towards the reduction operations in the + // persistent section of the graph, recompute the persistent buffer. + for (auto use : persistent_use_of_buffer) { + TORCH_INTERNAL_ASSERT(use->definition() != nullptr); + auto buffer_replicate = RecomputeTv::recompute(buffer); + // Create a shortcut buffer <--> buffer_replicate for propagation. + // Why is this needed? + // Consider that we have a fusion + // + // T0[I] + // T1[b b I] = broadcast(T0) + // T2[b b r] = reduction(T1) + // T3[b b b] = broadcast(T2) + // T4[b, b, I] = T1 + T3 + // T5[b, b, r] = reduction(T4) + // + // After projection, it becomes + // + // T0[I] + // T1[b b I] = broadcast(T0) + // T2[b b r] = reduction(T1) + // T3[b b b] = broadcast(T2) + // T6[b b I] = broadcast(T0) + // T4[b, b, I] = T6 + T3 + // T5[b, b, r] = reduction(T4) + // + // During schedule, we need to propagate from T2 to T5. However, in the + // resulting DAG, neither the propagation path T2->T3->T4->T5 nor + // T2->T1->T0->T6->T4->T5 works because they both have missing root + // domain. But adding `T7 = T1 + T6` creates a new propagation path + // `T2->T1->T7->T6->T4->T5` which has all root domain information. + // See FusionBroadcastPersistentReduction_CUDA for an example + dummy_outputs.emplace_back(add(buffer_replicate, buffer)); + ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); } } return dummy_outputs; @@ -670,16 +672,17 @@ std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion) { continue; } - for (auto resolution_point : persistent_resolution_points[buffer_i]) { - const auto persistent_use_of_buffer = - getPersistentUseOfBuffer(buffer, resolution_point); - for (auto use : persistent_use_of_buffer) { - TORCH_INTERNAL_ASSERT(use->definition() != nullptr); - auto buffer_replicate = RecomputeTv::recompute(buffer); - ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); - recomputed_tvs.insert(buffer); - recompute_done = true; - } + const auto& resolution_points = persistent_resolution_points.at(buffer_i); + + const auto persistent_use_of_buffer = + getPersistentUseOfBuffer(buffer, resolution_points); + + for (auto use : persistent_use_of_buffer) { + TORCH_INTERNAL_ASSERT(use->definition() != nullptr); + auto buffer_replicate = RecomputeTv::recompute(buffer); + ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); + recomputed_tvs.insert(buffer); + recompute_done = true; } }