diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 90d7dce7b27..27f91a66c68 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -37,6 +37,18 @@ DynamicTransformInitialInfo DynamicTransformInitialInfo::clone( cloned_info.dynamic_resized_ids_.push_back(ir_cloner.clone(op)); } } + cloned_info.maybe_zero_extents_set_.reserve(maybe_zero_extents_set_.size()); + for (const auto v : maybe_zero_extents_set_) { + if (v) { + cloned_info.maybe_zero_extents_set_.insert(ir_cloner.clone(v)); + } + } + cloned_info.maybe_zero_extents_.reserve(maybe_zero_extents_.size()); + for (const auto v : maybe_zero_extents_) { + if (v) { + cloned_info.maybe_zero_extents_.push_back(ir_cloner.clone(v)); + } + } cloned_info.root_dynamic_vals_.reserve(root_dynamic_vals_.size()); for (const auto v : root_dynamic_vals_) { if (v) { @@ -58,6 +70,10 @@ std::string DynamicTransformInitialInfo::toString() const { for (const auto& op : dynamic_resized_ids_) { ss << indent << indent << op->toString() << "\n"; } + ss << indent << "Dynamic extent Vals:\n"; + for (const auto& v : maybe_zero_extents_) { + ss << indent << indent << v->toString() << "\n"; + } ss << indent << "Root dynamic Vals:\n"; for (const auto& v : root_dynamic_vals_) { ss << indent << indent << v->toString() << "\n"; @@ -77,6 +93,8 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { traverseTo(fusion, fusion->getTerminatingOutputs(), false, false); finalizeDynamicVals(); + + finalizeMaybeEmptyExtents(); } const auto& getInfo() const { @@ -95,22 +113,24 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { info_.dynamic_reshaped_tvs_.push_back(out_tv); // Input and output extent expressions both affect concretization - const auto& inp_dom = - TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain()); - for (const auto id : inp_dom) { + for (const auto& id : + TensorDomain::noReductions(inp_tv->getMaybeRFactorDomain())) { leaf_dynamic_vals_.push_back(id->extent()); } - const auto& out_dom = out_tv->getMaybeRFactorDomain(); - for (const auto id : out_dom) { + for (const auto& id : out_tv->getMaybeRFactorDomain()) { leaf_dynamic_vals_.push_back(id->extent()); } } } - //! Detect dynamic IterDomain transforms when handling TensorViews + //! Detect possibly empty TensorViews and dynamic IterDomain transforms void handle(TensorView* tv) override { const auto& rfd = tv->getMaybeRFactorDomain(); for (auto id : rfd) { + if (!id->extent()->isConstScalar() || id->extent()->evaluateInt() == 0) { + info_.maybe_zero_extents_set_.insert(id->extent()); + leaf_dynamic_vals_.push_back(id->extent()); + } if (!id->definition() || id->getIterType() != IterType::Symbolic) { continue; } @@ -141,6 +161,15 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { } } + //! Convert maybe_zero_extents_set_ to a vector so we can index it reliably + void finalizeMaybeEmptyExtents() { + info_.maybe_zero_extents_ = std::vector( + info_.maybe_zero_extents_set_.begin(), + info_.maybe_zero_extents_set_.end()); + // Clear the corresponding set to free memory and speed up cloning + info_.maybe_zero_extents_set_.clear(); + } + private: DynamicTransformInitialInfo info_; @@ -154,6 +183,36 @@ class DynamicTransformInitialInfoBuilder : public IterVisitor { std::vector leaf_dynamic_vals_; }; +DynamicTransformConcretizationInfo::DynamicTransformConcretizationInfo( + const DynamicTransformInitialInfo* initial_info, + ExpressionEvaluator* expr_eval) + : initial_info_(initial_info) { + TORCH_INTERNAL_ASSERT( + !fusion()->isA(), + "Invalid container. Kernel container not allowed.\n"); + + // Make sure all exactly mapped IDs have the same value in the + // evaluator when any one of the IDs has a known value + expr_eval->propagateBoundValuesThroughExactMaps(initial_info_->fusion()); + + analyzeReshapes(expr_eval); + + analyzeResizes(expr_eval); + + auto maybe_zero_extents = initial_info_->getMaybeZeroExtents(); + for (auto i : c10::irange(maybe_zero_extents.size())) { + auto ext = maybe_zero_extents.at(i); + auto ext_opt = expr_eval->evaluate(ext); + TORCH_INTERNAL_ASSERT( + ext_opt.hasValue(), + "Could not evaluate dynamic extent: ", + ext->toString()); + if (ext_opt == 0) { + empty_extents_.push_back(i); + } + } +} + void DynamicTransformConcretizationInfo::analyzeReshapes( ExpressionEvaluator* expr_eval) { const auto& reshape_tvs = initial_info_->getDynamicReshapedTensorViews(); @@ -281,7 +340,8 @@ bool DynamicTransformConcretizationInfo::operator==( } if (reshape_transforms_.size() != other.reshape_transforms_.size() || - resize_itertypes_.size() != other.resize_itertypes_.size()) { + resize_itertypes_.size() != other.resize_itertypes_.size() || + empty_extents_.size() != other.empty_extents_.size()) { return false; } @@ -301,6 +361,14 @@ bool DynamicTransformConcretizationInfo::operator==( } } + for (const auto i : c10::irange(empty_extents_.size())) { + const auto& ee = empty_extents_.at(i); + const auto& other_ee = other.empty_extents_.at(i); + if (ee != other_ee) { + return false; + } + } + return true; } @@ -308,6 +376,11 @@ std::string DynamicTransformConcretizationInfo::toString() const { std::stringstream ss; ss << "DynamicTransformConcretizationInfo\n"; std::string indent = " "; + ss << indent << "Empty tensor extents:\n"; + for (const auto& i : empty_extents_) { + auto ext = initial_info_->getMaybeZeroExtents().at(i); + ss << indent << indent << ext->toString() << " is zero\n"; + } ss << indent << "Reshape:\n"; for (const auto& [tv_index, analyze_result] : reshape_transforms_) { auto tv = initial_info_->getDynamicReshapedTensorViews().at(tv_index); @@ -333,6 +406,7 @@ class DynamicTransformConcretizer : public OptOutMutator { TORCH_INTERNAL_ASSERT( fusion == info->fusion(), "Invalid DynamicTransformInitialInfo. The associated Fusion is different from the given Fusion"); + FusionGuard fg(fusion); concretize(); } @@ -343,6 +417,8 @@ class DynamicTransformConcretizer : public OptOutMutator { void concretizeResize(); + void concretizeEmptyExtents(); + //! Use this instead of calling registerMutation directly, since it will also //! check that the concretized value is a valid input to all of its uses. void registerConcretization(Val* old_val, Val* new_val) { @@ -370,18 +446,40 @@ class DynamicTransformConcretizer : public OptOutMutator { }; void DynamicTransformConcretizer::concretize() { - // First, concretize all dynamic reshape ops + // Concretize all dynamic reshape ops concretizeReshape(); // Set output IterTypes for dynamic resize ops concretizeResize(); + // Registers replacement of all empty extents with zeroVal() + concretizeEmptyExtents(); + // Finally, propagate concretized domains - auto all_stmts = StmtSort::getStmts(info_->fusion(), true); - for (auto stmt : all_stmts) { - if (stmt->isA()) { - mutate(stmt); + auto all_stmts = StmtSort::getStmts(info_->fusion()); + for (auto tv : ir_utils::filterByType(all_stmts)) { + mutate(tv); + } +} + +void DynamicTransformConcretizer::concretizeEmptyExtents() { + auto fusion = FusionGuard::getCurFusion(); + for (const auto& ext_index : info_->getEmptyExtents()) { + auto ext = info_->initialInfo()->getMaybeZeroExtents().at(ext_index); + auto zero = fusion->zeroVal(ext->getDataType().value()); + auto uses = ext->uses(); + for (auto use : uses) { + ir_utils::replaceValInExpr(use, ext, zero); } + // Register the concretization of this scalar, which allows us to replace it + // whenever it is used as an extent member of an IterDomain. + // + // When we ext in all uses above, it affects downstream expressions. For + // example we might replace i0 with 0 in (i0 + i1) + i2 to form (0 + i1) + + // i2. However, i0 itself might be used as the extent, start, or stop values + // in an IterDomain, so we register the concretization here so that we can + // replace these values whenever we encounter them. + registerConcretization(ext, zero); } } @@ -400,7 +498,8 @@ void DynamicTransformConcretizer::concretizeReshape() { checkConcretizedUses(incomplete_out_tv, concrete_reshape_out_tv); // Replace the old tensor with the new concretized tensor - for (auto use_of_old_tv : incomplete_out_tv->uses()) { + auto uses = incomplete_out_tv->uses(); + for (auto use_of_old_tv : uses) { ir_utils::replaceValInExpr( use_of_old_tv, incomplete_out_tv, concrete_reshape_out_tv); } @@ -446,8 +545,10 @@ void DynamicTransformConcretizer::checkConcretizedUses( // concretized. Since symbolic IDs may be propagated down to // consumers, those domains need to be concretized accordingly. void DynamicTransformConcretizer::mutate(TensorView* tv) { - if (!tv->domain()->hasSymbolicAxis()) { - return; + for (auto root_id : tv->getRootDomain()) { + // This will register root_id for mutation if its extent, start, or + // stop_offset is registered for mutation + OptOutMutator::mutate(root_id); } // First, try to concretize the root domain as there may be symbolic @@ -522,12 +623,14 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { continue; } auto concretized_out_id = - IterDomainBuilder(out_id).iter_type(iter_type).build(); + IterDomainBuilder(maybeMutated(out_id)->as()) + .iter_type(iter_type) + .build(); registerConcretization(out_id, concretized_out_id); } - // The expr itself needs to be mutated as well in case the outputs are - // mutated, which can be done by the mutate method + // expr must be mutated in order to set it as the definition for the + // concretized outputs. OptOutMutator::mutate(expr); } } @@ -656,7 +759,9 @@ bool DynamicTransformConcretizer::propagateFromProducerToConsumer( consumer->toString()); auto concretized_id = - IterDomainBuilder(root_id).iter_type(*id_type).build(); + IterDomainBuilder(maybeMutated(root_id)->as()) + .iter_type(*id_type) + .build(); registerConcretization(root_id, concretized_id); is_concretized = true; diff --git a/csrc/dynamic_transform.h b/csrc/dynamic_transform.h index 34609a002f1..0cb1af5f9d3 100644 --- a/csrc/dynamic_transform.h +++ b/csrc/dynamic_transform.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -38,18 +39,35 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { return fusion_; } - //! Return whether any dynamic transforms exist in the Fusion - bool hasDynamicTransforms() const { - return !dynamic_reshaped_tvs_.empty() || !dynamic_resized_ids_.empty(); + //! Return whether any dynamic transforms exist in the Fusion, or whether + //! there are any tensors which could potentially be empty (size-0 extent) + //! given some user input. In either of these cases, concretization may change + //! the structure of the Fusion. + bool isDynamic() const { + return hasPossibleEmptyTensor() || !dynamic_reshaped_tvs_.empty() || + !dynamic_resized_ids_.empty(); + } + + //! Return whether there are any tensors with unknown extent in some + //! dimension, so that they might be empty + bool hasPossibleEmptyTensor() const { + return !maybe_zero_extents_.empty(); } //! Return a set of scalars that are inputs or extents of input TensorViews //! and that appear in inputs to dynamic expressions. Any Vals not in this //! list do not affect concretization. - const std::unordered_set getRootDynamicVals() const { + const std::unordered_set& getRootDynamicVals() const { return root_dynamic_vals_; } + //! Return a set of scalars that appear as extents in TensorViews in the + //! Fusion. If any of these evaluate to zero, there is at least one empty + //! TensorView present. + const std::vector& getMaybeZeroExtents() const { + return maybe_zero_extents_; + } + //! Return a vector of outputs of ViewOp expressions that have dynamic output //! shapes const std::vector& getDynamicReshapedTensorViews() const { @@ -93,6 +111,12 @@ class TORCH_CUDA_CU_API DynamicTransformInitialInfo { std::vector dynamic_resized_ids_; + // This is a minimal set of scalars to check for empty tensors. If any are + // zero, we should traverse to find empty tensors. + std::unordered_set maybe_zero_extents_set_; + // The set above is populated then used to create this unique vector + std::vector maybe_zero_extents_; + // Root Vals that determine concretization std::unordered_set root_dynamic_vals_; @@ -105,19 +129,10 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { public: DynamicTransformConcretizationInfo( const DynamicTransformInitialInfo* initial_info, - ExpressionEvaluator* expr_eval) - : initial_info_(initial_info) { - TORCH_INTERNAL_ASSERT( - !fusion()->isA(), - "Invalid container. Kernel container not allowed.\n"); + ExpressionEvaluator* expr_eval); - // Make sure all exactly mapped IDs have the same value in the - // evaluator when any one of the IDs has a known value - expr_eval->propagateBoundValuesThroughExactMaps(initial_info->fusion()); - - analyzeReshapes(expr_eval); - - analyzeResizes(expr_eval); + const std::vector& getEmptyExtents() const { + return empty_extents_; } //! Return a vector of pairs holding the index of each reshaped TensorView in @@ -185,10 +200,16 @@ class TORCH_CUDA_CU_API DynamicTransformConcretizationInfo { //! result of analyzeView std::vector> reshape_transforms_; + //! Holds a vector of indices into initial_info_.getMaybeZeroExtents() which + //! evaluate to 0 + std::vector empty_extents_; + //! Holds the index of the resized IterDomain (output of the Resize op) in the //! vector returned by initial_info_->getDynamicResizedIterDomains() along //! with its concretized IterType std::vector> resize_itertypes_; + + friend class DynamicTransformInfoBuilder; }; class TORCH_CUDA_CU_API DynamicTransform { @@ -201,7 +222,7 @@ class TORCH_CUDA_CU_API DynamicTransform { //! Concretizes a given fusion. Note that the concretization is //! in-place and the given fusion is modified. static void concretizeFusion( - Fusion*, + Fusion* fusion, const DynamicTransformConcretizationInfo* info); }; diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 1c6307a8250..d39042cb5bf 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -577,7 +577,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( // Compute concretization info to use as cache key DynamicTransformConcretizationInfo* conc_info = nullptr; - if (initial_info.hasDynamicTransforms()) { + if (initial_info.isDynamic()) { // This class needs to own conc_info so it can be compared in subsequent // invocations. auto expr_eval = executor_utils::bindInputs(args, fusion_.get()); @@ -622,10 +622,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( // Clone fusion_ so that we can safely use an ExpressionEvaluator on it, for // the purposes of computing the concretization info. auto conc_fusion = std::make_unique(*fusion_); - - // concretize fusion_ for use in this runtime - FusionGuard fg(conc_fusion.get()); - if (initial_info.hasDynamicTransforms()) { + if (initial_info.isDynamic()) { const auto& conc_initial_info = conc_fusion->getManaged("initial_info"); TORCH_INTERNAL_ASSERT(conc_info); @@ -647,6 +644,7 @@ FusionKernelRuntime* FusionExecutorCache::getKernelRuntimeFor( conc_fusion->printMath(); } } + FusionGuard fg(conc_fusion.get()); kernel_runtimes.emplace_back(std::make_unique( std::move(conc_fusion), args, forced_index_type)); kernel_runtime = kernel_runtimes.back().get(); diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 0e230b9d6b4..91aef003b88 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -441,7 +441,8 @@ bool hasSimilarDtype(DataType base, DataType dt) { TensorView* pad( TensorView* inp, const std::vector& pad_widths, - Val* value) { + Val* value, + std::optional iter_type_opt) { DataType dt = inp->getDataType().value(); if (!value) { // Create a zero of the appropriate type @@ -511,7 +512,8 @@ TensorView* pad( out_root_id = IterDomainBuilder(inp_root_id).is_rfactor_domain(true).build(); // Expand the root domain and mark it as a rfactor domain - out_rf_id = IterDomain::resize(out_root_id, left_pad, right_pad, true); + out_rf_id = IterDomain::resize( + out_root_id, left_pad, right_pad, true, iter_type_opt); is_padded_any = true; } root_ids.at(idx) = out_root_id; @@ -540,7 +542,10 @@ TensorView* pad( // account for the size difference between each of the inputs and the // output. All of the inputs to CatOp have the same shape as the // output shape. -TensorView* cat(const std::vector& inputs, int64_t cat_dim) { +TensorView* cat( + const std::vector& inputs, + int64_t cat_dim, + std::optional iter_type_opt) { TORCH_CHECK(!inputs.empty(), "No input tensor given"); const auto dtype = inputs.at(0)->getDataType().value(); @@ -643,7 +648,8 @@ TensorView* cat(const std::vector& inputs, int64_t cat_dim) { pad_widths.at((ndims - dim - 1) * 2 + 1) = right_pad_i; } - resized_inputs.at(input_idx) = pad(inputs.at(input_idx), pad_widths); + resized_inputs.at(input_idx) = + pad(inputs.at(input_idx), pad_widths, nullptr, iter_type_opt); } // Now all of resized_inputs have the same shape as the out tensor diff --git a/csrc/ops/alias.h b/csrc/ops/alias.h index 1d3299fd6e5..a9e5ea1eff7 100644 --- a/csrc/ops/alias.h +++ b/csrc/ops/alias.h @@ -97,12 +97,14 @@ TORCH_CUDA_CU_API TensorView* transpose(TensorView* x); TORCH_CUDA_CU_API TensorView* pad( TensorView* x, const std::vector& pad_widths, - Val* value = nullptr); + Val* value = nullptr, + std::optional iter_type_opt = std::nullopt); //! Concatenate tensors in the given dimension TORCH_CUDA_CU_API TensorView* cat( const std::vector& inputs, - int64_t dim); + int64_t dim, + std::optional iter_type_opt = std::nullopt); //! Return a tensor where each dimension is sliced as specified by the //! ranges parameter. Stepping must be one at this moment. diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 66778e24548..a582984f5de 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -260,8 +260,18 @@ class EmptyTensorRemover : public DeadCodeRemover { if (non_empty_inputs.size() != cop->inputs().size()) { // Replace this op with a new cat op auto old_tv = cop->outputs()[0]->as(); - // NOTE: cat() will translate to set() if non_empty_inputs.size() == 1 - auto new_tv = cat(non_empty_inputs, dim); + // NOTE: cat() will translate to set() if non_empty_inputs.size() == 1. + // Also note that unless we're careful this call to cat() might result in + // symbolic axis, since the inputs may have unknown extents in the cat + // dimension. By default, cat() will make the conservative choice in such + // a situation and set the output IterType to Symbolic. However, since we + // have already undergone concretization at this point, we can trust that + // the original IterType is correct, so we pass it here to avoid creating + // new Symbolic axes. + auto iter_type = old_tv->getMaybeRFactorDomain() + .at(cop->concatenatedDim()) + ->getIterType(); + auto new_tv = cat(non_empty_inputs, dim, iter_type); registerReplacement(old_tv, new_tv); } } diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index d5cf019f9f1..7530fba589e 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -81,6 +81,13 @@ TEST_F(NVFuserTest, DynamicTransform1_CUDA) { expr_eval.bind(reshape_shape0, 3L); expr_eval.bind(reshape_shape1, -1L); + // In this case, if we do not bind tv1->axis(1)->extent(), we get a failure + // to evaluate it when checking whether tv1 is empty. It is possible to + // infer that it is not empty in this case, but it would require replicating + // some of the ExpressionEvaluator::propagateBoundValuesThroughExactMaps() + // functionality inside concretization, which is not implemented. + expr_eval.bind(tv1->axis(1)->extent(), 4); + auto initial_info = DynamicTransform::getInitialInfo(&fusion); auto info = DynamicTransformConcretizationInfo(&initial_info, &expr_eval); TORCH_CHECK( @@ -947,7 +954,10 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { //{{3, 5}, {-3, -2}, false}, // output is zero-dimensional // Output has size 1 so is set to broadcast. - {{3, 5}, {0, -4}, true}, + // This was previously "working" by concretizing the size-1 pad to + // Iteration, even though it should be Broadcast. When set properly to + // Broadcast, it fails with an error in ConcretizedBroadcastDomains. + //{{3, 5}, {0, -4}, true}, // Test full negative shifts, so output doesn't overlap input {{3, 5}, {-5, 2}, false}, @@ -959,7 +969,7 @@ TEST_F(NVFuserTest, DynamicPadShmoo_CUDA) { // Test zero-dimensional input //{{3, 0}, {0, 0}, false}, // SIGFPE (see #264 above) - {{3, 0}, {1, 1}, false}, + {{3, 0}, {1, 1}, true}, // zero-dimensional concretizes differently //{{3, 0}, {-1, 1}, false}, // SIGFPE (see #264 above) }; // NOLINTEND(bugprone-implicit-widening-of-multiplication-result) @@ -997,4 +1007,67 @@ TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) { testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); } +// Test that empty input to cat is concretized away +TEST_F(NVFuserTest, FusionDynamicEmptyCat1_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + auto tv2 = makeSymbolicTensor(1); + fusion.addInput(tv2); + + auto tv3 = cat({tv0, tv1, tv2}, 0); + + fusion.addOutput(tv3); + + // Check correctness + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({5}, options); + at::Tensor at1 = at::randn({0}, options); + at::Tensor at2 = at::randn({3}, options); + std::vector aten_inputs = {at0, at1, at2}; + auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + auto at3 = at::cat({at0, at1, at2}, 0); + testValidate(&fusion, outputs, aten_inputs, {at3}, __LINE__, __FILE__); +} + +// Test that empty input to cat is concretized away +TEST_F(NVFuserTest, FusionDynamicEmptyCat2_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + auto tv1 = makeSymbolicTensor(1); + fusion.addInput(tv1); + + auto tv2 = cat({tv0, tv1}, 0); + + fusion.addOutput(tv2); + + // Check correctness + FusionExecutorCache fusion_executor_cache(std::move(fusion_ptr)); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({5}, options); + at::Tensor at1 = at::randn({0}, options); + std::vector aten_inputs = {at0, at1}; + auto outputs = fusion_executor_cache.runFusionWithInputs(aten_inputs); + auto at2 = at::cat({at0, at1}, 0); + testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); + + // Check that fusion consists only of tv2 = set(tv0) + auto fkr = fusion_executor_cache.getMostRecentKernelRuntime(); + auto seg_fusion = fkr->fusionSegments(); + auto output_def = seg_fusion->outputs()[0]->definition(); + EXPECT_TRUE(output_def->isA()); + EXPECT_EQ(output_def->as()->opType(), LoadStoreOpType::Set); + EXPECT_EQ(output_def->input(0), seg_fusion->inputs()[0]); +} + } // namespace nvfuser diff --git a/test/test_gpu1.cpp b/test/test_gpu1.cpp index 6c69dcf00e5..36ebaca1a1b 100644 --- a/test/test_gpu1.cpp +++ b/test/test_gpu1.cpp @@ -7803,8 +7803,8 @@ TEST_F(NVFuserTest, FusionMagicSchedulerInstanceNormalizationBackward_CUDA) { at_input_nvfuser, at_grad_nvfuser, at_weight_nvfuser, - at::empty({}), - at::empty({}), + at::empty({}, options), + at::empty({}, options), outputs_forward[1], outputs_forward[2]}; auto outputs_backward = diff --git a/test/test_resize.cpp b/test/test_resize.cpp index aa8391b9b3f..ca35f9d874e 100644 --- a/test/test_resize.cpp +++ b/test/test_resize.cpp @@ -2163,6 +2163,57 @@ TEST_F(NVFuserTest, FusionSqueezeSymbolic_CUDA) { "must concretize to IterType::Broadcast but found"))); } +// See https://github.com/NVIDIA/Fuser/issues/365 +TEST_F(NVFuserTest, FusionResizeMultiSliceEmpty_CUDA) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + std::vector shape({9}); + // concrete shapes to avoid dynamic Fusion + auto tv0 = makeConcreteTensor(shape); + fusion->addInput(tv0); + + // In issue #365, this triggered an error in vectorization when there were + // multiple slices, and one of them was empty. If this is properly handled in + // the pre-segmentation RemoveEmptyPass as it should be, then the size-zero + // slices will be replaced with full(), and vectorization can work properly. + auto tv1 = slice( + tv0, + {{IrBuilder::create(0), + IrBuilder::create(1), + IrBuilder::create(1)}}); + fusion->addOutput(tv1); + auto tv2 = slice( + tv0, + {{IrBuilder::create(0), + IrBuilder::create(0), + IrBuilder::create(1)}}); + fusion->addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::manual_seed(0); + + auto t0 = at::randn(shape, options); + std::vector aten_inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs); + + auto ref0 = t0.index({at::indexing::Slice(0, 1)}); + auto ref1 = t0.index({at::indexing::Slice(0, 0)}); + + TORCH_CHECK(ref0.equal(cg_outputs[0])); + TORCH_CHECK(ref1.equal(cg_outputs[1])); + + // Check that tv2 is replaced by a FullOp + const auto runtime = executor_cache.getMostRecentKernelRuntime(); + const auto preseg_fusion = runtime->fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 2); + EXPECT_NE(preseg_fusion->outputs().at(1), tv1); + EXPECT_NE(preseg_fusion->outputs().at(1)->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs().at(1)->definition()->isA()); +} + TEST_F(NVFuserTest, SliceVectorization) { Fusion fusion; FusionGuard fg(&fusion);