diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 0206ab7f885..6871e6a592d 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -390,7 +390,7 @@ void GpuLower::analysis(Fusion* fusion) { // functionality should be affected. New IterDomains may be created, // so it is expected that generated code may use diffrent variable // names - if (isOptionEnabled(EnableOption::IdModel)) { + if (true || isOptionEnabled(EnableOption::IdModel)) { IdModel id_model(fusion_); } diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 618f768fff1..2afa9cecdc2 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -81,6 +81,15 @@ class VectorOfUniqueEntries { return false; } + // Returns if a node was actually added + bool pushFront(T entry) { + if (set_.emplace(entry).second) { + vector_.insert(vector_.begin(), entry); + return true; + } + return false; + } + // Returns true if any node was added bool pushBack(const VectorOfUniqueEntries& other) { return pushBack(other.vector()); @@ -170,6 +179,14 @@ class VectorOfUniqueEntries { return v; } + // Remove and returns the last element in vector + T popFront() { + T v = vector_.front(); + set_.erase(v); + vector_.erase(vector_.begin()); + return v; + } + // Returns if this container is empty bool empty() const { return vector_.empty(); @@ -394,7 +411,9 @@ class DisjointSets { entry_it != disjointSetMap().end(), "Strict mapping failed on element: ", abstractToString(entry0), - " either an error occurred, or non strict mapping should have been used."); + " either an error occurred, or non strict mapping should have been used.", + " ", + entry0->name()); return entry_it->second->has(entry1); } diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index cec856faf64..378ec142dd5 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -233,6 +234,203 @@ std::string IdModel::toString() const { return ss.str(); } +// Generate a new expr with the IterDomain inputs/outputs replaced based on map. +// Replaced inputs/outputs should almost exact match with provided expr. +Expr* IdModel::addExprWithReplacement( + const std::unordered_map& old_2_new_ids, + Expr* old_expr) { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointValSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + + // We will fill this map for every IterDomain in input and output. + std::unordered_map replacement_map = old_2_new_ids; + + // Validate replacement map. Make sure the keys are an input or output + for (auto replacement_entry : replacement_map) { + NVF_ERROR( + std::find( + old_expr->inputs().begin(), + old_expr->inputs().end(), + replacement_entry.first) != old_expr->inputs().end() || + std::find( + old_expr->outputs().begin(), + old_expr->outputs().end(), + replacement_entry.first) != old_expr->outputs().end(), + "Wanted to replace ", + replacement_entry.first->toString(), + " however the is not an input or output of:\n", + old_expr->toString()); + } + + // If all inputs and or all output were replaced + bool all_inps_replaced = true; + bool all_outs_replaced = true; + { + for (auto inp_id : ir_utils::filterByType(old_expr->inputs())) { + if (replacement_map.find(inp_id) == replacement_map.end()) { + all_inps_replaced = false; + replacement_map[inp_id] = inp_id->cloneWithoutRFactor(); + } + } + + for (auto out_id : + ir_utils::filterByType(old_expr->outputs())) { + if (replacement_map.find(out_id) == replacement_map.end()) { + all_outs_replaced = false; + replacement_map[out_id] = out_id->cloneWithoutRFactor(); + } + } + + NVF_ERROR( + (all_inps_replaced || all_outs_replaced), + "Either all the inputs or all the outputs need to be replaced when using this function."); + + for (auto mode : initialized_modes) { + for (auto inp_or_out_id : all_inps_replaced + ? ir_utils::filterByType(old_expr->inputs()) + : ir_utils::filterByType(old_expr->outputs())) { + NVF_ERROR( + idGraph(mode).hasGroup(inp_or_out_id), + "Expected ", + inp_or_out_id->toString(), + " to be initialized in graph mode: ", + mode); + } + } + } + + // Create the new expression with provided outputs + auto replay = ReplacementTransformCloner::clone(replacement_map, old_expr); + + // Add new output iter domains to id_definitions_/id_uses_ of IdModel + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + id_definitions_[out_id].pushBack(replay); + id_uses_[out_id]; + } + + // Add new input iter domains to id_definitions_/id_uses_ of IdModel + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + id_definitions_[inp_id]; + id_uses_[inp_id].pushBack(replay); + } + + // Update all the initialized graph mappings + for (auto mode : initialized_modes) { + auto& graph = idGraph(mode); + + graph.registerExpr(replay); + auto replay_group = graph.toGroup(replay); + + // Initialize any non-existent input ids, update existing ones + for (auto inp_id : ir_utils::filterByType(replay->inputs())) { + if (!graph.disjointValSets().mappingExists(inp_id)) { + // inp_id is not initialized in the map, initialize it + graph.initializeVal(inp_id, {}, {replay}); + } else { + // Update unique uses of existing input ids + auto inp_group = graph.toGroup(inp_id); + graph.addUniqueUses(inp_group, replay_group); + } + } + + // Initialize any non-existent output ids, update existing ones + for (auto out_id : ir_utils::filterByType(replay->outputs())) { + if (!graph.disjointValSets().mappingExists(out_id)) { + // out_id is not initialized in the map, initialize it + graph.initializeVal(out_id, {replay}, {}); + } else { + // out_id is already initialized, add the replay as a unique definition + // of its group + auto out_group = graph.toGroup(out_id); + graph.addUniqueDefinitions(out_group, replay_group); + } + } + + // If the inputs were replaced we want to map through forward the newly + // added expression. If the outputs were replaced we want to map through + // backwards the newly added expression. + + // Forward + VectorOfUniqueEntries representative_uses; + for (auto in : ir_utils::filterByType(replay->inputs())) { + for (const ExprGroup& use_group : graph.getUses(graph.toGroup(in))) { + if (use_group == replay_group) { + continue; + } + representative_uses.pushBack(use_group->front()); + } + } + + for (auto rep_use : representative_uses) { + graph.maybeMapThroughExprs(rep_use, replay, true); + } + + // Backwards + VectorOfUniqueEntries representative_defs; + for (auto out : ir_utils::filterByType(replay->outputs())) { + for (const ExprGroup& def_group : + graph.getDefinitions(graph.toGroup(out))) { + if (def_group == replay_group) { + continue; + } + representative_defs.pushBack(def_group->front()); + } + } + + for (auto rep_def : representative_defs) { + graph.maybeMapThroughExprs(rep_def, replay, false); + } + } + return replay; +} + +// Clone provided iter domain and return the new copy. Map that copy in relevant +// maps. +IterDomain* IdModel::cloneIterDomain(IterDomain* id) { + // Figure out which graphs are already initialized to make sure we add the new + // expression to them. + std::vector initialized_modes; + for (auto mode : kIdMappingModes) { + auto graph_it = id_graphs_.find(mode); + if (graph_it == id_graphs_.end()) { + continue; + } + + auto& graph = graph_it->second; + if (graph.disjointValSets().disjointSetMap().empty()) { + continue; + } + + initialized_modes.push_back(mode); + } + + auto id_copy = id->cloneWithoutRFactor(); + + id_uses_[id_copy] = {}; + id_definitions_[id_copy] = {}; + + for (auto mode : initialized_modes) { + idGraph(mode).initializeVal(id_copy, {}, {}); + idGraph(mode).mapVals(id, id_copy); + } + + return id_copy; +} + ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) const { ValGraph id_graph(propagate_through_exprs); @@ -603,6 +801,12 @@ void IdModel::buildLoopGraph() { maybeBuildGraph(IdMappingMode::EXACT); maybeBuildGraph(IdMappingMode::PERMISSIVE); + if (!tv_exprs_.empty()) { + std::stringstream ss; + tv_exprs_.at(0)->fusion()->print(ss); + VERBOSE() << ss.str(); + } + const StatefulInliningInfo inlining_info = buildStatefulInliningInfo(tv_exprs_, idGraph(IdMappingMode::PERMISSIVE)); @@ -610,6 +814,12 @@ void IdModel::buildLoopGraph() { validateLoopGraphHasNoSelfMappedLeafDomains(); + VERBOSE() << "Initial loop graph:\n"; + for (const auto& group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + VERBOSE() << nvfuser::toString(group) << std::endl; + } + loop_promotion_map_ = LoopPromotionMapBuilder::get( *this, inlining_info, loop_promotion_map_builder_callback_); @@ -620,7 +830,55 @@ void IdModel::buildLoopGraph() { idGraph(IdMappingMode::LOOP).validateConsistency(); } +// TODO: Reenable after reenabling parallel propagation. +// propagateLoopPTypes +void IdModel::validatePTypes(const std::vector& all_tvs) const { + // VectorOfUniqueEntries leaf_ids; + // for (auto tv : all_tvs) { + // leaf_ids.pushBack(tv->domain()->leaf()); + // } + + // for (const auto& disjoint_set : + // idGraph(IdMappingMode::EXACT).disjointValSets().disjointSets()) { + // for (auto id : disjoint_set->vector()) { + // auto id_ptype = id->getParallelType(); + + // NVF_ERROR( + // leaf_ids.has(id) || id_ptype == ParallelType::Serial, + // "Invalid parallelization of non leaf iter domain: ", + // id->toString()); + // } + // } +} + +void IdModel::propagateLoopPTypes() const { + for (const auto& loop_disjoint_set : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + ParallelType common_ptype = ParallelType::Serial; + for (auto id : loop_disjoint_set->vector()) { + auto id_ptype = id->as()->getParallelType(); + + NVF_ERROR( + id_ptype == common_ptype || id_ptype == ParallelType::Serial || + common_ptype == ParallelType::Serial, + "Issue validating parallel type disjoint ptype is, ", + common_ptype, + " but found in the set the id: ", + id->toString()); + + common_ptype = + common_ptype == ParallelType::Serial ? id_ptype : common_ptype; + } + + for (auto id : loop_disjoint_set->vector()) { + id->as()->parallelize(common_ptype); + } + } +} + void IdModel::buildAllGraphs() { + VERBOSE() << "*** Building all graphs ***"; + if (tvs_.empty()) { return; } @@ -663,6 +921,11 @@ void IdModel::buildAllGraphs() { idGraph(IdMappingMode::PERMISSIVE)); } + // Permissive graph needs the trivial exprs from the almost exact graph to + // build correctly. Once built though we can remove the trivial expressions + // from the almost exact graph. + idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); + buildLoopGraph(); } @@ -870,4 +1133,12 @@ std::unordered_map updateValGroupIdMap( return new_map; } +std::unordered_map IdModel::buildIndexGraph( + const std::vector& exprs, + const std::vector& all_tvs, + StatefulInliningInfo& info, + std::unordered_map stale_promotion_map) { + NVF_ERROR(false, "Not implemented yet."); +} + } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 5b60ff474c6..b459b79d5de 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -216,6 +216,36 @@ class IdModel : public PolymorphicBase { std::unordered_map buildLoopPromotionMap( const StatefulInliningInfo& info); + // Make sure only leaf nodes of tensor views are parallelized + void validatePTypes(const std::vector& all_tvs) const; + + //! Run through disjoint sets in the LOOP map, make sure there's only one + //! non-serial parallel type in each disjoint set, set the parallel type of + //! all IterDomains in the disjoint set to that PType. + void propagateLoopPTypes() const; + + // !! END Helper functions to build loop promotion and index map!! + + // Builds idGraph(IdMappingMode::INDEX) and returns the iter domain promotion + // map to go from leaf domains of each (consumer only?) tensor to their + // corresponding leaf domain in the index graph. + std::unordered_map buildIndexGraph( + const std::vector& exprs, + const std::vector& all_tvs, + StatefulInliningInfo& info, + std::unordered_map stale_promotion_map); + + // Returns the terminal rfactor or input iter domains each group in the almost + // exact map covers (in the almost exact map). This effectively returns all + // the input almost exact iter domain groups for each almost exact iter domain + // group. RFactor axes are considered an "input" as all broadcast dimensions + // have to be resolved by or before the rfactor iter domain. + std::unordered_map buildCoveredAlmostExact(); + + // TODO: + // Update the LOOP ID disjoint sets with resolved computeWith + void updateComputeWith(TensorView* compute_with_tv); + // Errors if self mapping occurs void assertNoSelfMapping(); @@ -224,6 +254,34 @@ class IdModel : public PolymorphicBase { // tensor. void validateLoopGraphHasNoSelfMappedLeafDomains() const; + // Similar to addReplayAs, but clones the expr exactly instead of replaying it + // forward. It's up to the calling code to make sure the replacements are + // valid for the provided expr. It's generally recommended that the + // IterDomains exactly match those in the expr. + // + // "forward" dictates the same argument for mapThroughExpr. If forward the + // function will apply mapThroughExpr forward if inputs map in each + // initialized map. Else does the same but backwards through the expression + // from outputs. + Expr* addExprWithReplacement( + const std::unordered_map& old_2_new_ids, + Expr* old_expr); + + // Make a new expr matching that provided but using the outputs provided. + // IterDomainGraphss will be updated for all maps that have entries. Adding + // the input iter domains of the replayed expression and adding potential + // mappings through the expressions. Input domains will match exactly in all + // properties as those in expr. This is unlike addReplayAs which will produce + // new outputs using transformations directly. + Expr* addBackwardsReplayAs( + const std::vector& new_outputs, + Expr* expr); + + // Make an exact copy of provided IterDomain (without rfactor set), and map + // the copy to the original in all registered IdModel. IterDomain copy will + // not have any registered uses or definitions. + IterDomain* cloneIterDomain(IterDomain* id); + protected: // All tensor expressions that this model analyzes std::vector tv_exprs_; @@ -268,6 +326,9 @@ class IdModel : public PolymorphicBase { std::unordered_set view_rfactor_ids_; + // Loop promotion map for inlined root broadcast domains + std::unordered_map iel_root_promotion_map_; + // Promotion domain for each loop group std::unordered_map loop_promotion_map_; }; diff --git a/csrc/id_model/transform_replay.cpp b/csrc/id_model/transform_replay.cpp index c9ec2e981ac..048ef0f94a2 100644 --- a/csrc/id_model/transform_replay.cpp +++ b/csrc/id_model/transform_replay.cpp @@ -82,4 +82,95 @@ void ReplayTransform::handle(const Resize* resize) { ->definition(); } +Expr* ReplacementTransformCloner::clone( + const std::unordered_map& + provided_expr_val_2_replacement_val, + const Expr* expression_to_match) { + ReplacementTransformCloner replay( + provided_expr_val_2_replacement_val, expression_to_match); + return replay.new_expr_; +} + +ReplacementTransformCloner::ReplacementTransformCloner( + const std::unordered_map& + provided_expr_val_2_replacement_val, + const Expr* expression_to_match) + : provided_expr_val_2_replacement_val_( + provided_expr_val_2_replacement_val) { + OptOutConstDispatch::dispatch(expression_to_match); +} + +IterDomain* ReplacementTransformCloner::replaceOrClone(IterDomain* id) { + if (provided_expr_val_2_replacement_val_.find(id) != + provided_expr_val_2_replacement_val_.end()) { + return provided_expr_val_2_replacement_val_.at(id); + } + return id->cloneWithoutRFactor(); +} + +// We're going to replay this split operation on the corresponding ID +void ReplacementTransformCloner::handle(const Split* split) { + // Replace or clone + + auto split_in = replaceOrClone(split->in()); + auto split_outer = replaceOrClone(split->outer()); + auto split_inner = replaceOrClone(split->inner()); + + // TODO: Should we check inner/outer matches the factor if + // innerSplit()/!innerSplit()? + + new_expr_ = IrBuilder::create( + split_outer, + split_inner, + split_in, + split->factor(), + split->innerSplit(), + split->startOffset(), + split->stopOffset()); +} + +// We're going to replay this merge operation on the corresponding IDs +void ReplacementTransformCloner::handle(const Merge* merge) { + // Replace or clone + auto merge_outer = replaceOrClone(merge->outer()); + auto merge_inner = replaceOrClone(merge->inner()); + auto merge_out = replaceOrClone(merge->out()); + new_expr_ = IrBuilder::create(merge_out, merge_outer, merge_inner); +} + +// We're going to replay this swizzle operation on the corresponding IDs +// if replaying swizzle is enabled. +void ReplacementTransformCloner::handle(const Swizzle2D* swizzle_2d) { + // Replace or clone + auto swizzle_inx = replaceOrClone(swizzle_2d->inX()); + auto swizzle_iny = replaceOrClone(swizzle_2d->inY()); + auto swizzle_outx = replaceOrClone(swizzle_2d->outX()); + auto swizzle_outy = replaceOrClone(swizzle_2d->outY()); + + new_expr_ = IrBuilder::create( + swizzle_outx, + swizzle_outy, + swizzle_inx, + swizzle_iny, + swizzle_2d->swizzleType(), + swizzle_2d->swizzleMode()); +} + +void ReplacementTransformCloner::handle(const Resize* resize) { + auto resize_in = resize->in(); + resize_in = provided_expr_val_2_replacement_val_.find(resize_in) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(resize_in) + : resize_in->cloneWithoutRFactor(); + + auto resize_out = resize->out(); + resize_out = provided_expr_val_2_replacement_val_.find(resize_out) != + provided_expr_val_2_replacement_val_.end() + ? provided_expr_val_2_replacement_val_.at(resize_out) + : resize_out->cloneWithoutRFactor(); + + new_expr_ = IrBuilder::create( + resize_out, resize_in, resize->leftExpand(), resize->rightExpand()); +} + } // namespace nvfuser diff --git a/csrc/id_model/transform_replay.h b/csrc/id_model/transform_replay.h index 531dcc9729d..eab671c00a1 100644 --- a/csrc/id_model/transform_replay.h +++ b/csrc/id_model/transform_replay.h @@ -55,4 +55,50 @@ class ReplayTransform : OptInConstDispatch { const std::vector& input_ids_; }; +class ReplacementTransformCloner : OptInConstDispatch { + public: + // Generates a copy of expression_to_match with inputs and/or outputs replaced + // by entries provided in the map. Inputs and outputs are expected to be + // "clones". Not literally, but it's up to the envoking code to make the + // input/output replacements are safe to use in the cloned expression. No + // validation is done on provided inputs/outputs. + // + // In other words a split i0{I0}->i1{I0//2}, i2{2} with a map: + // i2{2} -> i3{48} wouldn't throw an error, but would not be valid. + static Expr* clone( + const std::unordered_map& + provided_expr_val_2_replacement_val, + const Expr* expression_to_match); + + private: + ReplacementTransformCloner( + const std::unordered_map& + expr_to_match_2_replacement, + const Expr* expression_to_match); + + using OptInConstDispatch::handle; + + // Returns entry in provided_expr_val_2_replacement_val_ if exists otherwise + // returns a clone of the provided iter domain. + IterDomain* replaceOrClone(IterDomain* id); + + // We're going to replay this split operation on the corresponding ID + void handle(const Split* split) override; + + // We're going to replay this merge operation on the corresponding IDs + void handle(const Merge* merge) override; + + // We're going to replay this swizzle operation on the corresponding IDs + // if replaying swizzle is enabled. + void handle(const Swizzle2D* swizzle_2d) override; + + // We're going to replay this resize operation on the corresponding IDs + // if replaying resize is enabled. + void handle(const Resize* resize) override; + + Expr* new_expr_ = nullptr; + const std::unordered_map& + provided_expr_val_2_replacement_val_; +}; + } // namespace nvfuser diff --git a/csrc/id_model/utils.h b/csrc/id_model/utils.h new file mode 100644 index 00000000000..2d6327bf586 --- /dev/null +++ b/csrc/id_model/utils.h @@ -0,0 +1,55 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include +#include +#include + +#define VERBOSE() verbose(__LINE__) +#define WARN() warn(__LINE__) + +namespace nvfuser { + +// Temporary logging utility +class DebugStream { + public: + DebugStream() + : enabled_(getNvFuserEnv("ID_MODEL_VERBOSE")), out_(std::cerr) {} + + template + DebugStream& operator<<(const T& v) { + if (enabled_) { + out_ << v; + } + return *this; + } + + DebugStream& operator<<(std::ostream& (*endl)(std::ostream&)) { + if (enabled_) { + out_ << endl; + } + return *this; + } + + private: + bool enabled_ = false; + std::ostream& out_; +}; + +inline DebugStream verbose(int line) { + return DebugStream() << "[DEBUG@" << line << "] "; +} + +inline DebugStream warn(int line) { + return DebugStream() << "[WARN@" << line << "] "; +} + +} // namespace nvfuser diff --git a/csrc/val_graph.cpp b/csrc/val_graph.cpp index a297e5b8df2..82516c7259b 100644 --- a/csrc/val_graph.cpp +++ b/csrc/val_graph.cpp @@ -6,9 +6,12 @@ */ // clang-format on #include +#include #include #include +#include + namespace nvfuser { namespace { @@ -30,9 +33,7 @@ ValGraph::ValGraph(const ValGraph& other) new_expr_groups.pushBack(toGroup(orig_expr_group->front())); } - NVF_ERROR( - unique_definitions_.emplace(new_val_group, std::move(new_expr_groups)) - .second); + unique_definitions_[new_val_group] = new_expr_groups; } for (const auto& [orig_val_group, orig_expr_groups] : other.unique_uses_) { @@ -211,6 +212,251 @@ bool ValGraph::hasUses(const ValGroup& val_group) const { return unique_uses_.find(val_group) != unique_uses_.end(); } +ExprGroups ValGraph::getExprsBetween(const ValGroups& from, const ValGroups& to) + const { + ExprGroups all_uses_of_from = allUsesOf(from); + ExprGroups all_definitions_of_to = allDefinitionsOf(to); + + // All of the expressions between from and to. Not all will be used as we + // just want to define each iter domain group once. + ExprGroups all_exprs = + all_uses_of_from.computeIntersect(all_definitions_of_to); + + // There could be IterDomains in from or to that are between other from and + // to nodes. Make sure to clear those out. + ValGroups terminating_inputs; + ValGroups terminating_outputs; + { + ValGroups not_inputs; + ValGroups not_outputs; + ValGroups all_id_groups; + + for (const ExprGroup& expr_group : all_exprs) { + if (isTrivialExprGroup(expr_group)) { + // Expression is just a loop to its current group, ignore + continue; + } + + std::vector inp_groups = inputGroups(expr_group); + std::vector out_groups = outputGroups(expr_group); + + all_id_groups.pushBack(inp_groups); + not_outputs.pushBack(inp_groups); + + all_id_groups.pushBack(out_groups); + not_inputs.pushBack(out_groups); + } + terminating_inputs = all_id_groups.computeSubtract(not_inputs); + terminating_outputs = all_id_groups.computeSubtract(not_outputs); + } + + // Track all expressions to get from outputs to this IterDomain. We + // traverse backwards as that's the direction of indexing expressions. An + // index is assigned to each leaf of a domain and as we traverse backwards + // we're effectively accumulating indexing math. We'll only keep the fewest + // expression lists to get to the iter domain. + std::unordered_map required_ind_exprs_ids; + std::unordered_map required_ind_exprs_exprs; + + // Return if all output IterDomain groups of an expression group have + // already been visited + auto outputsVisited = [&](ExprGroup expr_group) { + auto output_groups = outputGroups(expr_group); + return std::all_of( + output_groups.begin(), + output_groups.end(), + [&](const ValGroup& output_group) { + return required_ind_exprs_ids.find(output_group) != + required_ind_exprs_ids.end(); + }); + }; + + // Returns all expression groups in required_ind_exprs_ids of outputs + auto requiredExprsOutputs = [&](ExprGroup expr_group) -> ExprGroups { + ExprGroups all_output_required_exprs; + for (const ValGroup& output_id_group : outputGroups(expr_group)) { + auto id_group_exprs_it = required_ind_exprs_ids.find(output_id_group); + NVF_ERROR( + id_group_exprs_it != required_ind_exprs_ids.end(), + "Failure in Iter Domain Graph index resolution, count expected for group: ", + output_id_group->toString()); + all_output_required_exprs.pushBack(id_group_exprs_it->second); + } + return all_output_required_exprs; + }; + + auto processExprGroup = [&](ExprGroup expr_group) -> bool { + if (!outputsVisited(expr_group)) { + return false; + } + // Accumulate expressions from all outputs add this expression and set it + // as current expressions required indexing expressions. + required_ind_exprs_exprs[expr_group] = requiredExprsOutputs(expr_group); + return true; + }; + + auto processValGroup = [&](ValGroup id_group) -> bool { + // Track if we've grabed any of the uses required indexing expressions. + bool initialized = false; + // Expression group of all indexing expressions required for this iter + // domain coming back from any of its uses. + ExprGroups min_groups; + + const ExprGroups& uses = getUses(id_group); + + if (uses.empty()) { + // No expressions required for this iter domain, it must be a + // terminating output. + required_ind_exprs_ids[id_group] = min_groups; + return true; + } + + // Only worry about expressions between inputs and outputs we're + // looking at. + for (const ExprGroup& use_group : uses.computeIntersect(all_exprs)) { + auto use_required_ind_exprs_it = required_ind_exprs_exprs.find(use_group); + if (use_required_ind_exprs_it == required_ind_exprs_exprs.end()) { + // If there isn't an entry for the use expression it wasn't + // processed, so don't try to process this iter domain yet. + return false; + } + if (!initialized) { + // If first use found initialize the minimum expression group + min_groups = + use_required_ind_exprs_it->second.computeUnion({use_group}); + initialized = true; + } else if ( + use_required_ind_exprs_it->second.size() + 1 < min_groups.size()) { + // If current use has fewer expressions use that, make sure to add the + // use expression. + min_groups = + use_required_ind_exprs_it->second.computeUnion({use_group}); + } + } + required_ind_exprs_ids[id_group] = min_groups; + return true; + }; + + // Backward traversal from the terminating outputs + ValGroups to_visit_ids = terminating_outputs; + ExprGroups to_visit_exprs; + + while (!to_visit_ids.empty() || !to_visit_exprs.empty()) { + // Process expressions first as all uses of iter domains have to be + // processed before we can process that iter domain. + + // Try to detect when nothing has been processed which would put us in an + // infinite loop + bool something_was_processed = false; + ExprGroups still_to_visit_exprs; + while (!to_visit_exprs.empty()) { + ExprGroup currently_visiting_exprs = to_visit_exprs.popFront(); + if (required_ind_exprs_exprs.find(currently_visiting_exprs) != + required_ind_exprs_exprs.end()) { + // currently_visiting_exprs is already visited + continue; + } + if (processExprGroup(currently_visiting_exprs)) { + something_was_processed = true; + std::vector inp_groups = + inputGroups(currently_visiting_exprs); + for (const ValGroup& inp_group : inp_groups) { + to_visit_ids.pushBack(inp_group); + } + } else { + still_to_visit_exprs.pushBack(currently_visiting_exprs); + } + } + + std::swap(to_visit_exprs, still_to_visit_exprs); + + ValGroups still_to_visit_ids; + while (!to_visit_ids.empty()) { + auto currently_visiting_ids = to_visit_ids.popFront(); + if (required_ind_exprs_ids.find(currently_visiting_ids) != + required_ind_exprs_ids.end()) { + continue; + } + + if (processValGroup(currently_visiting_ids)) { + something_was_processed = true; + for (const ExprGroup& def : getDefinitions(currently_visiting_ids)) { + if (!all_exprs.has(def)) { + continue; + } + if (required_ind_exprs_exprs.find(def) == + required_ind_exprs_exprs.end()) { + to_visit_exprs.pushBack(def); + } + } + } else { + still_to_visit_ids.pushBack(currently_visiting_ids); + } + } + + NVF_ERROR( + something_was_processed || + (to_visit_ids.empty() && to_visit_exprs.empty()), + "Infinite loop entered."); + } + + // We want to traverse the expressions registered in required_ind_exprs_ids, + // let's create a strict "uses path" + std::unordered_map uses_path; + for (const auto& entry : required_ind_exprs_ids) { + const ValGroup& id = entry.first; + const ExprGroups& traverse_exprs = entry.second; + const ExprGroups& all_uses = getUses(id); + uses_path[id] = traverse_exprs.computeIntersect(all_uses); + } + + // Topologically sort the uses_path. + ExprGroups sorted_exprs; + ExprGroups to_visit_expr_groups; + + for (const ValGroup& inp : terminating_inputs) { + auto use_it = uses_path.find(inp); + if (use_it == uses_path.end()) { + // This can happen for a trivial traversal where inputs and outputs are + // exactly the same. + continue; + } + const ExprGroups& uses = use_it->second; + for (const ExprGroup& use : uses) { + to_visit_expr_groups.pushBack(use); + } + } + + ValGroups visited = terminating_inputs; + + while (!to_visit_expr_groups.empty()) { + bool something_processed = false; + ExprGroups still_to_visit; + while (!to_visit_expr_groups.empty()) { + auto currently_visiting = to_visit_expr_groups.popFront(); + auto inputs = inputGroups(currently_visiting); + if (std::all_of(inputs.begin(), inputs.end(), [&](ValGroup inp_id) { + return visited.has(inp_id); + })) { + something_processed = true; + sorted_exprs.pushBack(currently_visiting); + auto outputs = outputGroups(currently_visiting); + for (const ValGroup& out_id : outputs) { + visited.pushBack(out_id); + const ExprGroups& uses = getUses(out_id); + still_to_visit.pushBack(uses.computeIntersect(all_exprs)); + } + } else { + still_to_visit.pushBack(currently_visiting); + } + } + std::swap(to_visit_expr_groups, still_to_visit); + NVF_ERROR(something_processed, "Infinite loop entered."); + } + + return sorted_exprs; +} + std::unordered_map> ValGraph::buildMapBetween( const std::vector& from, const std::vector& to) const { @@ -269,10 +515,6 @@ void ValGraph::initializeVal( const ValGroup& val_disjoint_set = disjoint_vals_.initializeSet(val).first->second; - // For now, the definition of a val should be unique. Remove this - // assertion as necessary - NVF_ERROR(definitions.size() <= 1); - ExprGroups def_groups; for (auto def : definitions) { const ExprGroup& expr_set = @@ -563,6 +805,56 @@ bool ValGraph::mapThroughExpr(Expr* first, Expr* second, bool forward) { return true; } +void ValGraph::removeTrivialExprs() { + ExprGroups trivial_expr_groups; + // This seems like it shouls just be a copy if. + for (const ExprGroup& expr_group : disjointExprSets().disjointSets()) { + if (isTrivialExprGroup(expr_group)) { + trivial_expr_groups.pushBack(expr_group); + } + } + + // Clear out expressions that map inputs and outputs to the same group + // from definitions and uses. They shouldn't be important in traversal, and + // will break the terminal input/terminal output logic of traversal. Similar + // to what's drafted in buildIndexGraph + for (const ExprGroup& trivial_expr_group : trivial_expr_groups) { + // Complexity of erase not good as both disjoint set and vector of unique + // entries require a vector find to erase an entry. + eraseExprGroup(trivial_expr_group); + } +} + +// Complexity here is not great. We might want a better complexity version when +// erasing multiple expr_groups. +void ValGraph::eraseExprGroup(const ExprGroup& expr_group) { + // Erase entries that exist in unique_definitions_ and unique_uses_ + for (const ValGroup& id_group : disjointValSets().disjointSets()) { + // Make sure the entries exists + NVF_ERROR( + unique_definitions_.find(id_group) != unique_definitions_.end(), + "Broken definitions, couldn't find entry for id group, ", + nvfuser::toString(id_group, 0, true)); + NVF_ERROR( + unique_uses_.find(id_group) != unique_uses_.end(), + "Broken uses, couldn't find entry for id group, ", + nvfuser::toString(id_group, 0, true)); + + unique_definitions_[id_group].erase(expr_group); + unique_uses_[id_group].erase(expr_group); + } + + for (auto expr : *expr_group) { + disjoint_exprs_.erase(expr); + } +} + +bool ValGraph::isTrivialExprGroup(const ExprGroup& expr_group) const { + return !ValGroups(inputGroups(expr_group)) + .computeIntersect(ValGroups(outputGroups(expr_group))) + .empty(); +} + void ValGraph::validateConsistency() const { // Check the consistency of the mapping information. Specifically: // 1. All ValGroup and ExprGroup sets are not empty. This may not be diff --git a/csrc/val_graph.h b/csrc/val_graph.h index 537a268dc7b..467ab9bfddc 100644 --- a/csrc/val_graph.h +++ b/csrc/val_graph.h @@ -146,6 +146,11 @@ class ValGraph { bool hasUses(const ValGroup& val_group) const; + // Return sorted expressions to go from the provided IterDomains in from to + // the provided IterDomains in to with provided mode. Minimal expressions to + // get from 'from' to 'to' returned. + ExprGroups getExprsBetween(const ValGroups& from, const ValGroups& to) const; + // Uses the Valgraph to produce mappings between from and to. // Supports one to many mappings. If a single Val in from maps to // multiple Vals in to, the order of the Vals in value of @@ -210,6 +215,7 @@ class ValGraph { // mappings. void validateConsistency() const; + public: void addUniqueUses(const ValGroup& id_group, const ExprGroup& uses) { unique_uses_.at(id_group).pushBack(uses); } @@ -230,6 +236,23 @@ class ValGraph { // be the only call in ValGraph to mapThroughExpr. void maybeMapThroughExprs(Expr* expr0, Expr* expr1, bool forward); + // Removes expressions from unique_definitions_ and unique_uses_ that return + // mappings from IdGraph::isTrivialExpr + void removeTrivialExprs(); + + // Removes the provided expression group from unique_definitions_ and + // unique_uses_ breaking traversal through them. + void eraseExprGroup(const ExprGroup& expr_group); + + // Returns if the expression group has an input id group that matches an + // output id group. This means traversing on this expression doesn't actually + // do anything. + bool isTrivialExprGroup(const ExprGroup& expr_group) const; + + void setPropagateThroughExprs(bool b) { + propagate_through_exprs_ = b; + } + // Can't back prop through merge without making sure one input actually // matches. This can be done on a map or extent basis. // TODO: Move this to val_graph.cpp once validation_utils.cpp is diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index b676b9043e2..80a73f0de30 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -3937,39 +3937,6 @@ TEST_F(NVFuserTest, FusionScheduleTransposeRepro1_CUDA) { testValidate(&fusion, outputs, {input0, input1}, __LINE__, __FILE__); } -// Repro for issue #1873 -TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeContigTensor(1); - auto tv1 = makeContigTensor(2); - fusion.addInput(tv0); - fusion.addInput(tv1); - auto tv2 = set(tv0); - auto tv3 = broadcast(tv2, {true, false}); - auto tv4 = add(tv3, tv1); - fusion.addOutput(tv4); - - tv4->merge(0); - tv4->split(0, 32); - - tv0->computeAt(tv4, 1); - - tv2->split(-1, 8); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({123}, options); - at::Tensor t1 = at::randn({3, 123}, options); - - FusionExecutor fe; - fe.compileFusion(&fusion, {t0, t1}); - - auto outputs = fe.runFusion({t0, t1}); - - testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__); -} - TEST_F(NVFuserTest, FusionPredicateUnshare_CUDA) { // https://github.com/csarofeen/pytorch/issues/1926 std::unique_ptr fusion_ptr = std::make_unique(); diff --git a/tests/cpp/test_gpu_indexing.cpp b/tests/cpp/test_gpu_indexing.cpp index 6a54182bc22..6b805bf5bf2 100644 --- a/tests/cpp/test_gpu_indexing.cpp +++ b/tests/cpp/test_gpu_indexing.cpp @@ -9,7 +9,10 @@ #include #include +#include #include +#include +#include #include #include #include @@ -73,6 +76,7 @@ TEST_F(NVFuserTest, FusionIndexing1_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } +// Same as 1 but merge starting from inner most dimension TEST_F(NVFuserTest, FusionIndexing2_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -123,6 +127,7 @@ TEST_F(NVFuserTest, FusionIndexing2_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } +// Same compute as 1 and 2 but use a scheduler. TEST_F(NVFuserTest, FusionIndexing3_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -153,6 +158,7 @@ TEST_F(NVFuserTest, FusionIndexing3_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } +// Same as 3 but use 3 dimensions and concrete sizes TEST_F(NVFuserTest, FusionIndexing4_CUDA) { Fusion fusion; FusionGuard fg(&fusion); @@ -352,8 +358,8 @@ TEST_F(NVFuserTest, FusionIndexing8_CUDA) { &fusion, cg_outputs, {at_t0, at_t1}, {aten_output}, __LINE__, __FILE__); } +// Same as 5 but using implicit broadcast TEST_F(NVFuserTest, FusionIndexing9_CUDA) { - // Same as 7 but with outer splits instead of inner Fusion fusion; FusionGuard fg(&fusion); @@ -727,43 +733,574 @@ TEST_F(NVFuserTest, FusionIndexing17_CUDA) { testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); } -// Repro of issue #2560 +// TODO: Finish and enable test TEST_F(NVFuserTest, FusionIndexing18_CUDA) { Fusion fusion; FusionGuard fg(&fusion); + TensorView* tv0 = makeConcreteTensor({5, 7, 11, 13}); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + + auto tv2 = makeConcreteTensor({5, 11}); + fusion.addInput(tv2); + + auto tv3 = broadcast(tv2, {false, true, false, true}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + // // tv4[5, 7, 11, 13] = tv3[5, b1, 11, b3] + tv1[5, 7, 11, 13] + tv4->merge(0, 3); + // tv4[5*13, 7, 11] + tv4->split(0, 3); + // tv4[5*13//3, 3, 7, 11] + tv4->merge(2, 3)->split(2, 2); + // tv4[5*13//3, 3, 7*11//2, 2] + // tv4->merge(0, 2); + // // tv4[(5*13//3)*(7*11//2), 3, 2] + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + inlineAllAt(tv4, 1, false); + fusion.printKernel(); + // std::cout<definition()->toString()<merge(0)->merge(0); + // tv10[7*11*13] + tv10->split(0, 5)->split(0, 3); + // tv10[7*11*13//5//3, 3, 5] + + TransformPropagatorWithCheck propagator(tv10); + MaxRootDomainInfoSpanningTree(tv10).traverse(&propagator); + + std::vector tensors_to_inline{tv1, tv2, tv4, tv6, tv8}; + for (auto tensor : tensors_to_inline) { + tensor->inlineAt(1); + } + + // Validation needs to be disabled as ComputeAtMap would fail with this fusion + IdModel id_model( + &fusion, + /* build_graphs */ true, + /* allow_self_mapping */ false, + /* validate */ false); + + // All of the IDs that are generated with merge operations from the + // root domains should be mapped to the single group. + const ValGroup& merge_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(tv1->getRootDomain().at(0)); + for (auto tv : {tv1, tv2, tv4, tv5, tv6, tv8, tv9}) { + for (auto id : ir_utils::allIDsOf(tv)) { + if (dynamic_cast(id->definition()) == nullptr) { + const ValGroup& loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(id); + ASSERT_EQ(loop_group, merge_loop_group) + << "Unexpected loop group: " << nvfuser::toString(loop_group); + } + } + } + + const auto& promotion_map = id_model.loopPromotionMap(); + + // The merge loop group should be promoted to the output of the + // final merge in tv10 + auto ref_merge_out = tv10->axis(0) + ->definition() + ->input(0) + ->definition() + ->input(0) + ->as(); + + auto promotion_map_it = promotion_map.find(merge_loop_group); + ASSERT_TRUE(promotion_map_it != promotion_map.end()) + << "Loop promotion not found for merge loop group: " + << nvfuser::toString(merge_loop_group); + auto merge_out_promotion_id = promotion_map_it->second; + ASSERT_EQ( + id_model.idGraph(IdMappingMode::EXACT).toGroup(merge_out_promotion_id), + id_model.idGraph(IdMappingMode::EXACT).toGroup(ref_merge_out)) + << "Merge loop group should be promoted to " << ref_merge_out->toString(); + ASSERT_NE( + id_model.idGraph(IdMappingMode::LOOP).toGroup(merge_out_promotion_id), + id_model.idGraph(IdMappingMode::LOOP).toGroup(ref_merge_out)) + << "Should not be loop-mapped with ref: " + << merge_out_promotion_id->toString(); + + // Get the corresponding reference ID in tv10 + auto getRefId = [&](TensorView* tv, IterDomain* id) -> IterDomain* { + if (dynamic_cast(id->definition()) != nullptr) { + if (id->uses().empty()) { + auto it = std::find( + tv->getLeafDomain().begin(), tv->getLeafDomain().end(), id); + NVF_ERROR(it != tv->getLeafDomain().end()); + int leaf_pos = + static_cast(std::distance(tv->getLeafDomain().begin(), it)); + return tv10->axis(leaf_pos); + } else { + return tv10->axis(0)->definition()->input(0)->as(); + } + } else { + return ref_merge_out; + } + }; + + // Check if id is a leaf of a consumer tensor of tv + auto isIdOfConsumerTensor = [&](IterDomain* id, TensorView* tv) -> bool { + auto consumer_tvs = ir_utils::consumerTvsOf(tv); + return std::any_of( + consumer_tvs.begin(), consumer_tvs.end(), [&](auto consumer_tv) { + auto all_ids = ir_utils::allIDsOf(consumer_tv); + return std::find(all_ids.begin(), all_ids.end(), id) != all_ids.end(); + }); + }; + + // At this point, all of the IDs from the root until split are + // validated. Validating the remaining IDs + for (auto tv : {tv1, tv2, tv4, tv5, tv6, tv8, tv9}) { + for (auto id : ir_utils::allIDsOf(tv)) { + const auto& loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(id); + if (loop_group == merge_loop_group) { + // already validated + continue; + } + + auto promotion_map_it = promotion_map.find(loop_group); + ASSERT_TRUE(promotion_map_it != promotion_map.end()) + << "Loop promotion not found for " << id->toString() << " of " + << tv->toString() + << ". Loop group: " << nvfuser::toString(loop_group); + + auto promotion_id = promotion_map_it->second; + + // Promotion ID should be loop-mapped + ASSERT_TRUE(loop_group->has(promotion_id)) + << "Loop promotion for " << id->toString() << " of " << tv->toString() + << " is promoted to an ID that isn't loop mapped: " + << promotion_id->toString() << std::endl; + + auto promotion_exact_group = + id_model.idGraph(IdMappingMode::EXACT).toGroup(promotion_id); + + auto ref_id = getRefId(tv, id); + auto ref_exact_group = + id_model.idGraph(IdMappingMode::EXACT).toGroup(ref_id); + + ASSERT_EQ(promotion_exact_group, ref_exact_group) + << "Invalid promotion: " << id->toString() << " of " << tv->toString() + << ". Promotion group: " << nvfuser::toString(promotion_exact_group); + + auto ref_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(ref_id); + ASSERT_NE(loop_group, ref_loop_group) + << "Invalid promotion: " << id->toString() << " of " << tv->toString() + << ". Should not be loop-mapped with ref: " + << nvfuser::toString(loop_group); + + // If id is a leaf, make sure it isn't mapped with + auto leaf_id_it = + std::find(tv->getLeafDomain().begin(), tv->getLeafDomain().end(), id); + if (leaf_id_it != tv->getLeafDomain().end() && + std::distance(tv->getLeafDomain().begin(), leaf_id_it) >= + tv->getComputeAtPosition()) { + for (auto loop_mapped_id : *loop_group) { + if (loop_mapped_id == id) { + continue; + } + ASSERT_FALSE( + isIdOfConsumerTensor(loop_mapped_id->as(), tv)) + << "Invalid promotion: " << id->toString() << " of " + << tv->toString() << ". Found to mapped a consumer tensor: " + << loop_mapped_id->name(); + } + } + } + } + + // The current ComputeAtMap fails with this fusion + // fusion.printKernel(); +} + +// Progressive loop promotion. producer gets promoted in consumer, consumer is +// promoted in a different way to its consumer. +TEST_F(NVFuserTest, FusionIndexing20_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeConcreteTensor({5}); + fusion.addInput(tv0); + + // [5] + auto tv1 = set(tv0); + auto tv2 = broadcast(tv1, {true, false}); + // [1, 5] + auto tv3 = makeConcreteTensor({3, 5}); + fusion.addInput(tv3); + auto tv4 = add(tv3, tv2); + // [3, 5] + + auto tv5 = broadcast(tv4, {false, false, true}); + // [3, 5, 1] + auto tv6 = makeConcreteTensor({3, 5, 7}); + fusion.addInput(tv6); + auto tv7 = add(tv5, tv6); + // [3, 5, 7] + fusion.addOutput(tv7); + + tv4->merge(0)->split(0, 2, false); + // [3, 5] + // [3, 3*5//2] + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + // tv0->tv1->tv2(b)->tv4->tv5(b)->tv7 + + tv1->inlineAt(1); + tv2->inlineAt(1); + tv4->inlineAt(1); + + // [2, 3*5//2] + tv5->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*1//4] + tv7->merge(1)->split(1, 4, false); + // [2, 4, (3*5//2)*7//4] + tv5->inlineAt(2); + + IdModel id_model(&fusion); + const auto& promotion_map = id_model.loopPromotionMap(); + + // For tv1, tv2, tv4, their first leaf domains should all be + // loop-mapped and promoted to a domain that is exaclty mapped with + // the first leaf domain of tv7. The second leaf domains should also + // be promoted to the domain at the same position in tv7 but since + // they are not inlined, they should not be loop-mapped + for (auto tv : {tv1, tv2, tv4}) { + // Validating the first leaf ID + { + auto leaf_id0 = tv->axis(0); + auto ref_id0 = tv7->axis(0); + ASSERT_TRUE(id_model.idGraph(IdMappingMode::LOOP) + .disjointValSets() + .strictAreMapped(leaf_id0, ref_id0)); + auto promotion_map_it = promotion_map.find( + id_model.idGraph(IdMappingMode::LOOP).toGroup(leaf_id0)); + ASSERT_NE(promotion_map_it, promotion_map.end()); + auto promoted_id = promotion_map_it->second; + ASSERT_TRUE(id_model.idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(promoted_id, ref_id0)) + << "Expected exact mapping: " << promoted_id->toString() << " with " + << ref_id0->toString() << " of " << tv7->toString(); + } + + // Validating the second leaf ID + { + auto leaf_id1 = tv->axis(1); + // Should be promoted to a domain that is exactly mapped with iS31 + auto ref_id1 = tv7->axis(1) + ->definition() + ->as() + ->in() + ->definition() + ->as() + ->outer(); + auto promotion_map_it = promotion_map.find( + id_model.idGraph(IdMappingMode::LOOP).toGroup(leaf_id1)); + ASSERT_NE(promotion_map_it, promotion_map.end()); + auto promoted_id = promotion_map_it->second; + ASSERT_TRUE(id_model.idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(promoted_id, ref_id1)) + << "Expected exact mapping: " << promoted_id->toString() << " with " + << ref_id1->toString() << " of " << tv7->toString(); + // While promoted ID should be exact-mapped with the reference ID, they + // should not be loop-mapped + ASSERT_FALSE(id_model.idGraph(IdMappingMode::LOOP) + .disjointValSets() + .strictAreMapped(promoted_id, ref_id1)) + << "Expected no loop mapping: " << promoted_id->toString() << " with " + << ref_id1->toString() << " of " << tv7->toString(); + + // In the case of tv1 and tv2, the promoted id is a newly replayed + // domain, whereas for the tv4, there should be no replay as + // there's no broadcast. So, the size of the loop group should be + // 2 for the former and 1 for the latter. + const auto& leaf_id1_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(leaf_id1); + ASSERT_EQ(leaf_id1_loop_group->size(), tv == tv4 ? 1 : 2) + << "Unexpected loop group: " + << nvfuser::toString(leaf_id1_loop_group); + } + } + + // Validate tv5. The last leaf domain should be promoted to a domain + // that is exactly mapped with the last domain of tv7 + { + auto last_leaf = tv5->axis(-1); + auto promotion_map_it = promotion_map.find( + id_model.idGraph(IdMappingMode::LOOP).toGroup(last_leaf)); + ASSERT_NE(promotion_map_it, promotion_map.end()); + auto promoted_id = promotion_map_it->second; + ASSERT_TRUE(id_model.idGraph(IdMappingMode::EXACT) + .disjointValSets() + .strictAreMapped(promoted_id, tv7->axis(-1))) + << "Expected exact mapping: " << promoted_id->toString() << " with " + << tv7->axis(-1)->toString() << " of " << tv7->toString(); + + // While promoted ID should be exact-mapped with the last ID, they + // should not be loop-mapped + ASSERT_FALSE(id_model.idGraph(IdMappingMode::LOOP) + .disjointValSets() + .strictAreMapped(promoted_id, tv7->axis(-1))) + << "Expected no loop maping: " << promoted_id->toString() << " with " + << tv7->axis(-1)->toString() << " of " << tv7->toString(); + } + + // Validation not enabled yet as incorrect code is generated. Need + // to use the loop promotion info to generate correct loop-nests +#if 0 + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({5}, options); + auto t3 = at::randn({3, 5}, options); + auto t6 = at::randn({3, 5, 7}, options); + std::vector aten_inputs = {t0, t3, t6}; + + FusionExecutor fe; + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__); +#endif +} + +// Repro for issue #1873 +TEST_F(NVFuserTest, FusionInlineBroadcastIndexing0_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(1); + auto tv1 = makeContigTensor(2); + fusion.addInput(tv0); + fusion.addInput(tv1); + auto tv2 = set(tv0); + auto tv3 = broadcast(tv2, {true, false}); + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + tv4->merge(0); + tv4->split(0, 32); + + TransformPropagatorWithCheck propagator(tv4); + MaxRootDomainInfoSpanningTree(tv4).traverse(&propagator); + + tv2->inlineAt(1); + tv3->inlineAt(1); + + tv2->split(-1, 8); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({123}, options); + at::Tensor t1 = at::randn({3, 123}, options); + + FusionExecutor fe; + fe.compileFusion(&fusion, {t0, t1}); + + auto outputs = fe.runFusion({t0, t1}); + + testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__); +} + +// Broadcast inline 3 times and merge all domains +TEST_F(NVFuserTest, FusionMultiPromotion_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + + // [y] + auto tv0 = makeSymbolicTensor(1); + // [w, x, y, z] + auto tv1 = makeSymbolicTensor(4); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // y + auto tv2 = broadcast(tv0, {true, false}); + // w, y, z + auto tv3 = broadcast(tv2, {false, false, true}); + // w, y, z + auto tv4 = broadcast(tv3, {false, true, false, false}); + // w, x, y, z + auto tv5 = add(tv4, tv1); + + fusion.addOutput(tv5); + + tv5->merge(1)->merge(1)->merge(0)->split(0, 11); + + tv0->computeAt(tv5, 1); + tv1->computeAt(tv5, 1); + + FusionExecutor fe; + + int w = 3, x = 4, y = 7, z = 8; + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + + at::Tensor t0 = at::randn({y}, options); + at::Tensor t1 = at::randn({w, x, y, z}, options); + + auto t4 = t0.unsqueeze(0).unsqueeze(0).unsqueeze(-1); + auto aten_output = t4.add(t1); + + std::vector aten_inputs = {t0, t1}; + + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); +} + +// Broadcast and concretize same domain in two different ways and try to merge +// their loops. The inlining pattern is invalid but the current +// inlining check is not capable of flagging the inlining poistion as +// invalid. The loop promotion analysis should not find any promotion +// of the loop group where all the leaf domains are merged into. +TEST_F(NVFuserTest, FusionMultiPromotion2_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] auto tv0 = makeSymbolicTensor(1); fusion.addInput(tv0); + + // [w, x] + auto tv1 = makeSymbolicTensor(2); + fusion.addInput(tv1); + + // [w, y] + auto tv2 = makeSymbolicTensor(2); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + // [w] + auto tv4 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv5 = add(tv4, tv1); + // [w, x] + fusion.addOutput(tv5); + + // [w] + auto tv6 = broadcast(tv3, {false, true}); + // [w, 1] + auto tv7 = add(tv6, tv2); + // [y] + fusion.addOutput(tv7); + + for (auto tv : std::vector{tv4, tv5, tv6, tv7}) { + tv->merge(0); + } + + // Since x and y are not proven to be the same, this inling position + // should not be allowed. + for (auto tv : std::vector{tv3, tv4, tv6}) { + tv->inlineAt(1); + } + + // For now, just make sure there's no loop promotion for the merged + // loop group. + IdModel id_model(&fusion); + const auto& leaf_loop_group = + id_model.idGraph(IdMappingMode::LOOP).toGroup(tv7->axis(0)); + auto promotion_map_it = id_model.loopPromotionMap().find(leaf_loop_group); + ASSERT_EQ(promotion_map_it, id_model.loopPromotionMap().end()); +} + +// TODO: All the above tests are merges followed by splits, we should make some +// more complex examples even though merging then spliting is the most likely +// use case. In multi-gpu it may be the exact opposite where we split out the +// outer most iter domain to the multi-gpu dimension, then schedule. + +TEST_F(NVFuserTest, FusionIndexSplitMerge_CUDA) { + Fusion fusion; + FusionGuard fg(&fusion); + // [w] + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // [w, x] auto tv1 = makeSymbolicTensor(2); fusion.addInput(tv1); auto tv2 = broadcast(tv0, {false, true}); - auto tv3 = add(tv2, tv1); - auto tv4 = sum(tv3, {0, 1}); - fusion.addOutput(tv4); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); - tv4->merge(0); - tv4->split(0, 4); - auto tv5 = tv4->rFactor({1}); + tv3->split(0, 3); + tv3->split(2, 4); + tv3->merge(1); + tv3->split(1, 5); - MaxRootDomainInfoSpanningTree tree(tv5); - TransformPropagator tp(tv5); + MaxRootDomainInfoSpanningTree tree(tv3); + TransformPropagator tp(tv3); tree.traverse(&tp); - inlineAllAt(tv4, 1, true); + inlineAllAt(tv3, 1, true); + FusionExecutor fe; + int x = 4, y = 7; auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({5}, options); - at::Tensor t1 = at::randn({5, 3}, options); - std::vector inputs = {t0, t1}; - FusionExecutor fe; - fe.compileFusion(&fusion, inputs); - auto cg_outputs = fe.runFusion(inputs); + at::Tensor t0 = at::randn({x}, options); + at::Tensor t1 = at::randn({x, y}, options); - auto ref = (t0.unsqueeze(-1) + t1).sum(); + auto t2 = t0.unsqueeze(-1); + auto aten_output = t1.add(t2); + + std::vector aten_inputs = {t0, t1}; - testValidate(fe.kernel(), cg_outputs, inputs, {ref}, __LINE__, __FILE__); + fe.compileFusion(&fusion, aten_inputs); + auto cg_outputs = fe.runFusion(aten_inputs); + + testValidate( + &fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__); } } // namespace nvfuser diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index eb7245a3b8d..15506e6b938 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include