From 67721286fc1f46d4ef12df52003ed71cd7151505 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 16:38:45 -0700 Subject: [PATCH 01/81] initial code shelve --- CMakeLists.txt | 6 +- csrc/optimization/opt_pass.cpp | 64 +++++++++++++++++++ csrc/optimization/opt_pass.h | 28 ++++++++ .../optimize_consecutive_cast.cpp | 27 ++++++++ 4 files changed, 120 insertions(+), 5 deletions(-) create mode 100644 csrc/optimization/opt_pass.cpp create mode 100644 csrc/optimization/opt_pass.h create mode 100644 csrc/optimization/optimize_consecutive_cast.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 40d6a9aabc4..23ac054eb61 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,11 +41,7 @@ else() # PROJECT_IS_TOP_LEVEL return() endif() - if(NOT USE_ROCM) - set(TORCHLIB_FLAVOR torch_cuda) - else() - set(TORCHLIB_FLAVOR torch_hip) - endif() + set(TORCHLIB_FLAVOR torch_cuda) # TODO: have TORCH_ROOT setup as a variable instead # currently we are expecting nvfuser to be added from the pytorch root cmake file. diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp new file mode 100644 index 00000000000..06cddc42fad --- /dev/null +++ b/csrc/optimization/opt_pass.cpp @@ -0,0 +1,64 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +namespace { + +class OptimizationRegistry { + public: + struct PassEntry { + int priority_; + FusionPass pass_; + std::string name_; + }; + + void register(const OptimizationPassCategory& cat, FusionPass func, std::string name_, int priority) { + + std::guard guard(mutex_); + auto& pass_entry_list = pass_categories_[cat]; + entry_iter = pass_entry_list.begin(); + while (entry_iter != pass_entry_list.end()) { + if (entry_iter->priority_ < priority) { + break; + } + } + pass_entry_list.emplace(entry_iter, priority, std::move(func), std::move(name_)); + } + + void apply(const OptimizationPassCategory& cat, Fusion* fusion) { + std::guard guard(mutex_); + const auto& pass_entry_list = pass_categories_[cat]; + for (const auto& entry : pass_entry_list) { + entry.pass_(fusion); + } + } + + static OptimizationRegistry& getInstance() { + static OptimizationRegistry registry; + return registry; + } + + protected: + // TODO: read access mutex_ should/could be optimized, since graph pass is thread-safe. + std::mutex mutex_; + std::unordered_map> pass_categories_; +}; + +} // namespace + +void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass pass, int priority) { + OptimizationRegistry::getInstance().register(category, pass.func(), pass.name(), priority); +} + +void applyOptimizationPass(const OptimizationPassCategory& category, Fusion* fusion) { + OptimizationRegistry::getInstance().apply(category, fusion); +} + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h new file mode 100644 index 00000000000..415908b8fb7 --- /dev/null +++ b/csrc/optimization/opt_pass.h @@ -0,0 +1,28 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser::optimization { + +TORCH_CUDA_CU_API enum class OptimizationPassCategory { PreSegmenter, Null }; +using FusionPass = std::function>; + +class OptimizationPass { + public: + FusionPass func() = 0; + std::string name() = 0; +}; + +// higher priority pass runs earlier +// newly registered pass runs at the end of all passes with identical priority +TORCH_CUDA_CU_API void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass& pass, int priority = 0); +TORCH_CUDA_CU_API void applyOptimizationPass(const OptimizationPassCategory& category, Fusion* fusion); + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp new file mode 100644 index 00000000000..5cef81fb8bd --- /dev/null +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -0,0 +1,27 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +namespace { + +class ConsecutiveCastPass : OptimizationPass { + public: + static void runPass(Fusion* fusion) { + } + std::string name() { return "ConsecutiveCastOptimization"; } + + ConsecutiveCastPass() { + registerOptimizationPass(OptimizationPassCategory::PreSegmenter, *this); + } +}; + +static Register register; + +} // namespace nvfuser::optimization From 3bb0edbdea58bfda77ad65f8929256494fbf4084 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 21:21:20 -0700 Subject: [PATCH 02/81] wip add build files --- CMakeLists.txt | 2 ++ csrc/kernel_cache.cpp | 4 ++++ csrc/optimization/optimize_consecutive_cast.cpp | 2 ++ 3 files changed, 8 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 23ac054eb61..88ada643ca1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,6 +180,8 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/utils.cpp ${NVFUSER_SRCS_DIR}/mma_type.cpp ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp + ${NVFUSER_SRCS_DIR}/optimization/opt_pass.cpp + ${NVFUSER_SRCS_DIR}/optimization/optimize_consecutive_cast.cpp ) set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index a12348ef9a4..fe96887c483 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -14,6 +14,8 @@ #include #include #include +#include + #include #include @@ -472,6 +474,8 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); + applyOptimizationPass(optimization::OptimizationPassCategory::PreSegmenter, fusion.get()); + all_tvs_ = ir_utils::allTvs(fusion.get()); // Run segmentation on the copied fusion diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 5cef81fb8bd..34be0b7e9ec 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -14,6 +14,8 @@ namespace { class ConsecutiveCastPass : OptimizationPass { public: static void runPass(Fusion* fusion) { + std::cout << "running optimization pass on fusion: " << std::endl; + fusion->printMath(); } std::string name() { return "ConsecutiveCastOptimization"; } From 28b488ab0da6f08338cea92df62f4315b8669df5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 22:00:43 -0700 Subject: [PATCH 03/81] fixing build --- csrc/optimization/opt_pass.cpp | 17 ++++++++++------- csrc/optimization/opt_pass.h | 10 +++++----- csrc/optimization/optimize_consecutive_cast.cpp | 9 ++++++--- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index 06cddc42fad..fefed5e4571 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -7,6 +7,8 @@ // clang-format on #include +#include + namespace nvfuser::optimization { namespace { @@ -17,23 +19,24 @@ class OptimizationRegistry { int priority_; FusionPass pass_; std::string name_; + PassEntry(int priority, FusionPass pass, std::string name) : priority_(priority), pass_(std::move(pass)), name_(std::move(name_)) {} }; - void register(const OptimizationPassCategory& cat, FusionPass func, std::string name_, int priority) { - - std::guard guard(mutex_); + void registerPass(const OptimizationPassCategory& cat, FusionPass func, std::string name_, int priority) { + std::lock_guard guard(mutex_); auto& pass_entry_list = pass_categories_[cat]; - entry_iter = pass_entry_list.begin(); + auto entry_iter = pass_entry_list.begin(); while (entry_iter != pass_entry_list.end()) { if (entry_iter->priority_ < priority) { break; } + entry_iter++; } pass_entry_list.emplace(entry_iter, priority, std::move(func), std::move(name_)); } void apply(const OptimizationPassCategory& cat, Fusion* fusion) { - std::guard guard(mutex_); + std::lock_guard guard(mutex_); const auto& pass_entry_list = pass_categories_[cat]; for (const auto& entry : pass_entry_list) { entry.pass_(fusion); @@ -53,8 +56,8 @@ class OptimizationRegistry { } // namespace -void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass pass, int priority) { - OptimizationRegistry::getInstance().register(category, pass.func(), pass.name(), priority); +void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass* pass, int priority) { + OptimizationRegistry::getInstance().registerPass(category, pass->func(), pass->name(), priority); } void applyOptimizationPass(const OptimizationPassCategory& category, Fusion* fusion) { diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index 415908b8fb7..c62d62648ae 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -11,18 +11,18 @@ namespace nvfuser::optimization { -TORCH_CUDA_CU_API enum class OptimizationPassCategory { PreSegmenter, Null }; -using FusionPass = std::function>; +enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter, Null }; +using FusionPass = std::function; class OptimizationPass { public: - FusionPass func() = 0; - std::string name() = 0; + virtual FusionPass func() = 0; + virtual std::string name() = 0; }; // higher priority pass runs earlier // newly registered pass runs at the end of all passes with identical priority -TORCH_CUDA_CU_API void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass& pass, int priority = 0); +TORCH_CUDA_CU_API void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass* pass, int priority = 0); TORCH_CUDA_CU_API void applyOptimizationPass(const OptimizationPassCategory& category, Fusion* fusion); } // namespace nvfuser::optimization diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 34be0b7e9ec..76b98a2bfe2 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -17,13 +17,16 @@ class ConsecutiveCastPass : OptimizationPass { std::cout << "running optimization pass on fusion: " << std::endl; fusion->printMath(); } - std::string name() { return "ConsecutiveCastOptimization"; } + std::string name() override { return "ConsecutiveCastOptimization"; } + FusionPass func() override { return runPass; } ConsecutiveCastPass() { - registerOptimizationPass(OptimizationPassCategory::PreSegmenter, *this); + registerOptimizationPass(OptimizationPassCategory::PreSegmenter, this); } }; -static Register register; +static ConsecutiveCastPass register_; + +} } // namespace nvfuser::optimization From 6675a7e79c17466a9453178135a1c8e9c70f39cf Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 23:01:34 -0700 Subject: [PATCH 04/81] adding optimization; adding test --- .../optimize_consecutive_cast.cpp | 31 ++++++++++++- test/test_gpu3.cpp | 45 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 76b98a2bfe2..25c7ec0e43e 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include namespace nvfuser::optimization { @@ -14,7 +15,35 @@ namespace { class ConsecutiveCastPass : OptimizationPass { public: static void runPass(Fusion* fusion) { - std::cout << "running optimization pass on fusion: " << std::endl; + auto is_cast_op = [] (Expr* expr) { + if (expr->isA()) { + auto op = expr->as(); + if (op->getUnaryOpType() == UnaryOpType::Cast) { + return true; + } + } + return false; + }; + + std::cout << "original fusion:" << std::endl; + fusion->printMath(); + + // NOTE: not the most efficient pass + for (auto expr : fusion->exprs()) { + if (is_cast_op(expr)) { + while (true) { + // in the loop, we just repetitively skip consecutive casts. + auto intermeidate_cast = expr->input(0); + auto prev_expr = intermeidate_cast->definition(); + if (prev_expr!=nullptr && is_cast_op(prev_expr)) { + replaceValInExpr(expr, intermediate_cast, prev_expr->input(0)); + } else { + break; + } + } + } + } + std::cout << "after mutation fusion:" << std::endl; fusion->printMath(); } std::string name() override { return "ConsecutiveCastOptimization"; } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 9966a26c036..ce75e811f10 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8434,6 +8434,51 @@ TEST_F(NVFuserTest, FusionTestSegmenterHint_CUDA) { executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); } +// Test cast optimization +TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + std::vector input_shape{32, 64, 8, 128}; + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv1 = castOp(DataType::Half, tv0); + auto tv2 = castOp(DataType::Float, tv1); + auto tv3 = relu(tv2); + auto tv4 = neg(tv3); + auto tv5 = castOp(DataType::Double, tv4); + fusion->addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({at_x}); + auto ref_out = at_x.clone().relu().neg(); + + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + + // TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen"); + // auto groups = optimized_fusion->fusionSegments()->groups(); + // TORCH_CHECK( + // groups.size() == 2, "segmentation hint isn't working as expected"); + // with the hint, segment_set should be grouped with its producer + // [relu, segment_set], [neg] + // for (auto& group : groups) { + // // we only check the group with a single node + // if (group->exprs().size() == 1) { + // auto relu_expr = group->exprs()[0]; + // TORCH_CHECK( + // relu_expr->isA() && + // relu_expr->as()->getUnaryOpType() == UnaryOpType::Neg, + // "segmentation result is not expected"); + // } + // } + testValidate( + executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From ec9827af55d9842662db14befe2cc3521563c5e9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 23:34:14 -0700 Subject: [PATCH 05/81] fixing tests --- csrc/optimization/optimize_consecutive_cast.cpp | 8 ++++---- test/test_gpu3.cpp | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 25c7ec0e43e..a4b349d59b6 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -6,7 +6,7 @@ */ // clang-format on #include -#include +#include namespace nvfuser::optimization { @@ -33,10 +33,10 @@ class ConsecutiveCastPass : OptimizationPass { if (is_cast_op(expr)) { while (true) { // in the loop, we just repetitively skip consecutive casts. - auto intermeidate_cast = expr->input(0); - auto prev_expr = intermeidate_cast->definition(); + auto intermediate_cast = expr->input(0); + auto prev_expr = intermediate_cast->definition(); if (prev_expr!=nullptr && is_cast_op(prev_expr)) { - replaceValInExpr(expr, intermediate_cast, prev_expr->input(0)); + expr = nvfuser::ir_utils::replaceValInExpr(expr, intermediate_cast, prev_expr->input(0)); } else { break; } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index ce75e811f10..4374bd64075 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8457,8 +8457,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { auto outputs = executor_cache.runFusionWithInputs({at_x}); auto ref_out = at_x.clone().relu().neg(); - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - + // auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); // TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen"); // auto groups = optimized_fusion->fusionSegments()->groups(); // TORCH_CHECK( From 623ae4fad671c591688a707b4bacbce59b136cd3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 23:54:08 -0700 Subject: [PATCH 06/81] fixing tests --- test/test_gpu3.cpp | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 4374bd64075..6ccdebdd317 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8444,7 +8444,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { .dtype(DataType::Double) .build(); fusion->addInput(tv0); - auto tv1 = castOp(DataType::Half, tv0); + auto tv1 = castOp(DataType::Half, tv0); // consecutive cast should be removed auto tv2 = castOp(DataType::Float, tv1); auto tv3 = relu(tv2); auto tv4 = neg(tv3); @@ -8457,23 +8457,16 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { auto outputs = executor_cache.runFusionWithInputs({at_x}); auto ref_out = at_x.clone().relu().neg(); - // auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - // TORCH_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen"); - // auto groups = optimized_fusion->fusionSegments()->groups(); - // TORCH_CHECK( - // groups.size() == 2, "segmentation hint isn't working as expected"); - // with the hint, segment_set should be grouped with its producer - // [relu, segment_set], [neg] - // for (auto& group : groups) { - // // we only check the group with a single node - // if (group->exprs().size() == 1) { - // auto relu_expr = group->exprs()[0]; - // TORCH_CHECK( - // relu_expr->isA() && - // relu_expr->as()->getUnaryOpType() == UnaryOpType::Neg, - // "segmentation result is not expected"); - // } - // } + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); + int cast_op_count = 0; + for (auto expr : complete_fusion->exprs()) { + if (expr->isA() && expr->as()->getUnaryOpType() == UnaryOpType::Cast) { + ++cast_op_count; + } + } + TORCH_CHECK(cast_op_count == 2, "cast optimization isn't working as expected"); + testValidate( executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); } From 2c4d4c196c6db015229aa65c5045d143e33f8c59 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 16 May 2023 23:55:04 -0700 Subject: [PATCH 07/81] remove debug print --- csrc/optimization/optimize_consecutive_cast.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index a4b349d59b6..7acb1513929 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -25,9 +25,6 @@ class ConsecutiveCastPass : OptimizationPass { return false; }; - std::cout << "original fusion:" << std::endl; - fusion->printMath(); - // NOTE: not the most efficient pass for (auto expr : fusion->exprs()) { if (is_cast_op(expr)) { @@ -43,8 +40,6 @@ class ConsecutiveCastPass : OptimizationPass { } } } - std::cout << "after mutation fusion:" << std::endl; - fusion->printMath(); } std::string name() override { return "ConsecutiveCastOptimization"; } FusionPass func() override { return runPass; } From b2d9ed8df6386c5aee197067e5cc408c4f4e3179 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 11:23:56 -0700 Subject: [PATCH 08/81] clangformat --- csrc/kernel_cache.cpp | 5 +-- csrc/optimization/opt_pass.cpp | 32 ++++++++++++++----- csrc/optimization/opt_pass.h | 9 ++++-- .../optimize_consecutive_cast.cpp | 21 +++++++----- test/test_gpu3.cpp | 6 ++-- 5 files changed, 51 insertions(+), 22 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 2280ab0bc17..ffebfb1e433 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -11,10 +11,10 @@ #include #include #include +#include #include #include #include -#include #include #include @@ -473,7 +473,8 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); - applyOptimizationPass(optimization::OptimizationPassCategory::PreSegmenter, fusion.get()); + applyOptimizationPass( + optimization::OptimizationPassCategory::PreSegmenter, fusion.get()); all_tvs_ = ir_utils::allTvs(fusion.get()); diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index fefed5e4571..91347e700a1 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -19,10 +19,17 @@ class OptimizationRegistry { int priority_; FusionPass pass_; std::string name_; - PassEntry(int priority, FusionPass pass, std::string name) : priority_(priority), pass_(std::move(pass)), name_(std::move(name_)) {} + PassEntry(int priority, FusionPass pass, std::string name) + : priority_(priority), + pass_(std::move(pass)), + name_(std::move(name_)) {} }; - void registerPass(const OptimizationPassCategory& cat, FusionPass func, std::string name_, int priority) { + void registerPass( + const OptimizationPassCategory& cat, + FusionPass func, + std::string name_, + int priority) { std::lock_guard guard(mutex_); auto& pass_entry_list = pass_categories_[cat]; auto entry_iter = pass_entry_list.begin(); @@ -32,7 +39,8 @@ class OptimizationRegistry { } entry_iter++; } - pass_entry_list.emplace(entry_iter, priority, std::move(func), std::move(name_)); + pass_entry_list.emplace( + entry_iter, priority, std::move(func), std::move(name_)); } void apply(const OptimizationPassCategory& cat, Fusion* fusion) { @@ -49,18 +57,26 @@ class OptimizationRegistry { } protected: - // TODO: read access mutex_ should/could be optimized, since graph pass is thread-safe. + // TODO: read access mutex_ should/could be optimized, since graph pass is + // thread-safe. std::mutex mutex_; - std::unordered_map> pass_categories_; + std::unordered_map> + pass_categories_; }; } // namespace -void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass* pass, int priority) { - OptimizationRegistry::getInstance().registerPass(category, pass->func(), pass->name(), priority); +void registerOptimizationPass( + const OptimizationPassCategory& category, + OptimizationPass* pass, + int priority) { + OptimizationRegistry::getInstance().registerPass( + category, pass->func(), pass->name(), priority); } -void applyOptimizationPass(const OptimizationPassCategory& category, Fusion* fusion) { +void applyOptimizationPass( + const OptimizationPassCategory& category, + Fusion* fusion) { OptimizationRegistry::getInstance().apply(category, fusion); } diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index c62d62648ae..a02aa86773c 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -22,7 +22,12 @@ class OptimizationPass { // higher priority pass runs earlier // newly registered pass runs at the end of all passes with identical priority -TORCH_CUDA_CU_API void registerOptimizationPass(const OptimizationPassCategory& category, OptimizationPass* pass, int priority = 0); -TORCH_CUDA_CU_API void applyOptimizationPass(const OptimizationPassCategory& category, Fusion* fusion); +TORCH_CUDA_CU_API void registerOptimizationPass( + const OptimizationPassCategory& category, + OptimizationPass* pass, + int priority = 0); +TORCH_CUDA_CU_API void applyOptimizationPass( + const OptimizationPassCategory& category, + Fusion* fusion); } // namespace nvfuser::optimization diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 7acb1513929..9a21f93d844 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -5,8 +5,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include +#include namespace nvfuser::optimization { @@ -15,7 +15,7 @@ namespace { class ConsecutiveCastPass : OptimizationPass { public: static void runPass(Fusion* fusion) { - auto is_cast_op = [] (Expr* expr) { + auto is_cast_op = [](Expr* expr) { if (expr->isA()) { auto op = expr->as(); if (op->getUnaryOpType() == UnaryOpType::Cast) { @@ -32,8 +32,9 @@ class ConsecutiveCastPass : OptimizationPass { // in the loop, we just repetitively skip consecutive casts. auto intermediate_cast = expr->input(0); auto prev_expr = intermediate_cast->definition(); - if (prev_expr!=nullptr && is_cast_op(prev_expr)) { - expr = nvfuser::ir_utils::replaceValInExpr(expr, intermediate_cast, prev_expr->input(0)); + if (prev_expr != nullptr && is_cast_op(prev_expr)) { + expr = nvfuser::ir_utils::replaceValInExpr( + expr, intermediate_cast, prev_expr->input(0)); } else { break; } @@ -41,16 +42,20 @@ class ConsecutiveCastPass : OptimizationPass { } } } - std::string name() override { return "ConsecutiveCastOptimization"; } - FusionPass func() override { return runPass; } + std::string name() override { + return "ConsecutiveCastOptimization"; + } + FusionPass func() override { + return runPass; + } ConsecutiveCastPass() { - registerOptimizationPass(OptimizationPassCategory::PreSegmenter, this); + registerOptimizationPass(OptimizationPassCategory::PreSegmenter, this); } }; static ConsecutiveCastPass register_; -} +} // namespace } // namespace nvfuser::optimization diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 6ccdebdd317..b5f37a7be01 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8461,11 +8461,13 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); int cast_op_count = 0; for (auto expr : complete_fusion->exprs()) { - if (expr->isA() && expr->as()->getUnaryOpType() == UnaryOpType::Cast) { + if (expr->isA() && + expr->as()->getUnaryOpType() == UnaryOpType::Cast) { ++cast_op_count; } } - TORCH_CHECK(cast_op_count == 2, "cast optimization isn't working as expected"); + TORCH_CHECK( + cast_op_count == 2, "cast optimization isn't working as expected"); testValidate( executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); From 72da69c74178d588399f7122fb1370dbfacd39e1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 13:21:31 -0700 Subject: [PATCH 09/81] short-cut to skip trivial casting --- csrc/optimization/optimize_consecutive_cast.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 9a21f93d844..ee87e252030 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -26,8 +26,10 @@ class ConsecutiveCastPass : OptimizationPass { }; // NOTE: not the most efficient pass + std::unordered_map replacement_map; for (auto expr : fusion->exprs()) { if (is_cast_op(expr)) { + bool mutated = false; while (true) { // in the loop, we just repetitively skip consecutive casts. auto intermediate_cast = expr->input(0); @@ -35,16 +37,29 @@ class ConsecutiveCastPass : OptimizationPass { if (prev_expr != nullptr && is_cast_op(prev_expr)) { expr = nvfuser::ir_utils::replaceValInExpr( expr, intermediate_cast, prev_expr->input(0)); + mutated = true; } else { break; } } + + if (mutated) { + // quick short-wire to skip current cast node if it's trivially casting to the same type + if (expr->input(0)->getDataType().value() == expr->output(0)->getDataType().value()) { + replacement_map[expr->output(0)] = expr->input(0); + } + } } } + if (!replacement_map.empty()) { + nvfuser::ir_utils::replaceValue(fusion, replacement_map); + } } + std::string name() override { return "ConsecutiveCastOptimization"; } + FusionPass func() override { return runPass; } From b786160235d850d58b4f8b45229b0bbff29310e8 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 16:47:06 -0700 Subject: [PATCH 10/81] fixing logic in cast operation. updating cpp tests --- .../optimize_consecutive_cast.cpp | 25 +++- test/test_gpu3.cpp | 128 ++++++++++++++---- 2 files changed, 121 insertions(+), 32 deletions(-) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index ee87e252030..9e48f6c3fb1 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -35,9 +35,24 @@ class ConsecutiveCastPass : OptimizationPass { auto intermediate_cast = expr->input(0); auto prev_expr = intermediate_cast->definition(); if (prev_expr != nullptr && is_cast_op(prev_expr)) { - expr = nvfuser::ir_utils::replaceValInExpr( - expr, intermediate_cast, prev_expr->input(0)); - mutated = true; + auto original_dtype = prev_expr->input(0)->getDataType().value(); + auto intermediate_dtype = intermediate_cast->getDataType().value(); + auto out_dtype = expr->output(0)->getDataType().value(); + // cases where skipping the intermediate cast is relatively safe, either: + // 1. intermediate is the same type category; + // 2. intermediate is a floating point while output is integral; + // and we support direct cast from input dtype to output dtype. + if (cast_func_str({original_dtype, out_dtype}).has_value() && + ((isIntegralType(intermediate_dtype) && isIntegralType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && isFloatingPointType(out_dtype)) || + (isComplexType(intermediate_dtype) && isComplexType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && isIntegralType(out_dtype)))) { + expr = nvfuser::ir_utils::replaceValInExpr( + expr, intermediate_cast, prev_expr->input(0)); + mutated = true; + } else { + break; + } } else { break; } @@ -47,6 +62,10 @@ class ConsecutiveCastPass : OptimizationPass { // quick short-wire to skip current cast node if it's trivially casting to the same type if (expr->input(0)->getDataType().value() == expr->output(0)->getDataType().value()) { replacement_map[expr->output(0)] = expr->input(0); + // NOTE: if current output is a fusion output, DCE won't kick in and we'll ended up with an illegal cast. + if (expr->output(0)->isFusionOutput()) { + fusion->replaceOutput(expr->output(0), expr->input(0)); + } } } } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index b5f37a7be01..0a4f6ee14cf 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8436,41 +8436,111 @@ TEST_F(NVFuserTest, FusionTestSegmenterHint_CUDA) { // Test cast optimization TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - std::vector input_shape{32, 64, 8, 128}; - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Double) - .build(); - fusion->addInput(tv0); - auto tv1 = castOp(DataType::Half, tv0); // consecutive cast should be removed - auto tv2 = castOp(DataType::Float, tv1); - auto tv3 = relu(tv2); - auto tv4 = neg(tv3); - auto tv5 = castOp(DataType::Double, tv4); - fusion->addOutput(tv5); - + std::vector input_shape{3, 7, 8}; auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); at::Tensor at_x = at::randn(input_shape, options); - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({at_x}); - auto ref_out = at_x.clone().relu().neg(); - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); - int cast_op_count = 0; - for (auto expr : complete_fusion->exprs()) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Cast) { - ++cast_op_count; + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv1 = castOp(DataType::Half, tv0); // consecutive cast should be removed + auto tv2 = castOp(DataType::Float, tv1); + auto tv3 = relu(tv2); + auto tv4 = neg(tv3); + auto tv5 = castOp(DataType::Double, tv4); + fusion->addOutput(tv5); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({at_x}); + auto ref_out = at_x.clone().relu().neg(); + + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); + int cast_op_count = 0; + for (auto expr : complete_fusion->exprs()) { + if (expr->isA() && + expr->as()->getUnaryOpType() == UnaryOpType::Cast) { + ++cast_op_count; + } } + TORCH_CHECK( + cast_op_count == 2, "cast optimization isn't working as expected"); + + testValidate( + executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); } - TORCH_CHECK( - cast_op_count == 2, "cast optimization isn't working as expected"); - testValidate( - executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv1 = castOp(DataType::Int, tv0); + // previous cast cannot be optimized away due to precision + auto tv2 = castOp(DataType::Float, tv1); + auto tv3 = neg(tv2); + fusion->addOutput(tv3); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({at_x}); + auto ref_out = at_x.clone().int().float().neg(); + + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); + int cast_op_count = 0; + for (auto expr : complete_fusion->exprs()) { + if (expr->isA() && + expr->as()->getUnaryOpType() == UnaryOpType::Cast) { + ++cast_op_count; + } + } + TORCH_CHECK( + cast_op_count == 2, "cast optimization isn't working as expected"); + + testValidate( + executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv1 = neg(tv0); + auto tv2 = castOp(DataType::Float, tv1); + auto tv3 = castOp(DataType::Double, tv2); + fusion->addOutput(tv3); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto outputs = executor_cache.runFusionWithInputs({at_x}); + auto ref_out = at_x.clone().neg(); + + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); + int cast_op_count = 0; + for (auto expr : complete_fusion->exprs()) { + if (expr->isA() && + expr->as()->getUnaryOpType() == UnaryOpType::Cast) { + ++cast_op_count; + } + } + TORCH_CHECK( + cast_op_count == 0, "cast optimization isn't working as expected"); + + testValidate( + executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); + } } // Test file size should be up to 10K LoC. Create a new file for more tests. From 816a5a0a19474566deb6f0ee8dfa8aa8b8d5455f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 17:09:37 -0700 Subject: [PATCH 11/81] fixing logic in safety check for cast optimization; fixing test --- csrc/optimization/optimize_consecutive_cast.cpp | 10 ++++++---- test/test_gpu3.cpp | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 9e48f6c3fb1..7515eeea22c 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -38,10 +38,12 @@ class ConsecutiveCastPass : OptimizationPass { auto original_dtype = prev_expr->input(0)->getDataType().value(); auto intermediate_dtype = intermediate_cast->getDataType().value(); auto out_dtype = expr->output(0)->getDataType().value(); - // cases where skipping the intermediate cast is relatively safe, either: - // 1. intermediate is the same type category; - // 2. intermediate is a floating point while output is integral; - // and we support direct cast from input dtype to output dtype. + // cases where skipping the intermediate cast is relatively safe, two conditions: + // 1. original_dtype is the same as out_dtype; or + // 2. we support direct cast from original_dtype to out_dtype. + // and + // 1. intermediate_dtype is the same type category as with out_dtype; or + // 2. intermediate_dtype is a floating point while output is integral; if (cast_func_str({original_dtype, out_dtype}).has_value() && ((isIntegralType(intermediate_dtype) && isIntegralType(out_dtype)) || (isFloatingPointType(intermediate_dtype) && isFloatingPointType(out_dtype)) || diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 0a4f6ee14cf..2853af651d2 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8491,7 +8491,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { FusionExecutorCache executor_cache(std::move(fusion)); auto outputs = executor_cache.runFusionWithInputs({at_x}); - auto ref_out = at_x.clone().int().float().neg(); + auto ref_out = at_x.clone().to(at::kInt).to(at::kFloat).neg(); auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); From 875164fb5e546215d350225cf070080331063779 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 17:47:20 -0700 Subject: [PATCH 12/81] a few knobs to switch optimization pass --- csrc/optimization/opt_pass.cpp | 29 ++++++++++++++++++++++++++++- csrc/optimization/opt_pass.h | 12 ++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index 91347e700a1..e353f8465b6 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -13,6 +13,8 @@ namespace nvfuser::optimization { namespace { +thread_local std::unordered_set disabled_pass_flag; + class OptimizationRegistry { public: struct PassEntry { @@ -66,6 +68,14 @@ class OptimizationRegistry { } // namespace +OptimizationPassGuard::OptimizationPassGuard(const OptimizationPassCategory& category, bool enable) : cat_(category) { + prev_status_ = switchOptimizationPass(cat_, enable); +} + +OptimizationPassGuard::~OptimizationPassGuard() { + switchOptimizationPass(cat_, prev_status_); +} + void registerOptimizationPass( const OptimizationPassCategory& category, OptimizationPass* pass, @@ -77,7 +87,24 @@ void registerOptimizationPass( void applyOptimizationPass( const OptimizationPassCategory& category, Fusion* fusion) { - OptimizationRegistry::getInstance().apply(category, fusion); + if (disabled_pass_flag.count(category) != 0) { + OptimizationRegistry::getInstance().apply(category, fusion); + } +} + +bool switchOptimizationPass( + const OptimizationPassCategory& category, + std::optional enable) { + auto enabled = disabled_pass_flag.count(category) == 0; + + if (enable.has_value()) { + if (enable.value()) { + disabled_pass_flag.erase(category); + } else { + disabled_pass_flag.insert(category); + } + } + return enabled; } } // namespace nvfuser::optimization diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index a02aa86773c..016d058e381 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -20,6 +20,15 @@ class OptimizationPass { virtual std::string name() = 0; }; +class OptimizationPassGuard { + public: + OptimizationPassGuard(const OptimizationPassCategory& category, bool enable); + ~OptimizationPassGuard(); + protected: + OptimizationPassCategory cat_; + bool prev_status_; +}; + // higher priority pass runs earlier // newly registered pass runs at the end of all passes with identical priority TORCH_CUDA_CU_API void registerOptimizationPass( @@ -29,5 +38,8 @@ TORCH_CUDA_CU_API void registerOptimizationPass( TORCH_CUDA_CU_API void applyOptimizationPass( const OptimizationPassCategory& category, Fusion* fusion); +TORCH_CUDA_CU_API bool switchOptimizationPass( + const OptimizationPassCategory& category, + std::optional enable); } // namespace nvfuser::optimization From aca1abd2c06e401daa14a2d5b51cc933b74098ed Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 17:53:57 -0700 Subject: [PATCH 13/81] fixing logic --- csrc/optimization/optimize_consecutive_cast.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 7515eeea22c..a2c6b3dc22a 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -44,7 +44,7 @@ class ConsecutiveCastPass : OptimizationPass { // and // 1. intermediate_dtype is the same type category as with out_dtype; or // 2. intermediate_dtype is a floating point while output is integral; - if (cast_func_str({original_dtype, out_dtype}).has_value() && + if ((original_dtype == out_dtype || cast_func_str({original_dtype, out_dtype}).has_value()) && ((isIntegralType(intermediate_dtype) && isIntegralType(out_dtype)) || (isFloatingPointType(intermediate_dtype) && isFloatingPointType(out_dtype)) || (isComplexType(intermediate_dtype) && isComplexType(out_dtype)) || From de7dee71b3a0e274d45d5984f420d368ea151d62 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 18:01:03 -0700 Subject: [PATCH 14/81] fixing logic in disabling flags --- csrc/optimization/opt_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index e353f8465b6..fd30bf97f3b 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -87,7 +87,7 @@ void registerOptimizationPass( void applyOptimizationPass( const OptimizationPassCategory& category, Fusion* fusion) { - if (disabled_pass_flag.count(category) != 0) { + if (disabled_pass_flag.count(category) == 0) { OptimizationRegistry::getInstance().apply(category, fusion); } } From 87d55dfcce6fcaa123c58324997568a4cff7ee13 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 18:07:55 -0700 Subject: [PATCH 15/81] patching tests --- test/test_gpu3.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 2853af651d2..087329e95e2 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -6076,6 +6076,8 @@ TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) { // Repro for // https://github.com/csarofeen/pytorch/issues/2094 TEST_F(NVFuserTest, FusionRepro2094_CUDA) { + // disable cast optimization, which causes numerical issue on tests + optimization::OptimizationPassGuard guard(optimization::OptimizationPassCategory::PreSegmenter, false); std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); From c5dd6ad79ad4447a719f2007194d3ddd6ef260cf Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 18:11:29 -0700 Subject: [PATCH 16/81] fixing tests --- test/test_gpu3.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 087329e95e2..4e79c5d22d1 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -38,6 +38,7 @@ #include #include #include +#include #include #include From d258a37edc881e6beb663241ba54692ea092275e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 18:14:51 -0700 Subject: [PATCH 17/81] clangformat --- csrc/optimization/opt_pass.cpp | 5 +- csrc/optimization/opt_pass.h | 1 + .../optimize_consecutive_cast.cpp | 67 +++++++++++-------- test/test_gpu3.cpp | 29 ++++++-- 4 files changed, 67 insertions(+), 35 deletions(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index fd30bf97f3b..e915c5815dd 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -68,7 +68,10 @@ class OptimizationRegistry { } // namespace -OptimizationPassGuard::OptimizationPassGuard(const OptimizationPassCategory& category, bool enable) : cat_(category) { +OptimizationPassGuard::OptimizationPassGuard( + const OptimizationPassCategory& category, + bool enable) + : cat_(category) { prev_status_ = switchOptimizationPass(cat_, enable); } diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index 016d058e381..ea113b4dc00 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -24,6 +24,7 @@ class OptimizationPassGuard { public: OptimizationPassGuard(const OptimizationPassCategory& category, bool enable); ~OptimizationPassGuard(); + protected: OptimizationPassCategory cat_; bool prev_status_; diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index a2c6b3dc22a..f35687eb14c 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -29,47 +29,58 @@ class ConsecutiveCastPass : OptimizationPass { std::unordered_map replacement_map; for (auto expr : fusion->exprs()) { if (is_cast_op(expr)) { - bool mutated = false; + bool mutated = false; while (true) { // in the loop, we just repetitively skip consecutive casts. auto intermediate_cast = expr->input(0); auto prev_expr = intermediate_cast->definition(); if (prev_expr != nullptr && is_cast_op(prev_expr)) { - auto original_dtype = prev_expr->input(0)->getDataType().value(); - auto intermediate_dtype = intermediate_cast->getDataType().value(); - auto out_dtype = expr->output(0)->getDataType().value(); - // cases where skipping the intermediate cast is relatively safe, two conditions: - // 1. original_dtype is the same as out_dtype; or - // 2. we support direct cast from original_dtype to out_dtype. - // and - // 1. intermediate_dtype is the same type category as with out_dtype; or - // 2. intermediate_dtype is a floating point while output is integral; - if ((original_dtype == out_dtype || cast_func_str({original_dtype, out_dtype}).has_value()) && - ((isIntegralType(intermediate_dtype) && isIntegralType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && isFloatingPointType(out_dtype)) || - (isComplexType(intermediate_dtype) && isComplexType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && isIntegralType(out_dtype)))) { + auto original_dtype = prev_expr->input(0)->getDataType().value(); + auto intermediate_dtype = intermediate_cast->getDataType().value(); + auto out_dtype = expr->output(0)->getDataType().value(); + // cases where skipping the intermediate cast is relatively safe, + // two conditions: + // 1. original_dtype is the same as out_dtype; or + // 2. we support direct cast from original_dtype to out_dtype. + // and + // 1. intermediate_dtype is the same type category as with + // out_dtype; or + // 2. intermediate_dtype is a floating point while output is + // integral; + if ((original_dtype == out_dtype || + cast_func_str({original_dtype, out_dtype}).has_value()) && + ((isIntegralType(intermediate_dtype) && + isIntegralType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && + isFloatingPointType(out_dtype)) || + (isComplexType(intermediate_dtype) && + isComplexType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && + isIntegralType(out_dtype)))) { expr = nvfuser::ir_utils::replaceValInExpr( expr, intermediate_cast, prev_expr->input(0)); - mutated = true; - } else { - break; - } + mutated = true; + } else { + break; + } } else { break; } } - if (mutated) { - // quick short-wire to skip current cast node if it's trivially casting to the same type - if (expr->input(0)->getDataType().value() == expr->output(0)->getDataType().value()) { + if (mutated) { + // quick short-wire to skip current cast node if it's trivially + // casting to the same type + if (expr->input(0)->getDataType().value() == + expr->output(0)->getDataType().value()) { replacement_map[expr->output(0)] = expr->input(0); - // NOTE: if current output is a fusion output, DCE won't kick in and we'll ended up with an illegal cast. - if (expr->output(0)->isFusionOutput()) { - fusion->replaceOutput(expr->output(0), expr->input(0)); - } - } - } + // NOTE: if current output is a fusion output, DCE won't kick in and + // we'll ended up with an illegal cast. + if (expr->output(0)->isFusionOutput()) { + fusion->replaceOutput(expr->output(0), expr->input(0)); + } + } + } } } if (!replacement_map.empty()) { diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index de94bb059d5..1d39920925f 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include #include #include -#include #include #include @@ -6071,7 +6071,8 @@ TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) { // https://github.com/csarofeen/pytorch/issues/2094 TEST_F(NVFuserTest, FusionRepro2094_CUDA) { // disable cast optimization, which causes numerical issue on tests - optimization::OptimizationPassGuard guard(optimization::OptimizationPassCategory::PreSegmenter, false); + optimization::OptimizationPassGuard guard( + optimization::OptimizationPassCategory::PreSegmenter, false); std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -8444,7 +8445,8 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { .dtype(DataType::Double) .build(); fusion->addInput(tv0); - auto tv1 = castOp(DataType::Half, tv0); // consecutive cast should be removed + auto tv1 = + castOp(DataType::Half, tv0); // consecutive cast should be removed auto tv2 = castOp(DataType::Float, tv1); auto tv3 = relu(tv2); auto tv4 = neg(tv3); @@ -8468,7 +8470,12 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { cast_op_count == 2, "cast optimization isn't working as expected"); testValidate( - executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); + executor_cache.fusion(), + outputs, + {at_x}, + {ref_out}, + __LINE__, + __FILE__); } { @@ -8502,7 +8509,12 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { cast_op_count == 2, "cast optimization isn't working as expected"); testValidate( - executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); + executor_cache.fusion(), + outputs, + {at_x}, + {ref_out}, + __LINE__, + __FILE__); } { @@ -8535,7 +8547,12 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { cast_op_count == 0, "cast optimization isn't working as expected"); testValidate( - executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); + executor_cache.fusion(), + outputs, + {at_x}, + {ref_out}, + __LINE__, + __FILE__); } } From faa269a7bd281c5b374ffb77fd02ee372500dc7a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 18:39:25 -0700 Subject: [PATCH 18/81] clangtidy --- csrc/optimization/opt_pass.cpp | 2 +- csrc/optimization/opt_pass.h | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index e915c5815dd..1fc60541be2 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -24,7 +24,7 @@ class OptimizationRegistry { PassEntry(int priority, FusionPass pass, std::string name) : priority_(priority), pass_(std::move(pass)), - name_(std::move(name_)) {} + name_(std::move(name)) {} }; void registerPass( diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index ea113b4dc00..62406d3e6c0 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -18,6 +18,7 @@ class OptimizationPass { public: virtual FusionPass func() = 0; virtual std::string name() = 0; + virtual ~OptimizationPass() = default; }; class OptimizationPassGuard { @@ -27,7 +28,7 @@ class OptimizationPassGuard { protected: OptimizationPassCategory cat_; - bool prev_status_; + bool prev_status_ = false; }; // higher priority pass runs earlier @@ -41,6 +42,6 @@ TORCH_CUDA_CU_API void applyOptimizationPass( Fusion* fusion); TORCH_CUDA_CU_API bool switchOptimizationPass( const OptimizationPassCategory& category, - std::optional enable); + std::optional enable) noexcept; } // namespace nvfuser::optimization From b42bba436580574473e1ff1aead4f2ff95b959e4 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 18:42:07 -0700 Subject: [PATCH 19/81] clangformat --- csrc/optimization/opt_pass.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index 1fc60541be2..4d03483f224 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -22,9 +22,7 @@ class OptimizationRegistry { FusionPass pass_; std::string name_; PassEntry(int priority, FusionPass pass, std::string name) - : priority_(priority), - pass_(std::move(pass)), - name_(std::move(name)) {} + : priority_(priority), pass_(std::move(pass)), name_(std::move(name)) {} }; void registerPass( @@ -97,7 +95,7 @@ void applyOptimizationPass( bool switchOptimizationPass( const OptimizationPassCategory& category, - std::optional enable) { + std::optional enable) noexcept { auto enabled = disabled_pass_flag.count(category) == 0; if (enable.has_value()) { From 6949fef2f7d5aa0557dc3679b0ce8cba951bdd1f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 20:04:39 -0700 Subject: [PATCH 20/81] CLANGTIDY --- csrc/optimization/opt_pass.cpp | 34 +++++++++++++++++++++------------- csrc/optimization/opt_pass.h | 4 ++-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index 4d03483f224..7122b1ecbd7 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -13,7 +13,21 @@ namespace nvfuser::optimization { namespace { -thread_local std::unordered_set disabled_pass_flag; +// TODO: throw away toy flags. Will redo this at later stage +thread_local bool pre_segmenter_flag; + +void setOptimizationFlag(const OptimizationPassCategory& flag, bool enabled) { + if (flag == OptimizationPassCategory::PreSegmenter) { + pre_segmenter_flag = enabled.value(); + } +} + +bool getOptimizationFlag(const OptimizationPassCategory& flag) { + if (flag == OptimizationPassCategory::PreSegmenter) { + return pre_segmenter_flag; + } + return false; +} class OptimizationRegistry { public: @@ -69,12 +83,10 @@ class OptimizationRegistry { OptimizationPassGuard::OptimizationPassGuard( const OptimizationPassCategory& category, bool enable) - : cat_(category) { - prev_status_ = switchOptimizationPass(cat_, enable); -} + : cat_(category), prev_status_(switchOptimizationPass(cat_, enable)) {} OptimizationPassGuard::~OptimizationPassGuard() { - switchOptimizationPass(cat_, prev_status_); + setOptimizationFlag(cat_, prev_status_); } void registerOptimizationPass( @@ -88,22 +100,18 @@ void registerOptimizationPass( void applyOptimizationPass( const OptimizationPassCategory& category, Fusion* fusion) { - if (disabled_pass_flag.count(category) == 0) { + if (getOptimizationFlag(category)) { OptimizationRegistry::getInstance().apply(category, fusion); } } bool switchOptimizationPass( const OptimizationPassCategory& category, - std::optional enable) noexcept { - auto enabled = disabled_pass_flag.count(category) == 0; + std::optional enable) { + auto enabled = getOptimizationFlag(category); if (enable.has_value()) { - if (enable.value()) { - disabled_pass_flag.erase(category); - } else { - disabled_pass_flag.insert(category); - } + setOptimizationFlag(category, enable.value()); } return enabled; } diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index 62406d3e6c0..45962c6cac3 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -11,7 +11,7 @@ namespace nvfuser::optimization { -enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter, Null }; +enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter }; using FusionPass = std::function; class OptimizationPass { @@ -42,6 +42,6 @@ TORCH_CUDA_CU_API void applyOptimizationPass( Fusion* fusion); TORCH_CUDA_CU_API bool switchOptimizationPass( const OptimizationPassCategory& category, - std::optional enable) noexcept; + std::optional enable); } // namespace nvfuser::optimization From 47da287fec66b07e88255cada0c3acbd22f4b485 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 17 May 2023 20:06:24 -0700 Subject: [PATCH 21/81] typo --- csrc/optimization/opt_pass.cpp | 2 +- csrc/optimization/opt_pass.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index 7122b1ecbd7..1ab4159294f 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -18,7 +18,7 @@ thread_local bool pre_segmenter_flag; void setOptimizationFlag(const OptimizationPassCategory& flag, bool enabled) { if (flag == OptimizationPassCategory::PreSegmenter) { - pre_segmenter_flag = enabled.value(); + pre_segmenter_flag = enabled; } } diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index 45962c6cac3..92eb800a16d 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -11,7 +11,7 @@ namespace nvfuser::optimization { -enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter }; +enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter, Null }; using FusionPass = std::function; class OptimizationPass { From f51481e776d167ded2b9eb558d32133e9be6db62 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 18 May 2023 02:04:51 -0700 Subject: [PATCH 22/81] initial value for optimization pass --- csrc/optimization/opt_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp index 1ab4159294f..196948bab19 100644 --- a/csrc/optimization/opt_pass.cpp +++ b/csrc/optimization/opt_pass.cpp @@ -14,7 +14,7 @@ namespace nvfuser::optimization { namespace { // TODO: throw away toy flags. Will redo this at later stage -thread_local bool pre_segmenter_flag; +thread_local bool pre_segmenter_flag = true; void setOptimizationFlag(const OptimizationPassCategory& flag, bool enabled) { if (flag == OptimizationPassCategory::PreSegmenter) { From 56bd373d0c20afd8e00b0398259812bd771f591f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 18 May 2023 10:55:27 -0700 Subject: [PATCH 23/81] addressing review comments --- csrc/optimization/opt_pass.h | 30 ++++++++++++++++--- .../optimize_consecutive_cast.cpp | 14 +++++---- test/test_gpu3.cpp | 15 ++++++++++ 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index 92eb800a16d..c9efe8e7685 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -11,17 +11,26 @@ namespace nvfuser::optimization { +//! [experimental API] +//! enum class to group optimization pass groups that runs at certain time in +//! the fusion execution. enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter, Null }; + using FusionPass = std::function; -class OptimizationPass { +//! [experimental API] +//! OptimizationPass is the base class to unify optimization pass APIs. +class TORCH_CUDA_CU_API OptimizationPass { public: virtual FusionPass func() = 0; virtual std::string name() = 0; virtual ~OptimizationPass() = default; }; -class OptimizationPassGuard { +//! [experimental API] +//! OptimizationPassGuard is used to temporarily switch enable/disable on a +//! certain pass. Original status will be restored at exit. +class TORCH_CUDA_CU_API OptimizationPassGuard { public: OptimizationPassGuard(const OptimizationPassCategory& category, bool enable); ~OptimizationPassGuard(); @@ -31,15 +40,28 @@ class OptimizationPassGuard { bool prev_status_ = false; }; -// higher priority pass runs earlier -// newly registered pass runs at the end of all passes with identical priority +//! [experimental API] +//! Register optimization pass with the `OptimizationPassCategroty`. +//! +//! all registered passes will run in order, where: +//! higher priority pass runs earlier; +//! newly registered pass runs at the end of all passes with identical priority. TORCH_CUDA_CU_API void registerOptimizationPass( const OptimizationPassCategory& category, OptimizationPass* pass, int priority = 0); + +//! [experimental API] +//! Run `category` group of optimization passes to `fusion`. TORCH_CUDA_CU_API void applyOptimizationPass( const OptimizationPassCategory& category, Fusion* fusion); + +//! [experimental API] +//! Switch the enable flag for a `category` group of optimization passes. +//! Returns the previous `enabled` status. Argument `std::optional enable` +//! is used to update the enable flag. An std::nullopt arg will leave the flag +//! unchanged. TORCH_CUDA_CU_API bool switchOptimizationPass( const OptimizationPassCategory& category, std::optional enable); diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index f35687eb14c..34b94059b8b 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -14,6 +14,11 @@ namespace { class ConsecutiveCastPass : OptimizationPass { public: + ConsecutiveCastPass() { + // registering ConsecutiveCastPass to PreSegmenter pass group + registerOptimizationPass(OptimizationPassCategory::PreSegmenter, this); + } + static void runPass(Fusion* fusion) { auto is_cast_op = [](Expr* expr) { if (expr->isA()) { @@ -45,8 +50,8 @@ class ConsecutiveCastPass : OptimizationPass { // and // 1. intermediate_dtype is the same type category as with // out_dtype; or - // 2. intermediate_dtype is a floating point while output is - // integral; + // 2. intermediate_dtype is more relaxed than out_dtype. e.g. a + // floating point vs. integral; if ((original_dtype == out_dtype || cast_func_str({original_dtype, out_dtype}).has_value()) && ((isIntegralType(intermediate_dtype) && @@ -95,12 +100,9 @@ class ConsecutiveCastPass : OptimizationPass { FusionPass func() override { return runPass; } - - ConsecutiveCastPass() { - registerOptimizationPass(OptimizationPassCategory::PreSegmenter, this); - } }; +// triggering the ConsecutiveCastPass constructor to register the pass static ConsecutiveCastPass register_; } // namespace diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 1d39920925f..f4bc5fcb4a8 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8589,6 +8589,8 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { auto tv11 = castOp(DataType::Float, tv9); auto tv12 = castOp(DataType::Float, tv10); auto tv13 = add(tv11, tv12); + // The this pair of cast just cancels each other out, we'll simply rewire it + // to be use tv13 in places of tv15 in the follow up auto tv14 = castOp(DataType::Half, tv13); auto tv15 = castOp(DataType::Float, tv14); auto tv16 = variance(tv15, {1}, false, false); @@ -8653,6 +8655,19 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { FusionExecutorCache fec(std::move(fusion_ptr)); auto cg_outputs = fec.runFusionWithInputs(inputs); + + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); + int cast_op_count = 0; + for (auto expr : complete_fusion->exprs()) { + if (expr->isA() && + expr->as()->getUnaryOpType() == UnaryOpType::Cast) { + ++cast_op_count; + } + } + TORCH_CHECK( + cast_op_count == 9, "cast optimization isn't working as expected"); + testValidate(fusion, cg_outputs, inputs, outputs, __LINE__, __FILE__); } From 41b09c671cef37b68893a69f9858c9ccc77a5eeb Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 18 May 2023 15:40:57 -0700 Subject: [PATCH 24/81] patching tests --- python_tests/test_python_frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 0d50d6631f9..f94c0a8c511 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -176,8 +176,8 @@ def fusion_func(fd: FusionDefinition): def test_cast_double_to_half(self): inputs = [ - torch.randn(2, 4, device="cuda", dtype=torch.float64), - torch.randn(2, 4, device="cuda", dtype=torch.float64), + torch.randn(2, 4, device="cuda", dtype=torch.float64).half().double(), + torch.randn(2, 4, device="cuda", dtype=torch.float64).half().double(), ] def fusion_func(fd: FusionDefinition): From 3e404d128cdd98c32cd94a35c43b41b0989d5b4f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 18 May 2023 15:48:18 -0700 Subject: [PATCH 25/81] typo --- test/test_gpu3.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 44fcd75e028..8c212064f90 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8656,7 +8656,7 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { FusionExecutorCache fec(std::move(fusion_ptr)); auto cg_outputs = fec.runFusionWithInputs(inputs); - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + auto optimized_fusion = fec.getMostRecentKernelRuntime(); auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); int cast_op_count = 0; for (auto expr : complete_fusion->exprs()) { From e1c4e42baff1e400fcab9b3bf81c8e83cd2208c1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 19 May 2023 13:03:42 -0700 Subject: [PATCH 26/81] renaming header --- csrc/kernel_cache.cpp | 2 +- csrc/optimization/optimize_consecutive_cast.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 8aabc8ab113..4a41c16cba1 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include #include #include diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 34b94059b8b..7305dae9b43 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include #include namespace nvfuser::optimization { From 206915d5a4c3e7acb0317e0c30df90892555ebc3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 19 May 2023 13:19:16 -0700 Subject: [PATCH 27/81] header change --- csrc/optimization/opt_pass.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index c9efe8e7685..f164cada9d0 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -7,7 +7,7 @@ // clang-format on #pragma once -#include +#include namespace nvfuser::optimization { From f1c1e6b3f3e19b0854ac37e9a2d65d7b1eabbde3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 13:00:36 -0700 Subject: [PATCH 28/81] refactor cast opt pass --- csrc/optimization/opt_pass.h | 2 +- .../optimize_consecutive_cast.cpp | 145 ++++++++---------- csrc/optimization/optimize_consecutive_cast.h | 18 +++ 3 files changed, 86 insertions(+), 79 deletions(-) create mode 100644 csrc/optimization/optimize_consecutive_cast.h diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h index f164cada9d0..d4dda2a8522 100644 --- a/csrc/optimization/opt_pass.h +++ b/csrc/optimization/opt_pass.h @@ -22,7 +22,7 @@ using FusionPass = std::function; //! OptimizationPass is the base class to unify optimization pass APIs. class TORCH_CUDA_CU_API OptimizationPass { public: - virtual FusionPass func() = 0; + virtual void run(Fusion*) = 0; virtual std::string name() = 0; virtual ~OptimizationPass() = default; }; diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/optimize_consecutive_cast.cpp index 7305dae9b43..c70fab10dea 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/optimize_consecutive_cast.cpp @@ -12,99 +12,88 @@ namespace nvfuser::optimization { namespace { -class ConsecutiveCastPass : OptimizationPass { - public: - ConsecutiveCastPass() { - // registering ConsecutiveCastPass to PreSegmenter pass group - registerOptimizationPass(OptimizationPassCategory::PreSegmenter, this); - } - - static void runPass(Fusion* fusion) { - auto is_cast_op = [](Expr* expr) { - if (expr->isA()) { - auto op = expr->as(); - if (op->getUnaryOpType() == UnaryOpType::Cast) { - return true; - } +void castOptimizationPass(Fusion* fusion) { + auto is_cast_op = [](Expr* expr) { + if (expr->isA()) { + auto op = expr->as(); + if (op->getUnaryOpType() == UnaryOpType::Cast) { + return true; } - return false; - }; + } + return false; + }; - // NOTE: not the most efficient pass - std::unordered_map replacement_map; - for (auto expr : fusion->exprs()) { - if (is_cast_op(expr)) { - bool mutated = false; - while (true) { - // in the loop, we just repetitively skip consecutive casts. - auto intermediate_cast = expr->input(0); - auto prev_expr = intermediate_cast->definition(); - if (prev_expr != nullptr && is_cast_op(prev_expr)) { - auto original_dtype = prev_expr->input(0)->getDataType().value(); - auto intermediate_dtype = intermediate_cast->getDataType().value(); - auto out_dtype = expr->output(0)->getDataType().value(); - // cases where skipping the intermediate cast is relatively safe, - // two conditions: - // 1. original_dtype is the same as out_dtype; or - // 2. we support direct cast from original_dtype to out_dtype. - // and - // 1. intermediate_dtype is the same type category as with - // out_dtype; or - // 2. intermediate_dtype is more relaxed than out_dtype. e.g. a - // floating point vs. integral; - if ((original_dtype == out_dtype || - cast_func_str({original_dtype, out_dtype}).has_value()) && - ((isIntegralType(intermediate_dtype) && - isIntegralType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && - isFloatingPointType(out_dtype)) || - (isComplexType(intermediate_dtype) && - isComplexType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && - isIntegralType(out_dtype)))) { - expr = nvfuser::ir_utils::replaceValInExpr( - expr, intermediate_cast, prev_expr->input(0)); - mutated = true; - } else { - break; - } + // NOTE: not the most efficient pass + std::unordered_map replacement_map; + for (auto expr : fusion->exprs()) { + if (is_cast_op(expr)) { + bool mutated = false; + while (true) { + // in the loop, we just repetitively skip consecutive casts. + auto intermediate_cast = expr->input(0); + auto prev_expr = intermediate_cast->definition(); + if (prev_expr != nullptr && is_cast_op(prev_expr)) { + auto original_dtype = prev_expr->input(0)->getDataType().value(); + auto intermediate_dtype = intermediate_cast->getDataType().value(); + auto out_dtype = expr->output(0)->getDataType().value(); + // cases where skipping the intermediate cast is relatively safe, + // two conditions: + // 1. original_dtype is the same as out_dtype; or + // 2. we support direct cast from original_dtype to out_dtype. + // and + // 1. intermediate_dtype is the same type category as with + // out_dtype; or + // 2. intermediate_dtype is more relaxed than out_dtype. e.g. a + // floating point vs. integral; + if ((original_dtype == out_dtype || + cast_func_str({original_dtype, out_dtype}).has_value()) && + ((isIntegralType(intermediate_dtype) && + isIntegralType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && + isFloatingPointType(out_dtype)) || + (isComplexType(intermediate_dtype) && + isComplexType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && + isIntegralType(out_dtype)))) { + expr = nvfuser::ir_utils::replaceValInExpr( + expr, intermediate_cast, prev_expr->input(0)); + mutated = true; } else { break; } + } else { + break; } + } - if (mutated) { - // quick short-wire to skip current cast node if it's trivially - // casting to the same type - if (expr->input(0)->getDataType().value() == - expr->output(0)->getDataType().value()) { - replacement_map[expr->output(0)] = expr->input(0); - // NOTE: if current output is a fusion output, DCE won't kick in and - // we'll ended up with an illegal cast. - if (expr->output(0)->isFusionOutput()) { - fusion->replaceOutput(expr->output(0), expr->input(0)); - } + if (mutated) { + // quick short-wire to skip current cast node if it's trivially + // casting to the same type + if (expr->input(0)->getDataType().value() == + expr->output(0)->getDataType().value()) { + replacement_map[expr->output(0)] = expr->input(0); + // NOTE: if current output is a fusion output, DCE won't kick in and + // we'll ended up with an illegal cast. + if (expr->output(0)->isFusionOutput()) { + fusion->replaceOutput(expr->output(0), expr->input(0)); } } } } - if (!replacement_map.empty()) { - nvfuser::ir_utils::replaceValue(fusion, replacement_map); - } } - - std::string name() override { - return "ConsecutiveCastOptimization"; + if (!replacement_map.empty()) { + nvfuser::ir_utils::replaceValue(fusion, replacement_map); } +} - FusionPass func() override { - return runPass; - } -}; +} // namespace -// triggering the ConsecutiveCastPass constructor to register the pass -static ConsecutiveCastPass register_; +void ConsecutiveCastPass::run(Fusion* fusion) override { + castOptimizationPass(fusion); +} -} // namespace +std::string ConsecutiveCastPass::name() override { + return "ConsecutiveCastOptimization"; +} } // namespace nvfuser::optimization diff --git a/csrc/optimization/optimize_consecutive_cast.h b/csrc/optimization/optimize_consecutive_cast.h new file mode 100644 index 00000000000..9a73e5d9fce --- /dev/null +++ b/csrc/optimization/optimize_consecutive_cast.h @@ -0,0 +1,18 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +class TORCH_CUDA_CU_API ConsecutiveCastPass : OptimizationPass { + public: + void run(Fusion* fusion) override; + std::string name() override; +} + +} // namespace nvfuser::optimization From 44605e152fd75c816f38e286c6989d215c9648ea Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 17:00:59 -0700 Subject: [PATCH 29/81] refactor registration --- CMakeLists.txt | 4 +- csrc/kernel_cache.cpp | 6 +- ...ive_cast.cpp => consecutive_cast_pass.cpp} | 2 +- ...ecutive_cast.h => consecutive_cast_pass.h} | 2 +- csrc/optimization/opt_pass.cpp | 119 ------------------ csrc/optimization/opt_pass.h | 69 ---------- csrc/optimization/optimization_pass.h | 68 ++++++++++ csrc/optimization/pre_segmenter.cpp | 18 +++ csrc/optimization/pre_segmenter.h | 19 +++ test/test_gpu3.cpp | 7 +- 10 files changed, 114 insertions(+), 200 deletions(-) rename csrc/optimization/{optimize_consecutive_cast.cpp => consecutive_cast_pass.cpp} (98%) rename csrc/optimization/{optimize_consecutive_cast.h => consecutive_cast_pass.h} (90%) delete mode 100644 csrc/optimization/opt_pass.cpp delete mode 100644 csrc/optimization/opt_pass.h create mode 100644 csrc/optimization/optimization_pass.h create mode 100644 csrc/optimization/pre_segmenter.cpp create mode 100644 csrc/optimization/pre_segmenter.h diff --git a/CMakeLists.txt b/CMakeLists.txt index d93ffd20955..391004226b0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,8 +180,8 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/utils.cpp ${NVFUSER_SRCS_DIR}/mma_type.cpp ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp - ${NVFUSER_SRCS_DIR}/optimization/opt_pass.cpp - ${NVFUSER_SRCS_DIR}/optimization/optimize_consecutive_cast.cpp + ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast_pass.cpp + ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ) set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 4a41c16cba1..d1460fde243 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -11,11 +11,10 @@ #include #include #include -#include +#include #include #include #include - #include #include @@ -622,8 +621,7 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); - applyOptimizationPass( - optimization::OptimizationPassCategory::PreSegmenter, fusion.get()); + optimization::PreSegmenter.runPass(fusion.get()); all_tvs_ = ir_utils::allTvs(fusion.get()); diff --git a/csrc/optimization/optimize_consecutive_cast.cpp b/csrc/optimization/consecutive_cast_pass.cpp similarity index 98% rename from csrc/optimization/optimize_consecutive_cast.cpp rename to csrc/optimization/consecutive_cast_pass.cpp index c70fab10dea..61cf5150b8d 100644 --- a/csrc/optimization/optimize_consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -6,7 +6,7 @@ */ // clang-format on #include -#include +#include namespace nvfuser::optimization { diff --git a/csrc/optimization/optimize_consecutive_cast.h b/csrc/optimization/consecutive_cast_pass.h similarity index 90% rename from csrc/optimization/optimize_consecutive_cast.h rename to csrc/optimization/consecutive_cast_pass.h index 9a73e5d9fce..ca9b7d31fab 100644 --- a/csrc/optimization/optimize_consecutive_cast.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -5,7 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include +#include namespace nvfuser::optimization { diff --git a/csrc/optimization/opt_pass.cpp b/csrc/optimization/opt_pass.cpp deleted file mode 100644 index 196948bab19..00000000000 --- a/csrc/optimization/opt_pass.cpp +++ /dev/null @@ -1,119 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#include - -#include - -namespace nvfuser::optimization { - -namespace { - -// TODO: throw away toy flags. Will redo this at later stage -thread_local bool pre_segmenter_flag = true; - -void setOptimizationFlag(const OptimizationPassCategory& flag, bool enabled) { - if (flag == OptimizationPassCategory::PreSegmenter) { - pre_segmenter_flag = enabled; - } -} - -bool getOptimizationFlag(const OptimizationPassCategory& flag) { - if (flag == OptimizationPassCategory::PreSegmenter) { - return pre_segmenter_flag; - } - return false; -} - -class OptimizationRegistry { - public: - struct PassEntry { - int priority_; - FusionPass pass_; - std::string name_; - PassEntry(int priority, FusionPass pass, std::string name) - : priority_(priority), pass_(std::move(pass)), name_(std::move(name)) {} - }; - - void registerPass( - const OptimizationPassCategory& cat, - FusionPass func, - std::string name_, - int priority) { - std::lock_guard guard(mutex_); - auto& pass_entry_list = pass_categories_[cat]; - auto entry_iter = pass_entry_list.begin(); - while (entry_iter != pass_entry_list.end()) { - if (entry_iter->priority_ < priority) { - break; - } - entry_iter++; - } - pass_entry_list.emplace( - entry_iter, priority, std::move(func), std::move(name_)); - } - - void apply(const OptimizationPassCategory& cat, Fusion* fusion) { - std::lock_guard guard(mutex_); - const auto& pass_entry_list = pass_categories_[cat]; - for (const auto& entry : pass_entry_list) { - entry.pass_(fusion); - } - } - - static OptimizationRegistry& getInstance() { - static OptimizationRegistry registry; - return registry; - } - - protected: - // TODO: read access mutex_ should/could be optimized, since graph pass is - // thread-safe. - std::mutex mutex_; - std::unordered_map> - pass_categories_; -}; - -} // namespace - -OptimizationPassGuard::OptimizationPassGuard( - const OptimizationPassCategory& category, - bool enable) - : cat_(category), prev_status_(switchOptimizationPass(cat_, enable)) {} - -OptimizationPassGuard::~OptimizationPassGuard() { - setOptimizationFlag(cat_, prev_status_); -} - -void registerOptimizationPass( - const OptimizationPassCategory& category, - OptimizationPass* pass, - int priority) { - OptimizationRegistry::getInstance().registerPass( - category, pass->func(), pass->name(), priority); -} - -void applyOptimizationPass( - const OptimizationPassCategory& category, - Fusion* fusion) { - if (getOptimizationFlag(category)) { - OptimizationRegistry::getInstance().apply(category, fusion); - } -} - -bool switchOptimizationPass( - const OptimizationPassCategory& category, - std::optional enable) { - auto enabled = getOptimizationFlag(category); - - if (enable.has_value()) { - setOptimizationFlag(category, enable.value()); - } - return enabled; -} - -} // namespace nvfuser::optimization diff --git a/csrc/optimization/opt_pass.h b/csrc/optimization/opt_pass.h deleted file mode 100644 index d4dda2a8522..00000000000 --- a/csrc/optimization/opt_pass.h +++ /dev/null @@ -1,69 +0,0 @@ -// clang-format off -/* - * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - */ -// clang-format on -#pragma once - -#include - -namespace nvfuser::optimization { - -//! [experimental API] -//! enum class to group optimization pass groups that runs at certain time in -//! the fusion execution. -enum class TORCH_CUDA_CU_API OptimizationPassCategory { PreSegmenter, Null }; - -using FusionPass = std::function; - -//! [experimental API] -//! OptimizationPass is the base class to unify optimization pass APIs. -class TORCH_CUDA_CU_API OptimizationPass { - public: - virtual void run(Fusion*) = 0; - virtual std::string name() = 0; - virtual ~OptimizationPass() = default; -}; - -//! [experimental API] -//! OptimizationPassGuard is used to temporarily switch enable/disable on a -//! certain pass. Original status will be restored at exit. -class TORCH_CUDA_CU_API OptimizationPassGuard { - public: - OptimizationPassGuard(const OptimizationPassCategory& category, bool enable); - ~OptimizationPassGuard(); - - protected: - OptimizationPassCategory cat_; - bool prev_status_ = false; -}; - -//! [experimental API] -//! Register optimization pass with the `OptimizationPassCategroty`. -//! -//! all registered passes will run in order, where: -//! higher priority pass runs earlier; -//! newly registered pass runs at the end of all passes with identical priority. -TORCH_CUDA_CU_API void registerOptimizationPass( - const OptimizationPassCategory& category, - OptimizationPass* pass, - int priority = 0); - -//! [experimental API] -//! Run `category` group of optimization passes to `fusion`. -TORCH_CUDA_CU_API void applyOptimizationPass( - const OptimizationPassCategory& category, - Fusion* fusion); - -//! [experimental API] -//! Switch the enable flag for a `category` group of optimization passes. -//! Returns the previous `enabled` status. Argument `std::optional enable` -//! is used to update the enable flag. An std::nullopt arg will leave the flag -//! unchanged. -TORCH_CUDA_CU_API bool switchOptimizationPass( - const OptimizationPassCategory& category, - std::optional enable); - -} // namespace nvfuser::optimization diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h new file mode 100644 index 00000000000..c30d7b0b4c6 --- /dev/null +++ b/csrc/optimization/optimization_pass.h @@ -0,0 +1,68 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include + +namespace nvfuser::optimization { + +using FusionPass = std::function; + +//! [experimental API] +//! Base class to unify optimization pass APIs. +class TORCH_CUDA_CU_API OptimizationPass { + public: + virtual void run(Fusion*) = 0; + virtual std::string name() = 0; + virtual ~OptimizationPass() = default; +}; + +//! [experimental API] +//! Base class to unify optimization pass APIs. +template +class TORCH_CUDA_CU_API OptimizationGroup { + public: + static bool flipEnabled(bool flip) { + std::lock_guard guard(mutex_); + static bool enable_flag_ = true; + enable_flag_ = enable_flag_ ^ flip; + return enable_flag_ ^ flip; + } + + static bool setEnabled(bool enabled) { + auto tmp = flipEnabled(false); + if (enable != tmp) { + OptGroup::flipEnabled(true); + } + return tmp + } + + private: + static std::mutex mutex_; + virtual ~OptimizationGroup() = default; +}; + +//! [experimental API] +////! OptimizationGroupGuard is used to temporarily switch enable/disable on a +////! certain pass. Original status will be restored at destruction. +template +class TORCH_CUDA_CU_API OptimizationGroupGuard { + public: + OptimizationGroupGuard(bool enable) { + prev_status_ = OptGroup::setEnable(enable); + } + ~OptimizationGroupGuard(){ + OptGroup::setEnabled(prev_status_); + } + protected: + bool prev_status_ = false; +}; + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp new file mode 100644 index 00000000000..adeb23832ca --- /dev/null +++ b/csrc/optimization/pre_segmenter.cpp @@ -0,0 +1,18 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +namespace nvfuser::optimization { + +void PreSegmenter::runPass(Fusion* fusion) { + ConsecutiveCastPass consecutive_cast_pass; + consecutive_cast_pass.run(fusion); +} + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/pre_segmenter.h b/csrc/optimization/pre_segmenter.h new file mode 100644 index 00000000000..469921ed867 --- /dev/null +++ b/csrc/optimization/pre_segmenter.h @@ -0,0 +1,19 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +namespace nvfuser::optimization { + +class TORCH_CUDA_CU_API PreSegmenter : public OptimizationGroup { + public: + static void runPass(Fusion* fusion); +}; + +} // namespace nvfuser::optimization diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 2233eac29ad..255ac834914 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include #include @@ -6071,9 +6071,8 @@ TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) { // Repro for // https://github.com/csarofeen/pytorch/issues/2094 TEST_F(NVFuserTest, FusionRepro2094_CUDA) { - // disable cast optimization, which causes numerical issue on tests - optimization::OptimizationPassGuard guard( - optimization::OptimizationPassCategory::PreSegmenter, false); + // disable cast optimization in pre segmenter, which causes numerical issue on tests + optimization::OptimizationGroupGuard guard(optimization::PreSegmenter, false); std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); From 109a4d6191af939a3af53ab3787a965e1e264f51 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 17:03:46 -0700 Subject: [PATCH 30/81] fixing typo and stuff --- csrc/optimization/optimization_pass.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index c30d7b0b4c6..11acec1f0f1 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -38,10 +38,10 @@ class TORCH_CUDA_CU_API OptimizationGroup { static bool setEnabled(bool enabled) { auto tmp = flipEnabled(false); - if (enable != tmp) { - OptGroup::flipEnabled(true); + if (enabled != tmp) { + flipEnabled(true); } - return tmp + return tmp; } private: @@ -55,8 +55,8 @@ class TORCH_CUDA_CU_API OptimizationGroup { template class TORCH_CUDA_CU_API OptimizationGroupGuard { public: - OptimizationGroupGuard(bool enable) { - prev_status_ = OptGroup::setEnable(enable); + OptimizationGroupGuard(bool enabled) { + prev_status_ = OptGroup::setEnable(enabled); } ~OptimizationGroupGuard(){ OptGroup::setEnabled(prev_status_); From 36c32d635d8ca2c9e23dcc232c13e7a85c82121a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 17:19:57 -0700 Subject: [PATCH 31/81] fixing build issue --- csrc/kernel_cache.cpp | 2 +- csrc/optimization/consecutive_cast_pass.cpp | 4 ++-- csrc/optimization/consecutive_cast_pass.h | 4 ++-- csrc/optimization/optimization_pass.h | 4 ++-- csrc/optimization/pre_segmenter.h | 2 +- test/test_gpu3.cpp | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index d1460fde243..961161ac998 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -621,7 +621,7 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); - optimization::PreSegmenter.runPass(fusion.get()); + optimization::PreSegmenter::runPass(fusion.get()); all_tvs_ = ir_utils::allTvs(fusion.get()); diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 61cf5150b8d..95b60394cbd 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -88,11 +88,11 @@ void castOptimizationPass(Fusion* fusion) { } // namespace -void ConsecutiveCastPass::run(Fusion* fusion) override { +void ConsecutiveCastPass::run(Fusion* fusion) { castOptimizationPass(fusion); } -std::string ConsecutiveCastPass::name() override { +std::string ConsecutiveCastPass::name() { return "ConsecutiveCastOptimization"; } diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast_pass.h index ca9b7d31fab..7069299baa5 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -9,10 +9,10 @@ namespace nvfuser::optimization { -class TORCH_CUDA_CU_API ConsecutiveCastPass : OptimizationPass { +class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { public: void run(Fusion* fusion) override; std::string name() override; -} +}; } // namespace nvfuser::optimization diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 11acec1f0f1..9d7b76016c7 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -44,9 +44,9 @@ class TORCH_CUDA_CU_API OptimizationGroup { return tmp; } + virtual ~OptimizationGroup() = default; private: static std::mutex mutex_; - virtual ~OptimizationGroup() = default; }; //! [experimental API] @@ -56,7 +56,7 @@ template class TORCH_CUDA_CU_API OptimizationGroupGuard { public: OptimizationGroupGuard(bool enabled) { - prev_status_ = OptGroup::setEnable(enabled); + prev_status_ = OptGroup::setEnabled(enabled); } ~OptimizationGroupGuard(){ OptGroup::setEnabled(prev_status_); diff --git a/csrc/optimization/pre_segmenter.h b/csrc/optimization/pre_segmenter.h index 469921ed867..7f6bc13c5e8 100644 --- a/csrc/optimization/pre_segmenter.h +++ b/csrc/optimization/pre_segmenter.h @@ -11,7 +11,7 @@ namespace nvfuser::optimization { -class TORCH_CUDA_CU_API PreSegmenter : public OptimizationGroup { +class TORCH_CUDA_CU_API PreSegmenter : public OptimizationGroup { public: static void runPass(Fusion* fusion); }; diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 255ac834914..5cde6588a3f 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -6072,7 +6072,7 @@ TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) { // https://github.com/csarofeen/pytorch/issues/2094 TEST_F(NVFuserTest, FusionRepro2094_CUDA) { // disable cast optimization in pre segmenter, which causes numerical issue on tests - optimization::OptimizationGroupGuard guard(optimization::PreSegmenter, false); + optimization::OptimizationGroupGuard guard(false); std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); From 9e81880d623414edffa2f0d396812afe1edaef23 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 17:25:05 -0700 Subject: [PATCH 32/81] fixing static member in template --- csrc/optimization/optimization_pass.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 9d7b76016c7..caa8d711edc 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -30,8 +30,10 @@ template class TORCH_CUDA_CU_API OptimizationGroup { public: static bool flipEnabled(bool flip) { - std::lock_guard guard(mutex_); + static std::mutex mutex_; static bool enable_flag_ = true; + + std::lock_guard guard(mutex_); enable_flag_ = enable_flag_ ^ flip; return enable_flag_ ^ flip; } @@ -45,8 +47,6 @@ class TORCH_CUDA_CU_API OptimizationGroup { } virtual ~OptimizationGroup() = default; - private: - static std::mutex mutex_; }; //! [experimental API] From 2b849a3b5b6d397f10e4fa31632ff9d0338a6f0a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 17:27:31 -0700 Subject: [PATCH 33/81] clangformat --- csrc/optimization/optimization_pass.h | 3 ++- csrc/optimization/pre_segmenter.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index caa8d711edc..e2deaa37d27 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -58,9 +58,10 @@ class TORCH_CUDA_CU_API OptimizationGroupGuard { OptimizationGroupGuard(bool enabled) { prev_status_ = OptGroup::setEnabled(enabled); } - ~OptimizationGroupGuard(){ + ~OptimizationGroupGuard() { OptGroup::setEnabled(prev_status_); } + protected: bool prev_status_ = false; }; diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index adeb23832ca..27c095eaf58 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -5,8 +5,8 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include +#include namespace nvfuser::optimization { From f1eeafaa9e71a9db7f11218bd8e1dc7e82004b8e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sun, 21 May 2023 18:20:19 -0700 Subject: [PATCH 34/81] linter; fixing flag check for groups --- csrc/optimization/optimization_pass.h | 5 ++--- csrc/optimization/pre_segmenter.cpp | 3 +++ test/test_gpu3.cpp | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index e2deaa37d27..68c0274f6bd 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -55,9 +55,8 @@ class TORCH_CUDA_CU_API OptimizationGroup { template class TORCH_CUDA_CU_API OptimizationGroupGuard { public: - OptimizationGroupGuard(bool enabled) { - prev_status_ = OptGroup::setEnabled(enabled); - } + OptimizationGroupGuard(bool enabled) + : prev_status_(OptGroup::setEnabled(enabled)) {} ~OptimizationGroupGuard() { OptGroup::setEnabled(prev_status_); } diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 27c095eaf58..2db7545f51c 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -11,6 +11,9 @@ namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { + if (!flipEnabled(false)) { + return; + } ConsecutiveCastPass consecutive_cast_pass; consecutive_cast_pass.run(fusion); } diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 5cde6588a3f..fed4cd6d6e5 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -6071,7 +6071,8 @@ TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) { // Repro for // https://github.com/csarofeen/pytorch/issues/2094 TEST_F(NVFuserTest, FusionRepro2094_CUDA) { - // disable cast optimization in pre segmenter, which causes numerical issue on tests + // disable cast optimization in pre segmenter, which causes numerical issue on + // tests optimization::OptimizationGroupGuard guard(false); std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); From ce88a4cdec56c65829df2f92939f4d48a98a2906 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 22 May 2023 14:05:07 -0700 Subject: [PATCH 35/81] adding comment --- csrc/optimization/consecutive_cast_pass.h | 1 + csrc/optimization/optimization_pass.h | 9 ++++++--- csrc/optimization/pre_segmenter.cpp | 2 ++ csrc/optimization/pre_segmenter.h | 1 + 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast_pass.h index 7069299baa5..9aaccc6d6cf 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -9,6 +9,7 @@ namespace nvfuser::optimization { +//! ConsecutiveCastPass removes redundant consecutive cast operations class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { public: void run(Fusion* fusion) override; diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 68c0274f6bd..1b07d7f4edd 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -17,6 +17,7 @@ using FusionPass = std::function; //! [experimental API] //! Base class to unify optimization pass APIs. +//! OptimizationPass is functional and defines the granularity of mutation passes that is used to compose OptimizationGroups class TORCH_CUDA_CU_API OptimizationPass { public: virtual void run(Fusion*) = 0; @@ -25,7 +26,9 @@ class TORCH_CUDA_CU_API OptimizationPass { }; //! [experimental API] -//! Base class to unify optimization pass APIs. +//! Base class to unify optimization group APIs. +//! OptimizationGroup composes optimization passes that is used at certain stage in the runtime system. OptimizationGroup can be turned on/off programmatically with the `setEnabled/flipEnabled` API. There's helper template OptimizationGroupGuard to temporarily switch the enablement within the context. +//! Note the we are using a curiously recurring template pattern here to ensure that static objects are unique for each DerivedClass. template class TORCH_CUDA_CU_API OptimizationGroup { public: @@ -50,8 +53,8 @@ class TORCH_CUDA_CU_API OptimizationGroup { }; //! [experimental API] -////! OptimizationGroupGuard is used to temporarily switch enable/disable on a -////! certain pass. Original status will be restored at destruction. +//! OptimizationGroupGuard is used to temporarily switch enable/disable on a +//! certain pass. Original status will be restored at destruction. template class TORCH_CUDA_CU_API OptimizationGroupGuard { public: diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 2db7545f51c..eb5bc1210e4 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -11,9 +11,11 @@ namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { + // TODO: boilerplate code needed to enable on/off switch if (!flipEnabled(false)) { return; } + // removes consecutive cast operations ConsecutiveCastPass consecutive_cast_pass; consecutive_cast_pass.run(fusion); } diff --git a/csrc/optimization/pre_segmenter.h b/csrc/optimization/pre_segmenter.h index 7f6bc13c5e8..909f1e43b82 100644 --- a/csrc/optimization/pre_segmenter.h +++ b/csrc/optimization/pre_segmenter.h @@ -11,6 +11,7 @@ namespace nvfuser::optimization { +//! PreSegmenter is an optimization group that runs right before fusion executor segments a fusion into multiple kernels. class TORCH_CUDA_CU_API PreSegmenter : public OptimizationGroup { public: static void runPass(Fusion* fusion); From 5c47da46e7bd50383bb0ddf41074f846184cb3bc Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 22 May 2023 14:10:44 -0700 Subject: [PATCH 36/81] updating skip logic --- csrc/kernel_cache.cpp | 2 +- csrc/optimization/optimization_pass.h | 7 +++++++ csrc/optimization/pre_segmenter.cpp | 4 ---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 961161ac998..b7dfe616f7f 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -621,7 +621,7 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); - optimization::PreSegmenter::runPass(fusion.get()); + optimization::OptimizationGroup::runPass(fusion.get()); all_tvs_ = ir_utils::allTvs(fusion.get()); diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 1b07d7f4edd..060b6d04a84 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -49,6 +49,13 @@ class TORCH_CUDA_CU_API OptimizationGroup { return tmp; } + static void runPass(Fusion* fusion) { + if (!flipEnabled(false)) { + return; + } + DerivedClass::runPass(fusion); + } + virtual ~OptimizationGroup() = default; }; diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index eb5bc1210e4..c3fe4f8b615 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -11,10 +11,6 @@ namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { - // TODO: boilerplate code needed to enable on/off switch - if (!flipEnabled(false)) { - return; - } // removes consecutive cast operations ConsecutiveCastPass consecutive_cast_pass; consecutive_cast_pass.run(fusion); From b8a56ca864b70beb4c4a5e06bcd2d199236c3721 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 22 May 2023 14:25:01 -0700 Subject: [PATCH 37/81] comment --- csrc/optimization/optimization_pass.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 060b6d04a84..7c7c75df361 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -29,6 +29,8 @@ class TORCH_CUDA_CU_API OptimizationPass { //! Base class to unify optimization group APIs. //! OptimizationGroup composes optimization passes that is used at certain stage in the runtime system. OptimizationGroup can be turned on/off programmatically with the `setEnabled/flipEnabled` API. There's helper template OptimizationGroupGuard to temporarily switch the enablement within the context. //! Note the we are using a curiously recurring template pattern here to ensure that static objects are unique for each DerivedClass. +//! In order to apply OptimizationGroup with the switch enabled, you need to run the function with +//! `OptimizationGroup::runPass(...)` template class TORCH_CUDA_CU_API OptimizationGroup { public: From f0ea352897ff4cfa94cc6c607a94e90fe5fe5bcd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 22 May 2023 14:25:48 -0700 Subject: [PATCH 38/81] clangformat --- csrc/kernel_cache.cpp | 3 ++- csrc/optimization/optimization_pass.h | 15 ++++++++++----- csrc/optimization/pre_segmenter.h | 3 ++- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index b7dfe616f7f..b7fe036e11c 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -621,7 +621,8 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); - optimization::OptimizationGroup::runPass(fusion.get()); + optimization::OptimizationGroup::runPass( + fusion.get()); all_tvs_ = ir_utils::allTvs(fusion.get()); diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 7c7c75df361..a4a196523b4 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -17,7 +17,8 @@ using FusionPass = std::function; //! [experimental API] //! Base class to unify optimization pass APIs. -//! OptimizationPass is functional and defines the granularity of mutation passes that is used to compose OptimizationGroups +//! OptimizationPass is functional and defines the granularity of mutation +//! passes that is used to compose OptimizationGroups class TORCH_CUDA_CU_API OptimizationPass { public: virtual void run(Fusion*) = 0; @@ -27,10 +28,14 @@ class TORCH_CUDA_CU_API OptimizationPass { //! [experimental API] //! Base class to unify optimization group APIs. -//! OptimizationGroup composes optimization passes that is used at certain stage in the runtime system. OptimizationGroup can be turned on/off programmatically with the `setEnabled/flipEnabled` API. There's helper template OptimizationGroupGuard to temporarily switch the enablement within the context. -//! Note the we are using a curiously recurring template pattern here to ensure that static objects are unique for each DerivedClass. -//! In order to apply OptimizationGroup with the switch enabled, you need to run the function with -//! `OptimizationGroup::runPass(...)` +//! OptimizationGroup composes optimization passes that is used at certain stage +//! in the runtime system. OptimizationGroup can be turned on/off +//! programmatically with the `setEnabled/flipEnabled` API. There's helper +//! template OptimizationGroupGuard to temporarily switch the enablement within +//! the context. Note the we are using a curiously recurring template pattern +//! here to ensure that static objects are unique for each DerivedClass. In +//! order to apply OptimizationGroup with the switch enabled, you need to run +//! the function with `OptimizationGroup::runPass(...)` template class TORCH_CUDA_CU_API OptimizationGroup { public: diff --git a/csrc/optimization/pre_segmenter.h b/csrc/optimization/pre_segmenter.h index 909f1e43b82..97dca9eff27 100644 --- a/csrc/optimization/pre_segmenter.h +++ b/csrc/optimization/pre_segmenter.h @@ -11,7 +11,8 @@ namespace nvfuser::optimization { -//! PreSegmenter is an optimization group that runs right before fusion executor segments a fusion into multiple kernels. +//! PreSegmenter is an optimization group that runs right before fusion executor +//! segments a fusion into multiple kernels. class TORCH_CUDA_CU_API PreSegmenter : public OptimizationGroup { public: static void runPass(Fusion* fusion); From d7cb3e38a8875387602930aedc50efc24c3a644c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 23 May 2023 00:39:52 -0700 Subject: [PATCH 39/81] Apply suggestions from code review committing suggestions on code changes Co-authored-by: Naoya Maruyama --- csrc/optimization/consecutive_cast_pass.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 95b60394cbd..e03f8af5f58 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -14,7 +14,7 @@ namespace { void castOptimizationPass(Fusion* fusion) { auto is_cast_op = [](Expr* expr) { - if (expr->isA()) { + if (expr != nullptr && expr->isA()) { auto op = expr->as(); if (op->getUnaryOpType() == UnaryOpType::Cast) { return true; @@ -32,7 +32,9 @@ void castOptimizationPass(Fusion* fusion) { // in the loop, we just repetitively skip consecutive casts. auto intermediate_cast = expr->input(0); auto prev_expr = intermediate_cast->definition(); - if (prev_expr != nullptr && is_cast_op(prev_expr)) { + if (prev_expr == nullptr || !is_cast_op(prev_expr)) { + break; + } auto original_dtype = prev_expr->input(0)->getDataType().value(); auto intermediate_dtype = intermediate_cast->getDataType().value(); auto out_dtype = expr->output(0)->getDataType().value(); @@ -82,7 +84,7 @@ void castOptimizationPass(Fusion* fusion) { } } if (!replacement_map.empty()) { - nvfuser::ir_utils::replaceValue(fusion, replacement_map); + ir_utils::replaceValue(fusion, replacement_map); } } From 352025e14fcc185db3149edbf9b96717c8d20434 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 23 May 2023 01:42:26 -0700 Subject: [PATCH 40/81] addressing review comments --- csrc/optimization/consecutive_cast_pass.cpp | 66 ++++++++++----------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index e03f8af5f58..97047544017 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -24,7 +24,6 @@ void castOptimizationPass(Fusion* fusion) { }; // NOTE: not the most efficient pass - std::unordered_map replacement_map; for (auto expr : fusion->exprs()) { if (is_cast_op(expr)) { bool mutated = false; @@ -35,34 +34,37 @@ void castOptimizationPass(Fusion* fusion) { if (prev_expr == nullptr || !is_cast_op(prev_expr)) { break; } - auto original_dtype = prev_expr->input(0)->getDataType().value(); - auto intermediate_dtype = intermediate_cast->getDataType().value(); - auto out_dtype = expr->output(0)->getDataType().value(); - // cases where skipping the intermediate cast is relatively safe, - // two conditions: - // 1. original_dtype is the same as out_dtype; or - // 2. we support direct cast from original_dtype to out_dtype. - // and - // 1. intermediate_dtype is the same type category as with - // out_dtype; or - // 2. intermediate_dtype is more relaxed than out_dtype. e.g. a - // floating point vs. integral; - if ((original_dtype == out_dtype || - cast_func_str({original_dtype, out_dtype}).has_value()) && - ((isIntegralType(intermediate_dtype) && - isIntegralType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && - isFloatingPointType(out_dtype)) || - (isComplexType(intermediate_dtype) && - isComplexType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && - isIntegralType(out_dtype)))) { - expr = nvfuser::ir_utils::replaceValInExpr( - expr, intermediate_cast, prev_expr->input(0)); - mutated = true; - } else { - break; - } + // Note, if intermediate_cast: + // is used by other none-cast operations; or + // is a direct output from fusion + // we skip the casting + if (intermediate_cast->isFusionOutput() || + !std::all_of( + intermediate_cast->uses().begin(), + intermediate_cast->uses().end(), + is_cast_op)) { + break; + } + auto original_dtype = prev_expr->input(0)->getDataType().value(); + auto intermediate_dtype = intermediate_cast->getDataType().value(); + auto out_dtype = expr->output(0)->getDataType().value(); + // cases where skipping the intermediate cast is relatively safe, + // two conditions: + // 1. original_dtype is the same as out_dtype; or + // 2. we support direct cast from original_dtype to out_dtype. + // and + // 1. intermediate_dtype is the same type category as with + // out_dtype + if ((original_dtype == out_dtype || + cast_func_str({original_dtype, out_dtype}).has_value()) && + ((isIntegralType(intermediate_dtype) && + isIntegralType(out_dtype)) || + (isFloatingPointType(intermediate_dtype) && + isFloatingPointType(out_dtype)) || + (isComplexType(intermediate_dtype) && isComplexType(out_dtype)))) { + expr = nvfuser::ir_utils::replaceValInExpr( + expr, intermediate_cast, prev_expr->input(0)); + mutated = true; } else { break; } @@ -73,7 +75,8 @@ void castOptimizationPass(Fusion* fusion) { // casting to the same type if (expr->input(0)->getDataType().value() == expr->output(0)->getDataType().value()) { - replacement_map[expr->output(0)] = expr->input(0); + // replacing output with input in the fusion + ir_utils::replaceValue(fusion, {{expr->output(0), expr->input(0)}}); // NOTE: if current output is a fusion output, DCE won't kick in and // we'll ended up with an illegal cast. if (expr->output(0)->isFusionOutput()) { @@ -83,9 +86,6 @@ void castOptimizationPass(Fusion* fusion) { } } } - if (!replacement_map.empty()) { - ir_utils::replaceValue(fusion, replacement_map); - } } } // namespace From 83f58895fe97096b1888163d037692ba83d7e683 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 25 May 2023 19:00:22 -0700 Subject: [PATCH 41/81] WIP --- csrc/optimization/consecutive_cast_pass.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 97047544017..f625d299f9b 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -27,6 +27,16 @@ void castOptimizationPass(Fusion* fusion) { for (auto expr : fusion->exprs()) { if (is_cast_op(expr)) { bool mutated = false; + std::vector chain_casts; + auto prev_expr = expr->input(0)->definition(); + while (prev_expr != nullptr && is_cast_op(prev_expr)) { + chain_casts.push_back(prev_expr->output(0)); + prev_expr = prev_expr->input(0)->definition(); + } + + if (!chain_casts.empty()) { + } + while (true) { // in the loop, we just repetitively skip consecutive casts. auto intermediate_cast = expr->input(0); From 6fece9a78ffc8acf227a5e81cf00c31e90699947 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 16:10:43 -0700 Subject: [PATCH 42/81] filling in implementation --- csrc/optimization/consecutive_cast_pass.cpp | 166 +++++++++++++------- 1 file changed, 106 insertions(+), 60 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index f625d299f9b..968311a5433 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -12,87 +12,133 @@ namespace nvfuser::optimization { namespace { +bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { + if ((isIntegralType(input_t) && isIntegralType(out_dtype)) || + (isFloatingPointType(input_t) && isFloatingPointType(out_dtype)) || + (isComplexType(input_t) && isComplexType(out_dtype))) { + return true; + } + return false; +} + +// check if type1 is a wider type than type0 +// Which indicates a cast from type0 -> type1 -> type0 should be bit-wise identical +bool isWiderType(const DataType& type0, const DataType& type1) { + if (type0 == type1) { + return true; + } else if (type0 == DataType::Double && (type1 == DataType::Float || type1 == DataType::Half || type1 == DataType::BFloat16)) { + return true; + } else if (type0 == DataType::Float && (type1 == DataType::Half || type1 == DataType::BFloat16)) { + return true; + } else if (type0 == DataType::Int && type1 == DataType::Int32) { + return true; + } else if (type0 == DataType::ComplexDouble && type1 == DataType::ComplexFloat) { + return true; + } + return false; +} + +// note: returns +// - -1 : v0 contains strictly less information than v1; +// - 0 : a complex case, where each v0 and v1 isn't a super set of the other; +// - 1 : v0 and v1 has the same dtype; +// - 2 : v0 contains strictly more information than v1; +int checkInformationLoss(Val* v0, Val* v1) { + auto dtype0 = v0->getDataType().value(); + auto dtype1 = v1->getDataType().value(); + if (dtype0 == dtype1) { + return 1; + } + if ((dtype0 == DataType::BFloat16 && dtype1 == DataType::Half) || + (dtype1 == DataType::BFloat16 && dtype0 == DataType::Half)) { + return 0; + } + if (isWiderType(dtype0, dtype1)) { + return -1; + } + TORCH_INTERNAL_ASSERT(isWiderType(dtype1, dtype0), "unrecognized cast category is encountered"); + return 2; +} + void castOptimizationPass(Fusion* fusion) { - auto is_cast_op = [](Expr* expr) { + auto is_foldable_cast_op = [](Expr* expr) { if (expr != nullptr && expr->isA()) { auto op = expr->as(); - if (op->getUnaryOpType() == UnaryOpType::Cast) { + if (op->getUnaryOpType() == UnaryOpType::Cast && + isSameDtypeCategory(expr->input(0)->getDataType().value(), expr->output(0)->getDataType().value())) { return true; } } return false; }; - // NOTE: not the most efficient pass + // TODO: Traveral implies topological order on returns exprs, we can leverage that to improve the effieciency of the pass. In the case of a straight line casts, we are doing a lot of meaningless work here on mutating intermediate casts that would have been done again at the end of the chain. for (auto expr : fusion->exprs()) { - if (is_cast_op(expr)) { + if (is_foldable_cast_op(expr)) { bool mutated = false; std::vector chain_casts; auto prev_expr = expr->input(0)->definition(); - while (prev_expr != nullptr && is_cast_op(prev_expr)) { - chain_casts.push_back(prev_expr->output(0)); - prev_expr = prev_expr->input(0)->definition(); - } - - if (!chain_casts.empty()) { - } - - while (true) { - // in the loop, we just repetitively skip consecutive casts. - auto intermediate_cast = expr->input(0); - auto prev_expr = intermediate_cast->definition(); - if (prev_expr == nullptr || !is_cast_op(prev_expr)) { - break; - } - // Note, if intermediate_cast: - // is used by other none-cast operations; or + while (prev_expr != nullptr && is_foldable_cast_op(prev_expr)) { + auto intermediate_cast = prev_expr->output(0); + // Note, if the output f prev_expr + // is used by other operation(s); or // is a direct output from fusion - // we skip the casting + // we skip the casting chaining if (intermediate_cast->isFusionOutput() || - !std::all_of( - intermediate_cast->uses().begin(), - intermediate_cast->uses().end(), - is_cast_op)) { - break; - } - auto original_dtype = prev_expr->input(0)->getDataType().value(); - auto intermediate_dtype = intermediate_cast->getDataType().value(); - auto out_dtype = expr->output(0)->getDataType().value(); - // cases where skipping the intermediate cast is relatively safe, - // two conditions: - // 1. original_dtype is the same as out_dtype; or - // 2. we support direct cast from original_dtype to out_dtype. - // and - // 1. intermediate_dtype is the same type category as with - // out_dtype - if ((original_dtype == out_dtype || - cast_func_str({original_dtype, out_dtype}).has_value()) && - ((isIntegralType(intermediate_dtype) && - isIntegralType(out_dtype)) || - (isFloatingPointType(intermediate_dtype) && - isFloatingPointType(out_dtype)) || - (isComplexType(intermediate_dtype) && isComplexType(out_dtype)))) { - expr = nvfuser::ir_utils::replaceValInExpr( - expr, intermediate_cast, prev_expr->input(0)); - mutated = true; - } else { + intermediate_cast->uses().size() > 1) { break; } + + // in the loop, we just repetitively chaining consecutive casts. + chain_casts.push_back(intermediate_cast); + prev_expr = prev_expr->input(0)->definition(); } - if (mutated) { - // quick short-wire to skip current cast node if it's trivially - // casting to the same type - if (expr->input(0)->getDataType().value() == - expr->output(0)->getDataType().value()) { - // replacing output with input in the fusion - ir_utils::replaceValue(fusion, {{expr->output(0), expr->input(0)}}); - // NOTE: if current output is a fusion output, DCE won't kick in and - // we'll ended up with an illegal cast. + // Note, chain_casts has a straight-line use without branches + if (!chain_casts.empty()) { + auto lo_anchor = chain_casts[0]->definition()->input(0); + auto starting_anchor = lo_anchor; + for (auto val : chain_casts) { + auto info = checkInformationLoss(anchor, val); + if (info >= 0) { + if (info == 0) { + // we run into a lemon case where we are casting between two types that can't be folded away. i.e. bf16 & fp16. + auto tmp_expr = val->definition(); + if (lo_anchor != tmp_expr->input(0)) { + tmp_expr = nvfuser::ir_utils::replaceValInExpr(tmp_expr, tmp_expr->input(0), lo_anchor); + } + // move starting_anchor past the ambiguous case + starting_anchor = val; + } + // updating lo_anchor + lo_anchor = val; + } + } + + auto info = checkInformationLoss(lo_anchor, expr->output(0)); + if (info == 1) { + // replacing output with lo_anchor in the fusion + ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); if (expr->output(0)->isFusionOutput()) { - fusion->replaceOutput(expr->output(0), expr->input(0)); + fusion->replaceOutput(expr->output(0), lo_anchor); } - } + } else if (info == -1 || info == 0) { + // expr output has either: + // higher precision than lo_anchor; or + // incompatible precision + // in either case, we can't fold away lo_anchor, we'll just re-wire the input to expr to lo_anchor + expr = nvfuser::ir_utils::replaceValInExpr( + expr, expr->input(0), lo_anchor); + } else if (info == 2) { + // if expr has lower precision than lo_anchor, we'll just fold away to the starting_anchor instead + expr = nvfuser::ir_utils::replaceValInExpr( + expr, expr->input(0), starting_anchor); + } else { + TORCH_INTERNAL_ASSERT(false, "checkInformationLoss returns a flag that's not recognized"); + } + + expr = nvfuser::ir_utils::replaceValInExpr( + expr, intermediate_cast, prev_expr->input(0)); } } } From 3d1ad953276a424ee51b42647c8d0ffe43a09db6 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 16:26:37 -0700 Subject: [PATCH 43/81] fixing typo; updating logic --- csrc/optimization/consecutive_cast_pass.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 968311a5433..00b487c3287 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -13,9 +13,9 @@ namespace nvfuser::optimization { namespace { bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { - if ((isIntegralType(input_t) && isIntegralType(out_dtype)) || - (isFloatingPointType(input_t) && isFloatingPointType(out_dtype)) || - (isComplexType(input_t) && isComplexType(out_dtype))) { + if ((isIntegralType(input_t) && isIntegralType(output_t)) || + (isFloatingPointType(input_t) && isFloatingPointType(output_t)) || + (isComplexType(input_t) && isComplexType(output_t))) { return true; } return false; @@ -75,7 +75,6 @@ void castOptimizationPass(Fusion* fusion) { // TODO: Traveral implies topological order on returns exprs, we can leverage that to improve the effieciency of the pass. In the case of a straight line casts, we are doing a lot of meaningless work here on mutating intermediate casts that would have been done again at the end of the chain. for (auto expr : fusion->exprs()) { if (is_foldable_cast_op(expr)) { - bool mutated = false; std::vector chain_casts; auto prev_expr = expr->input(0)->definition(); while (prev_expr != nullptr && is_foldable_cast_op(prev_expr)) { @@ -99,10 +98,11 @@ void castOptimizationPass(Fusion* fusion) { auto lo_anchor = chain_casts[0]->definition()->input(0); auto starting_anchor = lo_anchor; for (auto val : chain_casts) { - auto info = checkInformationLoss(anchor, val); - if (info >= 0) { + auto info = checkInformationLoss(lo_anchor, val); + // if information on new val drops below the anchor, we want to update the anchor + if (info <= 0) { + // we run into a complex case where we are casting between two types that can't be folded away. i.e. bf16 & fp16. We need to update the starting_anchor for the final fold to be past this current cast. if (info == 0) { - // we run into a lemon case where we are casting between two types that can't be folded away. i.e. bf16 & fp16. auto tmp_expr = val->definition(); if (lo_anchor != tmp_expr->input(0)) { tmp_expr = nvfuser::ir_utils::replaceValInExpr(tmp_expr, tmp_expr->input(0), lo_anchor); @@ -136,9 +136,6 @@ void castOptimizationPass(Fusion* fusion) { } else { TORCH_INTERNAL_ASSERT(false, "checkInformationLoss returns a flag that's not recognized"); } - - expr = nvfuser::ir_utils::replaceValInExpr( - expr, intermediate_cast, prev_expr->input(0)); } } } From 082c7fc29bc2a9ee221be6f6b550e90e1986062e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 17:09:05 -0700 Subject: [PATCH 44/81] patched tests --- csrc/type.h | 10 ++-- test/test_gpu3.cpp | 131 ++++----------------------------------------- 2 files changed, 17 insertions(+), 124 deletions(-) diff --git a/csrc/type.h b/csrc/type.h index 5b1e474284e..05a65b7bca7 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -62,16 +62,20 @@ enum class PredicateType { // int64_t which is relatively heavy to carry around. Index will be resolved // at compile time with KernelIndexMode. enum class PrimDataType { + // Floating point types Double, Float, Half, + BFloat16, + // Integral types Int, - Index, Int32, + Index, + // Boolean types Bool, - BFloat16, - ComplexFloat, + // Complex types ComplexDouble, + ComplexFloat, // Pointers SMemAddress, // Null diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 6364b8acdfa..537ed1ff88d 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8400,114 +8400,16 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { .dtype(DataType::Double) .build(); fusion->addInput(tv0); - auto tv1 = - castOp(DataType::Half, tv0); // consecutive cast should be removed - auto tv2 = castOp(DataType::Float, tv1); - auto tv3 = relu(tv2); - auto tv4 = neg(tv3); - auto tv5 = castOp(DataType::Double, tv4); - fusion->addOutput(tv5); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({at_x}); - auto ref_out = at_x.clone().relu().neg(); - - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); - int cast_op_count = 0; - for (auto expr : complete_fusion->exprs()) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Cast) { - ++cast_op_count; - } - } - TORCH_CHECK( - cast_op_count == 2, "cast optimization isn't working as expected"); - - testValidate( - executor_cache.fusion(), - outputs, - {at_x}, - {ref_out}, - __LINE__, - __FILE__); - } - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Double) - .build(); - fusion->addInput(tv0); - auto tv1 = castOp(DataType::Int, tv0); - // previous cast cannot be optimized away due to precision - auto tv2 = castOp(DataType::Float, tv1); - auto tv3 = neg(tv2); - fusion->addOutput(tv3); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({at_x}); - auto ref_out = at_x.clone().to(at::kInt).to(at::kFloat).neg(); - - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); - int cast_op_count = 0; - for (auto expr : complete_fusion->exprs()) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Cast) { - ++cast_op_count; - } - } - TORCH_CHECK( - cast_op_count == 2, "cast optimization isn't working as expected"); - - testValidate( - executor_cache.fusion(), - outputs, - {at_x}, - {ref_out}, - __LINE__, - __FILE__); - } - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Double) - .build(); - fusion->addInput(tv0); - auto tv1 = neg(tv0); - auto tv2 = castOp(DataType::Float, tv1); - auto tv3 = castOp(DataType::Double, tv2); - fusion->addOutput(tv3); - - FusionExecutorCache executor_cache(std::move(fusion)); - auto outputs = executor_cache.runFusionWithInputs({at_x}); - auto ref_out = at_x.clone().neg(); - - auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); - auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); - int cast_op_count = 0; - for (auto expr : complete_fusion->exprs()) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Cast) { - ++cast_op_count; - } - } - TORCH_CHECK( - cast_op_count == 0, "cast optimization isn't working as expected"); - - testValidate( - executor_cache.fusion(), - outputs, - {at_x}, - {ref_out}, - __LINE__, - __FILE__); + auto tv1 = castOp(DataType::Float, tv0); + auto tv2 = castOp(DataType::Half, tv1); + auto tv3 = castOp(DataType::Float, tv2); + auto tv4 = castOp(DataType::Double, tv3); + // (input)double -> float -> half -> float -> double + fusion->addOutput(tv4); + optimization::OptimizationGroup::runPass(fusion.get()); + auto ref_tv = castOp(DataType::Half, tv0); + auto ref_tv = castOp(DataType::Double, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } } @@ -8610,19 +8512,6 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { FusionExecutorCache fec(std::move(fusion_ptr)); auto cg_outputs = fec.runFusionWithInputs(inputs); - - auto optimized_fusion = fec.getMostRecentKernelRuntime(); - auto complete_fusion = optimized_fusion->fusionSegments()->completeFusion(); - int cast_op_count = 0; - for (auto expr : complete_fusion->exprs()) { - if (expr->isA() && - expr->as()->getUnaryOpType() == UnaryOpType::Cast) { - ++cast_op_count; - } - } - TORCH_CHECK( - cast_op_count == 9, "cast optimization isn't working as expected"); - testValidate(fusion, cg_outputs, inputs, outputs, __LINE__, __FILE__); } From ba667977016c948e5267506f7d3b8c2787905ba5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 17:17:49 -0700 Subject: [PATCH 45/81] fixing typo --- test/test_gpu3.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 537ed1ff88d..fd79b85bd2d 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8408,7 +8408,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv4); optimization::OptimizationGroup::runPass(fusion.get()); auto ref_tv = castOp(DataType::Half, tv0); - auto ref_tv = castOp(DataType::Double, ref_tv); + ref_tv = castOp(DataType::Double, ref_tv); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } } From 8233762b40c8761f76415550747343c02375bae5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 19:19:45 -0700 Subject: [PATCH 46/81] fixing logic, prints in test --- csrc/optimization/consecutive_cast_pass.cpp | 28 ++++++++++----------- test/test_gpu3.cpp | 7 ++++++ 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 00b487c3287..828561679f7 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -21,28 +21,28 @@ bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { return false; } -// check if type1 is a wider type than type0 -// Which indicates a cast from type0 -> type1 -> type0 should be bit-wise identical -bool isWiderType(const DataType& type0, const DataType& type1) { - if (type0 == type1) { +// check if type is a wider type than ref +// Which indicates a cast from ref -> type -> ref should be bit-wise identical +bool isWiderType(const DataType& ref, const DataType& type) { + if (ref == type) { return true; - } else if (type0 == DataType::Double && (type1 == DataType::Float || type1 == DataType::Half || type1 == DataType::BFloat16)) { + } else if (type == DataType::Double && (ref == DataType::Float || ref == DataType::Half || ref == DataType::BFloat16)) { return true; - } else if (type0 == DataType::Float && (type1 == DataType::Half || type1 == DataType::BFloat16)) { + } else if (type == DataType::Float && (ref == DataType::Half || ref == DataType::BFloat16)) { return true; - } else if (type0 == DataType::Int && type1 == DataType::Int32) { + } else if (type == DataType::Int && ref == DataType::Int32) { return true; - } else if (type0 == DataType::ComplexDouble && type1 == DataType::ComplexFloat) { + } else if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { return true; } return false; } // note: returns -// - -1 : v0 contains strictly less information than v1; +// - -1 : v0 contains strictly more information than v1; // - 0 : a complex case, where each v0 and v1 isn't a super set of the other; // - 1 : v0 and v1 has the same dtype; -// - 2 : v0 contains strictly more information than v1; +// - 2 : v0 contains strictly less information than v1; int checkInformationLoss(Val* v0, Val* v1) { auto dtype0 = v0->getDataType().value(); auto dtype1 = v1->getDataType().value(); @@ -54,10 +54,10 @@ int checkInformationLoss(Val* v0, Val* v1) { return 0; } if (isWiderType(dtype0, dtype1)) { - return -1; + return 2; } TORCH_INTERNAL_ASSERT(isWiderType(dtype1, dtype0), "unrecognized cast category is encountered"); - return 2; + return -1; } void castOptimizationPass(Fusion* fusion) { @@ -122,14 +122,14 @@ void castOptimizationPass(Fusion* fusion) { if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); } - } else if (info == -1 || info == 0) { + } else if (info == 2 || info == 0) { // expr output has either: // higher precision than lo_anchor; or // incompatible precision // in either case, we can't fold away lo_anchor, we'll just re-wire the input to expr to lo_anchor expr = nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), lo_anchor); - } else if (info == 2) { + } else if (info == -1) { // if expr has lower precision than lo_anchor, we'll just fold away to the starting_anchor instead expr = nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), starting_anchor); diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index fd79b85bd2d..9dbbe276023 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8406,9 +8406,16 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { auto tv4 = castOp(DataType::Double, tv3); // (input)double -> float -> half -> float -> double fusion->addOutput(tv4); + printf("----start----\n"); + fusion->printMath(); optimization::OptimizationGroup::runPass(fusion.get()); + printf("---- opt ----\n"); + fusion->printMath(); auto ref_tv = castOp(DataType::Half, tv0); ref_tv = castOp(DataType::Double, ref_tv); + printf("---- ref ----\n"); + fusion->addOutput(ref_tv); + fusion->printMath(); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } } From a6066c3357a784f7a9eb67efb85f073143a65adb Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 19:35:20 -0700 Subject: [PATCH 47/81] more test case added --- test/test_gpu3.cpp | 95 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 6 deletions(-) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 9dbbe276023..ee58a36a684 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8400,21 +8400,104 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { .dtype(DataType::Double) .build(); fusion->addInput(tv0); - auto tv1 = castOp(DataType::Float, tv0); - auto tv2 = castOp(DataType::Half, tv1); - auto tv3 = castOp(DataType::Float, tv2); - auto tv4 = castOp(DataType::Double, tv3); + auto tv = castOp(DataType::Float, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); // (input)double -> float -> half -> float -> double - fusion->addOutput(tv4); + fusion->addOutput(tv); + optimization::OptimizationGroup::runPass(fusion.get()); + // simplified as (input)double -> half -> double + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::Double, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::Float, tv); + // (input)float -> double -> float + fusion->addOutput(tv); + printf("----start----\n"); + fusion->printMath(); + optimization::OptimizationGroup::runPass(fusion.get()); + printf("---- opt ----\n"); + fusion->printMath(); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input) + printf("---- ref ----\n"); + fusion->printMath(); + ASSERT_TRUE(tv0->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + // (input)float -> double -> half -> float -> double -> float -> double -> float -> double -> float + fusion->addOutput(tv); printf("----start----\n"); fusion->printMath(); optimization::OptimizationGroup::runPass(fusion.get()); printf("---- opt ----\n"); fusion->printMath(); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input)float -> half -> float auto ref_tv = castOp(DataType::Half, tv0); - ref_tv = castOp(DataType::Double, ref_tv); + ref_tv = castOp(DataType::Float, ref_tv); + fusion->addOutput(ref_tv); printf("---- ref ----\n"); + fusion->printMath(); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::BFloat16, tv); + tv = castOp(DataType::Float, tv); + // (input)float -> double -> half -> float -> bfloat16 -> float + fusion->addOutput(tv); + printf("----start----\n"); + fusion->printMath(); + optimization::OptimizationGroup::runPass(fusion.get()); + printf("---- opt ----\n"); + fusion->printMath(); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input)float -> half -> bfloat16 -> float + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::BFloat16, ref_tv); + ref_tv = castOp(DataType::Float, ref_tv); fusion->addOutput(ref_tv); + printf("---- ref ----\n"); fusion->printMath(); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } From b08054d0f35ebfaf96b43972cca925fbe72429a1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 20:01:15 -0700 Subject: [PATCH 48/81] fixing tests --- csrc/optimization/consecutive_cast_pass.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 828561679f7..47a35f94dd4 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -89,13 +89,13 @@ void castOptimizationPass(Fusion* fusion) { } // in the loop, we just repetitively chaining consecutive casts. - chain_casts.push_back(intermediate_cast); + chain_casts.push_front(intermediate_cast); prev_expr = prev_expr->input(0)->definition(); } // Note, chain_casts has a straight-line use without branches if (!chain_casts.empty()) { - auto lo_anchor = chain_casts[0]->definition()->input(0); + auto lo_anchor = chain_casts.front()->definition()->input(0); auto starting_anchor = lo_anchor; for (auto val : chain_casts) { auto info = checkInformationLoss(lo_anchor, val); From 3cdd2dc687acde5d44faf4d992d53766db42f355 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 20:02:57 -0700 Subject: [PATCH 49/81] changing container type --- csrc/optimization/consecutive_cast_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 47a35f94dd4..0e5a85746c7 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -75,7 +75,7 @@ void castOptimizationPass(Fusion* fusion) { // TODO: Traveral implies topological order on returns exprs, we can leverage that to improve the effieciency of the pass. In the case of a straight line casts, we are doing a lot of meaningless work here on mutating intermediate casts that would have been done again at the end of the chain. for (auto expr : fusion->exprs()) { if (is_foldable_cast_op(expr)) { - std::vector chain_casts; + std::list chain_casts; auto prev_expr = expr->input(0)->definition(); while (prev_expr != nullptr && is_foldable_cast_op(prev_expr)) { auto intermediate_cast = prev_expr->output(0); From 48ac0cf4a9412262208a0e07d3cad0243323dec6 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 20:10:36 -0700 Subject: [PATCH 50/81] moving wide type check to type.h/cpp --- csrc/optimization/consecutive_cast_pass.cpp | 17 ----------------- csrc/type.cpp | 17 +++++++++++++++++ csrc/type.h | 4 ++++ 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 0e5a85746c7..d679ba79ec6 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -21,23 +21,6 @@ bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { return false; } -// check if type is a wider type than ref -// Which indicates a cast from ref -> type -> ref should be bit-wise identical -bool isWiderType(const DataType& ref, const DataType& type) { - if (ref == type) { - return true; - } else if (type == DataType::Double && (ref == DataType::Float || ref == DataType::Half || ref == DataType::BFloat16)) { - return true; - } else if (type == DataType::Float && (ref == DataType::Half || ref == DataType::BFloat16)) { - return true; - } else if (type == DataType::Int && ref == DataType::Int32) { - return true; - } else if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { - return true; - } - return false; -} - // note: returns // - -1 : v0 contains strictly more information than v1; // - 0 : a complex case, where each v0 and v1 isn't a super set of the other; diff --git a/csrc/type.cpp b/csrc/type.cpp index 49b5bdee5e3..d4df95d423b 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -32,6 +32,23 @@ KernelIndexMode indexTypeToMode(DataType index_type) { : KernelIndexMode::INT64; } +bool isWiderType(const DataType& ref, const DataType& type) { + if (ref == type) { + return true; + } else if (ref == DataType::Bool) { + return true; + } else if ((type == DataType::Double || type == DataType::ComplexDouble) && (ref == DataType::Float || ref == DataType::Half || ref == DataType::BFloat16)) { + return true; + } else if ((type == DataType::Float || type == DataType::ComplexFloat) && (ref == DataType::Half || ref == DataType::BFloat16)) { + return true; + } else if ((type == DataType::Int || type == DataType::Double || type == DataType::ComplexDouble) && ref == DataType::Int32) { + return true; + } else if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { + return true; + } + return false; +} + DataType getTypeFromComplexType(DataType dtype) { switch (std::get(dtype.type)) { case DataType::ComplexFloat: diff --git a/csrc/type.h b/csrc/type.h index 05a65b7bca7..996ca0002fc 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -141,6 +141,10 @@ enum class KernelIndexMode { INT32, INT64 }; PrimDataType indexModeToDtype(KernelIndexMode index_mode); KernelIndexMode indexTypeToMode(DataType index_type); +// check if type is a wider type than ref +// Which indicates a cast from ref -> type -> ref should be bit-wise identical +bool isWiderType(const DataType& ref, const DataType& type); + // Returns if the datatype is a floating point type TORCH_CUDA_CU_API inline bool isFloatingPointType(DataType dtype) { TORCH_CHECK( From 527ec3d370751ee8305edaf7641495993fd610ec Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 20:11:54 -0700 Subject: [PATCH 51/81] lintrunner --- csrc/optimization/consecutive_cast_pass.cpp | 81 ++++++++++++--------- csrc/type.cpp | 14 +++- test/test_gpu3.cpp | 15 ++-- 3 files changed, 69 insertions(+), 41 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index d679ba79ec6..f99f218352a 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -35,11 +35,12 @@ int checkInformationLoss(Val* v0, Val* v1) { if ((dtype0 == DataType::BFloat16 && dtype1 == DataType::Half) || (dtype1 == DataType::BFloat16 && dtype0 == DataType::Half)) { return 0; - } + } if (isWiderType(dtype0, dtype1)) { return 2; } - TORCH_INTERNAL_ASSERT(isWiderType(dtype1, dtype0), "unrecognized cast category is encountered"); + TORCH_INTERNAL_ASSERT( + isWiderType(dtype1, dtype0), "unrecognized cast category is encountered"); return -1; } @@ -47,21 +48,26 @@ void castOptimizationPass(Fusion* fusion) { auto is_foldable_cast_op = [](Expr* expr) { if (expr != nullptr && expr->isA()) { auto op = expr->as(); - if (op->getUnaryOpType() == UnaryOpType::Cast && - isSameDtypeCategory(expr->input(0)->getDataType().value(), expr->output(0)->getDataType().value())) { + if (op->getUnaryOpType() == UnaryOpType::Cast && + isSameDtypeCategory( + expr->input(0)->getDataType().value(), + expr->output(0)->getDataType().value())) { return true; } } return false; }; - // TODO: Traveral implies topological order on returns exprs, we can leverage that to improve the effieciency of the pass. In the case of a straight line casts, we are doing a lot of meaningless work here on mutating intermediate casts that would have been done again at the end of the chain. + // TODO: Traveral implies topological order on returns exprs, we can leverage + // that to improve the effieciency of the pass. In the case of a straight line + // casts, we are doing a lot of meaningless work here on mutating intermediate + // casts that would have been done again at the end of the chain. for (auto expr : fusion->exprs()) { if (is_foldable_cast_op(expr)) { std::list chain_casts; auto prev_expr = expr->input(0)->definition(); while (prev_expr != nullptr && is_foldable_cast_op(prev_expr)) { - auto intermediate_cast = prev_expr->output(0); + auto intermediate_cast = prev_expr->output(0); // Note, if the output f prev_expr // is used by other operation(s); or // is a direct output from fusion @@ -72,53 +78,62 @@ void castOptimizationPass(Fusion* fusion) { } // in the loop, we just repetitively chaining consecutive casts. - chain_casts.push_front(intermediate_cast); + chain_casts.push_front(intermediate_cast); prev_expr = prev_expr->input(0)->definition(); } // Note, chain_casts has a straight-line use without branches if (!chain_casts.empty()) { auto lo_anchor = chain_casts.front()->definition()->input(0); - auto starting_anchor = lo_anchor; - for (auto val : chain_casts) { - auto info = checkInformationLoss(lo_anchor, val); - // if information on new val drops below the anchor, we want to update the anchor + auto starting_anchor = lo_anchor; + for (auto val : chain_casts) { + auto info = checkInformationLoss(lo_anchor, val); + // if information on new val drops below the anchor, we want to update + // the anchor if (info <= 0) { - // we run into a complex case where we are casting between two types that can't be folded away. i.e. bf16 & fp16. We need to update the starting_anchor for the final fold to be past this current cast. + // we run into a complex case where we are casting between two types + // that can't be folded away. i.e. bf16 & fp16. We need to update + // the starting_anchor for the final fold to be past this current + // cast. if (info == 0) { - auto tmp_expr = val->definition(); - if (lo_anchor != tmp_expr->input(0)) { - tmp_expr = nvfuser::ir_utils::replaceValInExpr(tmp_expr, tmp_expr->input(0), lo_anchor); - } - // move starting_anchor past the ambiguous case - starting_anchor = val; - } - // updating lo_anchor - lo_anchor = val; - } - } + auto tmp_expr = val->definition(); + if (lo_anchor != tmp_expr->input(0)) { + tmp_expr = nvfuser::ir_utils::replaceValInExpr( + tmp_expr, tmp_expr->input(0), lo_anchor); + } + // move starting_anchor past the ambiguous case + starting_anchor = val; + } + // updating lo_anchor + lo_anchor = val; + } + } - auto info = checkInformationLoss(lo_anchor, expr->output(0)); - if (info == 1) { + auto info = checkInformationLoss(lo_anchor, expr->output(0)); + if (info == 1) { // replacing output with lo_anchor in the fusion ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); } - } else if (info == 2 || info == 0) { + } else if (info == 2 || info == 0) { // expr output has either: - // higher precision than lo_anchor; or - // incompatible precision - // in either case, we can't fold away lo_anchor, we'll just re-wire the input to expr to lo_anchor + // higher precision than lo_anchor; or + // incompatible precision + // in either case, we can't fold away lo_anchor, we'll just re-wire + // the input to expr to lo_anchor expr = nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), lo_anchor); - } else if (info == -1) { - // if expr has lower precision than lo_anchor, we'll just fold away to the starting_anchor instead + } else if (info == -1) { + // if expr has lower precision than lo_anchor, we'll just fold away to + // the starting_anchor instead expr = nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), starting_anchor); } else { - TORCH_INTERNAL_ASSERT(false, "checkInformationLoss returns a flag that's not recognized"); - } + TORCH_INTERNAL_ASSERT( + false, + "checkInformationLoss returns a flag that's not recognized"); + } } } } diff --git a/csrc/type.cpp b/csrc/type.cpp index d4df95d423b..2cc29e24567 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -37,11 +37,19 @@ bool isWiderType(const DataType& ref, const DataType& type) { return true; } else if (ref == DataType::Bool) { return true; - } else if ((type == DataType::Double || type == DataType::ComplexDouble) && (ref == DataType::Float || ref == DataType::Half || ref == DataType::BFloat16)) { + } else if ( + (type == DataType::Double || type == DataType::ComplexDouble) && + (ref == DataType::Float || ref == DataType::Half || + ref == DataType::BFloat16)) { return true; - } else if ((type == DataType::Float || type == DataType::ComplexFloat) && (ref == DataType::Half || ref == DataType::BFloat16)) { + } else if ( + (type == DataType::Float || type == DataType::ComplexFloat) && + (ref == DataType::Half || ref == DataType::BFloat16)) { return true; - } else if ((type == DataType::Int || type == DataType::Double || type == DataType::ComplexDouble) && ref == DataType::Int32) { + } else if ( + (type == DataType::Int || type == DataType::Double || + type == DataType::ComplexDouble) && + ref == DataType::Int32) { return true; } else if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { return true; diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 02ee943d94b..e1e64a54c8d 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8406,7 +8406,8 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Double, tv); // (input)double -> float -> half -> float -> double fusion->addOutput(tv); - optimization::OptimizationGroup::runPass(fusion.get()); + optimization::OptimizationGroup::runPass( + fusion.get()); // simplified as (input)double -> half -> double auto ref_tv = castOp(DataType::Half, tv0); ref_tv = castOp(DataType::Double, ref_tv); @@ -8427,7 +8428,8 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv); printf("----start----\n"); fusion->printMath(); - optimization::OptimizationGroup::runPass(fusion.get()); + optimization::OptimizationGroup::runPass( + fusion.get()); printf("---- opt ----\n"); fusion->printMath(); // TODO: should I have copied the tensor to avoid an alised output?! @@ -8454,11 +8456,13 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Float, tv); tv = castOp(DataType::Double, tv); tv = castOp(DataType::Float, tv); - // (input)float -> double -> half -> float -> double -> float -> double -> float -> double -> float + // (input)float -> double -> half -> float -> double -> float -> double -> + // float -> double -> float fusion->addOutput(tv); printf("----start----\n"); fusion->printMath(); - optimization::OptimizationGroup::runPass(fusion.get()); + optimization::OptimizationGroup::runPass( + fusion.get()); printf("---- opt ----\n"); fusion->printMath(); // TODO: should I have copied the tensor to avoid an alised output?! @@ -8488,7 +8492,8 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv); printf("----start----\n"); fusion->printMath(); - optimization::OptimizationGroup::runPass(fusion.get()); + optimization::OptimizationGroup::runPass( + fusion.get()); printf("---- opt ----\n"); fusion->printMath(); // TODO: should I have copied the tensor to avoid an alised output?! From c3bcd35c6f4454df1973c7ada4f766045c2873a0 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 20:16:49 -0700 Subject: [PATCH 52/81] removing print --- test/test_gpu3.cpp | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index e1e64a54c8d..d1c6d1fee9f 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8426,16 +8426,10 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Float, tv); // (input)float -> double -> float fusion->addOutput(tv); - printf("----start----\n"); - fusion->printMath(); optimization::OptimizationGroup::runPass( fusion.get()); - printf("---- opt ----\n"); - fusion->printMath(); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input) - printf("---- ref ----\n"); - fusion->printMath(); ASSERT_TRUE(tv0->sameAs(fusion->outputs()[0])); } @@ -8459,19 +8453,12 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { // (input)float -> double -> half -> float -> double -> float -> double -> // float -> double -> float fusion->addOutput(tv); - printf("----start----\n"); - fusion->printMath(); optimization::OptimizationGroup::runPass( fusion.get()); - printf("---- opt ----\n"); - fusion->printMath(); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> float auto ref_tv = castOp(DataType::Half, tv0); ref_tv = castOp(DataType::Float, ref_tv); - fusion->addOutput(ref_tv); - printf("---- ref ----\n"); - fusion->printMath(); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } @@ -8490,20 +8477,13 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Float, tv); // (input)float -> double -> half -> float -> bfloat16 -> float fusion->addOutput(tv); - printf("----start----\n"); - fusion->printMath(); optimization::OptimizationGroup::runPass( fusion.get()); - printf("---- opt ----\n"); - fusion->printMath(); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> bfloat16 -> float auto ref_tv = castOp(DataType::Half, tv0); ref_tv = castOp(DataType::BFloat16, ref_tv); ref_tv = castOp(DataType::Float, ref_tv); - fusion->addOutput(ref_tv); - printf("---- ref ----\n"); - fusion->printMath(); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } } From fcba58992e9bac09d6083a16f88554246f86a831 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 23:27:14 -0700 Subject: [PATCH 53/81] clang-format clang-tidy --- csrc/optimization/consecutive_cast_pass.cpp | 6 +++--- csrc/type.cpp | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index f99f218352a..327e65e78d5 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -98,7 +98,7 @@ void castOptimizationPass(Fusion* fusion) { if (info == 0) { auto tmp_expr = val->definition(); if (lo_anchor != tmp_expr->input(0)) { - tmp_expr = nvfuser::ir_utils::replaceValInExpr( + nvfuser::ir_utils::replaceValInExpr( tmp_expr, tmp_expr->input(0), lo_anchor); } // move starting_anchor past the ambiguous case @@ -122,12 +122,12 @@ void castOptimizationPass(Fusion* fusion) { // incompatible precision // in either case, we can't fold away lo_anchor, we'll just re-wire // the input to expr to lo_anchor - expr = nvfuser::ir_utils::replaceValInExpr( + nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), lo_anchor); } else if (info == -1) { // if expr has lower precision than lo_anchor, we'll just fold away to // the starting_anchor instead - expr = nvfuser::ir_utils::replaceValInExpr( + nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), starting_anchor); } else { TORCH_INTERNAL_ASSERT( diff --git a/csrc/type.cpp b/csrc/type.cpp index 2cc29e24567..1dea623802b 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -35,23 +35,25 @@ KernelIndexMode indexTypeToMode(DataType index_type) { bool isWiderType(const DataType& ref, const DataType& type) { if (ref == type) { return true; - } else if (ref == DataType::Bool) { + } + if (ref == DataType::Bool) { return true; - } else if ( - (type == DataType::Double || type == DataType::ComplexDouble) && + } + if ((type == DataType::Double || type == DataType::ComplexDouble) && (ref == DataType::Float || ref == DataType::Half || ref == DataType::BFloat16)) { return true; - } else if ( - (type == DataType::Float || type == DataType::ComplexFloat) && + } + if ((type == DataType::Float || type == DataType::ComplexFloat) && (ref == DataType::Half || ref == DataType::BFloat16)) { return true; - } else if ( - (type == DataType::Int || type == DataType::Double || + } + if ((type == DataType::Int || type == DataType::Double || type == DataType::ComplexDouble) && ref == DataType::Int32) { return true; - } else if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { + } + if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { return true; } return false; From 0f7ddcc0da5d828ba2df8d2262238e06d3e9dadd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 26 May 2023 23:34:35 -0700 Subject: [PATCH 54/81] clangformat --- csrc/optimization/consecutive_cast_pass.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 327e65e78d5..18fa5afb69c 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -122,8 +122,7 @@ void castOptimizationPass(Fusion* fusion) { // incompatible precision // in either case, we can't fold away lo_anchor, we'll just re-wire // the input to expr to lo_anchor - nvfuser::ir_utils::replaceValInExpr( - expr, expr->input(0), lo_anchor); + nvfuser::ir_utils::replaceValInExpr(expr, expr->input(0), lo_anchor); } else if (info == -1) { // if expr has lower precision than lo_anchor, we'll just fold away to // the starting_anchor instead From 5966632730cac35f02ead6cb0fbf195448b7eebd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 11:14:02 -0700 Subject: [PATCH 55/81] review comment --- csrc/type.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/type.cpp b/csrc/type.cpp index 1dea623802b..73f5e6b0c75 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -40,12 +40,13 @@ bool isWiderType(const DataType& ref, const DataType& type) { return true; } if ((type == DataType::Double || type == DataType::ComplexDouble) && - (ref == DataType::Float || ref == DataType::Half || - ref == DataType::BFloat16)) { + (ref == DataType::Double || ref == DataType::Float || + ref == DataType::Half || ref == DataType::BFloat16)) { return true; } if ((type == DataType::Float || type == DataType::ComplexFloat) && - (ref == DataType::Half || ref == DataType::BFloat16)) { + (ref == DataType::Float || ref == DataType::Half || + ref == DataType::BFloat16)) { return true; } if ((type == DataType::Int || type == DataType::Double || From e95f8985b2205df72a67d2f424dab0a369a9635b Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 15:44:27 -0700 Subject: [PATCH 56/81] addressing review comments --- csrc/optimization/consecutive_cast_pass.cpp | 169 ++++++++++---------- csrc/type.cpp | 26 +-- csrc/type.h | 2 +- test/test_gpu3.cpp | 5 - 4 files changed, 101 insertions(+), 101 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 18fa5afb69c..3ebe5b13d93 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -12,11 +12,24 @@ namespace nvfuser::optimization { namespace { +// same dtype category for folding is only considered on for integral/floating point and complex dtypes bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { - if ((isIntegralType(input_t) && isIntegralType(output_t)) || + return (isIntegralType(input_t) && isIntegralType(output_t)) || (isFloatingPointType(input_t) && isFloatingPointType(output_t)) || - (isComplexType(input_t) && isComplexType(output_t))) { - return true; + (isComplexType(input_t) && isComplexType(output_t)); +} + +// We consider cast operations are foldable when it's casting within selected dtype categories. +// NOTE: this might not be necessary, but it keeps the logic simpler when we try to optimize a chain of cast ops. +bool isFoldableCast(Expr* expr) { + if (expr != nullptr && expr->isA()) { + auto op = expr->as(); + if (op->getUnaryOpType() == UnaryOpType::Cast && + isSameDtypeCategory( + expr->input(0)->getDataType().value(), + expr->output(0)->getDataType().value())) { + return true; + } } return false; } @@ -45,95 +58,87 @@ int checkInformationLoss(Val* v0, Val* v1) { } void castOptimizationPass(Fusion* fusion) { - auto is_foldable_cast_op = [](Expr* expr) { - if (expr != nullptr && expr->isA()) { - auto op = expr->as(); - if (op->getUnaryOpType() == UnaryOpType::Cast && - isSameDtypeCategory( - expr->input(0)->getDataType().value(), - expr->output(0)->getDataType().value())) { - return true; - } - } - return false; - }; - - // TODO: Traveral implies topological order on returns exprs, we can leverage + // TODO: Traveral implies topological order on returned exprs, we can leverage // that to improve the effieciency of the pass. In the case of a straight line // casts, we are doing a lot of meaningless work here on mutating intermediate // casts that would have been done again at the end of the chain. for (auto expr : fusion->exprs()) { - if (is_foldable_cast_op(expr)) { - std::list chain_casts; - auto prev_expr = expr->input(0)->definition(); - while (prev_expr != nullptr && is_foldable_cast_op(prev_expr)) { - auto intermediate_cast = prev_expr->output(0); - // Note, if the output f prev_expr - // is used by other operation(s); or - // is a direct output from fusion - // we skip the casting chaining - if (intermediate_cast->isFusionOutput() || - intermediate_cast->uses().size() > 1) { - break; - } - - // in the loop, we just repetitively chaining consecutive casts. - chain_casts.push_front(intermediate_cast); - prev_expr = prev_expr->input(0)->definition(); + // skip current expr if it's not a foldable cast + if (!isFoldableCast(expr)) { + continue; + } + std::list chain_cast_tvs; + auto prev_expr = expr->input(0)->definition(); + while (prev_expr != nullptr && isFoldableCast(prev_expr)) { + auto intermediate_cast = prev_expr->output(0); + // Note, if the output of prev_expr + // is used by other operation(s); or + // is a direct output from fusion + // we skip the casting chaining + if (intermediate_cast->isFusionOutput() || + intermediate_cast->uses().size() > 1) { + break; } - // Note, chain_casts has a straight-line use without branches - if (!chain_casts.empty()) { - auto lo_anchor = chain_casts.front()->definition()->input(0); - auto starting_anchor = lo_anchor; - for (auto val : chain_casts) { - auto info = checkInformationLoss(lo_anchor, val); - // if information on new val drops below the anchor, we want to update - // the anchor - if (info <= 0) { - // we run into a complex case where we are casting between two types - // that can't be folded away. i.e. bf16 & fp16. We need to update - // the starting_anchor for the final fold to be past this current - // cast. - if (info == 0) { - auto tmp_expr = val->definition(); - if (lo_anchor != tmp_expr->input(0)) { - nvfuser::ir_utils::replaceValInExpr( - tmp_expr, tmp_expr->input(0), lo_anchor); - } - // move starting_anchor past the ambiguous case - starting_anchor = val; - } - // updating lo_anchor - lo_anchor = val; - } - } + // in the loop, we just repetitively chaining consecutive casts. + chain_cast_tvs.push_front(intermediate_cast); + prev_expr = prev_expr->input(0)->definition(); + } + + // skip current expr if there's no chain_cast_tvs + if (chain_cast_tvs.empty()) { + continue; + } - auto info = checkInformationLoss(lo_anchor, expr->output(0)); - if (info == 1) { - // replacing output with lo_anchor in the fusion - ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); - if (expr->output(0)->isFusionOutput()) { - fusion->replaceOutput(expr->output(0), lo_anchor); + // Note, chain_cast_tvs has a straight-line use without branches + auto lo_anchor = chain_cast_tvs.front()->definition()->input(0); + auto starting_anchor = lo_anchor; + for (auto val : chain_cast_tvs) { + auto info = checkInformationLoss(lo_anchor, val); + // if information on new val drops below the anchor, we want to update + // the anchor + if (info <= 0) { + // we run into a complex case where we are casting between two types + // that can't be folded away. i.e. bf16 & fp16. We need to update + // the starting_anchor for the final fold to be past this current + // cast. + if (info == 0) { + auto tmp_expr = val->definition(); + if (lo_anchor != tmp_expr->input(0)) { + nvfuser::ir_utils::replaceValInExpr( + tmp_expr, tmp_expr->input(0), lo_anchor); } - } else if (info == 2 || info == 0) { - // expr output has either: - // higher precision than lo_anchor; or - // incompatible precision - // in either case, we can't fold away lo_anchor, we'll just re-wire - // the input to expr to lo_anchor - nvfuser::ir_utils::replaceValInExpr(expr, expr->input(0), lo_anchor); - } else if (info == -1) { - // if expr has lower precision than lo_anchor, we'll just fold away to - // the starting_anchor instead - nvfuser::ir_utils::replaceValInExpr( - expr, expr->input(0), starting_anchor); - } else { - TORCH_INTERNAL_ASSERT( - false, - "checkInformationLoss returns a flag that's not recognized"); + // move starting_anchor past the ambiguous case + starting_anchor = val; } + // updating lo_anchor + lo_anchor = val; + } + } + + auto info = checkInformationLoss(lo_anchor, expr->output(0)); + if (info == 1) { + // replacing output with lo_anchor in the fusion + ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); + if (expr->output(0)->isFusionOutput()) { + fusion->replaceOutput(expr->output(0), lo_anchor); } + } else if (info == 2 || info == 0) { + // expr output has either: + // higher precision than lo_anchor; or + // incompatible precision + // in either case, we can't fold away lo_anchor, we'll just re-wire + // the input to expr to lo_anchor + nvfuser::ir_utils::replaceValInExpr(expr, expr->input(0), lo_anchor); + } else if (info == -1) { + // if expr has lower precision than lo_anchor, we'll just fold away to + // the starting_anchor instead + nvfuser::ir_utils::replaceValInExpr( + expr, expr->input(0), starting_anchor); + } else { + TORCH_INTERNAL_ASSERT( + false, + "checkInformationLoss returns a flag that's not recognized"); } } } diff --git a/csrc/type.cpp b/csrc/type.cpp index 691243e85c7..619f743cd96 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -32,29 +32,29 @@ KernelIndexMode indexTypeToMode(DataType index_type) { : KernelIndexMode::INT64; } -bool isWiderType(const DataType& ref, const DataType& type) { - if (ref == type) { +bool isWiderType(const DataType& base_type, const DataType& wider_type) { + if (base_type == wider_type) { return true; } - if (ref == DataType::Bool) { + if (base_type == DataType::Bool) { return true; } - if ((type == DataType::Double || type == DataType::ComplexDouble) && - (ref == DataType::Double || ref == DataType::Float || - ref == DataType::Half || ref == DataType::BFloat16)) { + if ((wider_type == DataType::Double || wider_type == DataType::ComplexDouble) && + (base_type == DataType::Double || base_type == DataType::Float || + base_type == DataType::Half || base_type == DataType::BFloat16)) { return true; } - if ((type == DataType::Float || type == DataType::ComplexFloat) && - (ref == DataType::Float || ref == DataType::Half || - ref == DataType::BFloat16)) { + if ((wider_type == DataType::Float || wider_type == DataType::ComplexFloat) && + (base_type == DataType::Float || base_type == DataType::Half || + base_type == DataType::BFloat16)) { return true; } - if ((type == DataType::Int || type == DataType::Double || - type == DataType::ComplexDouble) && - ref == DataType::Int32) { + if ((wider_type == DataType::Int || wider_type == DataType::Double || + wider_type == DataType::ComplexDouble) && + base_type == DataType::Int32) { return true; } - if (type == DataType::ComplexDouble && ref == DataType::ComplexFloat) { + if (wider_type == DataType::ComplexDouble && base_type == DataType::ComplexFloat) { return true; } return false; diff --git a/csrc/type.h b/csrc/type.h index 649cd89895e..9ae3b172ac8 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -143,7 +143,7 @@ KernelIndexMode indexTypeToMode(DataType index_type); // check if type is a wider type than ref // Which indicates a cast from ref -> type -> ref should be bit-wise identical -bool isWiderType(const DataType& ref, const DataType& type); +bool isWiderType(const DataType& base_type, const DataType& wider_type); // Returns if the datatype is a floating point type TORCH_CUDA_CU_API inline bool isFloatingPointType(DataType dtype) { diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 1ab320c71eb..61ac8df3fd0 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -6036,9 +6036,6 @@ TEST_F(NVFuserTest, FusionBroadcastPersistentReduction_CUDA) { // Repro for // https://github.com/csarofeen/pytorch/issues/2094 TEST_F(NVFuserTest, FusionRepro2094_CUDA) { - // disable cast optimization in pre segmenter, which causes numerical issue on - // tests - optimization::OptimizationGroupGuard guard(false); std::unique_ptr fusion_ptr = std::make_unique(); auto fusion = fusion_ptr.get(); FusionGuard fg(fusion); @@ -8580,8 +8577,6 @@ TEST_F(NVFuserTest, FusionLayerNormFusedOpsRedundantCast_CUDA) { auto tv11 = castOp(DataType::Float, tv9); auto tv12 = castOp(DataType::Float, tv10); auto tv13 = add(tv11, tv12); - // The this pair of cast just cancels each other out, we'll simply rewire it - // to be use tv13 in places of tv15 in the follow up auto tv14 = castOp(DataType::Half, tv13); auto tv15 = castOp(DataType::Float, tv14); auto tv16 = variance(tv15, {1}, false, false); From 0cf81a10cd358fbe4ea7ae0527a20c86d893296f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 15:45:16 -0700 Subject: [PATCH 57/81] clangformat --- csrc/optimization/consecutive_cast_pass.cpp | 11 ++++++----- csrc/type.cpp | 6 ++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 3ebe5b13d93..406a124ee5d 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -12,15 +12,17 @@ namespace nvfuser::optimization { namespace { -// same dtype category for folding is only considered on for integral/floating point and complex dtypes +// same dtype category for folding is only considered on for integral/floating +// point and complex dtypes bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { return (isIntegralType(input_t) && isIntegralType(output_t)) || (isFloatingPointType(input_t) && isFloatingPointType(output_t)) || (isComplexType(input_t) && isComplexType(output_t)); } -// We consider cast operations are foldable when it's casting within selected dtype categories. -// NOTE: this might not be necessary, but it keeps the logic simpler when we try to optimize a chain of cast ops. +// We consider cast operations are foldable when it's casting within selected +// dtype categories. NOTE: this might not be necessary, but it keeps the logic +// simpler when we try to optimize a chain of cast ops. bool isFoldableCast(Expr* expr) { if (expr != nullptr && expr->isA()) { auto op = expr->as(); @@ -137,8 +139,7 @@ void castOptimizationPass(Fusion* fusion) { expr, expr->input(0), starting_anchor); } else { TORCH_INTERNAL_ASSERT( - false, - "checkInformationLoss returns a flag that's not recognized"); + false, "checkInformationLoss returns a flag that's not recognized"); } } } diff --git a/csrc/type.cpp b/csrc/type.cpp index 619f743cd96..c4a2aa2cb4a 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -39,7 +39,8 @@ bool isWiderType(const DataType& base_type, const DataType& wider_type) { if (base_type == DataType::Bool) { return true; } - if ((wider_type == DataType::Double || wider_type == DataType::ComplexDouble) && + if ((wider_type == DataType::Double || + wider_type == DataType::ComplexDouble) && (base_type == DataType::Double || base_type == DataType::Float || base_type == DataType::Half || base_type == DataType::BFloat16)) { return true; @@ -54,7 +55,8 @@ bool isWiderType(const DataType& base_type, const DataType& wider_type) { base_type == DataType::Int32) { return true; } - if (wider_type == DataType::ComplexDouble && base_type == DataType::ComplexFloat) { + if (wider_type == DataType::ComplexDouble && + base_type == DataType::ComplexFloat) { return true; } return false; From 9b175edb26f26b1336b7bc659819c350dfec112e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 17:13:54 -0700 Subject: [PATCH 58/81] refactoring optimization passes --- csrc/optimization/consecutive_cast_pass.cpp | 100 +++++++++----------- 1 file changed, 47 insertions(+), 53 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 406a124ee5d..d9d0dc05d9b 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -12,26 +12,10 @@ namespace nvfuser::optimization { namespace { -// same dtype category for folding is only considered on for integral/floating -// point and complex dtypes -bool isSameDtypeCategory(const DataType& input_t, const DataType& output_t) { - return (isIntegralType(input_t) && isIntegralType(output_t)) || - (isFloatingPointType(input_t) && isFloatingPointType(output_t)) || - (isComplexType(input_t) && isComplexType(output_t)); -} - -// We consider cast operations are foldable when it's casting within selected -// dtype categories. NOTE: this might not be necessary, but it keeps the logic -// simpler when we try to optimize a chain of cast ops. -bool isFoldableCast(Expr* expr) { +bool isCast(Expr* expr) { if (expr != nullptr && expr->isA()) { auto op = expr->as(); - if (op->getUnaryOpType() == UnaryOpType::Cast && - isSameDtypeCategory( - expr->input(0)->getDataType().value(), - expr->output(0)->getDataType().value())) { - return true; - } + return op->getUnaryOpType() == UnaryOpType::Cast; } return false; } @@ -59,6 +43,7 @@ int checkInformationLoss(Val* v0, Val* v1) { return -1; } +// castOptimizationPass void castOptimizationPass(Fusion* fusion) { // TODO: Traveral implies topological order on returned exprs, we can leverage // that to improve the effieciency of the pass. In the case of a straight line @@ -66,12 +51,12 @@ void castOptimizationPass(Fusion* fusion) { // casts that would have been done again at the end of the chain. for (auto expr : fusion->exprs()) { // skip current expr if it's not a foldable cast - if (!isFoldableCast(expr)) { + if (!isCast(expr)) { continue; } std::list chain_cast_tvs; auto prev_expr = expr->input(0)->definition(); - while (prev_expr != nullptr && isFoldableCast(prev_expr)) { + while (prev_expr != nullptr && isCast(prev_expr)) { auto intermediate_cast = prev_expr->output(0); // Note, if the output of prev_expr // is used by other operation(s); or @@ -94,52 +79,61 @@ void castOptimizationPass(Fusion* fusion) { // Note, chain_cast_tvs has a straight-line use without branches auto lo_anchor = chain_cast_tvs.front()->definition()->input(0); + auto anchor_dtype = lo_anchor->getDataType().value(); auto starting_anchor = lo_anchor; for (auto val : chain_cast_tvs) { - auto info = checkInformationLoss(lo_anchor, val); - // if information on new val drops below the anchor, we want to update - // the anchor - if (info <= 0) { - // we run into a complex case where we are casting between two types - // that can't be folded away. i.e. bf16 & fp16. We need to update - // the starting_anchor for the final fold to be past this current - // cast. - if (info == 0) { - auto tmp_expr = val->definition(); - if (lo_anchor != tmp_expr->input(0)) { - nvfuser::ir_utils::replaceValInExpr( - tmp_expr, tmp_expr->input(0), lo_anchor); - } - // move starting_anchor past the ambiguous case - starting_anchor = val; + auto val_dtype = val->getDataType().value(); + + // short-cut when we are not losing precision in the cast, either: + // 1. casting to the same type as the previously seen lowest precision; + // or + // 2. casting to a wider type. + if (val_dtype == anchor_dtype || isWiderType(anchor_dtype, val_dtype)) { + continue; + } + + // NOTE: To enter here, we have + // !isWiderType(anchor_dtype, val_dtype) && isWiderType(val_dtype, + // anchor_dtype) + // + // Which means the dtype between lo_anchor and val isn't compatible and + // can't be fold away without losing information. So we update the + // starting_anchor to current val, which ensures that we preserve the + // incompatible casts. e.g. for cases where no one type is strictly wider + // than the other: i.e. bf16 & fp16, int32 & float32 e.t.c. + if (!isWiderType(val_dtype, anchor_dtype)) { + auto tmp_expr = val->definition(); + // we replace the input to current expr with lo_anchor when it's not. + if (lo_anchor != tmp_expr->input(0)) { + nvfuser::ir_utils::replaceValInExpr( + tmp_expr, tmp_expr->input(0), lo_anchor); } - // updating lo_anchor - lo_anchor = val; + // We need to update the starting_anchor for the fold to be past this + // current cast. + starting_anchor = val; } + // updating new lo_anchor to current val + lo_anchor = val; + anchor_dtype = lo_anchor->getDataType().value(); } - auto info = checkInformationLoss(lo_anchor, expr->output(0)); - if (info == 1) { - // replacing output with lo_anchor in the fusion + auto output_dtype = expr->output(0)->getDataType().value(); + if (anchor_dtype == output_dtype) { + // final cast is the same dtype as with previous lo_anchor, replacing + // output with lo_anchor in the fusion ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); } - } else if (info == 2 || info == 0) { - // expr output has either: - // higher precision than lo_anchor; or - // incompatible precision - // in either case, we can't fold away lo_anchor, we'll just re-wire - // the input to expr to lo_anchor - nvfuser::ir_utils::replaceValInExpr(expr, expr->input(0), lo_anchor); - } else if (info == -1) { - // if expr has lower precision than lo_anchor, we'll just fold away to - // the starting_anchor instead + } else if (isWiderType(output_dtype, lo_anchor)) { + // if lo_anchor is wider than output_dtype, casting to lo_anchor isn't + // doing anything, we'll just fold away to the starting_anchor instead nvfuser::ir_utils::replaceValInExpr( expr, expr->input(0), starting_anchor); } else { - TORCH_INTERNAL_ASSERT( - false, "checkInformationLoss returns a flag that's not recognized"); + // This is the case where we cannot fold away the cast of lo_anchor; we'll + // just re-wire input to expr with lo_anchor + nvfuser::ir_utils::replaceValInExpr(expr, expr->input(0), lo_anchor); } } } From 7c8ea520ffa5bd0cc804f5c35b0b26964bf9c9f2 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 17:18:14 -0700 Subject: [PATCH 59/81] fixing typoe --- csrc/optimization/consecutive_cast_pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index d9d0dc05d9b..7e293aa0515 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -125,7 +125,7 @@ void castOptimizationPass(Fusion* fusion) { if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); } - } else if (isWiderType(output_dtype, lo_anchor)) { + } else if (isWiderType(output_dtype, anchor_dtype)) { // if lo_anchor is wider than output_dtype, casting to lo_anchor isn't // doing anything, we'll just fold away to the starting_anchor instead nvfuser::ir_utils::replaceValInExpr( From 784660049a37c34829d930885470ab0a55b016b8 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 18:10:50 -0700 Subject: [PATCH 60/81] comment added --- csrc/optimization/consecutive_cast_pass.cpp | 63 ++++++++++++++++++++- csrc/optimization/consecutive_cast_pass.h | 3 +- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 7e293aa0515..d3599f09673 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -43,7 +43,68 @@ int checkInformationLoss(Val* v0, Val* v1) { return -1; } -// castOptimizationPass +// castOptimizationPass folds away consecutive cast operations. e.g. a chain of +// cast like fp16 -> fp32 -> fp16 can be simplified away without impacting the +// output of the fusion. For a section of consecutive casts, chained as: +// first_tv -> cast_tv0 -> cast_tv1 -> cast_tv2 -> output_tv +// Supposed we have a TensorView `lo_anchor` in the chain, that the dtype of +// every other tv is either the same as or wider than +// `lo_anchor->getDataType()`, we can then assume that all the other cast ops in +// the chain is a no-op (except the last one, which defines the output dtype). +// so the above chain can be re-wired with only two casts as: +// first_tv -> lo_anchor -> output_tv +// +// A complexity that could happen is, we might not necessarily have a narrowest +// dtype in the chain. i.e. think about pairs like fp16/bfloat16, or fp32/int32, +// where one can't represent the other. In order to handle this scenario, we can +// just keep track of a `starting_anchor` that indicates the starting point of +// the section, that has a valid `lo_anchor`. When we encounter a new cast op +// that breaks the assumption, we'll optimize what we have seen in the existing +// section and start a new section with the next cast op. +// +// The algorithm: +// 1. iterating through all expr in the fusion: +// 1.1 we skip all exprs other than cast; +// 1.2 for each end cast-op 'expr', we trace back its producers iteratively +// and push the value(s) on top of `chain_cast_tvs`, until: +// +// a. the producer is not a cast op; or +// +// b. the producer is used by other ops, or is a fusion output. +// +// 1.3 at this point, each `chain_cast_tvs` has an ordered cast outputs with +// a straight line dependency: +// 1.3.1 we point starting_anchor at the beginning op, indicating the +// starting point of our folding optimization, meanwhile, we point +// lo_anchor at the first op, indicating the narrowest dtype we have seen +// in the segment; +// 1.3.2 we enter the loop to iterate through items +// inside `chain_cast_tvs`, for item `val`: +// +// a. if `val_dtype` is the same as, or wider than `anchor_dtype` +// of `lo_anchor`, current cast is a no-op and can be ignored; +// +// b. if `anchor_dtype` is narrower than `val_dtype`, previous cast +// to `lo_anchor` is a no-op and can be folded away. We update +// `lo_anchor` to point to `val`; +// +// c. otherwise, `val` and `lo_anchor` are incompatible casts and +// both needs to be preserved. We'll rewire it as: +// `starting_anchor`->`lo_anchor`->`val`. Afterwards, we'll update +// `starting_anchor` and `lo_anchor` to both point at `val`. +// +// 1.4 At this point we look at `anchor_dtype` of `lo_anchor` and +// `output_dtype` of `expr->output(0)`: +// +// a. if `anchor_dtype` is the same as `output_dtype`, we skip the last +// cast op and replace all its uses with `lo_anchor`; +// +// b. if `anchor_dtype` is wider than `output_dtype`, all previous cast +// after `starting_anchor` is no-op, we re-wire `starting_anchor` +// directly to `expr`; +// +// c. otherwise, we can't bypass `lo_anchor` cast, we rewire this +// section as `starting_anchor`->`lo_anchor`->`expr->output(0)` void castOptimizationPass(Fusion* fusion) { // TODO: Traveral implies topological order on returned exprs, we can leverage // that to improve the effieciency of the pass. In the case of a straight line diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast_pass.h index 9aaccc6d6cf..9addb694a22 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -9,7 +9,8 @@ namespace nvfuser::optimization { -//! ConsecutiveCastPass removes redundant consecutive cast operations +//! ConsecutiveCastPass removes redundant consecutive cast operations that +//! doesn't have any impact on output from fusion. class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { public: void run(Fusion* fusion) override; From 1216b38d6e8c463fb3825bdd23200d759043b2fb Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 18:38:34 -0700 Subject: [PATCH 61/81] added more documentation --- csrc/optimization/consecutive_cast_pass.cpp | 69 +++++++++------------ 1 file changed, 29 insertions(+), 40 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index d3599f09673..223532fa559 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -20,27 +20,17 @@ bool isCast(Expr* expr) { return false; } -// note: returns -// - -1 : v0 contains strictly more information than v1; -// - 0 : a complex case, where each v0 and v1 isn't a super set of the other; -// - 1 : v0 and v1 has the same dtype; -// - 2 : v0 contains strictly less information than v1; -int checkInformationLoss(Val* v0, Val* v1) { - auto dtype0 = v0->getDataType().value(); - auto dtype1 = v1->getDataType().value(); - if (dtype0 == dtype1) { - return 1; +// replaces input to the cast op that produes cast_output, return the new +// cast_output +Val* replaceInputInCast(Val* cast_output, Val* new_input) { + auto tmp_expr = cast_output->definition(); + // short-cut for cases when no substitution is needed; + if (new_input == tmp_expr->input(0)) { + return cast_output; } - if ((dtype0 == DataType::BFloat16 && dtype1 == DataType::Half) || - (dtype1 == DataType::BFloat16 && dtype0 == DataType::Half)) { - return 0; - } - if (isWiderType(dtype0, dtype1)) { - return 2; - } - TORCH_INTERNAL_ASSERT( - isWiderType(dtype1, dtype0), "unrecognized cast category is encountered"); - return -1; + auto new_expr = nvfuser::ir_utils::replaceValInExpr( + tmp_expr, tmp_expr->input(0), new_input); + return new_expr->output(0); } // castOptimizationPass folds away consecutive cast operations. e.g. a chain of @@ -110,6 +100,8 @@ void castOptimizationPass(Fusion* fusion) { // that to improve the effieciency of the pass. In the case of a straight line // casts, we are doing a lot of meaningless work here on mutating intermediate // casts that would have been done again at the end of the chain. + // We should really use the reverse topological order and filters out exprs + // that has been rendered as dead code during the pass. for (auto expr : fusion->exprs()) { // skip current expr if it's not a foldable cast if (!isCast(expr)) { @@ -119,7 +111,7 @@ void castOptimizationPass(Fusion* fusion) { auto prev_expr = expr->input(0)->definition(); while (prev_expr != nullptr && isCast(prev_expr)) { auto intermediate_cast = prev_expr->output(0); - // Note, if the output of prev_expr + // 1.2 Note, if the output of prev_expr // is used by other operation(s); or // is a direct output from fusion // we skip the casting chaining @@ -138,14 +130,14 @@ void castOptimizationPass(Fusion* fusion) { continue; } - // Note, chain_cast_tvs has a straight-line use without branches + // 1.3.1 Note, chain_cast_tvs has a straight-line use without branches auto lo_anchor = chain_cast_tvs.front()->definition()->input(0); auto anchor_dtype = lo_anchor->getDataType().value(); auto starting_anchor = lo_anchor; for (auto val : chain_cast_tvs) { auto val_dtype = val->getDataType().value(); - // short-cut when we are not losing precision in the cast, either: + // 1.3.2.a short-cut when we are not losing precision, either: // 1. casting to the same type as the previously seen lowest precision; // or // 2. casting to a wider type. @@ -153,7 +145,7 @@ void castOptimizationPass(Fusion* fusion) { continue; } - // NOTE: To enter here, we have + // 1.3.2.c NOTE: To enter here, we have // !isWiderType(anchor_dtype, val_dtype) && isWiderType(val_dtype, // anchor_dtype) // @@ -163,38 +155,35 @@ void castOptimizationPass(Fusion* fusion) { // incompatible casts. e.g. for cases where no one type is strictly wider // than the other: i.e. bf16 & fp16, int32 & float32 e.t.c. if (!isWiderType(val_dtype, anchor_dtype)) { - auto tmp_expr = val->definition(); - // we replace the input to current expr with lo_anchor when it's not. - if (lo_anchor != tmp_expr->input(0)) { - nvfuser::ir_utils::replaceValInExpr( - tmp_expr, tmp_expr->input(0), lo_anchor); - } + lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); + val = replaceInputInCast(val, lo_anchor); // We need to update the starting_anchor for the fold to be past this // current cast. starting_anchor = val; } - // updating new lo_anchor to current val + // 1.3.2.b/c updating new lo_anchor to current val lo_anchor = val; anchor_dtype = lo_anchor->getDataType().value(); } auto output_dtype = expr->output(0)->getDataType().value(); if (anchor_dtype == output_dtype) { - // final cast is the same dtype as with previous lo_anchor, replacing - // output with lo_anchor in the fusion + // 1.4.a final cast is the same dtype as with previous lo_anchor, + // replacing output with lo_anchor in the fusion ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); } } else if (isWiderType(output_dtype, anchor_dtype)) { - // if lo_anchor is wider than output_dtype, casting to lo_anchor isn't - // doing anything, we'll just fold away to the starting_anchor instead - nvfuser::ir_utils::replaceValInExpr( - expr, expr->input(0), starting_anchor); + // 1.4.b: if lo_anchor is wider than output_dtype, casting to lo_anchor + // isn't doing anything, we'll just fold away to the starting_anchor + // instead + replaceInputInCast(expr->input(0), starting_anchor); } else { - // This is the case where we cannot fold away the cast of lo_anchor; we'll - // just re-wire input to expr with lo_anchor - nvfuser::ir_utils::replaceValInExpr(expr, expr->input(0), lo_anchor); + // 1.4.c: This is the case where we cannot fold away the cast of + // lo_anchor; we'll just re-wire input to expr with lo_anchor + lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); + replaceInputInCast(expr->input(0), lo_anchor); } } } From 10520e94d02b8fda5111bdaae604ca8f231c95e3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 30 May 2023 18:47:38 -0700 Subject: [PATCH 62/81] added test case with mixed dtype categories --- test/test_gpu3.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 61ac8df3fd0..a44cd60612d 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8483,6 +8483,31 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { ref_tv = castOp(DataType::Float, ref_tv); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Int32) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::ComplexDouble, tv); + tv = castOp(DataType::Int, tv); + tv = castOp(DataType::BFloat16, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + // (input)float -> double -> half -> float -> bfloat16 -> float + fusion->addOutput(tv); + optimization::OptimizationGroup::runPass( + fusion.get()); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input)float -> half -> bfloat16 -> float + auto ref_tv = castOp(DataType::BFloat16, tv0); + ref_tv = castOp(DataType::Double, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } } TEST_F(NVFuserTest, FusionTestWarnRegisterSpill_CUDA) { From 90fabd466cd7a79601351675000179b4bbc63d85 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 08:20:06 -0700 Subject: [PATCH 63/81] merging OptimizationGroup to OptimizationPass --- csrc/kernel_cache.cpp | 2 +- csrc/optimization/consecutive_cast_pass.cpp | 4 -- csrc/optimization/consecutive_cast_pass.h | 5 +-- csrc/optimization/optimization_pass.h | 44 ++++++--------------- csrc/optimization/pre_segmenter.cpp | 3 +- csrc/optimization/pre_segmenter.h | 2 +- test/test_gpu3.cpp | 10 ++--- 7 files changed, 23 insertions(+), 47 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 7780e8b7c6c..0d20eac996e 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -666,7 +666,7 @@ FusionKernelRuntime::FusionKernelRuntime( !fusion->hasDynamicTransform(), "Fusion must be concretized before constructing FusionKernelRuntime"); - optimization::OptimizationGroup::runPass( + optimization::OptimizationPass::runPass( fusion.get()); all_tvs_ = ir_utils::allTvs(fusion.get()); diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 223532fa559..59d4589b608 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -194,8 +194,4 @@ void ConsecutiveCastPass::run(Fusion* fusion) { castOptimizationPass(fusion); } -std::string ConsecutiveCastPass::name() { - return "ConsecutiveCastOptimization"; -} - } // namespace nvfuser::optimization diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast_pass.h index 9addb694a22..e92e7b7d35e 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -11,10 +11,9 @@ namespace nvfuser::optimization { //! ConsecutiveCastPass removes redundant consecutive cast operations that //! doesn't have any impact on output from fusion. -class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { +class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { public: - void run(Fusion* fusion) override; - std::string name() override; + static void run(Fusion* fusion); }; } // namespace nvfuser::optimization diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index a4a196523b4..ba16539ca18 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -15,55 +15,37 @@ namespace nvfuser::optimization { using FusionPass = std::function; -//! [experimental API] -//! Base class to unify optimization pass APIs. -//! OptimizationPass is functional and defines the granularity of mutation -//! passes that is used to compose OptimizationGroups -class TORCH_CUDA_CU_API OptimizationPass { - public: - virtual void run(Fusion*) = 0; - virtual std::string name() = 0; - virtual ~OptimizationPass() = default; -}; - //! [experimental API] //! Base class to unify optimization group APIs. -//! OptimizationGroup composes optimization passes that is used at certain stage -//! in the runtime system. OptimizationGroup can be turned on/off +//! OptimizationPass composes optimization passes that is used at certain stage +//! in the runtime system. OptimizationPass can be turned on/off //! programmatically with the `setEnabled/flipEnabled` API. There's helper //! template OptimizationGroupGuard to temporarily switch the enablement within //! the context. Note the we are using a curiously recurring template pattern //! here to ensure that static objects are unique for each DerivedClass. In -//! order to apply OptimizationGroup with the switch enabled, you need to run -//! the function with `OptimizationGroup::runPass(...)` +//! order to apply OptimizationPass with the switch enabled, you need to run +//! the function with `OptimizationPass::runPass(...)` template -class TORCH_CUDA_CU_API OptimizationGroup { +class TORCH_CUDA_CU_API OptimizationPass { public: - static bool flipEnabled(bool flip) { - static std::mutex mutex_; - static bool enable_flag_ = true; - - std::lock_guard guard(mutex_); - enable_flag_ = enable_flag_ ^ flip; - return enable_flag_ ^ flip; + static void setEnabled(bool enabled) { + flag_.store(enabled); } - static bool setEnabled(bool enabled) { - auto tmp = flipEnabled(false); - if (enabled != tmp) { - flipEnabled(true); - } - return tmp; + static bool getEnabled() { + return flag_.load(); } static void runPass(Fusion* fusion) { - if (!flipEnabled(false)) { + if (!flag_.load()) { return; } DerivedClass::runPass(fusion); } - virtual ~OptimizationGroup() = default; + virtual ~OptimizationPass() = default; + protected: + static inline std::atomic flag_{true}; }; //! [experimental API] diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index c3fe4f8b615..334f2a10f03 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -12,8 +12,7 @@ namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { // removes consecutive cast operations - ConsecutiveCastPass consecutive_cast_pass; - consecutive_cast_pass.run(fusion); + OptimizationPass::run(fusion); } } // namespace nvfuser::optimization diff --git a/csrc/optimization/pre_segmenter.h b/csrc/optimization/pre_segmenter.h index 97dca9eff27..f1d55e27059 100644 --- a/csrc/optimization/pre_segmenter.h +++ b/csrc/optimization/pre_segmenter.h @@ -13,7 +13,7 @@ namespace nvfuser::optimization { //! PreSegmenter is an optimization group that runs right before fusion executor //! segments a fusion into multiple kernels. -class TORCH_CUDA_CU_API PreSegmenter : public OptimizationGroup { +class TORCH_CUDA_CU_API PreSegmenter : public OptimizationPass { public: static void runPass(Fusion* fusion); }; diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index a44cd60612d..94b8617ef49 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8403,7 +8403,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Double, tv); // (input)double -> float -> half -> float -> double fusion->addOutput(tv); - optimization::OptimizationGroup::runPass( + optimization::OptimizationPass::runPass( fusion.get()); // simplified as (input)double -> half -> double auto ref_tv = castOp(DataType::Half, tv0); @@ -8423,7 +8423,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Float, tv); // (input)float -> double -> float fusion->addOutput(tv); - optimization::OptimizationGroup::runPass( + optimization::OptimizationPass::runPass( fusion.get()); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input) @@ -8450,7 +8450,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { // (input)float -> double -> half -> float -> double -> float -> double -> // float -> double -> float fusion->addOutput(tv); - optimization::OptimizationGroup::runPass( + optimization::OptimizationPass::runPass( fusion.get()); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> float @@ -8474,7 +8474,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Float, tv); // (input)float -> double -> half -> float -> bfloat16 -> float fusion->addOutput(tv); - optimization::OptimizationGroup::runPass( + optimization::OptimizationPass::runPass( fusion.get()); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> bfloat16 -> float @@ -8500,7 +8500,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::Double, tv); // (input)float -> double -> half -> float -> bfloat16 -> float fusion->addOutput(tv); - optimization::OptimizationGroup::runPass( + optimization::OptimizationPass::runPass( fusion.get()); // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> bfloat16 -> float From 1501a767c28c67dfc67652b42125db88ce702477 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 09:16:02 -0700 Subject: [PATCH 64/81] renaming methods --- csrc/optimization/consecutive_cast_pass.cpp | 2 +- csrc/optimization/consecutive_cast_pass.h | 2 +- csrc/optimization/pre_segmenter.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 59d4589b608..81ea399f108 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -190,7 +190,7 @@ void castOptimizationPass(Fusion* fusion) { } // namespace -void ConsecutiveCastPass::run(Fusion* fusion) { +void ConsecutiveCastPass::runPass(Fusion* fusion) { castOptimizationPass(fusion); } diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast_pass.h index e92e7b7d35e..f084d8242dc 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -13,7 +13,7 @@ namespace nvfuser::optimization { //! doesn't have any impact on output from fusion. class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { public: - static void run(Fusion* fusion); + static void runPass(Fusion* fusion); }; } // namespace nvfuser::optimization diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 334f2a10f03..97ae6a27675 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -12,7 +12,7 @@ namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { // removes consecutive cast operations - OptimizationPass::run(fusion); + OptimizationPass::runPass(fusion); } } // namespace nvfuser::optimization From 96c239837bf5ef615b2d5ac353cda863740f1791 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 11:04:59 -0700 Subject: [PATCH 65/81] fixing re-wiring bug --- csrc/optimization/consecutive_cast_pass.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast_pass.cpp index 81ea399f108..013a2c82fab 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast_pass.cpp @@ -25,7 +25,7 @@ bool isCast(Expr* expr) { Val* replaceInputInCast(Val* cast_output, Val* new_input) { auto tmp_expr = cast_output->definition(); // short-cut for cases when no substitution is needed; - if (new_input == tmp_expr->input(0)) { + if (cast_output == new_input || new_input == tmp_expr->input(0)) { return cast_output; } auto new_expr = nvfuser::ir_utils::replaceValInExpr( @@ -178,12 +178,12 @@ void castOptimizationPass(Fusion* fusion) { // 1.4.b: if lo_anchor is wider than output_dtype, casting to lo_anchor // isn't doing anything, we'll just fold away to the starting_anchor // instead - replaceInputInCast(expr->input(0), starting_anchor); + replaceInputInCast(expr->output(0), starting_anchor); } else { // 1.4.c: This is the case where we cannot fold away the cast of // lo_anchor; we'll just re-wire input to expr with lo_anchor lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); - replaceInputInCast(expr->input(0), lo_anchor); + replaceInputInCast(expr->output(0), lo_anchor); } } } From 80c98bc8845eb67ff764532b62b25f30aed2f4dd Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 11:07:20 -0700 Subject: [PATCH 66/81] clangformat --- csrc/optimization/consecutive_cast_pass.h | 3 ++- csrc/optimization/optimization_pass.h | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast_pass.h index f084d8242dc..f8a745dc13b 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast_pass.h @@ -11,7 +11,8 @@ namespace nvfuser::optimization { //! ConsecutiveCastPass removes redundant consecutive cast operations that //! doesn't have any impact on output from fusion. -class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { +class TORCH_CUDA_CU_API ConsecutiveCastPass + : public OptimizationPass { public: static void runPass(Fusion* fusion); }; diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index ba16539ca18..09906bea982 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -44,6 +44,7 @@ class TORCH_CUDA_CU_API OptimizationPass { } virtual ~OptimizationPass() = default; + protected: static inline std::atomic flag_{true}; }; From 0fdfd612c7c99d7c95da91053b7d1f872252fae1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 11:36:54 -0700 Subject: [PATCH 67/81] file renaming and class renaming --- CMakeLists.txt | 2 +- ...ive_cast_pass.cpp => consecutive_cast.cpp} | 3 +- ...ecutive_cast_pass.h => consecutive_cast.h} | 4 +- csrc/optimization/optimization_pass.h | 40 +++++++++++-------- csrc/optimization/pre_segmenter.cpp | 3 +- csrc/optimization/pre_segmenter.h | 4 +- 6 files changed, 35 insertions(+), 21 deletions(-) rename csrc/optimization/{consecutive_cast_pass.cpp => consecutive_cast.cpp} (99%) rename csrc/optimization/{consecutive_cast_pass.h => consecutive_cast.h} (89%) diff --git a/CMakeLists.txt b/CMakeLists.txt index c1e45698c59..b7a258b0bbb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/utils.cpp ${NVFUSER_SRCS_DIR}/mma_type.cpp ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp - ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast_pass.cpp + ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ) diff --git a/csrc/optimization/consecutive_cast_pass.cpp b/csrc/optimization/consecutive_cast.cpp similarity index 99% rename from csrc/optimization/consecutive_cast_pass.cpp rename to csrc/optimization/consecutive_cast.cpp index 013a2c82fab..dad57983191 100644 --- a/csrc/optimization/consecutive_cast_pass.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -5,8 +5,9 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include + #include -#include namespace nvfuser::optimization { diff --git a/csrc/optimization/consecutive_cast_pass.h b/csrc/optimization/consecutive_cast.h similarity index 89% rename from csrc/optimization/consecutive_cast_pass.h rename to csrc/optimization/consecutive_cast.h index f8a745dc13b..d643c9bc7e2 100644 --- a/csrc/optimization/consecutive_cast_pass.h +++ b/csrc/optimization/consecutive_cast.h @@ -13,7 +13,9 @@ namespace nvfuser::optimization { //! doesn't have any impact on output from fusion. class TORCH_CUDA_CU_API ConsecutiveCastPass : public OptimizationPass { - public: + friend class OptimizationPass; + + protected: static void runPass(Fusion* fusion); }; diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 09906bea982..74232dc546c 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -16,15 +16,23 @@ namespace nvfuser::optimization { using FusionPass = std::function; //! [experimental API] -//! Base class to unify optimization group APIs. -//! OptimizationPass composes optimization passes that is used at certain stage -//! in the runtime system. OptimizationPass can be turned on/off -//! programmatically with the `setEnabled/flipEnabled` API. There's helper -//! template OptimizationGroupGuard to temporarily switch the enablement within -//! the context. Note the we are using a curiously recurring template pattern -//! here to ensure that static objects are unique for each DerivedClass. In -//! order to apply OptimizationPass with the switch enabled, you need to run -//! the function with `OptimizationPass::runPass(...)` +//! Base class to unify optimization pass APIs. +//! OptimizationPass can be turned on/off programmatically with the `setEnabled` +//! API. There's helper template OptimizationPassGuard to temporarily switch the +//! enablement within the context. Note the we are using a curiously recurring +//! template pattern here to ensure that static objects are unique for each +//! DerivedClass. In order to apply OptimizationPass with the switch enabled, +//! you need to run the function with +//! `OptimizationPass::runPass(...)` +//! +//! Specific optimization pass needs to be created like: +//! +//! class TORCH_CUDA_CU_API Pass0 : public OptimizationPass { +//! friend class OptimizationPass; +//! +//! protected: +//! static void runPass(Fusion* fusion); +//! }; template class TORCH_CUDA_CU_API OptimizationPass { public: @@ -50,15 +58,15 @@ class TORCH_CUDA_CU_API OptimizationPass { }; //! [experimental API] -//! OptimizationGroupGuard is used to temporarily switch enable/disable on a +//! OptimizationPassGuard is used to temporarily switch enable/disable on a //! certain pass. Original status will be restored at destruction. -template -class TORCH_CUDA_CU_API OptimizationGroupGuard { +template +class TORCH_CUDA_CU_API OptimizationPassGuard { public: - OptimizationGroupGuard(bool enabled) - : prev_status_(OptGroup::setEnabled(enabled)) {} - ~OptimizationGroupGuard() { - OptGroup::setEnabled(prev_status_); + OptimizationPassGuard(bool enabled) + : prev_status_(OptPass::setEnabled(enabled)) {} + ~OptimizationPassGuard() { + OptPass::setEnabled(prev_status_); } protected: diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 97ae6a27675..baaa69fd7e4 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -5,9 +5,10 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on -#include #include +#include + namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { diff --git a/csrc/optimization/pre_segmenter.h b/csrc/optimization/pre_segmenter.h index f1d55e27059..c7f4f4257a6 100644 --- a/csrc/optimization/pre_segmenter.h +++ b/csrc/optimization/pre_segmenter.h @@ -14,7 +14,9 @@ namespace nvfuser::optimization { //! PreSegmenter is an optimization group that runs right before fusion executor //! segments a fusion into multiple kernels. class TORCH_CUDA_CU_API PreSegmenter : public OptimizationPass { - public: + friend class OptimizationPass; + + protected: static void runPass(Fusion* fusion); }; From 1c3c6ced0095b0b192f3f61e49051f5bfa0f8143 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 11:45:28 -0700 Subject: [PATCH 68/81] moving tests to a separate file --- CMakeLists.txt | 1 + test/test_gpu3.cpp | 127 --------------------------- test/test_optimization_pass.cpp | 148 ++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 127 deletions(-) create mode 100644 test/test_optimization_pass.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b7a258b0bbb..2f97c6518c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -374,6 +374,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/test/test_gpu_multidevice.cpp ${NVFUSER_ROOT}/test/test_multicluster_fusion.cpp ${NVFUSER_ROOT}/test/test_combined_inner_outer_reduction.cpp + ${NVFUSER_ROOT}/test/test_optimization_pass.cpp ) list(APPEND JIT_TEST_CU_SRCS ${NVFUSER_ROOT}/test/test_gpu_rng.cu) diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 94b8617ef49..6571da2c311 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -8383,133 +8383,6 @@ TEST_F(NVFuserTest, FusionTestSegmenterHint_CUDA) { executor_cache.fusion(), outputs, {at_x}, {ref_out}, __LINE__, __FILE__); } -// Test cast optimization -TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { - std::vector input_shape{3, 7, 8}; - auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); - at::Tensor at_x = at::randn(input_shape, options); - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Double) - .build(); - fusion->addInput(tv0); - auto tv = castOp(DataType::Float, tv0); - tv = castOp(DataType::Half, tv); - tv = castOp(DataType::Float, tv); - tv = castOp(DataType::Double, tv); - // (input)double -> float -> half -> float -> double - fusion->addOutput(tv); - optimization::OptimizationPass::runPass( - fusion.get()); - // simplified as (input)double -> half -> double - auto ref_tv = castOp(DataType::Half, tv0); - ref_tv = castOp(DataType::Double, ref_tv); - ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); - } - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Float) - .build(); - fusion->addInput(tv0); - auto tv = castOp(DataType::Double, tv0); - tv = castOp(DataType::Float, tv); - // (input)float -> double -> float - fusion->addOutput(tv); - optimization::OptimizationPass::runPass( - fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! - // simplified as (input) - ASSERT_TRUE(tv0->sameAs(fusion->outputs()[0])); - } - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Float) - .build(); - fusion->addInput(tv0); - auto tv = castOp(DataType::Double, tv0); - tv = castOp(DataType::Half, tv); - tv = castOp(DataType::Float, tv); - tv = castOp(DataType::Double, tv); - tv = castOp(DataType::Float, tv); - tv = castOp(DataType::Double, tv); - tv = castOp(DataType::Float, tv); - tv = castOp(DataType::Double, tv); - tv = castOp(DataType::Float, tv); - // (input)float -> double -> half -> float -> double -> float -> double -> - // float -> double -> float - fusion->addOutput(tv); - optimization::OptimizationPass::runPass( - fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! - // simplified as (input)float -> half -> float - auto ref_tv = castOp(DataType::Half, tv0); - ref_tv = castOp(DataType::Float, ref_tv); - ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); - } - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Float) - .build(); - fusion->addInput(tv0); - auto tv = castOp(DataType::Double, tv0); - tv = castOp(DataType::Half, tv); - tv = castOp(DataType::Float, tv); - tv = castOp(DataType::BFloat16, tv); - tv = castOp(DataType::Float, tv); - // (input)float -> double -> half -> float -> bfloat16 -> float - fusion->addOutput(tv); - optimization::OptimizationPass::runPass( - fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! - // simplified as (input)float -> half -> bfloat16 -> float - auto ref_tv = castOp(DataType::Half, tv0); - ref_tv = castOp(DataType::BFloat16, ref_tv); - ref_tv = castOp(DataType::Float, ref_tv); - ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); - } - - { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - auto tv0 = TensorViewBuilder() - .ndims(input_shape.size()) - .dtype(DataType::Int32) - .build(); - fusion->addInput(tv0); - auto tv = castOp(DataType::Double, tv0); - tv = castOp(DataType::ComplexDouble, tv); - tv = castOp(DataType::Int, tv); - tv = castOp(DataType::BFloat16, tv); - tv = castOp(DataType::Float, tv); - tv = castOp(DataType::Double, tv); - // (input)float -> double -> half -> float -> bfloat16 -> float - fusion->addOutput(tv); - optimization::OptimizationPass::runPass( - fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! - // simplified as (input)float -> half -> bfloat16 -> float - auto ref_tv = castOp(DataType::BFloat16, tv0); - ref_tv = castOp(DataType::Double, ref_tv); - ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); - } -} - TEST_F(NVFuserTest, FusionTestWarnRegisterSpill_CUDA) { const int hidden_size = 1024 * 10; std::unique_ptr fusion_ptr = std::make_unique(); diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp new file mode 100644 index 00000000000..eb28b322cb0 --- /dev/null +++ b/test/test_optimization_pass.cpp @@ -0,0 +1,148 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include + +#include + +#include +#include +#include + +namespace nvfuser::optimization { + +// Test cast optimization +TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { + std::vector input_shape{3, 7, 8}; + auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); + at::Tensor at_x = at::randn(input_shape, options); + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Float, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + // (input)double -> float -> half -> float -> double + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // simplified as (input)double -> half -> double + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::Double, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::Float, tv); + // (input)float -> double -> float + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input) + ASSERT_TRUE(tv0->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + // (input)float -> double -> half -> float -> double -> float -> double -> + // float -> double -> float + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input)float -> half -> float + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::Float, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::BFloat16, tv); + tv = castOp(DataType::Float, tv); + // (input)float -> double -> half -> float -> bfloat16 -> float + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input)float -> half -> bfloat16 -> float + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::BFloat16, ref_tv); + ref_tv = castOp(DataType::Float, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Int32) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + tv = castOp(DataType::ComplexDouble, tv); + tv = castOp(DataType::Int, tv); + tv = castOp(DataType::BFloat16, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + // (input)float -> double -> half -> float -> bfloat16 -> float + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // TODO: should I have copied the tensor to avoid an alised output?! + // simplified as (input)float -> half -> bfloat16 -> float + auto ref_tv = castOp(DataType::BFloat16, tv0); + ref_tv = castOp(DataType::Double, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } +} + +} From 29f6403cec7d71dd0ad687648339bea43d59afe4 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 12:31:42 -0700 Subject: [PATCH 69/81] fixing test; fixing context switch --- csrc/optimization/optimization_pass.h | 7 +++-- test/test_optimization_pass.cpp | 37 ++++++++++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 74232dc546c..359977b9ca5 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -63,8 +63,11 @@ class TORCH_CUDA_CU_API OptimizationPass { template class TORCH_CUDA_CU_API OptimizationPassGuard { public: - OptimizationPassGuard(bool enabled) - : prev_status_(OptPass::setEnabled(enabled)) {} + OptimizationPassGuard(bool enabled) : prev_status_(OptPass::getEnabled()) { + if (prev_status_ != enabled) { + OptPass::setEnabled(enabled); + } + } ~OptimizationPassGuard() { OptPass::setEnabled(prev_status_); } diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index eb28b322cb0..5f1968d9b63 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -9,6 +9,11 @@ #include #include +#include +#include +#include +#include +#include #include @@ -18,6 +23,36 @@ namespace nvfuser::optimization { +TEST_F(NVFuserTest, FusionTestOptimizationPassFlag_CUDA) { + class DerivedPass : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion) { + throw std::runtime_error("running DerivedPass"); + }; + }; + + auto fusion = std::make_unique(); + + { + // disabling the flag explicitly + OptimizationPassGuard guard(false); + OptimizationPass::runPass(fusion.get()); + } + + // the flag should be default on + bool except_thrown = false; + try { + OptimizationPass::runPass(fusion.get()); + } catch (std::runtime_error& err) { + if (std::strcmp(err.what(), "running DerivedPass") == 0) { + except_thrown = true; + } + } + TORCH_CHECK(except_thrown, "optimization pass is skipped unexpectedly"); +} + // Test cast optimization TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { std::vector input_shape{3, 7, 8}; @@ -145,4 +180,4 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } } -} +} // namespace nvfuser::optimization From e87ce9af3eb9e4f65332a713db7c05390008267e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 16:07:45 -0700 Subject: [PATCH 70/81] code cleaning; review comments addressed --- csrc/optimization/consecutive_cast.cpp | 40 ++++++++++++-------------- csrc/optimization/optimization_pass.h | 2 +- csrc/type.cpp | 2 +- csrc/type.h | 6 ++-- test/test_gpu3.cpp | 1 - test/test_optimization_pass.cpp | 23 ++++++--------- 6 files changed, 31 insertions(+), 43 deletions(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index dad57983191..34b9e015aa0 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -14,8 +14,7 @@ namespace nvfuser::optimization { namespace { bool isCast(Expr* expr) { - if (expr != nullptr && expr->isA()) { - auto op = expr->as(); + if (auto op = dynamic_cast(expr)) { return op->getUnaryOpType() == UnaryOpType::Cast; } return false; @@ -57,20 +56,20 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { // 1. iterating through all expr in the fusion: // 1.1 we skip all exprs other than cast; // 1.2 for each end cast-op 'expr', we trace back its producers iteratively -// and push the value(s) on top of `chain_cast_tvs`, until: +// and push the value(s) on top of `chain_cast_vals`, until: // // a. the producer is not a cast op; or // // b. the producer is used by other ops, or is a fusion output. // -// 1.3 at this point, each `chain_cast_tvs` has an ordered cast outputs with +// 1.3 at this point, each `chain_cast_vals` has an ordered cast outputs with // a straight line dependency: // 1.3.1 we point starting_anchor at the beginning op, indicating the // starting point of our folding optimization, meanwhile, we point // lo_anchor at the first op, indicating the narrowest dtype we have seen // in the segment; // 1.3.2 we enter the loop to iterate through items -// inside `chain_cast_tvs`, for item `val`: +// inside `chain_cast_vals`, for item `val`: // // a. if `val_dtype` is the same as, or wider than `anchor_dtype` // of `lo_anchor`, current cast is a no-op and can be ignored; @@ -108,9 +107,9 @@ void castOptimizationPass(Fusion* fusion) { if (!isCast(expr)) { continue; } - std::list chain_cast_tvs; + std::list chain_cast_vals; auto prev_expr = expr->input(0)->definition(); - while (prev_expr != nullptr && isCast(prev_expr)) { + while (isCast(prev_expr)) { auto intermediate_cast = prev_expr->output(0); // 1.2 Note, if the output of prev_expr // is used by other operation(s); or @@ -122,40 +121,37 @@ void castOptimizationPass(Fusion* fusion) { } // in the loop, we just repetitively chaining consecutive casts. - chain_cast_tvs.push_front(intermediate_cast); + chain_cast_vals.push_front(intermediate_cast); prev_expr = prev_expr->input(0)->definition(); } - // skip current expr if there's no chain_cast_tvs - if (chain_cast_tvs.empty()) { + // skip current expr if there's no chain_cast_vals + if (chain_cast_vals.empty()) { continue; } - // 1.3.1 Note, chain_cast_tvs has a straight-line use without branches - auto lo_anchor = chain_cast_tvs.front()->definition()->input(0); + // 1.3.1 Note, chain_cast_vals has a straight-line use without branches + auto lo_anchor = chain_cast_vals.front()->definition()->input(0); auto anchor_dtype = lo_anchor->getDataType().value(); auto starting_anchor = lo_anchor; - for (auto val : chain_cast_tvs) { + for (auto val : chain_cast_vals) { auto val_dtype = val->getDataType().value(); - // 1.3.2.a short-cut when we are not losing precision, either: - // 1. casting to the same type as the previously seen lowest precision; - // or - // 2. casting to a wider type. - if (val_dtype == anchor_dtype || isWiderType(anchor_dtype, val_dtype)) { + // 1.3.2.a short-cut when we are not losing precision + if (isInclusiveType(anchor_dtype, val_dtype)) { continue; } // 1.3.2.c NOTE: To enter here, we have - // !isWiderType(anchor_dtype, val_dtype) && isWiderType(val_dtype, - // anchor_dtype) + // !isInclusiveType(anchor_dtype, val_dtype) && + // !isInclusiveType(val_dtype, anchor_dtype) // // Which means the dtype between lo_anchor and val isn't compatible and // can't be fold away without losing information. So we update the // starting_anchor to current val, which ensures that we preserve the // incompatible casts. e.g. for cases where no one type is strictly wider // than the other: i.e. bf16 & fp16, int32 & float32 e.t.c. - if (!isWiderType(val_dtype, anchor_dtype)) { + if (!isInclusiveType(val_dtype, anchor_dtype)) { lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); val = replaceInputInCast(val, lo_anchor); // We need to update the starting_anchor for the fold to be past this @@ -175,7 +171,7 @@ void castOptimizationPass(Fusion* fusion) { if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); } - } else if (isWiderType(output_dtype, anchor_dtype)) { + } else if (isInclusiveType(output_dtype, anchor_dtype)) { // 1.4.b: if lo_anchor is wider than output_dtype, casting to lo_anchor // isn't doing anything, we'll just fold away to the starting_anchor // instead diff --git a/csrc/optimization/optimization_pass.h b/csrc/optimization/optimization_pass.h index 359977b9ca5..2826ab7afc3 100644 --- a/csrc/optimization/optimization_pass.h +++ b/csrc/optimization/optimization_pass.h @@ -9,7 +9,7 @@ #include -#include +#include namespace nvfuser::optimization { diff --git a/csrc/type.cpp b/csrc/type.cpp index c4a2aa2cb4a..fe5ba535464 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -32,7 +32,7 @@ KernelIndexMode indexTypeToMode(DataType index_type) { : KernelIndexMode::INT64; } -bool isWiderType(const DataType& base_type, const DataType& wider_type) { +bool isInclusiveType(const DataType& base_type, const DataType& wider_type) { if (base_type == wider_type) { return true; } diff --git a/csrc/type.h b/csrc/type.h index 9ae3b172ac8..9d575c7834f 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -141,9 +141,9 @@ enum class KernelIndexMode { INT32, INT64 }; PrimDataType indexModeToDtype(KernelIndexMode index_mode); KernelIndexMode indexTypeToMode(DataType index_type); -// check if type is a wider type than ref -// Which indicates a cast from ref -> type -> ref should be bit-wise identical -bool isWiderType(const DataType& base_type, const DataType& wider_type); +// check if type preserves all information from base_type. Which indicates a +// cast from base_type -> type -> base_type should be bit-wise identical +bool isInclusiveType(const DataType& base_type, const DataType& type); // Returns if the datatype is a floating point type TORCH_CUDA_CU_API inline bool isFloatingPointType(DataType dtype) { diff --git a/test/test_gpu3.cpp b/test/test_gpu3.cpp index 6571da2c311..62274e49b6e 100644 --- a/test/test_gpu3.cpp +++ b/test/test_gpu3.cpp @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index 5f1968d9b63..a74662961a4 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -42,15 +42,10 @@ TEST_F(NVFuserTest, FusionTestOptimizationPassFlag_CUDA) { } // the flag should be default on - bool except_thrown = false; - try { - OptimizationPass::runPass(fusion.get()); - } catch (std::runtime_error& err) { - if (std::strcmp(err.what(), "running DerivedPass") == 0) { - except_thrown = true; - } - } - TORCH_CHECK(except_thrown, "optimization pass is skipped unexpectedly"); + EXPECT_THAT( + [&]() { OptimizationPass::runPass(fusion.get()); }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("running DerivedPass"))); } // Test cast optimization @@ -95,7 +90,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv); optimization::OptimizationPass::runPass( fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! + // TODO: should I have copied the tensor to avoid an aliased output?! // simplified as (input) ASSERT_TRUE(tv0->sameAs(fusion->outputs()[0])); } @@ -122,7 +117,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv); optimization::OptimizationPass::runPass( fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> float auto ref_tv = castOp(DataType::Half, tv0); ref_tv = castOp(DataType::Float, ref_tv); @@ -146,7 +140,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv); optimization::OptimizationPass::runPass( fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! // simplified as (input)float -> half -> bfloat16 -> float auto ref_tv = castOp(DataType::Half, tv0); ref_tv = castOp(DataType::BFloat16, ref_tv); @@ -168,12 +161,12 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { tv = castOp(DataType::BFloat16, tv); tv = castOp(DataType::Float, tv); tv = castOp(DataType::Double, tv); - // (input)float -> double -> half -> float -> bfloat16 -> float + // (input)int32 -> double -> complex double -> int64 -> bfloat16 -> float -> + // double fusion->addOutput(tv); optimization::OptimizationPass::runPass( fusion.get()); - // TODO: should I have copied the tensor to avoid an alised output?! - // simplified as (input)float -> half -> bfloat16 -> float + // simplified as (input)int32 -> bfloat16 -> double auto ref_tv = castOp(DataType::BFloat16, tv0); ref_tv = castOp(DataType::Double, ref_tv); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); From 0429b122e19b94f4c7d81d8e45cdf2fc9974ce01 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 17:02:41 -0700 Subject: [PATCH 71/81] added missing test case --- test/test_optimization_pass.cpp | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index a74662961a4..9a624d012e0 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -171,6 +171,31 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { ref_tv = castOp(DataType::Double, ref_tv); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } + + { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Double, tv0); + fusion->addOutput(tv); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Float, tv); + // (input)float -> double(output0) -> half -> double -> float(output1) + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // simplified as (input)int32 -> bfloat16 -> double + auto ref_tv = castOp(DataType::Double, tv0); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + ref_tv = castOp(DataType::Half, ref_tv); + ref_tv = castOp(DataType::Float, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[1])); + } } } // namespace nvfuser::optimization From 3a8724e76fd81778d01e987e8090b67ea9a3bcd3 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 31 May 2023 17:12:57 -0700 Subject: [PATCH 72/81] reverting unwanted changes --- python_tests/test_python_frontend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python_tests/test_python_frontend.py b/python_tests/test_python_frontend.py index 4b89ecaeaba..96d62566d9a 100644 --- a/python_tests/test_python_frontend.py +++ b/python_tests/test_python_frontend.py @@ -176,8 +176,8 @@ def fusion_func(fd: FusionDefinition): def test_cast_double_to_half(self): inputs = [ - torch.randn(2, 4, device="cuda", dtype=torch.float64).half().double(), - torch.randn(2, 4, device="cuda", dtype=torch.float64).half().double(), + torch.randn(2, 4, device="cuda", dtype=torch.float64), + torch.randn(2, 4, device="cuda", dtype=torch.float64), ] def fusion_func(fd: FusionDefinition): From 6153e2f3f27e8eb0ec6445f2e663066173151f86 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 09:32:36 -0700 Subject: [PATCH 73/81] review comments --- csrc/optimization/consecutive_cast.cpp | 1 + test/test_optimization_pass.cpp | 76 +++++++++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 34b9e015aa0..e057b1d3050 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -167,6 +167,7 @@ void castOptimizationPass(Fusion* fusion) { if (anchor_dtype == output_dtype) { // 1.4.a final cast is the same dtype as with previous lo_anchor, // replacing output with lo_anchor in the fusion + lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); if (expr->output(0)->isFusionOutput()) { fusion->replaceOutput(expr->output(0), lo_anchor); diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index 9a624d012e0..329b28e4d79 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -55,6 +55,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { at::Tensor at_x = at::randn(input_shape, options); { + // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -77,6 +78,52 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { + // 1.4.b testing case + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Float, tv0); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Half, tv); + // (input)double -> float -> double -> half + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // simplified as (input)double -> half + auto ref_tv = castOp(DataType::Half, tv0); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + // 1.4.a testing case + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Double) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Float, tv0); + tv = castOp(DataType::Half, tv); + tv = castOp(DataType::Float, tv); + tv = castOp(DataType::Double, tv); + tv = castOp(DataType::Half, tv); + // (input)double -> float -> half -> float -> double -> half + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // simplified as (input)double -> half + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::Double, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } + + { + // 1.4.a testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -96,6 +143,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { + // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -124,6 +172,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { + // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -148,6 +197,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { + // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -173,6 +223,7 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { + // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -189,13 +240,36 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion->addOutput(tv); optimization::OptimizationPass::runPass( fusion.get()); - // simplified as (input)int32 -> bfloat16 -> double + // simplified as (input)float -> double(output0) -> half -> float(output1) auto ref_tv = castOp(DataType::Double, tv0); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); ref_tv = castOp(DataType::Half, ref_tv); ref_tv = castOp(DataType::Float, ref_tv); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[1])); } + + { + // 1.4.c testing case + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + auto tv0 = TensorViewBuilder() + .ndims(input_shape.size()) + .dtype(DataType::Float) + .build(); + fusion->addInput(tv0); + auto tv = castOp(DataType::Half, tv0); + tv = castOp(DataType::BFloat16, tv); + tv = castOp(DataType::Half, tv); + // (input)float -> half -> bfloat16 -> half + fusion->addOutput(tv); + optimization::OptimizationPass::runPass( + fusion.get()); + // simplified as (input)float -> half -> bfloat -> half + auto ref_tv = castOp(DataType::Half, tv0); + ref_tv = castOp(DataType::BFloat16, ref_tv); + ref_tv = castOp(DataType::Half, ref_tv); + ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); + } } } // namespace nvfuser::optimization From a85e1a64792b55604b0ab5a1f588a892f9d5a4db Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 09:44:15 -0700 Subject: [PATCH 74/81] err test is wrong --- test/test_optimization_pass.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index 329b28e4d79..792120ef2fd 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -118,7 +118,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { fusion.get()); // simplified as (input)double -> half auto ref_tv = castOp(DataType::Half, tv0); - ref_tv = castOp(DataType::Double, ref_tv); ASSERT_TRUE(ref_tv->sameAs(fusion->outputs()[0])); } From 29a61377c73044829a14b16c23176d69a960d978 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 09:50:35 -0700 Subject: [PATCH 75/81] quick refactor on redundant processing of short-wired casts --- csrc/optimization/consecutive_cast.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index e057b1d3050..0a186c46c6a 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -96,15 +96,13 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { // c. otherwise, we can't bypass `lo_anchor` cast, we rewire this // section as `starting_anchor`->`lo_anchor`->`expr->output(0)` void castOptimizationPass(Fusion* fusion) { - // TODO: Traveral implies topological order on returned exprs, we can leverage - // that to improve the effieciency of the pass. In the case of a straight line - // casts, we are doing a lot of meaningless work here on mutating intermediate - // casts that would have been done again at the end of the chain. - // We should really use the reverse topological order and filters out exprs - // that has been rendered as dead code during the pass. - for (auto expr : fusion->exprs()) { - // skip current expr if it's not a foldable cast - if (!isCast(expr)) { + auto exprs = fusion->exprs(); + std::unordered_set visited; + for (int i = exprs.size(); i >= 0; --i) { + auto expr = exprs[i]; + // skip current expr if it's not a foldable cast or it has already been + // addressed + if (!isCast(expr) || visited.count(expr) != 0) { continue; } std::list chain_cast_vals; @@ -122,6 +120,8 @@ void castOptimizationPass(Fusion* fusion) { // in the loop, we just repetitively chaining consecutive casts. chain_cast_vals.push_front(intermediate_cast); + // adding intermediate_cast to visited node so we'll short-cut it. + visited.insert(intermediate_cast); prev_expr = prev_expr->input(0)->definition(); } From e665635d69545688bd17cca6f7e93daf6b5f9715 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 09:52:15 -0700 Subject: [PATCH 76/81] type was wrong --- csrc/optimization/consecutive_cast.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 0a186c46c6a..699d3cc293a 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -97,7 +97,7 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { // section as `starting_anchor`->`lo_anchor`->`expr->output(0)` void castOptimizationPass(Fusion* fusion) { auto exprs = fusion->exprs(); - std::unordered_set visited; + std::unordered_set visited; for (int i = exprs.size(); i >= 0; --i) { auto expr = exprs[i]; // skip current expr if it's not a foldable cast or it has already been From 55241f63ca3cfcb8448478d6c3fad292bd766c24 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 10:13:52 -0700 Subject: [PATCH 77/81] fixing short-cut check to avoid segfault --- csrc/optimization/consecutive_cast.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 699d3cc293a..0e96b153d73 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -98,11 +98,11 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { void castOptimizationPass(Fusion* fusion) { auto exprs = fusion->exprs(); std::unordered_set visited; - for (int i = exprs.size(); i >= 0; --i) { + for (int i = exprs.size() - 1; i >= 0; --i) { auto expr = exprs[i]; // skip current expr if it's not a foldable cast or it has already been // addressed - if (!isCast(expr) || visited.count(expr) != 0) { + if (visited.count(expr) != 0 || !isCast(expr) ) { continue; } std::list chain_cast_vals; @@ -118,10 +118,10 @@ void castOptimizationPass(Fusion* fusion) { break; } + // adding prev_expr to visited node so we'll short-cut it. + visited.insert(prev_expr); // in the loop, we just repetitively chaining consecutive casts. chain_cast_vals.push_front(intermediate_cast); - // adding intermediate_cast to visited node so we'll short-cut it. - visited.insert(intermediate_cast); prev_expr = prev_expr->input(0)->definition(); } From 57cf5ea31c79a074df396ad7bd270528852b6caf Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 10:15:59 -0700 Subject: [PATCH 78/81] clangformat --- csrc/optimization/consecutive_cast.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 0e96b153d73..6038fabf89b 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -102,7 +102,7 @@ void castOptimizationPass(Fusion* fusion) { auto expr = exprs[i]; // skip current expr if it's not a foldable cast or it has already been // addressed - if (visited.count(expr) != 0 || !isCast(expr) ) { + if (visited.count(expr) != 0 || !isCast(expr)) { continue; } std::list chain_cast_vals; From a5bef364d611fb9f09f324f0c9a6c01068e3d3a0 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 10:41:30 -0700 Subject: [PATCH 79/81] clangtidy --- csrc/optimization/consecutive_cast.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 6038fabf89b..81e0c8a7d90 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -98,8 +98,8 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { void castOptimizationPass(Fusion* fusion) { auto exprs = fusion->exprs(); std::unordered_set visited; - for (int i = exprs.size() - 1; i >= 0; --i) { - auto expr = exprs[i]; + for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) { + auto expr = *iter; // skip current expr if it's not a foldable cast or it has already been // addressed if (visited.count(expr) != 0 || !isCast(expr)) { From 99c54b6be212c7203e8442d219ba7ffd0ed86bc4 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 11:10:50 -0700 Subject: [PATCH 80/81] merging logic since we are only looking at straight-line logic now --- csrc/optimization/consecutive_cast.cpp | 26 +++++++------------------- test/test_optimization_pass.cpp | 9 --------- 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 81e0c8a7d90..7e6bb1b1d01 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -86,14 +86,11 @@ Val* replaceInputInCast(Val* cast_output, Val* new_input) { // 1.4 At this point we look at `anchor_dtype` of `lo_anchor` and // `output_dtype` of `expr->output(0)`: // -// a. if `anchor_dtype` is the same as `output_dtype`, we skip the last -// cast op and replace all its uses with `lo_anchor`; -// -// b. if `anchor_dtype` is wider than `output_dtype`, all previous cast -// after `starting_anchor` is no-op, we re-wire `starting_anchor` +// a. if `anchor_dtype` is no narrower than `output_dtype`, all previous +// cast after `starting_anchor` is no-op, we re-wire `starting_anchor` // directly to `expr`; // -// c. otherwise, we can't bypass `lo_anchor` cast, we rewire this +// b. otherwise, we can't bypass `lo_anchor` cast, we rewire this // section as `starting_anchor`->`lo_anchor`->`expr->output(0)` void castOptimizationPass(Fusion* fusion) { auto exprs = fusion->exprs(); @@ -164,21 +161,12 @@ void castOptimizationPass(Fusion* fusion) { } auto output_dtype = expr->output(0)->getDataType().value(); - if (anchor_dtype == output_dtype) { - // 1.4.a final cast is the same dtype as with previous lo_anchor, - // replacing output with lo_anchor in the fusion - lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); - ir_utils::replaceValue(fusion, {{expr->output(0), lo_anchor}}); - if (expr->output(0)->isFusionOutput()) { - fusion->replaceOutput(expr->output(0), lo_anchor); - } - } else if (isInclusiveType(output_dtype, anchor_dtype)) { - // 1.4.b: if lo_anchor is wider than output_dtype, casting to lo_anchor - // isn't doing anything, we'll just fold away to the starting_anchor - // instead + + if (isInclusiveType(output_dtype, anchor_dtype)) { + // 1.4.a: if lo_anchor is no narrower than output_dtype, everything is an no-op replaceInputInCast(expr->output(0), starting_anchor); } else { - // 1.4.c: This is the case where we cannot fold away the cast of + // 1.4.b: This is the case where we cannot fold away the cast of // lo_anchor; we'll just re-wire input to expr with lo_anchor lo_anchor = replaceInputInCast(lo_anchor, starting_anchor); replaceInputInCast(expr->output(0), lo_anchor); diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index 792120ef2fd..2cc3411fd37 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -55,7 +55,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { at::Tensor at_x = at::randn(input_shape, options); { - // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -78,7 +77,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.b testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -99,7 +97,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.a testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -122,7 +119,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.a testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -142,7 +138,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -171,7 +166,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -196,7 +190,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -222,7 +215,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() @@ -248,7 +240,6 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } { - // 1.4.c testing case auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = TensorViewBuilder() From 44a6aa38affdd8805107687fd93c1e1d17bcf6fa Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 1 Jun 2023 11:23:03 -0700 Subject: [PATCH 81/81] patching the case where the cast needs to be removed --- csrc/optimization/consecutive_cast.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/csrc/optimization/consecutive_cast.cpp b/csrc/optimization/consecutive_cast.cpp index 7e6bb1b1d01..ca8fcd1b5de 100644 --- a/csrc/optimization/consecutive_cast.cpp +++ b/csrc/optimization/consecutive_cast.cpp @@ -163,8 +163,19 @@ void castOptimizationPass(Fusion* fusion) { auto output_dtype = expr->output(0)->getDataType().value(); if (isInclusiveType(output_dtype, anchor_dtype)) { - // 1.4.a: if lo_anchor is no narrower than output_dtype, everything is an no-op - replaceInputInCast(expr->output(0), starting_anchor); + // 1.4.a: if lo_anchor is no narrower than output_dtype, everything is an + // no-op + + if (starting_anchor->getDataType().value() == output_dtype) { + // if output dtype is identical to starting_anchor dtype, we can't keep + // the last cast op and will need to re-write all uses here + ir_utils::replaceValue(fusion, {{expr->output(0), starting_anchor}}); + if (expr->output(0)->isFusionOutput()) { + fusion->replaceOutput(expr->output(0), starting_anchor); + } + } else { + replaceInputInCast(expr->output(0), starting_anchor); + } } else { // 1.4.b: This is the case where we cannot fold away the cast of // lo_anchor; we'll just re-wire input to expr with lo_anchor