From 685445993b03be30ce6934970e4e23b328ec107f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Sep 2024 08:28:16 -0700 Subject: [PATCH 1/4] test cleanup --- CMakeLists.txt | 3 +- tests/cpp/test_gpu3.cpp | 33 -- ...ndexing.cpp => test_indexing_advanced.cpp} | 451 ++++++++++++++++-- 3 files changed, 416 insertions(+), 71 deletions(-) rename tests/cpp/{test_gpu_indexing.cpp => test_indexing_advanced.cpp} (62%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 41690a9e30e..0611cfecb38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -529,14 +529,13 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_gpu3.cpp ${NVFUSER_ROOT}/tests/cpp/test_gpu_compute_with.cpp ${NVFUSER_ROOT}/tests/cpp/test_gpu_fused_reduction.cpp - ${NVFUSER_ROOT}/tests/cpp/test_gpu_indexing.cpp ${NVFUSER_ROOT}/tests/cpp/test_gpu_indexing_ops.cpp ${NVFUSER_ROOT}/tests/cpp/test_gpu_outer_reduction.cpp ${NVFUSER_ROOT}/tests/cpp/test_gpu_transpose.cpp ${NVFUSER_ROOT}/tests/cpp/test_gpu_utils.cpp ${NVFUSER_ROOT}/tests/cpp/test_id_model.cpp ${NVFUSER_ROOT}/tests/cpp/test_indexing.cpp - ${NVFUSER_ROOT}/tests/cpp/test_inlining.cpp + ${NVFUSER_ROOT}/tests/cpp/test_indexing_advanced.cpp ${NVFUSER_ROOT}/tests/cpp/test_iter_visitor.cpp ${NVFUSER_ROOT}/tests/cpp/test_linked_hash_map.cpp ${NVFUSER_ROOT}/tests/cpp/test_loop_rotation.cpp diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index a7c18898530..9b256db75db 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -3765,39 +3765,6 @@ TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) { testValidate(&fusion, cg_outputs, {input0, input1}, __LINE__, __FILE__); } -// Repro for issue #1873 -TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - auto tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = set(tv0); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->split(0, 32); - - tv0->computeAt(tv4, 1); - - tv2->split(-1, 8); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({123}, options); - at::Tensor t1 = at::randn({3, 123}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - - auto outputs = fe.runFusion({t0, t1}); - - testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__); -} - TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { // https://github.com/csarofeen/pytorch/issues/1926 std::unique_ptr fusion_ptr = std::make_unique(); diff --git a/tests/cpp/test_gpu_indexing.cpp b/tests/cpp/test_indexing_advanced.cpp similarity index 62% rename from tests/cpp/test_gpu_indexing.cpp rename to tests/cpp/test_indexing_advanced.cpp index fd0c4211173..9677e903d84 100644 --- a/tests/cpp/test_gpu_indexing.cpp +++ b/tests/cpp/test_indexing_advanced.cpp @@ -9,21 +9,78 @@ #include #include -#include -#include -#include -#include -#include -#include - #include #include -#include +#include +#include +#include +#include namespace nvfuser { -TEST_F(NVFuserTest, FusionIndexing1_CUDA) { +class AdvancedIndexingTest : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + if (GetParam()) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } else { + EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel); + } + } + + private: + EnableOptionsGuard enable_options_guard_; +}; + +class AdvancedIndexingIdModelTest : public NVFuserTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + } + + private: + EnableOptionsGuard enable_options_guard_; +}; + +// Repro for issue #1873 +TEST_P(AdvancedIndexingTest, InlineBroadcast) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 32); + + TransformPropagatorWithCheck propagator(tv4); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv2->inlineAt(1); + tv3->inlineAt(1); + + tv2->split(-1, 8); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({123}, options); + at::Tensor t1 = at::randn({3, 123}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + + auto outputs = fe.runFusion({t0, t1}); + + testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__); +} + +TEST_P(AdvancedIndexingTest, 1) { Fusion fusion; FusionGuard fg(&fusion); @@ -73,7 +130,8 @@ TEST_F(NVFuserTest, FusionIndexing1_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing2_CUDA) { +// Same as 1 but merge starting from inner most dimension +TEST_P(AdvancedIndexingTest, 2) { Fusion fusion; FusionGuard fg(&fusion); @@ -123,7 +181,8 @@ TEST_F(NVFuserTest, FusionIndexing2_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing3_CUDA) { +// Same compute as 1 and 2 but use a scheduler. +TEST_P(AdvancedIndexingTest, 3) { Fusion fusion; FusionGuard fg(&fusion); @@ -143,12 +202,14 @@ TEST_F(NVFuserTest, FusionIndexing3_CUDA) { at::Tensor t1 = at::randn({w, x, y, z}, options); std::vector aten_inputs = {t0, t1}; - auto cg_outputs = - scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs; + + auto cg_outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs; + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing4_CUDA) { +// Same as 3 but use 3 dimensions and concrete sizes +TEST_P(AdvancedIndexingTest, 4) { Fusion fusion; FusionGuard fg(&fusion); @@ -176,7 +237,7 @@ TEST_F(NVFuserTest, FusionIndexing4_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing5_CUDA) { +TEST_P(AdvancedIndexingTest, 5) { Fusion fusion; FusionGuard fg(&fusion); @@ -210,7 +271,7 @@ TEST_F(NVFuserTest, FusionIndexing5_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing6_CUDA) { +TEST_P(AdvancedIndexingTest, 6) { Fusion fusion; FusionGuard fg(&fusion); @@ -231,26 +292,26 @@ TEST_F(NVFuserTest, FusionIndexing6_CUDA) { at::Tensor input0 = at::randn(tensor0_shape, options); at::Tensor input1 = at::randn(tensor1_shape, options); - std::vector aten_inputs({input0, input1}); + std::vector aten_inputs = {input0, input1}; + + std::vector reduction_axes{0, 1}; - auto cg_reults = - scheduleAndRun(&fusion, SchedulerType::Reduction, aten_inputs); + auto results = scheduleAndRun(&fusion, SchedulerType::Reduction, aten_inputs); - std::vector reduction_axes{0, 1}; auto aten_output = input0.add(input1).to(at::kDouble).sum(reduction_axes); testValidate( &fusion, - cg_reults.outputs, + results.outputs, {input0, input1}, {aten_output}, __LINE__, __FILE__, "", - cg_reults.heuristic_params->lparams); + results.heuristic_params->lparams); } -TEST_F(NVFuserTest, FusionIndexing7_CUDA) { +TEST_P(AdvancedIndexingTest, 7) { // Might be able to use this one without 6 as the heuristics in 6 may change // and this test is to cover the same issue. Fusion fusion; @@ -297,7 +358,7 @@ TEST_F(NVFuserTest, FusionIndexing7_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing8_CUDA) { +TEST_P(AdvancedIndexingTest, 8) { // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -343,8 +404,8 @@ TEST_F(NVFuserTest, FusionIndexing8_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing9_CUDA) { - // Same as 7 but with outer splits instead of inner +// Same as 5 but using implicit broadcast +TEST_P(AdvancedIndexingTest, 9) { Fusion fusion; FusionGuard fg(&fusion); @@ -372,10 +433,11 @@ TEST_F(NVFuserTest, FusionIndexing9_CUDA) { auto cg_outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs; + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing10_CUDA) { +TEST_P(AdvancedIndexingTest, 10) { Fusion fusion; FusionGuard fg(&fusion); @@ -434,7 +496,7 @@ TEST_F(NVFuserTest, FusionIndexing10_CUDA) { NVF_CHECK(output_ref.equal(output)); } -TEST_F(NVFuserTest, FusionIndexing11_CUDA) { +TEST_P(AdvancedIndexingTest, 11) { Fusion fusion; FusionGuard fg(&fusion); @@ -484,7 +546,7 @@ TEST_F(NVFuserTest, FusionIndexing11_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing12_CUDA) { +TEST_P(AdvancedIndexingTest, 12) { Fusion fusion; FusionGuard fg(&fusion); @@ -507,7 +569,7 @@ TEST_F(NVFuserTest, FusionIndexing12_CUDA) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor aten_input = at::randn({9, 5}, options); - auto t1 = aten_input.add(1.0); + auto t1 = aten_input.to(at::kDouble).add(1.0); auto t2 = t1.add(2.0); auto t3 = t1.add(3.0); auto t4 = t3.sum(1); @@ -522,7 +584,7 @@ TEST_F(NVFuserTest, FusionIndexing12_CUDA) { &fusion, cg_outputs, {aten_input}, aten_outputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing13_CUDA) { +TEST_P(AdvancedIndexingTest, 13) { Fusion fusion; FusionGuard fg(&fusion); @@ -570,7 +632,7 @@ TEST_F(NVFuserTest, FusionIndexing13_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing14_CUDA) { +TEST_P(AdvancedIndexingTest, 14) { Fusion fusion; FusionGuard fg(&fusion); @@ -615,7 +677,7 @@ TEST_F(NVFuserTest, FusionIndexing14_CUDA) { // This excercises indexing with broadcast root axes. Non-broadcast // axes need to be preferred when propagating index exprs to root // axes. See, e.g., Index::getConsumerIndex_impl. -TEST_F(NVFuserTest, FusionIndexing15_CUDA) { +TEST_P(AdvancedIndexingTest, 15) { Fusion fusion; FusionGuard fg(&fusion); @@ -647,7 +709,7 @@ TEST_F(NVFuserTest, FusionIndexing15_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing16_CUDA) { +TEST_P(AdvancedIndexingTest, 16) { Fusion fusion; FusionGuard fg(&fusion); @@ -679,7 +741,7 @@ TEST_F(NVFuserTest, FusionIndexing16_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -TEST_F(NVFuserTest, FusionIndexing17_CUDA) { +TEST_P(AdvancedIndexingTest, 17) { Fusion fusion; FusionGuard fg(&fusion); @@ -715,7 +777,7 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { } // Repro of issue #2560 -TEST_F(NVFuserTest, FusionIndexing18_CUDA) { +TEST_P(AdvancedIndexingTest, 18) { Fusion fusion; FusionGuard fg(&fusion); @@ -753,4 +815,321 @@ TEST_F(NVFuserTest, FusionIndexing18_CUDA) { testValidate(fe.kernel(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); } +TEST_P(AdvancedIndexingTest, 19) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeConcreteTensor({5, 7, 11, 13}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = makeConcreteTensor({5, 11}); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv2, {false, true, false, true}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + // // tv4[5, 7, 11, 13] = tv3[5, b1, 11, b3] + tv1[5, 7, 11, 13] + tv4->merge(0, 3); + // tv4[5*13, 7, 11] + tv4->split(0, 3); + // tv4[5*13//3, 3, 7, 11] + tv4->merge(2, 3)->split(2, 2); + // tv4[5*13//3, 3, 7*11//2, 2] + tv4->merge(0, 2); + // tv4[(5*13//3)*(7*11//2), 3, 2] + + TransformPropagatorWithCheck propagator(tv4); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); + inlineAllAt(tv4, 1, false); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({5, 7, 11, 13}, options); + at::Tensor t1 = at::randn({5, 11}, options); + std::vector inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + +// Create a case where we're missing a valid concrete id so the compute at map +// processing will fail. We need to be able to create the concrete ID not just +// look for one. +TEST_F(AdvancedIndexingIdModelTest, 20) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({7}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = broadcast(tv1, {false, true}); + + auto tv3 = makeConcreteTensor({7, 11}); + fusion.addInput(tv3); + + auto tv4 = add(tv3, tv2); + auto tv5 = broadcast(tv4, {false, false, true}); + // tv4[7, 11, 1] + + auto tv6 = broadcast(tv1, {false, true}); + + auto tv7 = makeConcreteTensor({7, 13}); + fusion.addInput(tv7); + auto tv8 = add(tv7, tv6); + auto tv9 = broadcast(tv8, {false, true, false}); + // tv9[7, 1, 13] + + auto tv10 = add(tv5, tv9); + fusion.addOutput(tv10); + + // tv10[7, 11, 13] + tv10->merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxLogicalDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + // TODO: Finish and enable the lowering and execution. ComputeAtMap + // fails as it cannot find concrete domains. Even if IdModle is + // fully enabled, ComputeAtMap is still required at this moment. +#if 0 + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({7}, options); + at::Tensor t1 = at::randn({7, 11}, options); + at::Tensor t2 = at::randn({7, 13}, options); + std::vector inputs = {t0, t1, t2}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +#endif +} + +// Progressive loop promotion. producer gets promoted in consumer, consumer is +// promoted in a different way to its consumer. +TEST_F(AdvancedIndexingIdModelTest, 21) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({5}); + fusion.addInput(tv0); + + // [5] + auto tv1 = set(tv0); + auto tv2 = broadcast(tv1, {true, false}); + // [1, 5] + auto tv3 = makeConcreteTensor({3, 5}); + fusion.addInput(tv3); + auto tv4 = add(tv3, tv2); + // [3, 5] + + auto tv5 = broadcast(tv4, {false, false, true}); + // [3, 5, 1] + auto tv6 = makeConcreteTensor({3, 5, 7}); + fusion.addInput(tv6); + auto tv7 = add(tv5, tv6); + // [3, 5, 7] + fusion.addOutput(tv7); + + tv4->merge(0)->split(0, 2, false); + // [3, 5] + // [3, 3*5//2] + + TransformPropagatorWithCheck propagator(tv4); + MaxLogicalDomainInfoSpanningTree(tv4).traverse(&propagator); + + // tv0->tv1->tv2(b)->tv4->tv5(b)->tv7 + + tv1->inlineAt(1); + tv2->inlineAt(1); + tv4->inlineAt(1); + + // [2, 3*5//2] + tv5->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*1//4] + tv7->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*7//4] + tv5->inlineAt(2); + + // Validation not enabled yet as incorrect code is generated. Need + // to use the loop promotion info to generate correct loop-nests +#if 0 + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({5}, options); + auto t3 = at::randn({3, 5}, options); + auto t6 = at::randn({3, 5, 7}, options); + std::vector inputs = {t0, t3, t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, inputs); + auto outputs = fe.runFusion(inputs); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +#endif +} + +// Broadcast inline 3 times and merge all domains +TEST_F(AdvancedIndexingIdModelTest, MultiPromotion1) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [y] + auto tv0 = makeSymbolicTensor(1); + // [w, x, y, z] + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // y + auto tv2 = broadcast(tv0, {true, false}); + // w, y, z + auto tv3 = broadcast(tv2, {false, false, true}); + // w, y, z + auto tv4 = broadcast(tv3, {false, true, false, false}); + // w, x, y, z + auto tv5 = add(tv4, tv1); + + fusion.addOutput(tv5); + + tv5->merge(1)->merge(1)->merge(0)->split(0, 11); + + tv0->computeAt(tv5, 1); + tv1->computeAt(tv5, 1); + + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +} + +// Broadcast and concretize same domain in two different ways and try to merge +// their loops. The inlining pattern is invalid but the current +// inlining check is not capable of flagging the inlining poistion as +// invalid. The loop promotion analysis should not find any promotion +// of the loop group where all the leaf domains are merged into. +TEST_F(AdvancedIndexingIdModelTest, MultiPromotion2) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // [w, x] + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + // [w, y] + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + // [w] + auto tv4 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv5 = add(tv4, tv1); + // [w, x] + fusion.addOutput(tv5); + + // [w] + auto tv6 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv7 = add(tv6, tv2); + // [w, y] + fusion.addOutput(tv7); + + for (auto tv : std::vector{tv4, tv5, tv6, tv7}) { + tv->merge(0); + } + + // Since x and y are not proven to be the same, this inling position + // should not be allowed. + // TODO: Make this throw an error + for (auto tv : std::vector{tv3, tv4, tv6}) { + tv->inlineAt(1); + } +} + +// TODO: All the above tests are merges followed by splits, we should make some +// more complex examples even though merging then spliting is the most likely +// use case. In multi-gpu it may be the exact opposite where we split out the +// outer most iter domain to the multi-gpu dimension, then schedule. +TEST_F(AdvancedIndexingIdModelTest, IndexSplitMerge) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // [w, x] + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + auto tv2 = broadcast(tv0, {false, true}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + tv3->split(0, 3); + tv3->split(2, 4); + tv3->merge(1); + tv3->split(1, 5); + + MaxLogicalDomainInfoSpanningTree tree(tv3); + TransformPropagator tp(tv3); + tree.traverse(&tp); + + inlineAllAt(tv3, 1, true); + + int x = 4, y = 7; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({x}, options); + at::Tensor t1 = at::randn({x, y}, options); + + auto t2 = t0.unsqueeze(-1); + auto aten_output = t1.add(t2); + + std::vector aten_inputs = {t0, t1}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +INSTANTIATE_TEST_SUITE_P( + , + AdvancedIndexingTest, + testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "IdModel" : "Legacy"; + }); + } // namespace nvfuser From cd249d99792ecd9c82f0479bacb4b4552b00d527 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Sep 2024 08:32:57 -0700 Subject: [PATCH 2/4] disable idmodel validation with computeatmap --- csrc/device_lower/lower2device.cpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 940061dc461..034fdd078b5 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -417,21 +417,11 @@ void GpuLower::analysis(Fusion* fusion) { // names if (this->requiresIdModel() || isOptionEnabled(EnableOption::IdModel)) { // Enable validation in the DEBUG build mode -#ifdef NDEBUG - // Not DEBUG build id_model_ = std::make_unique( fusion_, /*build_graphs=*/true, /*allow_self_mapping=*/false, /*validate=*/false); -#else - // DEBUG build - id_model_ = std::make_unique( - fusion_, - /*build_graphs=*/true, - /*allow_self_mapping=*/false, - /*validate=*/true); -#endif id_model_->validateAndPropagatePType(); } From 090e8f80ed0b1b6d95390334e9332c51a24780b4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Sep 2024 08:42:27 -0700 Subject: [PATCH 3/4] format --- tests/cpp/test_indexing_advanced.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_indexing_advanced.cpp b/tests/cpp/test_indexing_advanced.cpp index 9677e903d84..3e62dcfec5c 100644 --- a/tests/cpp/test_indexing_advanced.cpp +++ b/tests/cpp/test_indexing_advanced.cpp @@ -203,7 +203,8 @@ TEST_P(AdvancedIndexingTest, 3) { std::vector aten_inputs = {t0, t1}; - auto cg_outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs; + auto cg_outputs = + scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs; testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } @@ -293,7 +294,7 @@ TEST_P(AdvancedIndexingTest, 6) { at::Tensor input0 = at::randn(tensor0_shape, options); at::Tensor input1 = at::randn(tensor1_shape, options); std::vector aten_inputs = {input0, input1}; - + std::vector reduction_axes{0, 1}; auto results = scheduleAndRun(&fusion, SchedulerType::Reduction, aten_inputs); @@ -848,7 +849,7 @@ TEST_P(AdvancedIndexingTest, 19) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({5, 7, 11, 13}, options); at::Tensor t1 = at::randn({5, 11}, options); - std::vector inputs = {t0, t1}; + std::vector inputs = {t0, t1}; FusionExecutor fe; fe.compileFusion(&fusion, inputs); @@ -861,6 +862,8 @@ TEST_P(AdvancedIndexingTest, 19) { // processing will fail. We need to be able to create the concrete ID not just // look for one. TEST_F(AdvancedIndexingIdModelTest, 20) { + GTEST_SKIP() << "Not supported yet"; + Fusion fusion; FusionGuard fg(&fusion); @@ -918,12 +921,14 @@ TEST_F(AdvancedIndexingIdModelTest, 20) { auto outputs = fe.runFusion(inputs); testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); -#endif +#endif } // Progressive loop promotion. producer gets promoted in consumer, consumer is // promoted in a different way to its consumer. TEST_F(AdvancedIndexingIdModelTest, 21) { + GTEST_SKIP() << "Not supported yet"; + Fusion fusion; FusionGuard fg(&fusion); @@ -966,7 +971,7 @@ TEST_F(AdvancedIndexingIdModelTest, 21) { tv7->merge(1)->split(1, 4, false); // [2, 4, (3*5//2)*7//4] tv5->inlineAt(2); - + // Validation not enabled yet as incorrect code is generated. Need // to use the loop promotion info to generate correct loop-nests #if 0 @@ -1024,8 +1029,7 @@ TEST_F(AdvancedIndexingIdModelTest, MultiPromotion1) { fe.compileFusion(&fusion, aten_inputs); auto cg_outputs = fe.runFusion(aten_inputs); - testValidate( - &fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } // Broadcast and concretize same domain in two different ways and try to merge @@ -1034,6 +1038,8 @@ TEST_F(AdvancedIndexingIdModelTest, MultiPromotion1) { // invalid. The loop promotion analysis should not find any promotion // of the loop group where all the leaf domains are merged into. TEST_F(AdvancedIndexingIdModelTest, MultiPromotion2) { + GTEST_SKIP() << "Not supported yet"; + Fusion fusion; FusionGuard fg(&fusion); // [w] @@ -1130,6 +1136,6 @@ INSTANTIATE_TEST_SUITE_P( testing::Bool(), [](const testing::TestParamInfo& info) { return info.param ? "IdModel" : "Legacy"; - }); + }); } // namespace nvfuser From eefaa184680c21e36cfaca55f0970282bf9cf1e5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 26 Sep 2024 08:50:58 -0700 Subject: [PATCH 4/4] format --- tests/cpp/test_indexing_advanced.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_indexing_advanced.cpp b/tests/cpp/test_indexing_advanced.cpp index 3e62dcfec5c..1e746b1a22c 100644 --- a/tests/cpp/test_indexing_advanced.cpp +++ b/tests/cpp/test_indexing_advanced.cpp @@ -913,8 +913,8 @@ TEST_F(AdvancedIndexingIdModelTest, 20) { auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); at::Tensor t0 = at::randn({7}, options); at::Tensor t1 = at::randn({7, 11}, options); - at::Tensor t2 = at::randn({7, 13}, options); - std::vector inputs = {t0, t1, t2}; + at::Tensor t2 = at::randn({7, 13}, options); + std::vector inputs = {t0, t1, t2}; FusionExecutor fe; fe.compileFusion(&fusion, inputs);