diff --git a/third_party/nvfuser/csrc/scheduler/normalization.cpp b/third_party/nvfuser/csrc/scheduler/normalization.cpp index 4a992365552a..2682e3f9195e 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization.cpp @@ -1130,7 +1130,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { std::vector dummy_outputs; if (rparams.project_persistent_buffers && ir_utils::getViewOps(fusion).empty()) { - dummy_outputs = reduction_scheduler_utils::projectPersistentBuffers(fusion); + dummy_outputs = + normalization_scheduler_utils::projectPersistentBuffers(fusion); } // Cache tensors before grabbing any references to reductions as cache_before @@ -1213,6 +1214,8 @@ void schedulePersistentKernel(Fusion* fusion, const ReductionParams& rparams) { cached_outputs, dummy_outputs); + normalization_scheduler_utils::fixUpInvalidPersistentBuffers(fusion); + if (rparams.compute_persistent_buffer_with_first_consumer) { TORCH_INTERNAL_ASSERT( rparams.persistent_kernel, diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp index 479bb4a6e573..b9e3ff6f35b8 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.cpp @@ -1,9 +1,17 @@ +#include +#include +#include +#include +#include #include #include +#include #include #include +#include + namespace nvfuser { namespace normalization_scheduler_utils { @@ -494,5 +502,196 @@ std::optional getGridOuterNormalizationParams( return std::nullopt; } +namespace { + +// Go through the resolution points one by one. Resolution points are points +// in which the reduction branch meets the residual branch. These are points +// where the persitent buffer may no longer be needed (one point could be +// after another, and the buffer would be needed until the last resolution +// points) +std::vector getPersistentUseOfBuffer( + TensorView* buffer, + const std::vector& resolution_points) { + std::vector persistent_use_of_buffer; + + for (auto resolution_point : resolution_points) { + // Need to go through all paths from the persistent buffer to the + // resolution point + auto chains_to_resolution = + DependencyCheck::getAllDependencyChains(buffer, resolution_point); + for (auto chain : chains_to_resolution) { + auto tv_chain = ir_utils::filterByType(chain); + + // To move the persistent buffers to the inputs, we need to recompute + // the persistent buffer for all branches that don't go through a + // reduction. If there's a reduction on the current path between the + // persistent buffer and resolution, continue, there's no need to + // replicate this use. + if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { + return tv->hasReduction(); + })) { + continue; + } + + // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] + // is its first use. + auto use = chain[1]; + + // Only grab unique uses, a persistent buffer could be used multiple + // times in the same expression. + if (std::find( + persistent_use_of_buffer.begin(), + persistent_use_of_buffer.end(), + use) != persistent_use_of_buffer.end()) { + continue; + } + persistent_use_of_buffer.emplace_back(use->as()); + } + } + + return persistent_use_of_buffer; +} + +} // namespace + +std::vector projectPersistentBuffers(Fusion* fusion) { + auto persistent_info = scheduler_utils::persistentBuffers(fusion); + std::vector dummy_outputs; + + // Convenience accessors + const auto& persistent_buffers = persistent_info.persistent_buffers; + const auto& persistent_resolution_points = + persistent_info.persistent_buffer_resolution_points; + const auto& projected_buffers = + persistent_info.projectable_persistent_buffers; + + TORCH_INTERNAL_ASSERT( + persistent_buffers.size() == persistent_resolution_points.size()); + + // Iterate through projected buffers, tracking which index it corresponds too + // since there's a resolution point entry for every buffer. + for (auto buffer_i : c10::irange(persistent_buffers.size())) { + auto buffer = persistent_buffers[buffer_i]; + if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) == + projected_buffers.end()) { + continue; + } + + const auto& resolution_points = persistent_resolution_points.at(buffer_i); + + const auto persistent_use_of_buffer = + getPersistentUseOfBuffer(buffer, resolution_points); + + // For all uses that do not go towards the reduction operations in the + // persistent section of the graph, recompute the persistent buffer. + for (auto use : persistent_use_of_buffer) { + TORCH_INTERNAL_ASSERT(use->definition() != nullptr); + auto buffer_replicate = RecomputeTv::recompute(buffer); + // Create a shortcut buffer <--> buffer_replicate for propagation. + // Why is this needed? + // Consider that we have a fusion + // + // T0[I] + // T1[b b I] = broadcast(T0) + // T2[b b r] = reduction(T1) + // T3[b b b] = broadcast(T2) + // T4[b, b, I] = T1 + T3 + // T5[b, b, r] = reduction(T4) + // + // After projection, it becomes + // + // T0[I] + // T1[b b I] = broadcast(T0) + // T2[b b r] = reduction(T1) + // T3[b b b] = broadcast(T2) + // T6[b b I] = broadcast(T0) + // T4[b, b, I] = T6 + T3 + // T5[b, b, r] = reduction(T4) + // + // During schedule, we need to propagate from T2 to T5. However, in the + // resulting DAG, neither the propagation path T2->T3->T4->T5 nor + // T2->T1->T0->T6->T4->T5 works because they both have missing root + // domain. But adding `T7 = T1 + T6` creates a new propagation path + // `T2->T1->T7->T6->T4->T5` which has all root domain information. + // See FusionBroadcastPersistentReduction_CUDA for an example + dummy_outputs.emplace_back(add(buffer_replicate, buffer)); + ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); + } + } + return dummy_outputs; +} + +std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion) { + auto persistent_info = scheduler_utils::persistentBuffers(fusion); + const auto& persistent_buffers = persistent_info.persistent_buffers; + const auto& persistent_resolution_points = + persistent_info.persistent_buffer_resolution_points; + + std::unique_ptr concretize_info; + + std::unordered_set recomputed_tvs; + + bool recompute_done = false; + + for (auto buffer_i : c10::irange(persistent_buffers.size())) { + auto buffer = persistent_buffers.at(buffer_i); + + // Check if this buffer needs to be recomputed + bool need_recomputation = false; + + for (const auto i : c10::irange( + buffer->getComputeAtPosition(), + buffer->domain()->domain().size())) { + auto axis = buffer->axis(i); + + // Unresolved data dependency could exist if: + // - Parallelized by tidx and stored on Local + // - Parallelized by bidx and stored on Local or Shared + if ((axis->isThreadDim() && + buffer->getMemoryType() == MemoryType::Local) || + (axis->isBlockDim() && + (buffer->getMemoryType() == MemoryType::Local || + buffer->getMemoryType() == MemoryType::Shared))) { + if (!concretize_info) { + concretize_info = + std::make_unique(fusion); + } + + // concretize_info.isConcretized(axis) means the axis is a + // concreteized broadcast domain + if (!concretize_info->isConcretized(axis)) { + continue; + } + + need_recomputation = true; + break; + } + } + + if (!need_recomputation) { + continue; + } + + const auto& resolution_points = persistent_resolution_points.at(buffer_i); + + const auto persistent_use_of_buffer = + getPersistentUseOfBuffer(buffer, resolution_points); + + for (auto use : persistent_use_of_buffer) { + TORCH_INTERNAL_ASSERT(use->definition() != nullptr); + auto buffer_replicate = RecomputeTv::recompute(buffer); + ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); + recomputed_tvs.insert(buffer); + recompute_done = true; + } + } + + if (recompute_done) { + inlineMost(); + } + + return recomputed_tvs; +} + } // namespace normalization_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/scheduler/normalization_utils.h b/third_party/nvfuser/csrc/scheduler/normalization_utils.h index ff353de41a34..0393b25d0f7c 100644 --- a/third_party/nvfuser/csrc/scheduler/normalization_utils.h +++ b/third_party/nvfuser/csrc/scheduler/normalization_utils.h @@ -8,6 +8,9 @@ #include namespace nvfuser { + +class Fusion; + namespace normalization_scheduler_utils { //! Utility class to iterate candidates of launch configurations in a @@ -145,5 +148,24 @@ std::optional getGridOuterNormalizationParams( int64_t vectorize_factor, int64_t persistent_buffer_size); +// Take all projectable persistent buffers, and move them to the inputs. This +// function create dummy outputs which should be used in later stages of the +// scheduling. +TORCH_CUDA_CU_API std::vector projectPersistentBuffers( + Fusion* fusion); + +//! Persistent buffers are effectively tensors that cannot be inlined +//! due to the reduction and broadcast pattern. Since we store +//! persistent buffers in registers, they have to be parallelized in +//! such a way that no data dependency exists between threads. Buffers +//! that do have dependencies cannot be persistent. This function +//! detects such buffers and make them non-persistent by +//! recomputation. Returns buffers that are made non-persistent. +//! +//! Alternatively, such buffers could be stored on shared memory if +//! the dependecy only exists between threads in the same thread +//! block. Not considered yet. +std::unordered_set fixUpInvalidPersistentBuffers(Fusion* fusion); + } // namespace normalization_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp b/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp index 4b32898fd78d..22d84e02a5a2 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp +++ b/third_party/nvfuser/csrc/scheduler/reduction_utils.cpp @@ -656,111 +656,5 @@ TensorView* sortAndRFactor(TensorView* reference_tv) { return ir_utils::rfactorHelper(reference_tv, rfactor_axes); } -std::vector projectPersistentBuffers(Fusion* fusion) { - auto persistent_info = scheduler_utils::persistentBuffers(fusion); - std::vector dummy_outputs; - - // Convenience accessors - const auto& persistent_buffers = persistent_info.persistent_buffers; - const auto& persistent_resolution_points = - persistent_info.persistent_buffer_resolution_points; - const auto& projected_buffers = - persistent_info.projectable_persistent_buffers; - - TORCH_INTERNAL_ASSERT( - persistent_buffers.size() == persistent_resolution_points.size()); - - // Iterate through projected buffers, tracking which index it corresponds too - // since there's a resolution point entry for every buffer. - for (auto buffer_i : c10::irange(persistent_buffers.size())) { - auto buffer = persistent_buffers[buffer_i]; - if (std::find(projected_buffers.begin(), projected_buffers.end(), buffer) == - projected_buffers.end()) { - continue; - } - - auto resolution_points = persistent_resolution_points[buffer_i]; - - std::vector persistent_use_of_buffer; - - // Go through the resolution points one by one. Resolution points are points - // in which the reduction branch meets the residual branch. These are points - // where the persitent buffer may no longer be needed (one point could be - // after another, and the buffer would be needed until the last resolution - // points) - for (auto resolution_point : resolution_points) { - // Need to go through all paths from the persistent buffer to the - // resolution point - auto chains_to_resolution = - DependencyCheck::getAllDependencyChains(buffer, resolution_point); - for (auto chain : chains_to_resolution) { - auto tv_chain = ir_utils::filterByType(chain); - - // To move the persistent buffers to the inputs, we need to recompute - // the persistent buffer for all branches that don't go through a - // reduction. If there's a reduction on the current path between the - // persistent buffer and resolution, continue, there's no need to - // replicate this use. - if (std::any_of(tv_chain.begin(), tv_chain.end(), [](TensorView* tv) { - return tv->hasReduction(); - })) { - continue; - } - - // Grab use of the buffer, chain[0] is the persistent buffer, chain[1] - // is its first use. - auto use = chain[1]; - - // Only grab unique uses, a persistent buffer could be used multiple - // times in the same expression. - if (std::find( - persistent_use_of_buffer.begin(), - persistent_use_of_buffer.end(), - use) != persistent_use_of_buffer.end()) { - continue; - } - persistent_use_of_buffer.emplace_back(use); - } - } - - // For all uses that do not go towards the reduction operations in the - // persistent section of the graph, recompute the persistent buffer. - for (auto use : persistent_use_of_buffer) { - TORCH_INTERNAL_ASSERT(use->definition() != nullptr); - auto buffer_replicate = RecomputeTv::recompute(buffer); - // Create a shortcut buffer <--> buffer_replicate for propagation. - // Why is this needed? - // Consider that we have a fusion - // - // T0[I] - // T1[b b I] = broadcast(T0) - // T2[b b r] = reduction(T1) - // T3[b b b] = broadcast(T2) - // T4[b, b, I] = T1 + T3 - // T5[b, b, r] = reduction(T4) - // - // After projection, it becomes - // - // T0[I] - // T1[b b I] = broadcast(T0) - // T2[b b r] = reduction(T1) - // T3[b b b] = broadcast(T2) - // T6[b b I] = broadcast(T0) - // T4[b, b, I] = T6 + T3 - // T5[b, b, r] = reduction(T4) - // - // During schedule, we need to propagate from T2 to T5. However, in the - // resulting DAG, neither the propagation path T2->T3->T4->T5 nor - // T2->T1->T0->T6->T4->T5 works because they both have missing root - // domain. But adding `T7 = T1 + T6` creates a new propagation path - // `T2->T1->T7->T6->T4->T5` which has all root domain information. - // See FusionBroadcastPersistentReduction_CUDA for an example - dummy_outputs.emplace_back(add(buffer_replicate, buffer)); - ir_utils::replaceValInExpr(use->definition(), buffer, buffer_replicate); - } - } - return dummy_outputs; -} - } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/csrc/scheduler/reduction_utils.h b/third_party/nvfuser/csrc/scheduler/reduction_utils.h index e05ada8a618b..adb6548a1974 100644 --- a/third_party/nvfuser/csrc/scheduler/reduction_utils.h +++ b/third_party/nvfuser/csrc/scheduler/reduction_utils.h @@ -41,11 +41,5 @@ TORCH_CUDA_CU_API void multiReductionInliner( // Reduction inliner expects an rfactored domain. TORCH_CUDA_CU_API TensorView* sortAndRFactor(TensorView* reference_tv); -// Take all projectable persistent buffers, and move them to the inputs. This -// function create dummy outputs which should be used in later stages of the -// scheduling. -TORCH_CUDA_CU_API std::vector projectPersistentBuffers( - Fusion* fusion); - } // namespace reduction_scheduler_utils } // namespace nvfuser diff --git a/third_party/nvfuser/test/test_gpu2.cpp b/third_party/nvfuser/test/test_gpu2.cpp index ab905d91d39e..6f790dc039c8 100644 --- a/third_party/nvfuser/test/test_gpu2.cpp +++ b/third_party/nvfuser/test/test_gpu2.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -9614,7 +9615,7 @@ TEST_F(NVFuserTest, FusionPersistentBufferProjection_CUDA) { fusion.addOutput(tv9); - reduction_scheduler_utils::projectPersistentBuffers(&fusion); + normalization_scheduler_utils::projectPersistentBuffers(&fusion); auto tv5_producers = ir_utils::producerTvsOf(tv5); auto tv7_producers = ir_utils::producerTvsOf(tv7); diff --git a/third_party/nvfuser/test/test_gpu3.cpp b/third_party/nvfuser/test/test_gpu3.cpp index 0f40b9441390..1ba2dbbfe369 100644 --- a/third_party/nvfuser/test/test_gpu3.cpp +++ b/third_party/nvfuser/test/test_gpu3.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -7881,6 +7882,81 @@ TEST_F(NVFuserTest, FusionCompileIndexType_CUDA) { } } +// Repro of #2559 +TEST_F(NVFuserTest, FusionNonMatchingPersistentBuffer_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1, DataType::Half); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(2, DataType::Half); + fusion.addInput(tv1); + + auto tv2 = castOp(DataType::Float, tv0); + auto tv3 = castOp(DataType::Float, tv1); + + auto tv4 = broadcast(tv2, {false, true}); + auto tv5 = add(tv4, tv3); + + auto tv6 = sum(tv5, {0, 1}); + auto tv7 = broadcast(tv6, {true, true}); + + auto tv8 = add(tv3, tv7); + auto tv9 = add(tv4, tv7); + auto tv10 = add(tv8, tv9); + + fusion.addOutput(tv10); + + const int tidx = 256; + const int pb = 8; + const int bidx = 2; + + tv6->merge(0); + // tidx + tv6->split(0, tidx); + // persistent buffer + tv6->split(0, pb); + // [BIDx, PB, TIDX] + + MaxRootDomainInfoSpanningTree tree(tv6); + TransformPropagator tp(tv6); + tree.traverse(&tp); + + auto tv6_rf = tv6->rFactor({1}); + + inlineMost(); + + tv6_rf->axis(0)->parallelize(ParallelType::BIDx); + tv6_rf->axis(2)->parallelize(ParallelType::TIDx); + scheduler_utils::parallelizeAllLike(tv6_rf); + + normalization_scheduler_utils::fixUpInvalidPersistentBuffers(&fusion); + + std::vector shape({bidx, tidx * pb}); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::manual_seed(0); + at::Tensor t0 = at::randn({shape[0]}, options); + at::Tensor t1 = at::randn(shape, options); + std::vector inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + auto t2 = t0.to(at::kDouble); + auto t3 = t1.to(at::kDouble); + auto t4 = t2.unsqueeze(-1); + auto t5 = t4 + t3; + auto t6 = sum(t5); + auto t7 = t6.unsqueeze(-1).unsqueeze(-1); + auto t8 = t3 + t7; + auto t9 = t4 + t7; + auto t10 = t8 + t9; + + testValidate(fe.kernel(), cg_outputs, inputs, {t10}, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser