-
Notifications
You must be signed in to change notification settings - Fork 69
WIP: IterDomain Graphs #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
16f064e
45959ef
eee9bf7
8ce0963
f745410
6edbba3
1a2b261
8b3fe64
4d5c604
0e47e5c
91fb637
cc17fef
1175da2
7891741
ee4e311
4249e94
853cd44
dc88239
8d571a6
8879459
a3a86fd
6eaeb06
32cbc5b
db3ba36
adbad3d
68dec08
6f682d7
0588276
6991b90
81ba299
4db481b
ad7012b
a9d192e
e923a0a
d3793f0
2896a28
0cc1356
ddc858e
8fd5bce
84ed670
a466c5c
3934fe6
0729dee
1f10bd3
2a96ccb
cccb079
899d5e9
23b2e78
cd593fd
69a0b0f
f8c1812
93bb70a
04172d7
b75e197
8cf25af
a98883d
51243b5
a0d8c43
397edcc
fd97525
177d40c
535ce1d
b80e871
bb4968b
d6504f8
6a6ee7a
edddb91
60407b2
49967fb
24dc758
f9045f9
c51379e
7d4acab
d77436c
7f34b17
6603e0a
aacc529
a761409
d3eb4c1
3bb9692
b778788
4c50dcf
9fa2d1c
e20a76b
27c619a
4523cb2
86c574a
bff13b8
283a3c9
aadcb9d
5f0e5c4
811a4ad
2577dfc
f3bbd8f
f2448c2
57469b3
5dde9f1
7f956c4
e01e2e6
b3e60b5
f7f4d84
f073d04
6ba29ed
9eacdf6
f8a9585
9733ab6
c07402e
dbff25c
f1b5f63
ad5debf
a17dc11
14ab237
f6e6848
5715ce9
a8070a9
6c4a5f3
ad83b72
554bd3e
4691a3a
1f36eb7
f5b39e0
636fbd7
bc8fc05
9b6d761
4d2dc50
4ad9eeb
151b0ef
27cd19c
e45b1bb
3adffd9
f9c9d37
6fbebc9
b364f28
3b94574
1904eff
133613a
c66617d
119cf0f
15bfe64
64b409e
a12514e
e6c43e9
197f227
5d5458e
de20be4
9286489
ae70e3c
2dc3262
a902c6e
7ef52af
eced85a
927aec4
cbaaf0e
45dd418
be6eaac
534ac78
2837e53
45b8be9
ec8a2f5
1960432
74c98e2
ccd7bab
6505c63
1c11182
148ef83
3e7992c
fe1517d
62346a9
04458cb
6a112a3
f055994
4781d65
7962338
1eebde4
0e72708
7561d19
48e3019
a828f6f
fa99fe2
2321f3e
ad3da55
e35fb69
9727240
4028dea
32f47db
3d8b582
44da75a
44ad7e0
d8f3ed4
cb5d1bc
c14dd72
fc861a6
e6071e5
0998319
c8db2f7
8740217
0e69513
30e3d5b
fcc3b96
b30682a
9ed0b82
4d1dcdc
1231b60
04b1479
4b391d8
270ac19
5d20b8d
bd46e88
e17fa35
5573fbc
e644665
e21e4e5
801a7b4
649563a
b25ced1
3038e0d
6ef567b
0027ee9
2bcbcb4
fa8455a
c05fd85
31d28e5
d0bdc8e
aa9b414
7aef29d
24a0c8c
c7c04b1
2b605cd
0f5ab07
423fde1
c78da86
f0aeab7
ce5f8ee
3214bc7
d8aacea
18de523
83edc44
e71cb5a
009a11a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| #include <id_model/loop_promotion.h> | ||
| #include <id_model/to_string.h> | ||
| #include <id_model/transform_replay.h> | ||
| #include <id_model/utils.h> | ||
| #include <id_model/validation_utils.h> | ||
|
|
||
| #include <device_lower/analysis/trivial_broadcast.h> | ||
|
|
@@ -233,6 +234,203 @@ std::string IdModel::toString() const { | |
| return ss.str(); | ||
| } | ||
|
|
||
| // Generate a new expr with the IterDomain inputs/outputs replaced based on map. | ||
| // Replaced inputs/outputs should almost exact match with provided expr. | ||
| Expr* IdModel::addExprWithReplacement( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused |
||
| const std::unordered_map<IterDomain*, IterDomain*>& old_2_new_ids, | ||
| Expr* old_expr) { | ||
| // Figure out which graphs are already initialized to make sure we add the new | ||
| // expression to them. | ||
| std::vector<IdMappingMode> initialized_modes; | ||
| for (auto mode : kIdMappingModes) { | ||
| auto graph_it = id_graphs_.find(mode); | ||
| if (graph_it == id_graphs_.end()) { | ||
| continue; | ||
| } | ||
|
|
||
| auto& graph = graph_it->second; | ||
| if (graph.disjointValSets().disjointSetMap().empty()) { | ||
| continue; | ||
| } | ||
|
|
||
| initialized_modes.push_back(mode); | ||
| } | ||
|
|
||
| // We will fill this map for every IterDomain in input and output. | ||
| std::unordered_map<IterDomain*, IterDomain*> replacement_map = old_2_new_ids; | ||
|
|
||
| // Validate replacement map. Make sure the keys are an input or output | ||
| for (auto replacement_entry : replacement_map) { | ||
| NVF_ERROR( | ||
| std::find( | ||
| old_expr->inputs().begin(), | ||
| old_expr->inputs().end(), | ||
| replacement_entry.first) != old_expr->inputs().end() || | ||
| std::find( | ||
| old_expr->outputs().begin(), | ||
| old_expr->outputs().end(), | ||
| replacement_entry.first) != old_expr->outputs().end(), | ||
| "Wanted to replace ", | ||
| replacement_entry.first->toString(), | ||
| " however the is not an input or output of:\n", | ||
| old_expr->toString()); | ||
| } | ||
|
|
||
| // If all inputs and or all output were replaced | ||
| bool all_inps_replaced = true; | ||
| bool all_outs_replaced = true; | ||
| { | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(old_expr->inputs())) { | ||
| if (replacement_map.find(inp_id) == replacement_map.end()) { | ||
| all_inps_replaced = false; | ||
| replacement_map[inp_id] = inp_id->cloneWithoutRFactor(); | ||
| } | ||
| } | ||
|
|
||
| for (auto out_id : | ||
| ir_utils::filterByType<IterDomain>(old_expr->outputs())) { | ||
| if (replacement_map.find(out_id) == replacement_map.end()) { | ||
| all_outs_replaced = false; | ||
| replacement_map[out_id] = out_id->cloneWithoutRFactor(); | ||
| } | ||
| } | ||
|
|
||
| NVF_ERROR( | ||
| (all_inps_replaced || all_outs_replaced), | ||
| "Either all the inputs or all the outputs need to be replaced when using this function."); | ||
|
|
||
| for (auto mode : initialized_modes) { | ||
| for (auto inp_or_out_id : all_inps_replaced | ||
| ? ir_utils::filterByType<IterDomain>(old_expr->inputs()) | ||
| : ir_utils::filterByType<IterDomain>(old_expr->outputs())) { | ||
| NVF_ERROR( | ||
| idGraph(mode).hasGroup(inp_or_out_id), | ||
| "Expected ", | ||
| inp_or_out_id->toString(), | ||
| " to be initialized in graph mode: ", | ||
| mode); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Create the new expression with provided outputs | ||
| auto replay = ReplacementTransformCloner::clone(replacement_map, old_expr); | ||
|
|
||
| // Add new output iter domains to id_definitions_/id_uses_ of IdModel | ||
| for (auto out_id : ir_utils::filterByType<IterDomain>(replay->outputs())) { | ||
| id_definitions_[out_id].pushBack(replay); | ||
| id_uses_[out_id]; | ||
| } | ||
|
|
||
| // Add new input iter domains to id_definitions_/id_uses_ of IdModel | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(replay->inputs())) { | ||
| id_definitions_[inp_id]; | ||
| id_uses_[inp_id].pushBack(replay); | ||
| } | ||
|
|
||
| // Update all the initialized graph mappings | ||
| for (auto mode : initialized_modes) { | ||
| auto& graph = idGraph(mode); | ||
|
|
||
| graph.registerExpr(replay); | ||
| auto replay_group = graph.toGroup(replay); | ||
|
|
||
| // Initialize any non-existent input ids, update existing ones | ||
| for (auto inp_id : ir_utils::filterByType<IterDomain>(replay->inputs())) { | ||
| if (!graph.disjointValSets().mappingExists(inp_id)) { | ||
| // inp_id is not initialized in the map, initialize it | ||
| graph.initializeVal(inp_id, {}, {replay}); | ||
| } else { | ||
| // Update unique uses of existing input ids | ||
| auto inp_group = graph.toGroup(inp_id); | ||
| graph.addUniqueUses(inp_group, replay_group); | ||
| } | ||
| } | ||
|
|
||
| // Initialize any non-existent output ids, update existing ones | ||
| for (auto out_id : ir_utils::filterByType<IterDomain>(replay->outputs())) { | ||
| if (!graph.disjointValSets().mappingExists(out_id)) { | ||
| // out_id is not initialized in the map, initialize it | ||
| graph.initializeVal(out_id, {replay}, {}); | ||
| } else { | ||
| // out_id is already initialized, add the replay as a unique definition | ||
| // of its group | ||
| auto out_group = graph.toGroup(out_id); | ||
| graph.addUniqueDefinitions(out_group, replay_group); | ||
| } | ||
| } | ||
|
|
||
| // If the inputs were replaced we want to map through forward the newly | ||
| // added expression. If the outputs were replaced we want to map through | ||
| // backwards the newly added expression. | ||
|
|
||
| // Forward | ||
| VectorOfUniqueEntries<Expr*> representative_uses; | ||
| for (auto in : ir_utils::filterByType<IterDomain>(replay->inputs())) { | ||
| for (const ExprGroup& use_group : graph.getUses(graph.toGroup(in))) { | ||
| if (use_group == replay_group) { | ||
| continue; | ||
| } | ||
| representative_uses.pushBack(use_group->front()); | ||
| } | ||
| } | ||
|
|
||
| for (auto rep_use : representative_uses) { | ||
| graph.maybeMapThroughExprs(rep_use, replay, true); | ||
| } | ||
|
|
||
| // Backwards | ||
| VectorOfUniqueEntries<Expr*> representative_defs; | ||
| for (auto out : ir_utils::filterByType<IterDomain>(replay->outputs())) { | ||
| for (const ExprGroup& def_group : | ||
| graph.getDefinitions(graph.toGroup(out))) { | ||
| if (def_group == replay_group) { | ||
| continue; | ||
| } | ||
| representative_defs.pushBack(def_group->front()); | ||
| } | ||
| } | ||
|
|
||
| for (auto rep_def : representative_defs) { | ||
| graph.maybeMapThroughExprs(rep_def, replay, false); | ||
| } | ||
| } | ||
| return replay; | ||
| } | ||
|
|
||
| // Clone provided iter domain and return the new copy. Map that copy in relevant | ||
| // maps. | ||
| IterDomain* IdModel::cloneIterDomain(IterDomain* id) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused |
||
| // Figure out which graphs are already initialized to make sure we add the new | ||
| // expression to them. | ||
| std::vector<IdMappingMode> initialized_modes; | ||
| for (auto mode : kIdMappingModes) { | ||
| auto graph_it = id_graphs_.find(mode); | ||
| if (graph_it == id_graphs_.end()) { | ||
| continue; | ||
| } | ||
|
|
||
| auto& graph = graph_it->second; | ||
| if (graph.disjointValSets().disjointSetMap().empty()) { | ||
| continue; | ||
| } | ||
|
|
||
| initialized_modes.push_back(mode); | ||
| } | ||
|
|
||
| auto id_copy = id->cloneWithoutRFactor(); | ||
|
|
||
| id_uses_[id_copy] = {}; | ||
| id_definitions_[id_copy] = {}; | ||
|
|
||
| for (auto mode : initialized_modes) { | ||
| idGraph(mode).initializeVal(id_copy, {}, {}); | ||
| idGraph(mode).mapVals(id, id_copy); | ||
| } | ||
|
|
||
| return id_copy; | ||
| } | ||
|
|
||
| ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) const { | ||
| ValGraph id_graph(propagate_through_exprs); | ||
|
|
||
|
|
@@ -603,13 +801,25 @@ void IdModel::buildLoopGraph() { | |
| maybeBuildGraph(IdMappingMode::EXACT); | ||
| maybeBuildGraph(IdMappingMode::PERMISSIVE); | ||
|
|
||
| if (!tv_exprs_.empty()) { | ||
| std::stringstream ss; | ||
| tv_exprs_.at(0)->fusion()->print(ss); | ||
| VERBOSE() << ss.str(); | ||
| } | ||
|
|
||
| const StatefulInliningInfo inlining_info = | ||
| buildStatefulInliningInfo(tv_exprs_, idGraph(IdMappingMode::PERMISSIVE)); | ||
|
|
||
| initializeLoopGraph(inlining_info); | ||
|
|
||
| validateLoopGraphHasNoSelfMappedLeafDomains(); | ||
|
|
||
| VERBOSE() << "Initial loop graph:\n"; | ||
| for (const auto& group : | ||
| idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { | ||
| VERBOSE() << nvfuser::toString(group) << std::endl; | ||
| } | ||
|
|
||
| loop_promotion_map_ = LoopPromotionMapBuilder::get( | ||
| *this, inlining_info, loop_promotion_map_builder_callback_); | ||
|
|
||
|
|
@@ -620,7 +830,55 @@ void IdModel::buildLoopGraph() { | |
| idGraph(IdMappingMode::LOOP).validateConsistency(); | ||
| } | ||
|
|
||
| // TODO: Reenable after reenabling parallel propagation. | ||
| // propagateLoopPTypes | ||
| void IdModel::validatePTypes(const std::vector<TensorView*>& all_tvs) const { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused |
||
| // VectorOfUniqueEntries<IterDomain*> leaf_ids; | ||
| // for (auto tv : all_tvs) { | ||
| // leaf_ids.pushBack(tv->domain()->leaf()); | ||
| // } | ||
|
|
||
| // for (const auto& disjoint_set : | ||
| // idGraph(IdMappingMode::EXACT).disjointValSets().disjointSets()) { | ||
| // for (auto id : disjoint_set->vector()) { | ||
| // auto id_ptype = id->getParallelType(); | ||
|
|
||
| // NVF_ERROR( | ||
| // leaf_ids.has(id) || id_ptype == ParallelType::Serial, | ||
| // "Invalid parallelization of non leaf iter domain: ", | ||
| // id->toString()); | ||
| // } | ||
| // } | ||
| } | ||
|
|
||
| void IdModel::propagateLoopPTypes() const { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused |
||
| for (const auto& loop_disjoint_set : | ||
| idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { | ||
| ParallelType common_ptype = ParallelType::Serial; | ||
| for (auto id : loop_disjoint_set->vector()) { | ||
| auto id_ptype = id->as<IterDomain>()->getParallelType(); | ||
|
|
||
| NVF_ERROR( | ||
| id_ptype == common_ptype || id_ptype == ParallelType::Serial || | ||
| common_ptype == ParallelType::Serial, | ||
| "Issue validating parallel type disjoint ptype is, ", | ||
| common_ptype, | ||
| " but found in the set the id: ", | ||
| id->toString()); | ||
|
|
||
| common_ptype = | ||
| common_ptype == ParallelType::Serial ? id_ptype : common_ptype; | ||
| } | ||
|
|
||
| for (auto id : loop_disjoint_set->vector()) { | ||
| id->as<IterDomain>()->parallelize(common_ptype); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void IdModel::buildAllGraphs() { | ||
| VERBOSE() << "*** Building all graphs ***"; | ||
|
|
||
| if (tvs_.empty()) { | ||
| return; | ||
| } | ||
|
|
@@ -663,6 +921,11 @@ void IdModel::buildAllGraphs() { | |
| idGraph(IdMappingMode::PERMISSIVE)); | ||
| } | ||
|
|
||
| // Permissive graph needs the trivial exprs from the almost exact graph to | ||
| // build correctly. Once built though we can remove the trivial expressions | ||
| // from the almost exact graph. | ||
| idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs(); | ||
|
|
||
| buildLoopGraph(); | ||
| } | ||
|
|
||
|
|
@@ -870,4 +1133,12 @@ std::unordered_map<ValGroup, IterDomain*> updateValGroupIdMap( | |
| return new_map; | ||
| } | ||
|
|
||
| std::unordered_map<IterDomain*, IterDomain*> IdModel::buildIndexGraph( | ||
| const std::vector<Expr*>& exprs, | ||
| const std::vector<TensorView*>& all_tvs, | ||
| StatefulInliningInfo& info, | ||
| std::unordered_map<ValGroup, IterDomain*> stale_promotion_map) { | ||
| NVF_ERROR(false, "Not implemented yet."); | ||
| } | ||
|
|
||
| } // namespace nvfuser | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we need to pop front, should we make this
DequeOfUniqueEntries?