diff --git a/test/test_nvfuser_frontend.py b/test/test_nvfuser_frontend.py index b34ac484ea40..4872f86b54db 100644 --- a/test/test_nvfuser_frontend.py +++ b/test/test_nvfuser_frontend.py @@ -61,13 +61,107 @@ def exec_nvfuser(self, fusion_func, inputs, new_fusion_expected=True) : self.assertEqual(fc.num_fusions() - before_fusions, int(new_fusion_expected)) return out, fs - def test_basic(self) : + def test_add(self): inputs = [ - torch.ones(2, 4, 8, device='cuda'), - torch.ones(2, 4, 8, device='cuda'), + torch.ones(2, 4, 8, device="cuda"), + torch.ones(2, 4, 8, device="cuda"), ] - def fusion_func(fd: FusionDefinition) : + def fusion_func(fd: FusionDefinition): + t0 = fd.define_tensor(3) + t1 = fd.define_tensor(3) + + t2 = fd.ops.add(t0, t1) + + fd.add_output(t2) + + # Expected Output is a tensor of 2's + nvf_out1, _ = self.exec_nvfuser(fusion_func, inputs) + + # Create a new fusion with the same definition, it should hit the cache! + nvf_out2, fs2 = self.exec_nvfuser( + fusion_func, inputs, new_fusion_expected=False + ) + + # Create a fusion from a fusion id and make sure it executes! + fs3 = Fusion(fs2.id()) + nvf_out3 = fs3.execute(inputs)[0] + + eager_out = inputs[0] + inputs[1] + self.assertEqual(eager_out, nvf_out1) + self.assertEqual(eager_out, nvf_out2) + self.assertEqual(eager_out, nvf_out3) + + def test_super_basic(self): + inputs = [ + torch.ones(4, 8, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition): + t0 = fd.define_tensor(2) + c0 = fd.define_constant(3.0) + + t1 = fd.ops.mul(t0, c0) + t2 = fd.ops.sum(t1, [-1], False, DataType.Float) + + fd.add_output(t2) + + # Expected Output is a tensor of 24's + nvf_out1, _ = self.exec_nvfuser(fusion_func, inputs) + + # Create a new fusion with the same definition, it should hit the cache! + nvf_out2, fs2 = self.exec_nvfuser( + fusion_func, inputs, new_fusion_expected=False + ) + + # Create a fusion from a fusion id and make sure it executes! + fs3 = Fusion(fs2.id()) + nvf_out3 = fs3.execute(inputs)[0] + + eager_out = torch.sum(inputs[0] * 3.0, dim=-1) + self.assertEqual(eager_out, nvf_out1) + self.assertEqual(eager_out, nvf_out2) + self.assertEqual(eager_out, nvf_out3) + + def test_super_basic_fp16(self): + inputs = [ + torch.ones(4, 8, device="cuda", dtype=torch.float16), + ] + + def fusion_func(fd: FusionDefinition): + t0 = fd.define_tensor(2, DataType.Half) + c0 = fd.define_constant(3.0) + + t1 = fd.ops.mul(t0, c0) + t2 = fd.ops.sum(t1, [-1], False, DataType.Float) + + t3 = fd.ops.cast(t2, DataType.Half) + fd.add_output(t3) + + # Expected Output is a tensor of 48's + nvf_out1, _ = self.exec_nvfuser(fusion_func, inputs) + + # Create a new fusion with the same definition, it should hit the cache! + nvf_out2, fs2 = self.exec_nvfuser( + fusion_func, inputs, new_fusion_expected=False + ) + + # Create a fusion from a fusion id and make sure it executes! + fs3 = Fusion(fs2.id()) + nvf_out3 = fs3.execute(inputs)[0] + + eager_out = torch.sum(inputs[0] * 3.0, dim=-1) + self.assertEqual(eager_out, nvf_out1) + self.assertEqual(eager_out, nvf_out2) + self.assertEqual(eager_out, nvf_out3) + + def test_basic(self): + inputs = [ + torch.ones(2, 4, 8, device="cuda"), + torch.ones(2, 4, 8, device="cuda"), + ] + + def fusion_func(fd: FusionDefinition): t0 = fd.define_tensor(3) t1 = fd.define_tensor(3) c0 = fd.define_constant(3.0) diff --git a/third_party/nvfuser/CMakeLists.txt b/third_party/nvfuser/CMakeLists.txt index 020b3694721f..b418f9101693 100644 --- a/third_party/nvfuser/CMakeLists.txt +++ b/third_party/nvfuser/CMakeLists.txt @@ -332,6 +332,7 @@ if(BUILD_TEST) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_indexing_ops.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_indexing.cpp) list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_gather_ops.cpp) + list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/test/test_gpu_match_frontend.cpp) set(JIT_TEST_CU_SRCS) list(APPEND JIT_TEST_CU_SRCS ${NVFUSER_ROOT}/test/test_gpu_rng.cu) diff --git a/third_party/nvfuser/csrc/fusion.cpp b/third_party/nvfuser/csrc/fusion.cpp index 31bb763ac559..901e729f6cfc 100644 --- a/third_party/nvfuser/csrc/fusion.cpp +++ b/third_party/nvfuser/csrc/fusion.cpp @@ -13,6 +13,8 @@ #include #include +#include + namespace torch { namespace jit { namespace fuser { @@ -344,26 +346,26 @@ void Fusion::validateInputs() { } } -void Fusion::print() { +void Fusion::print(std::ostream& stream) { FUSER_PERF_SCOPE("Fusion::print"); FusionGuard fg(this); - std::cout << "\n%kernel {\n"; - IrMathPrinter op_exprs(std::cout); + stream << "\n%kernel {\n"; + IrMathPrinter op_exprs(stream); op_exprs.handle(this); - std::cout << "\nTransformPrinter : \n"; - IrTransformPrinter t_exprs(std::cout); + stream << "\nTransformPrinter : \n"; + IrTransformPrinter t_exprs(stream); t_exprs.handle(this); - std::cout << "}\n\n"; + stream << "}\n\n"; } -void Fusion::printKernel(DataType index_type) { +void Fusion::printKernel(DataType index_type, std::ostream& stream) { FUSER_PERF_SCOPE("Fusion::printKernel"); TORCH_INTERNAL_ASSERT( !this->isA(), "Cannot \"print kernel\" of a kernel container. ", "This would require lowering during lowering."); - std::cout << codegen::generateCudaKernel(GpuLower(this, index_type).kernel()); + stream << codegen::generateCudaKernel(GpuLower(this, index_type).kernel()); } std::unordered_map> Fusion::bankConflictInfo( @@ -380,19 +382,19 @@ std::unordered_map> Fusion::bankConflictInfo( return result; } -void Fusion::printMath(bool from_outputs_only) { +void Fusion::printMath(bool from_outputs_only, std::ostream& stream) { FUSER_PERF_SCOPE("Fusion::printMath"); FusionGuard fg(this); auto exprs_for_print = exprs(); - std::cout << "Inputs:" << std::endl; + stream << "Inputs:" << std::endl; for (auto inp : inputs()) { - std::cout << " " << inp << ", " << inp->getDataType().value() << std::endl; + stream << " " << inp << ", " << inp->getDataType().value() << std::endl; } - std::cout << "Outputs:" << std::endl; + stream << "Outputs:" << std::endl; for (auto out : outputs()) { - std::cout << " " << out << ", " << out->getDataType().value() << std::endl; + stream << " " << out << ", " << out->getDataType().value() << std::endl; } // If we want everything in the fusion, grab all values without uses to @@ -407,11 +409,11 @@ void Fusion::printMath(bool from_outputs_only) { exprs_for_print = StmtSort::getExprs(this, leaf_vals); } - std::cout << "\n%kernel_math {\n"; + stream << "\n%kernel_math {\n"; for (auto expr : exprs_for_print) { - std::cout << expr; + stream << expr; } - std::cout << "}\n\n"; + stream << "}\n\n"; } std::vector Fusion::inputsAndCreated() { @@ -427,11 +429,11 @@ std::vector Fusion::inputsAndCreated() { return result; } -void Fusion::printTransforms() { +void Fusion::printTransforms(std::ostream& stream) { FUSER_PERF_SCOPE("Fusion::printTransforms"); FusionGuard fg(this); - IrTransformPrinter t_exprs(std::cout); + IrTransformPrinter t_exprs(stream); t_exprs.handle(this); } diff --git a/third_party/nvfuser/csrc/fusion.h b/third_party/nvfuser/csrc/fusion.h index d8cef33fda0d..3912048c1e11 100644 --- a/third_party/nvfuser/csrc/fusion.h +++ b/third_party/nvfuser/csrc/fusion.h @@ -124,17 +124,22 @@ class TORCH_CUDA_CU_API Fusion : public IrContainer { void validateInputs(); //! Print this fusion to the console - void print(); + void print(std::ostream& stream = std::cout); //! Print Arith exprs //! \param from_outputs_only Only print exprs reachable from outputs - void printMath(bool from_outputs_only = true); + //! \param stream Where to print output (defaults to std::cout) + void printMath( + bool from_outputs_only = true, + std::ostream& stream = std::cout); //! Print transformations used in fusion (can be very verbose) - void printTransforms(); + void printTransforms(std::ostream& stream = std::cout); //! Lower the fusion and print a kernel - void printKernel(DataType index_type = DataType::Int); + void printKernel( + DataType index_type = DataType::Int, + std::ostream& stream = std::cout); //! Returns if this fusion is noop, for example, trivially forwarding inputs, //! or all outputs are size-0 tensors, etc. diff --git a/third_party/nvfuser/test/test_gpu_match_frontend.cpp b/third_party/nvfuser/test/test_gpu_match_frontend.cpp new file mode 100644 index 000000000000..32a27e6a55cb --- /dev/null +++ b/third_party/nvfuser/test/test_gpu_match_frontend.cpp @@ -0,0 +1,1367 @@ +//! These tests replicate those that appear in test/test_nvfuser_frontend.py +//! In this file, we manually schedule each fusion, and compare that to the +//! automatic scheduling that occurs in the python test. + +#if defined(USE_CUDA) +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +using namespace torch::jit::fuser::cuda; + +//! Compare fusions by printing the IR math of each to string then doing a +//! strcmp +void compare_ir_math(Fusion& factual, Fusion& fexpected) { + std::ostringstream sactual, sexpected; + + factual.printMath(true, sactual); + fexpected.printMath(true, sexpected); + + if (sactual.str() != sexpected.str()) { + std::cerr << "========= EXPECTED ==========" << std::endl; + std::cerr << sexpected.str() << std::endl; + std::cerr << "========= ACTUAL ==========" << std::endl; + std::cerr << sactual.str() << std::endl; + TORCH_INTERNAL_ASSERT(false, "Fusion IR math does not match expected"); + } +} + +//! Compare fusions by printing the IR transforms of each to string then doing +//! a strcmp +void compare_transforms(Fusion& factual, Fusion& fexpected) { + std::ostringstream sactual, sexpected; + + factual.printTransforms(sactual); + fexpected.printTransforms(sexpected); + + if (sactual.str() != sexpected.str()) { + std::cerr << "========= EXPECTED ==========" << std::endl; + std::cerr << sexpected.str() << std::endl; + std::cerr << "========= ACTUAL ==========" << std::endl; + std::cerr << sactual.str() << std::endl; + TORCH_INTERNAL_ASSERT(false, "Generated transforms do not match expected"); + } +} + +//! Compare fusions by printing the generated CUDA kernel of each to string +//! then doing a strcmp +void compare_kernels(Fusion& factual, Fusion& fexpected) { + std::ostringstream sactual, sexpected; + + factual.printKernel(DataType::Int, sactual); + fexpected.printKernel(DataType::Int, sexpected); + + if (sactual.str() != sexpected.str()) { + std::cerr << "========= EXPECTED ==========" << std::endl; + std::cerr << sexpected.str() << std::endl; + std::cerr << "========= ACTUAL ==========" << std::endl; + std::cerr << sactual.str() << std::endl; + TORCH_INTERNAL_ASSERT( + false, "Generated CUDA kernel does not match expected"); + } +} + +void compare_ir(Fusion& factual, Fusion& fexpected) { + compare_ir_math(factual, fexpected); + compare_transforms(factual, fexpected); + compare_kernels(factual, fexpected); +} + +//! A simple point-wise test computing (x + y) for 3D inputs +//! ```python +//! def fusion_func(fd: FusionDefinition): +//! t0 = fd.define_tensor(2) +//! t1 = fd.define_tensor(2) +//! +//! t2 = fd.ops.add(t0, t1) +//! +//! fd.add_output(t2) +//! ``` +TEST_F(NVFuserTest, FusionFrontendAdd_CUDA) { + // Create inputs + + int x = 2, y = 4, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + std::vector inputs = {t0, t1}; + + // Define fusion for automatic scheduling + Fusion fauto; + { + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + auto tv2 = add(tv0, tv1); + // auto tv4 = sum(tv2, {-1}, false, DataType::Float); + + fauto.addOutput(tv2); + + // Run automatic scheduler + auto pointwise_params = getPointwiseHeuristics(&fauto, inputs); + TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!"); + schedulePointwise(&fauto, *pointwise_params); + } + + // Repeat definition of fusion for manual scheduling + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + // auto tv4 = sum(tv2, {-1}, false, DataType::Float); + + fusion.addOutput(tv2); + + // Perform manual scheduling + + // Before schedulePointwise() is called, getPointwiseHeuristics() calls + // vectorize_helper::getExpandedVectorization() which in turn calls: + // vectorize_helper::getVectorizationSize + // vectorize_helper::ProjectedExtent::getNumerator + // vectorize_helper::ProjectedExtent::computeNumerDenomir + // IrContainer::oneVal + // oneVal() creates an actual Val here to hold the denominator and + // initializes it to 1. Since this is reflected in the fusion log, I'm + // inserting it here even though it has not effect on the generated kernel. + fusion.oneVal(); + + // scheduler_utils::cacheInputs(fusion, true); + tv0->cacheAfter(); // tv3 + tv1->cacheAfter(); // tv4 + + // scheduler_utils::cacheAndForkOutputs(fusion, true); + auto tv5 = tv2->cacheBefore(); // tv5 + + tv2->merge(1, 2); + tv2->merge(0, 1); + tv2->reorder({{0, -1}}); + tv2->reorder({{-1, 0}}); + tv2->split(0, 128); + tv2->split(0, 1); + tv2->split(0, 1); + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::Unswitch); + tv2->axis(3)->parallelize(ParallelType::TIDx); + + // inlineMost(); + // tv3->computeAt(tv2, 2); + // tv4->computeAt(tv2, 2); + + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv2); + + // Pointwise scheduler does not use inlineMost(), as reduction scheduler does + // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors) + inlineAllAt(tv2, 2, true); + inlineMost(std::vector({tv5, tv1, tv0})); + + // Note that inlineAllAt iterates through an unordered_set to do inlining, so + // it is not practical to match the fusion_debug log exactly when using + // pointwise scheduler + compare_ir_math(fusion, fauto); + compare_transforms(fusion, fauto); + // compare_fusion_debug(fusion, fauto); + compare_kernels(fusion, fauto); + + // compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t2 = t0.add(t1); + // auto t4 = t2.sum({-1}, false); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t2}, __LINE__, __FILE__, "", lparams); +} + +//! A very simple test computing sum(x * 3.0, dim=-1) for 2D inputs +//! ```python +//! def fusion_func(fd: FusionDefinition): +//! t0 = fd.define_tensor(2) +//! c0 = fd.define_constant(3.0) +//! +//! t1 = fd.ops.mul(t0, c0) +//! t2 = fd.ops.sum(t1, [-1], False, DataType.Float) +//! +//! fd.add_output(t2) +//! ``` +TEST_F(NVFuserTest, FusionFrontendSuperBasic_CUDA) { + // Create inputs + int y = 4, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y, z}, options); + + std::vector inputs = {t0}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(2); // {i0, i1} + auto c0 = IrBuilder::create(3.0); + + fauto.addInput(tv0); + + auto tv1 = mul(tv0, c0); // {i0, i1} + auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1} + + fauto.addOutput(tv2); + + // Run automatic scheduler + auto reduction_params = getReductionHeuristics(&fauto, inputs); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fauto, *reduction_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); // {i0, i1} + auto c0 = IrBuilder::create(3.0); + + fusion.addInput(tv0); + + auto tv1 = mul(tv0, c0); // {i0, i1} + auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1} + + fusion.addOutput(tv2); + + // Perform manual scheduling + + tv2->reorder({{1, 0}}); // Removing these two reorders does not effect the + // generated kernel + tv2->reorder({{1, 0}}); + tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + tv2->axis(2)->parallelize(ParallelType::TIDx); + tv2->split(1, 1); + tv2->axis(2)->parallelize(ParallelType::Unswitch); + tv2->axis(0)->parallelize(ParallelType::BIDx); + + // tv2->reorder({{-2, -1}}) has same effect but this shows the mapping + // explicitly + tv2->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 2}}); + + auto tv3 = tv2->rFactor({1, 3}); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + scheduler_utils::parallelizeAllLike( + tv3, + {}, + allParallelTypesExcept( + {ParallelType::Unroll, + ParallelType::Vectorize, + ParallelType::MisalignedVectorize})); + + inlineMost(); + + compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t1 = t0 * 3.0; + auto t2 = t1.sum({-1}, false); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t2}, __LINE__, __FILE__, "", lparams); +} + +//! The same test as FusionFrontendSuperBasic_CUDA, but with half-precision +//! inputs and outputs +//! ```python +//! def fusion_func(fd: FusionDefinition): +//! t0 = fd.define_tensor(2, DataType.Half) +//! c0 = fd.define_constant(3.0) +//! +//! t1 = fd.ops.mul(t0, c0) +//! t2 = fd.ops.sum(t1, [-1], False, DataType.Float) +//! +//! t3 = fd.ops.cast(t2, DataType.Half) +//! fd.add_output(t3) +//! ``` +TEST_F(NVFuserTest, FusionFrontendSuperBasicFP16_CUDA) { + // Create inputs + int y = 4, z = 8; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y, z}, options); + + std::vector inputs = {t0}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); // {i0, i1} + auto c0 = IrBuilder::create(3.0); + + fauto.addInput(tv0); + + auto tv1 = mul(tv0, c0); // {i0, i1} + auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1} + auto tv3 = castOp(DataType::Half, tv2); + + fauto.addOutput(tv3); + + // Run automatic scheduler + auto reduction_params = getReductionHeuristics(&fauto, inputs); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fauto, *reduction_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); // {i0, i1} + auto c0 = IrBuilder::create(3.0); + + fusion.addInput(tv0); + + auto tv1 = mul(tv0, c0); // {i0, i1} + auto tv2 = sum(tv1, {-1}, false, DataType::Float); // {i0, r1} + auto tv4 = castOp(DataType::Half, tv2); + + fusion.addOutput(tv4); + + // Perform manual scheduling + tv2->reorder({{1, 0}}); // Removing these two reorders does not effect the + // generated kernel + tv2->reorder({{1, 0}}); + tv2->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + tv2->axis(2)->parallelize(ParallelType::TIDx); + tv2->split(1, 1); + tv2->axis(2)->parallelize(ParallelType::Unswitch); + tv2->axis(0)->parallelize(ParallelType::BIDx); + + // tv2->reorder({{-2, -1}}) has same effect but this shows the mapping + // explicitly + tv2->reorder({{0, 0}, {1, 1}, {2, 3}, {3, 2}}); + + auto tv3 = tv2->rFactor({1, 3}); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv3); + MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator); + scheduler_utils::parallelizeAllLike( + tv3, + {}, + allParallelTypesExcept( + {ParallelType::Unroll, + ParallelType::Vectorize, + ParallelType::MisalignedVectorize})); + + inlineMost(); + + compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t1 = t0 * 3.0; + auto t2 = t1.sum({-1}, false); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t2}, __LINE__, __FILE__, "", lparams); +} + +//! A simple test computing sum((x + y) * 3.0, dim=-1) for 3D inputs +//! A simple test computing (x + y) for 3D inputs +//! ```python +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor(3) +//! t1 = fd.define_tensor(3) +//! c0 = fd.define_constant(3.0) +//! +//! t2 = fd.ops.add(t0, t1) +//! t3 = fd.ops.mul(t2, c0) +//! t4 = fd.ops.sum(t3, [-1], False, DataType.Float) +//! +//! fd.add_output(t4) +//! ``` +TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) { + // Create inputs + int x = 2, y = 4, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto c0 = IrBuilder::create(3.0); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + auto tv2 = add(tv0, tv1); + auto tv3 = mul(tv2, c0); + auto tv4 = sum(tv3, {-1}, false, DataType::Float); + + fauto.addOutput(tv4); + + // Run automatic scheduler + auto reduction_params = getReductionHeuristics(&fauto, inputs); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fauto, *reduction_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + auto tv1 = makeSymbolicTensor(3); + auto c0 = IrBuilder::create(3.0); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + auto tv3 = mul(tv2, c0); + auto tv4 = sum(tv3, {-1}, false, DataType::Float); + + fusion.addOutput(tv4); + + // Perform manual scheduling + + auto tv5 = tv0->cacheAfter(); // tv5 + auto tv6 = tv1->cacheAfter(); // tv6 + auto tv7 = tv4->cacheBefore(); // tv7 + + tv7->reorder({{2, 0}}); + tv7->merge(1, 2); + tv7->reorder({{1, 0}}); + tv7->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + tv7->axis(2)->parallelize(ParallelType::TIDx); + tv7->split(1, 1); + tv7->axis(2)->parallelize(ParallelType::Unswitch); + tv7->split(0, 2); + tv7->axis(1)->parallelize(ParallelType::Unroll); + tv7->split(0, 1); + tv7->axis(1)->parallelize(ParallelType::Unswitch); + tv7->axis(0)->parallelize(ParallelType::BIDx); + + tv7->reorder({{0, 0}, {1, 2}, {2, 3}, {3, 1}, {4, 5}, {5, 4}}); + + auto tv8 = tv7->rFactor({1, 5}); + + // NOTE: see multiReductionInliner for more info on how propagation and + // inlining works in the reduction scheduler + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv8); + MaxRootDomainInfoSpanningTree(tv8).traverse(&propagator); + // Propagate parallelization except vectorization and unrolling + scheduler_utils::parallelizeAllLike( + tv8, + {}, + allParallelTypesExcept( + {ParallelType::Unroll, + ParallelType::Vectorize, + ParallelType::MisalignedVectorize})); + // Propagate vectorization/unrolling to those tensors that need it + scheduler_utils::parallelizeAllLike( + tv8, + {tv4, tv6, tv5}, + { + ParallelType::Unroll, + ParallelType::Vectorize, + ParallelType::MisalignedVectorize, + }); + // If reference shouldn't be unrolled, clear that parallel type. + tv8->axis(3)->parallelize(ParallelType::Serial); + tv7->axis(2)->parallelize(ParallelType::Serial); + + inlineMost(); + + compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t2 = t0.add(t1); + auto t3 = t2.mul(3.0); + auto t4 = t3.sum({-1}, false); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__, "", lparams); +} + +//! The same test as FusionFrontendBasic_CUDA, but with half-precision +//! inputs and outputs +//! ```python +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor(3, DataType.Half) +//! t1 = fd.define_tensor(3, DataType.Half) +//! c0 = fd.define_constant(3.0) +//! +//! t2 = fd.ops.add(t0, t1) +//! t3 = fd.ops.mul(t2, c0) +//! t4 = fd.ops.sum(t3, [-1], False, DataType.Float) +//! +//! t5 = fd.ops.cast(t4, DataType.Half) +//! fd.add_output(t5) +//! ``` +TEST_F(NVFuserTest, FusionFrontendBasicFP16_CUDA) { + // Create inputs + int x = 2, y = 4, z = 8; + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y, z}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(3, DataType::Half); + auto tv1 = makeSymbolicTensor(3, DataType::Half); + auto c0 = IrBuilder::create(3.0); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + auto tv2 = add(tv0, tv1); + auto tv3 = mul(tv2, c0); + auto tv4 = sum(tv3, {-1}, false, DataType::Float); + auto tv5 = castOp(DataType::Half, tv4); + + fauto.addOutput(tv5); + + // Run automatic scheduler + auto reduction_params = getReductionHeuristics(&fauto, inputs); + TORCH_CHECK(reduction_params, "Reduction schedule was not generated!"); + scheduleReduction(&fauto, *reduction_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3, DataType::Half); + auto tv1 = makeSymbolicTensor(3, DataType::Half); + auto c0 = IrBuilder::create(3.0); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = add(tv0, tv1); + auto tv3 = mul(tv2, c0); + auto tv4 = sum(tv3, {-1}, false, DataType::Float); + auto tv5 = castOp(DataType::Half, tv4); + + fusion.addOutput(tv5); + + // Perform manual scheduling + + auto tv6 = tv0->cacheAfter(); // tv6 + auto tv7 = tv1->cacheAfter(); // tv7 + tv5->cacheBefore(); // tv8 + + // NOTE: tv4 is now chosen as the representative tensor + tv4->reorder({{2, 0}}); + tv4->merge(1, 2); + tv4->reorder({{1, 0}}); + tv4->split(1, NamedScalar::getParallelDim(ParallelType::TIDx)); + tv4->axis(2)->parallelize(ParallelType::TIDx); + tv4->split(1, 1); + tv4->axis(2)->parallelize(ParallelType::Unswitch); + tv4->split(0, 2); + tv4->axis(1)->parallelize(ParallelType::Unroll); + tv4->split(0, 1); + tv4->axis(1)->parallelize(ParallelType::Unswitch); + tv4->axis(0)->parallelize(ParallelType::BIDx); + + tv4->reorder({{0, 0}, {1, 2}, {2, 3}, {3, 1}, {4, 5}, {5, 4}}); + + auto tv9 = tv4->rFactor({1, 5}); + + // NOTE: see multiReductionInliner for more info on how propagation and + // inlining works in the reduction scheduler + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv9); + MaxRootDomainInfoSpanningTree(tv9).traverse(&propagator); + // Propagate parallelization except vectorization and unrolling + scheduler_utils::parallelizeAllLike( + tv9, + {}, + allParallelTypesExcept( + {ParallelType::Unroll, + ParallelType::Vectorize, + ParallelType::MisalignedVectorize})); + // Propagate vectorization/unrolling to those tensors that need it + scheduler_utils::parallelizeAllLike( + tv9, + {tv5, tv7, tv6}, + { + ParallelType::Unroll, + ParallelType::Vectorize, + ParallelType::MisalignedVectorize, + }); + // If reference shouldn't be unrolled, clear that parallel type. + tv9->axis(3)->parallelize(ParallelType::Serial); + tv4->axis(2)->parallelize(ParallelType::Serial); + + inlineMost(); + + compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t2 = t0.add(t1); + auto t3 = t2.mul(3.0); + auto t4 = t3.sum({-1}, false); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__, "", lparams); +} + +//! Convert double inputs to half, then do some point-wise operations and +//! output half precision +//! ```python +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor(2, DataType.Double) +//! t1 = fd.define_tensor(2, DataType.Double) +//! +//! t0h = fd.ops.cast(t0, DataType.Half) +//! t1h = fd.ops.cast(t1, DataType.Half) +//! t2 = fd.ops.add(t0h, t1h) +//! t3 = fd.ops.relu(t2) +//! t4 = fd.ops.cast(t3, DataType.Half) +//! +//! fd.add_output(t4) +//! ``` +TEST_F(NVFuserTest, FusionFrontendCastDoubleToHalf_CUDA) { + // Create inputs + int x = 2, y = 4; + auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options); + at::Tensor t1 = at::randn({x, y}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(2, DataType::Double); + auto tv1 = makeSymbolicTensor(2, DataType::Double); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + auto tv2 = castOp(DataType::Half, tv0); + auto tv3 = castOp(DataType::Half, tv1); + // implicit casts + auto tv4 = castOp(DataType::Float, tv2); + auto tv5 = castOp(DataType::Float, tv3); + auto tv6 = add(tv4, tv5); + auto tv7 = relu(tv6); + auto tv8 = castOp(DataType::Half, tv7); + + fauto.addOutput(tv8); + + // Run automatic scheduler + auto pointwise_params = getPointwiseHeuristics(&fauto, inputs); + TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!"); + schedulePointwise(&fauto, *pointwise_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Double); + auto tv1 = makeSymbolicTensor(2, DataType::Double); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = castOp(DataType::Half, tv0); + auto tv3 = castOp(DataType::Half, tv1); + // implicit casts + auto tv4 = castOp(DataType::Float, tv2); + auto tv5 = castOp(DataType::Float, tv3); + auto tv6 = add(tv4, tv5); + auto tv7 = relu(tv6); + auto tv8 = castOp(DataType::Half, tv7); + + fusion.addOutput(tv8); + + // Perform manual scheduling + + // Before schedulePointwise() is called, getPointwiseHeuristics() calls + // vectorize_helper::getExpandedVectorization() which in turn calls: + // vectorize_helper::getVectorizationSize + // vectorize_helper::ProjectedExtent::getNumerator + // vectorize_helper::ProjectedExtent::computeNumerDenomir + // IrContainer::oneVal + // oneVal() creates an actual Val here to hold the denominator and + // initializes it to 1. Since this is reflected in the fusion log, I'm + // inserting it here even though it has not effect on the generated kernel. + fusion.oneVal(); + + tv0->cacheAfter(); // tv9 + tv1->cacheAfter(); // tv10 + auto tv11 = tv8->cacheBefore(); // tv11 + + tv8->merge(0, 1); + tv8->reorder({{0, -1}}); + tv8->reorder({{-1, 0}}); + tv8->split(0, 128); + tv8->split(0, 1); + tv8->split(0, 1); + tv8->axis(0)->parallelize(ParallelType::BIDx); + tv8->axis(1)->parallelize(ParallelType::Unswitch); + tv8->axis(3)->parallelize(ParallelType::TIDx); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv8); + MaxRootDomainInfoSpanningTree(tv8).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv8); + + // Pointwise scheduler does not use inlineMost(), as reduction scheduler does + // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors) + inlineAllAt(tv8, 2, true); + inlineMost( + std::vector({tv0, tv1, tv2, tv3, tv4, tv5, tv6, tv7, tv11})); + + // Note that inlineAllAt iterates through an unordered_set to do inlining, so + // it is not practical to match the fusion_debug log exactly when using + // pointwise scheduler + compare_ir_math(fusion, fauto); + compare_transforms(fusion, fauto); + // compare_fusion_debug(fusion, fauto); + compare_kernels(fusion, fauto); + + // compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t0h = t0.to(options.dtype(at::kHalf)); + auto t1h = t1.to(options.dtype(at::kHalf)); + + auto t2 = t0h.add(t1h); + auto t3 = t2.relu(); + auto t4 = t3.to(options); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__, "", lparams); +} + +//! Same test as FusionFrontendCastDoubleToHalf_CUDA, but with mixed inputs +//! (double and half) and without the explicit cast to half, so that +//! computation and output are all at double precision. +//! ```python +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor(2, DataType.Half) +//! t1 = fd.define_tensor(2, DataType.Double) +//! +//! t2 = fd.ops.add(t0, t1) +//! t5 = fd.ops.relu(t2) +//! +//! fd.add_output(t5) +//! ``` +TEST_F(NVFuserTest, FusionFrontendPromoteToDouble_CUDA) { + // Create inputs + int x = 2, y = 4; + auto options = at::TensorOptions().dtype(at::kDouble).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, y}, options.dtype(at::kHalf)); + at::Tensor t1 = at::randn({x, y}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + auto tv1 = makeSymbolicTensor(2, DataType::Double); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + auto tv2 = castOp(DataType::Double, tv0); + auto tv3 = add(tv2, tv1); + auto tv4 = relu(tv3); + + fauto.addOutput(tv4); + + // Run automatic scheduler + auto pointwise_params = getPointwiseHeuristics(&fauto, inputs); + TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!"); + schedulePointwise(&fauto, *pointwise_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2, DataType::Half); + auto tv1 = makeSymbolicTensor(2, DataType::Double); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = castOp(DataType::Double, tv0); + auto tv3 = add(tv2, tv1); + auto tv4 = relu(tv3); + + fusion.addOutput(tv4); + + // Perform manual scheduling + + // Before schedulePointwise() is called, getPointwiseHeuristics() calls + // vectorize_helper::getExpandedVectorization() which in turn calls: + // vectorize_helper::getVectorizationSize + // vectorize_helper::ProjectedExtent::getNumerator + // vectorize_helper::ProjectedExtent::computeNumerDenomir + // IrContainer::oneVal + // oneVal() creates an actual Val here to hold the denominator and + // initializes it to 1. Since this is reflected in the fusion log, I'm + // inserting it here even though it has not effect on the generated kernel. + fusion.oneVal(); + + tv0->cacheAfter(); // tv5 + tv1->cacheAfter(); // tv6 + auto tv7 = tv4->cacheBefore(); + + tv4->merge(0, 1); + tv4->reorder({{0, -1}}); + tv4->reorder({{-1, 0}}); + tv4->split(0, 128); + tv4->split(0, 1); + tv4->split(0, 1); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unswitch); + tv4->axis(3)->parallelize(ParallelType::TIDx); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv4); + + // Pointwise scheduler does not use inlineMost(), as reduction scheduler does + // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors) + inlineAllAt(tv4, 2, true); + inlineMost(std::vector({tv0, tv1, tv2, tv3, tv7})); + + // Note that inlineAllAt iterates through an unordered_set to do inlining, so + // it is not practical to match the fusion_debug log exactly when using + // pointwise scheduler + // compare_fusion_debug(fusion, fauto); + compare_ir_math(fusion, fauto); + compare_transforms(fusion, fauto); + compare_kernels(fusion, fauto); + + // compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t2 = t0.to(options.dtype(at::kDouble)); + auto t3 = t2.add(t1); + auto t4 = t3.relu(); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__, "", lparams); +} + +//! Test broadcasting one input then adding another +//! ```python +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor(1) +//! t1 = fd.define_tensor(3) +//! +//! t0_b = fd.ops.broadcast_in_dim(t0, [2, 3, 4], [1]) +//! t2 = fd.ops.add(t0_b, t1) +//! +//! fd.add_output(t2) +//! ``` +TEST_F(NVFuserTest, FusionFrontendImplicitBroadcastInput_CUDA) { + // Create inputs + int w = 3, x = 2, y = 3, z = 4; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({w}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeSymbolicTensor(1); + auto tv1 = makeSymbolicTensor(3); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + // explicitly tell tv0 to broadcast along new first and last dimensions + auto tv2 = broadcast(tv0, {true, false, true}); + auto tv3 = expand( + tv2, + {tv1->axis(0)->extent(), + tv1->axis(1)->extent(), + tv1->axis(2)->extent()}); + auto tv4 = add(tv3, tv1); + + fauto.addOutput(tv4); + + // Run automatic scheduler + auto pointwise_params = getPointwiseHeuristics(&fauto, inputs); + TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!"); + schedulePointwise(&fauto, *pointwise_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + auto tv1 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // explicitly tell tv0 to broadcast along new first and last dimensions + auto tv2 = broadcast(tv0, {true, false, true}); + auto tv3 = expand( + tv2, + {tv1->axis(0)->extent(), tv1->axis(1)->extent(), tv1->axis(2)->extent()}); + auto tv4 = add(tv3, tv1); + + fusion.addOutput(tv4); + + // Perform manual scheduling + + // Before schedulePointwise() is called, getPointwiseHeuristics() calls + // vectorize_helper::getExpandedVectorization() which in turn calls: + // vectorize_helper::getVectorizationSize + // vectorize_helper::ProjectedExtent::getNumerator + // vectorize_helper::ProjectedExtent::computeNumerDenomir + // IrContainer::oneVal + // oneVal() creates an actual Val here to hold the denominator and + // initializes it to 1. Since this is reflected in the fusion log, I'm + // inserting it here even though it has not effect on the generated kernel. + fusion.oneVal(); + + tv0->cacheAfter(); // tv5 + tv1->cacheAfter(); // tv6 + auto tv7 = tv4->cacheBefore(); + + tv4->merge(1, 2); + tv4->merge(0, 1); + tv4->reorder({{0, -1}}); + tv4->reorder({{-1, 0}}); + tv4->split(0, 128); + tv4->split(0, 1); + tv4->split(0, 1); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unswitch); + tv4->axis(3)->parallelize(ParallelType::TIDx); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv4); + + // Pointwise scheduler does not use inlineMost(), as reduction scheduler does + // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors) + inlineAllAt(tv4, 2, true); + inlineMost(std::vector({tv0, tv1, tv2, tv3, tv7})); + + // Note that inlineAllAt iterates through an unordered_set to do inlining, so + // it is not practical to match the fusion_debug log exactly when using + // pointwise scheduler + // compare_fusion_debug(fusion, fauto); + compare_ir_math(fusion, fauto); + compare_transforms(fusion, fauto); + compare_kernels(fusion, fauto); + + // compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t2 = t0.view({1, w, 1}); + auto t4 = t2.add(t1); + + int runtime_threadIdx_dim = 128; + LaunchParams lparams(-1, -1, -1, runtime_threadIdx_dim, -1, -1); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs, lparams); + auto cg_outputs = fe.runFusion(inputs, lparams); + + testValidate( + &fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__, "", lparams); +} + +//! Test broadcasting an input with existing broadcast dimensions, then adding +//! ```python +//! inputs = [ +//! torch.randn(1, 1, 4, device='cuda'), +//! torch.randn(2, 3, 4, device='cuda'), +//! ] +//! +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor(sizes=inputs[0].size(), +//! strides=inputs[0].stride()) t1 = +//! fd.define_tensor(sizes=inputs[1].size(), strides=inputs[1].stride()) +//! +//! t0_b = fd.ops.broadcast_in_dim(t0, inputs[1].size(), [0, 1, 2]) +//! t2 = fd.ops.add(t0_b, t1) +//! +//! fd.add_output(t2) +//! ``` +TEST_F(NVFuserTest, FusionFrontendExplicitBroadcastInput_CUDA) { + // Create inputs + int x = 2, y = 3, z = 4; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({1, 1, z}, options); + at::Tensor t1 = at::randn({x, y, z}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + // We depend on the input having size 1 in the first two dimensions, so we + // create a concrete tensor instead of using makeSymbolicTensor. The last + // dimension is still free. + auto tv0 = makeConcreteTensor({1, 1, -1}); + auto tv1 = makeSymbolicTensor(3); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + // The following line is unnecessary, but matches what is done in the + // frontend's broadcast_in_dim + auto tv2 = broadcast(tv0, {false, false, false}); + + auto tv3 = expand_as(tv2, tv1); + auto tv4 = add(tv3, tv1); + + fauto.addOutput(tv4); + + // Run automatic scheduler + auto pointwise_params = getPointwiseHeuristics(&fauto, inputs); + TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!"); + schedulePointwise(&fauto, *pointwise_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({1, 1, -1}); + auto tv1 = makeSymbolicTensor(3); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // The following line is unnecessary, but matches what is done in the + // frontend's broadcast_in_dim + auto tv2 = broadcast(tv0, {false, false, false}); + + auto tv3 = expand_as(tv2, tv1); + auto tv4 = add(tv3, tv1); + + fusion.addOutput(tv4); + + // Perform manual scheduling + + // Before schedulePointwise() is called, getPointwiseHeuristics() calls + // vectorize_helper::getExpandedVectorization() which in turn calls: + // vectorize_helper::getVectorizationSize + // vectorize_helper::ProjectedExtent::getNumerator + // vectorize_helper::ProjectedExtent::computeNumerDenomir + // IrContainer::oneVal + // oneVal() creates an actual Val here to hold the denominator and + // initializes it to 1. Since this is reflected in the fusion log, I'm + // inserting it here even though it has not effect on the generated kernel. + fusion.oneVal(); + + tv0->cacheAfter(); // tv5 + tv1->cacheAfter(); // tv6 + auto tv7 = tv4->cacheBefore(); + + tv4->merge(1, 2); + tv4->merge(0, 1); + tv4->reorder({{0, -1}}); + tv4->reorder({{-1, 0}}); + tv4->split(0, 128); + tv4->split(0, 1); + tv4->split(0, 1); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unswitch); + tv4->axis(3)->parallelize(ParallelType::TIDx); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv4); + + // Pointwise scheduler does not use inlineMost(), as reduction scheduler does + // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors) + inlineAllAt(tv4, 2, true); + inlineMost(std::vector({tv0, tv1, tv2, tv3, tv7})); + + // Note that inlineAllAt iterates through an unordered_set to do inlining, so + // it is not practical to match the fusion_debug log exactly when using + // pointwise scheduler + // compare_fusion_debug(fusion, fauto); + compare_ir_math(fusion, fauto); + compare_transforms(fusion, fauto); + compare_kernels(fusion, fauto); + + // compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t4 = t1.add(t0); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + testValidate(&fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__); +} + +//! Test adding implicitly and explicitly broadcast tensors together +//! ```python +//! inputs = [ +//! torch.randn(3, 1, device='cuda'), +//! torch.randn(3, device='cuda'), +//! ] +//! +//! def fusion_func(fd: FusionDefinition) : +//! t0 = fd.define_tensor([3, 1], [1, 1]) +//! t1 = fd.define_tensor(1) +//! +//! t1_b = fd.ops.broadcast_in_dim(t1, [3, 3], [0]) +//! t2 = fd.ops.add(t0, t1_b) +//! +//! fd.add_output(t2) +//! ``` +TEST_F(NVFuserTest, FusionFrontendBroadcastMixing_CUDA) { + // Create inputs + int x = 3; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x, 1}, options); + at::Tensor t1 = at::randn({x}, options); + + std::vector inputs = {t0, t1}; + + Fusion fauto; + { // Do automatic scheduling on fauto + FusionGuard fg(&fauto); + + auto tv0 = makeConcreteTensor({-1, 1}); + auto tv1 = makeSymbolicTensor(1); + + fauto.addInput(tv0); + fauto.addInput(tv1); + + // The following line is unnecessary, but matches what is done in the + // frontend's broadcast_in_dim + auto tv2 = broadcast(tv1, {false, true}); + auto xc = IrBuilder::create(x); + auto tv3 = expand(tv2, {tv2->axis(0)->extent(), xc}); + auto tv4 = add(tv0, tv3); + + fauto.addOutput(tv4); + + // Run automatic scheduler + auto pointwise_params = getPointwiseHeuristics(&fauto, inputs); + TORCH_CHECK(pointwise_params, "Pointwise schedule was not generated!"); + schedulePointwise(&fauto, *pointwise_params); + } + + // Re-define the fusion exactly for manual scheduling + // This is necessary in order to catch all the constructors inside each + // Fusion independently. + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({-1, 1}); + auto tv1 = makeSymbolicTensor(1); + + fusion.addInput(tv0); + fusion.addInput(tv1); + + // The following line is unnecessary, but matches what is done in the + // frontend's broadcast_in_dim + auto tv2 = broadcast(tv1, {false, true}); + auto xc = IrBuilder::create(x); + auto tv3 = expand(tv2, {tv2->axis(0)->extent(), xc}); + auto tv4 = add(tv0, tv3); + + fusion.addOutput(tv4); + + // Perform manual scheduling + + // Before schedulePointwise() is called, getPointwiseHeuristics() calls + // vectorize_helper::getExpandedVectorization() which in turn calls: + // vectorize_helper::getVectorizationSize + // vectorize_helper::ProjectedExtent::getNumerator + // vectorize_helper::ProjectedExtent::computeNumerDenomir + // IrContainer::oneVal + // oneVal() creates an actual Val here to hold the denominator and + // initializes it to 1. Since this is reflected in the fusion log, I'm + // inserting it here even though it has not effect on the generated kernel. + fusion.oneVal(); + + tv0->cacheAfter(); // tv5 + tv1->cacheAfter(); // tv6 + auto tv7 = tv4->cacheBefore(); + + tv4->merge(0, 1); + tv4->reorder({{0, -1}}); + tv4->reorder({{-1, 0}}); + tv4->split(0, 128); + tv4->split(0, 1); + tv4->split(0, 1); + tv4->axis(0)->parallelize(ParallelType::BIDx); + tv4->axis(1)->parallelize(ParallelType::Unswitch); + tv4->axis(3)->parallelize(ParallelType::TIDx); + + // propagate the mapping to other tensors + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + scheduler_utils::parallelizeAllLike(tv4); + + // Pointwise scheduler does not use inlineMost(), as reduction scheduler does + // Instead, it uses inlineAllAt followed by inlineMost(innermost_tensors) + inlineAllAt(tv4, 2, true); + inlineMost(std::vector({tv0, tv1, tv2, tv3, tv7})); + + // Note that inlineAllAt iterates through an unordered_set to do inlining, so + // it is not practical to match the fusion_debug log exactly when using + // pointwise scheduler + // compare_fusion_debug(fusion, fauto); + compare_ir_math(fusion, fauto); + compare_transforms(fusion, fauto); + compare_kernels(fusion, fauto); + + // compare_ir(fusion, fauto); + + // Perform eager computation and verify + auto t4 = t1.view({x, 1}).add(t0.expand({x, x})); + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto cg_outputs = fe.runFusion(inputs); + + testValidate(&fusion, cg_outputs, inputs, {t4}, __LINE__, __FILE__); +} + +} // namespace jit +} // namespace torch +#endif // #if defined(USE_CUDA)