Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
239 commits
Select commit Hold shift + click to select a range
16f064e
Builds and can start debugging again.
csarofeen Mar 19, 2023
45959ef
Helps if I include the right files.
csarofeen Mar 19, 2023
eee9bf7
Forward replay now works, need to implement backward index replay.
csarofeen Mar 23, 2023
8ce0963
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen Mar 25, 2023
f745410
Finish computing loop group promotions.
csarofeen Mar 25, 2023
6edbba3
Prepare for backward graph replays.
csarofeen Mar 27, 2023
1a2b261
Build promoted tensor domains.
csarofeen Apr 5, 2023
8b3fe64
First attempt at replaying the index operations.
csarofeen Apr 5, 2023
4d5c604
Print indexing expressoins and iter domains.
csarofeen Apr 5, 2023
0e47e5c
Clean up some printing.
csarofeen Apr 6, 2023
91fb637
Fix for iel promotion replay.
csarofeen Apr 6, 2023
cc17fef
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen Apr 7, 2023
1175da2
Fix indexing expression generation. Some minor cleanup and refactoring.
csarofeen Apr 8, 2023
7891741
First working index graph example.
csarofeen Apr 13, 2023
ee4e311
Fix building permissive graph.
csarofeen Apr 13, 2023
4249e94
Small fix, clean up printing.
csarofeen Apr 14, 2023
853cd44
Rework some interfaces around IdGraph, add option to not propagate th…
csarofeen Apr 20, 2023
dc88239
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen Apr 20, 2023
8d571a6
Cleanup debug print.
csarofeen Apr 20, 2023
8879459
Improve accuracy of loop grouping.
csarofeen Apr 20, 2023
a3a86fd
Four major tests working, minimal index map.
csarofeen Apr 26, 2023
6eaeb06
Cleanup print utilities.
csarofeen Apr 26, 2023
32cbc5b
Cleanup.
csarofeen Apr 26, 2023
db3ba36
Refactoring mono-function, in progress, functional.
csarofeen Apr 30, 2023
adbad3d
WIP Index Graph.
csarofeen May 6, 2023
68dec08
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen May 6, 2023
6f682d7
Merge conflicts.
csarofeen May 6, 2023
0588276
Swizzle fix.
csarofeen May 7, 2023
6991b90
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen May 7, 2023
81ba299
Most tests passing, only those looking for exact code match fail, or …
csarofeen May 8, 2023
4db481b
Project CI Green
csarofeen May 9, 2023
ad7012b
Reenable self mapping.
csarofeen May 9, 2023
a9d192e
Cleanup.
csarofeen May 9, 2023
e923a0a
Cleanup, move id modelling into its own directory.
csarofeen May 10, 2023
d3793f0
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen May 10, 2023
2896a28
License reference in files.
csarofeen May 10, 2023
0cc1356
Merge branch 'main' of https://github.com/nvidia/fuser into IterDomai…
csarofeen May 20, 2023
ddc858e
Merge Conflict Fix.
csarofeen May 20, 2023
8fd5bce
Cleanup.
csarofeen May 20, 2023
84ed670
Comment out WIP test.
csarofeen May 20, 2023
a466c5c
Minor cleanup.
csarofeen May 20, 2023
3934fe6
Cleanup.
csarofeen May 20, 2023
0729dee
Shuffling code.
csarofeen May 20, 2023
1f10bd3
Test fix.
csarofeen May 20, 2023
2a96ccb
Comments.
csarofeen May 20, 2023
cccb079
clang format.
csarofeen May 20, 2023
899d5e9
Merge branch 'main' into IterDomainGraphsMerge
csarofeen Jul 29, 2023
23b2e78
merge upstream
csarofeen Jul 30, 2023
cd593fd
Fix build with clang
naoyam Aug 15, 2023
69a0b0f
Merge branch 'main' into IterDomainGraphs
naoyam Aug 15, 2023
f8c1812
clang-format
naoyam Aug 15, 2023
93bb70a
Merge branch 'main' into HEAD
naoyam Aug 30, 2023
04172d7
Merge branch 'main' into IterDomainGraphs
naoyam Aug 30, 2023
b75e197
Cosmetic cleanup
naoyam Aug 30, 2023
8cf25af
Cleanup
naoyam Aug 31, 2023
a98883d
cleanup
naoyam Aug 31, 2023
51243b5
cleanup
naoyam Aug 31, 2023
a0d8c43
comment
naoyam Aug 31, 2023
397edcc
cleanup mapIds
naoyam Aug 31, 2023
fd97525
cleanup
naoyam Aug 31, 2023
177d40c
cleanup
naoyam Aug 31, 2023
535ce1d
Enable tests
naoyam Aug 31, 2023
b80e871
cleanup
naoyam Sep 1, 2023
bb4968b
comment
naoyam Sep 1, 2023
d6504f8
debug output
naoyam Sep 23, 2023
6a6ee7a
Merge remote-tracking branch 'origin/main' into IterDomainGraphs
naoyam Oct 2, 2023
edddb91
Merge branch 'main' into IterDomainGraphs
naoyam Oct 4, 2023
60407b2
Remove --no-allow-shlib-undefined as our dependencies may not have
naoyam Oct 4, 2023
49967fb
Merge branch 'remove-no-allow-shlib-undefined' into IterDomainGraphs
naoyam Oct 4, 2023
24dc758
Fill the final loop promotion map (#1028)
naoyam Oct 5, 2023
f9045f9
Just cleanup. There should be no functional change (#1061)
naoyam Oct 11, 2023
c51379e
IdGraph cleanup (#1062)
naoyam Oct 11, 2023
7d4acab
Merge branch 'main' into IterDomainGraphs
naoyam Oct 11, 2023
d77436c
remove TORCH_CUDA_CU_API
naoyam Oct 11, 2023
7f34b17
clang-tidy
naoyam Oct 11, 2023
6603e0a
Do not map broadcast and non-broadcast domains in EXACT (#1065)
naoyam Oct 12, 2023
aacc529
Fix mapping propagations through uses (#1072)
naoyam Oct 12, 2023
a761409
Use the EXACT map as the starting map of the PERMISSIVE map (#1073)
naoyam Oct 12, 2023
d3eb4c1
Same as #1072 but for definitions (#1081)
naoyam Oct 13, 2023
3bb9692
minor cleanup
naoyam Oct 13, 2023
b778788
Disable mappings of non-ca domains in LOOP. (#1082)
naoyam Oct 13, 2023
4c50dcf
Remove unnecessary code (#1087)
naoyam Oct 16, 2023
9fa2d1c
Merge branch 'main' into IterDomainGraphs
naoyam Oct 16, 2023
e20a76b
clang-format
naoyam Oct 16, 2023
27c619a
cleanup
naoyam Oct 16, 2023
4523cb2
WIP: Loop promotion with IEL (#1090)
naoyam Oct 16, 2023
86c574a
Verify loop mappings of leaf domains
naoyam Oct 17, 2023
bff13b8
minor change
naoyam Oct 17, 2023
283a3c9
Merge branch 'main' into IterDomainGraphs
naoyam Oct 25, 2023
aadcb9d
Fix use of replayed input IDs (#1144)
naoyam Oct 25, 2023
5f0e5c4
cleanup
naoyam Oct 25, 2023
811a4ad
Finalize loop promotion map (#1149)
naoyam Oct 26, 2023
2577dfc
Merge branch 'main' into IterDomainGraphs
naoyam Oct 26, 2023
f3bbd8f
minor
naoyam Oct 26, 2023
f2448c2
Renaming IterDomainGraphs to IdModel (#1161)
naoyam Oct 26, 2023
57469b3
Follow-up to #1161 (#1165)
naoyam Oct 27, 2023
5dde9f1
build fix (#1166)
naoyam Oct 27, 2023
7f956c4
Merge branch 'main' into IterDomainGraphs
naoyam Nov 8, 2023
e01e2e6
Mechanical changes extracted from #1168 (#1293)
naoyam Nov 14, 2023
b3e60b5
Merge branch 'main' into IterDomainGraphs
naoyam Nov 14, 2023
f7f4d84
Updating IterDomainGraphs branch with main (#1412)
naoyam Nov 29, 2023
f073d04
Merging current main to IterDomainGraphs (#1417)
naoyam Nov 30, 2023
6ba29ed
Merge lower2device
naoyam Nov 30, 2023
9eacdf6
Merge from main
naoyam Nov 30, 2023
f8a9585
Fix validation
naoyam Dec 1, 2023
9733ab6
Merge from main
naoyam Dec 1, 2023
c07402e
WIP
naoyam Dec 1, 2023
dbff25c
remove non-const disjointExprs and disjointVals
naoyam Dec 1, 2023
f1b5f63
Clean up self mapping
naoyam Dec 1, 2023
ad5debf
IdModel: merge main (#1453)
naoyam Dec 5, 2023
a17dc11
Merge branch 'main' into IterDomainGraphs
naoyam Dec 5, 2023
14ab237
IdModel: cleanup almost exact (#1457)
naoyam Dec 6, 2023
f6e6848
cleanup transform_iter
naoyam Dec 7, 2023
5715ce9
Merge branch 'main' into IterDomainGraphs
naoyam Dec 8, 2023
a8070a9
cleanup
naoyam Dec 8, 2023
6c4a5f3
fix
naoyam Dec 8, 2023
ad83b72
cleanup
naoyam Dec 8, 2023
554bd3e
It should not be necessary to handle swizzles in the permissive map
naoyam Dec 8, 2023
4691a3a
Bug fix
naoyam Dec 8, 2023
1f36eb7
Merge branch 'main' into IterDomainGraphs
naoyam Dec 8, 2023
f5b39e0
Fix mapping propagation through backward merge (#1511)
naoyam Dec 13, 2023
636fbd7
Merge branch 'main' into IterDomainGraphs
naoyam Dec 13, 2023
bc8fc05
Temporarily remove IdMappingMode::INDEX as it doesn't exist yet
naoyam Dec 13, 2023
9b6d761
Revert unnecessary change
naoyam Dec 13, 2023
4d2dc50
cleanup
naoyam Dec 13, 2023
4ad9eeb
Make it explicit that broadcasts are mapped
naoyam Dec 13, 2023
151b0ef
[IdModel] Propagation fix for permissive graphs (#1538)
naoyam Dec 21, 2023
27cd19c
Merge branch 'main' into IterDomainGraphs
naoyam Dec 22, 2023
e45b1bb
[IdModel] merge main 20231227 (#1564)
naoyam Dec 28, 2023
3adffd9
Merge branch 'main' into IterDomainGraphs
naoyam Dec 28, 2023
f9c9d37
cleanup
naoyam Dec 29, 2023
6fbebc9
cleanup
naoyam Dec 30, 2023
b364f28
cleanup
naoyam Dec 31, 2023
3b94574
IdModel: refactoring loop promotion (#1589)
naoyam Jan 6, 2024
1904eff
further refactoring of loop promotion
naoyam Jan 7, 2024
133613a
IdModel: enable compliment mapping (#1611)
naoyam Jan 11, 2024
c66617d
test cleanup
naoyam Jan 12, 2024
119cf0f
Merge branch 'main' into IterDomainGraphs
naoyam Jan 12, 2024
15bfe64
[IdModle] Val graph cleanup (#1637)
naoyam Jan 18, 2024
64b409e
fix memory usage
naoyam Jan 18, 2024
a12514e
IdModel: Fix inconsistent graph issue (#1627)
naoyam Jan 18, 2024
e6c43e9
Merge branch 'main' into IterDomainGraphs
naoyam Jan 18, 2024
197f227
Cherry pick #1627
naoyam Jan 18, 2024
5d5458e
enable idmodel
naoyam Jan 18, 2024
de20be4
disable idmodel
naoyam Jan 19, 2024
9286489
[IdModel] Refactoring for testing (#1624)
naoyam Jan 19, 2024
ae70e3c
Merge branch 'main' into IterDomainGraphs
naoyam Jan 19, 2024
2dc3262
Merge branch 'main' into idmodel_fix_inconsistent_graph
naoyam Jan 20, 2024
a902c6e
Refactoring to allow more flexible testing
naoyam Jan 20, 2024
7ef52af
enable idmodel
naoyam Jan 20, 2024
eced85a
comment
naoyam Jan 20, 2024
927aec4
Merge branch 'main' into idmodel_refactoring
naoyam Jan 21, 2024
cbaaf0e
WIP: Build a map of broadcast resolutions for root domains
naoyam Jan 21, 2024
45dd418
fix
naoyam Jan 21, 2024
be6eaac
Merge branch 'main' into idmodel_root_resolution
naoyam Jan 22, 2024
534ac78
rename
naoyam Jan 26, 2024
2837e53
Add tests with example fusions used in the design doc
naoyam Jan 26, 2024
45b8be9
Clean up tests
naoyam Jan 26, 2024
ec8a2f5
bug fix
naoyam Jan 26, 2024
1960432
Add test
naoyam Jan 26, 2024
74c98e2
Add a test
naoyam Jan 27, 2024
ccd7bab
comment
naoyam Jan 27, 2024
6505c63
test cleanup
naoyam Jan 27, 2024
1c11182
IdModel: merge c8c21fd087737dd8fb863867e5a012116a69a731 (#1688)
naoyam Jan 27, 2024
148ef83
Merge branch 'main' into IterDomainGraphs
naoyam Jan 27, 2024
3e7992c
Merge branch 'main' into idmodel_root_resolution
naoyam Jan 27, 2024
fe1517d
Merge remote-tracking branch 'origin/idmodel_root_resolution' into It…
naoyam Jan 27, 2024
62346a9
Add tests for Step 2
naoyam Jan 27, 2024
04458cb
rename
naoyam Feb 1, 2024
6a112a3
Cleaning up visitor
naoyam Feb 1, 2024
f055994
Move ValGraphVisitor out of id_model
naoyam Feb 1, 2024
4781d65
Tests for ValGraphStmtSort
naoyam Feb 1, 2024
7962338
WIP
naoyam Feb 1, 2024
1eebde4
Fix determinism bug
naoyam Feb 2, 2024
0e72708
refactoring
naoyam Feb 2, 2024
7561d19
test cleanup
naoyam Feb 2, 2024
48e3019
cleanup
naoyam Feb 2, 2024
a828f6f
Simplify by removing sub_selection as it's not used
naoyam Feb 2, 2024
fa99fe2
fix for trivial exprs
naoyam Feb 2, 2024
2321f3e
Accidentally added
naoyam Feb 2, 2024
ad3da55
Revert "Accidentally added"
naoyam Feb 2, 2024
e35fb69
Accidentally added
naoyam Feb 2, 2024
9727240
Repro for the compliment mapping issue
naoyam Feb 13, 2024
4028dea
clang-format
naoyam Feb 13, 2024
32f47db
Merge branch 'main' into IterDomainGraphs
naoyam Feb 13, 2024
3d8b582
rename
naoyam Feb 13, 2024
44da75a
Error check
naoyam Feb 13, 2024
44ad7e0
Merge branch 'main' into IterDomainGraphs
naoyam Feb 16, 2024
d8f3ed4
WIP: Loop promotion analysis step 2
naoyam Feb 16, 2024
cb5d1bc
Merge branch 'main' into idmodel_step2
naoyam Feb 23, 2024
c14dd72
cleanup
naoyam Feb 23, 2024
fc861a6
update
naoyam Feb 23, 2024
e6071e5
cleanup
naoyam Feb 23, 2024
0998319
Merge branch 'main' into idmodel_step2
naoyam Feb 23, 2024
c8db2f7
Merge branch 'main' into IterDomainGraphs
naoyam Feb 23, 2024
8740217
Merge branch 'idmodel_step2' into IterDomainGraphs
naoyam Feb 23, 2024
0e69513
Simplify for step 2
naoyam Feb 23, 2024
30e3d5b
cleanup
naoyam Feb 23, 2024
fcc3b96
Merge branch 'idmodel_step2' into IterDomainGraphs
naoyam Feb 24, 2024
b30682a
format
naoyam Feb 24, 2024
9ed0b82
Merge branch 'main' into IterDomainGraphs
naoyam Mar 13, 2024
4d1dcdc
Merge branch 'main' into IterDomainGraphs
naoyam Mar 15, 2024
1231b60
Merge branch 'main' into IterDomainGraphs
naoyam Mar 18, 2024
04b1479
Test cleanup
naoyam Mar 18, 2024
4b391d8
cleanup
naoyam Mar 25, 2024
270ac19
test update for step 4
naoyam Mar 26, 2024
5d20b8d
Merge branch 'main' into IterDomainGraphs
naoyam Mar 26, 2024
bd46e88
Merge branch 'main' into IterDomainGraphs
naoyam Apr 4, 2024
e17fa35
cleanup
naoyam Apr 4, 2024
5573fbc
Merge branch 'main' into IterDomainGraphs
naoyam Apr 11, 2024
e644665
Merge branch 'main' into IterDomainGraphs
naoyam May 8, 2024
e21e4e5
clang-tidy
naoyam May 8, 2024
801a7b4
Step 5 of loop promotion analysis
naoyam May 8, 2024
649563a
comment
naoyam May 10, 2024
b25ced1
repro of issue #1759
naoyam Mar 18, 2024
3038e0d
Merge branch 'main' into idmodel_step5
naoyam May 10, 2024
6ef567b
Merge branch 'main' into IterDomainGraphs
naoyam May 10, 2024
0027ee9
Merge branch 'idmodel_step5' into IterDomainGraphs
naoyam May 10, 2024
2bcbcb4
Merge branch 'main' into IterDomainGraphs
naoyam May 10, 2024
fa8455a
LoopPromotionBuilder class
naoyam May 10, 2024
c05fd85
Copied all loop promotion code
naoyam May 10, 2024
31d28e5
enable idmodel
naoyam May 10, 2024
d0bdc8e
Switch to the new builder
naoyam May 10, 2024
aa9b414
const
naoyam May 10, 2024
7aef29d
cleanup
naoyam May 10, 2024
24a0c8c
Remove loop promotion code from IdModel
naoyam May 10, 2024
c7c04b1
replace tester with callback
naoyam May 10, 2024
2b605cd
clang-format
naoyam May 10, 2024
0f5ab07
comment
naoyam May 10, 2024
423fde1
Merge branch 'idmodel_loop_promotion_cleanup' into IterDomainGraphs
naoyam May 10, 2024
c78da86
Merge branch 'main' into IterDomainGraphs
naoyam May 21, 2024
f0aeab7
cleanup
naoyam May 21, 2024
ce5f8ee
Merge branch 'main' into IterDomainGraphs
naoyam May 21, 2024
3214bc7
Merge branch 'main' into IterDomainGraphs
naoyam May 24, 2024
d8aacea
Merge branch 'main' into IterDomainGraphs
naoyam May 28, 2024
18de523
enable
naoyam May 28, 2024
83edc44
Merge branch 'main' into IterDomainGraphs
naoyam May 30, 2024
e71cb5a
Merge branch 'main' into IterDomainGraphs
naoyam May 31, 2024
009a11a
Merge branch 'main' into IterDomainGraphs
naoyam Jun 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ void GpuLower::analysis(Fusion* fusion) {
// functionality should be affected. New IterDomains may be created,
// so it is expected that generated code may use diffrent variable
// names
if (isOptionEnabled(EnableOption::IdModel)) {
if (true || isOptionEnabled(EnableOption::IdModel)) {
IdModel id_model(fusion_);
}

Expand Down
21 changes: 20 additions & 1 deletion csrc/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ class VectorOfUniqueEntries {
return false;
}

// Returns if a node was actually added
bool pushFront(T entry) {
if (set_.emplace(entry).second) {
vector_.insert(vector_.begin(), entry);
return true;
}
return false;
}

// Returns true if any node was added
bool pushBack(const VectorOfUniqueEntries<T, Hash>& other) {
return pushBack(other.vector());
Expand Down Expand Up @@ -170,6 +179,14 @@ class VectorOfUniqueEntries {
return v;
}

// Remove and returns the last element in vector
T popFront() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we need to pop front, should we make this DequeOfUniqueEntries?

T v = vector_.front();
set_.erase(v);
vector_.erase(vector_.begin());
return v;
}

// Returns if this container is empty
bool empty() const {
return vector_.empty();
Expand Down Expand Up @@ -394,7 +411,9 @@ class DisjointSets {
entry_it != disjointSetMap().end(),
"Strict mapping failed on element: ",
abstractToString(entry0),
" either an error occurred, or non strict mapping should have been used.");
" either an error occurred, or non strict mapping should have been used.",
" ",
entry0->name());
return entry_it->second->has(entry1);
}

Expand Down
271 changes: 271 additions & 0 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <id_model/loop_promotion.h>
#include <id_model/to_string.h>
#include <id_model/transform_replay.h>
#include <id_model/utils.h>
#include <id_model/validation_utils.h>

#include <device_lower/analysis/trivial_broadcast.h>
Expand Down Expand Up @@ -233,6 +234,203 @@ std::string IdModel::toString() const {
return ss.str();
}

// Generate a new expr with the IterDomain inputs/outputs replaced based on map.
// Replaced inputs/outputs should almost exact match with provided expr.
Expr* IdModel::addExprWithReplacement(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

const std::unordered_map<IterDomain*, IterDomain*>& old_2_new_ids,
Expr* old_expr) {
// Figure out which graphs are already initialized to make sure we add the new
// expression to them.
std::vector<IdMappingMode> initialized_modes;
for (auto mode : kIdMappingModes) {
auto graph_it = id_graphs_.find(mode);
if (graph_it == id_graphs_.end()) {
continue;
}

auto& graph = graph_it->second;
if (graph.disjointValSets().disjointSetMap().empty()) {
continue;
}

initialized_modes.push_back(mode);
}

// We will fill this map for every IterDomain in input and output.
std::unordered_map<IterDomain*, IterDomain*> replacement_map = old_2_new_ids;

// Validate replacement map. Make sure the keys are an input or output
for (auto replacement_entry : replacement_map) {
NVF_ERROR(
std::find(
old_expr->inputs().begin(),
old_expr->inputs().end(),
replacement_entry.first) != old_expr->inputs().end() ||
std::find(
old_expr->outputs().begin(),
old_expr->outputs().end(),
replacement_entry.first) != old_expr->outputs().end(),
"Wanted to replace ",
replacement_entry.first->toString(),
" however the is not an input or output of:\n",
old_expr->toString());
}

// If all inputs and or all output were replaced
bool all_inps_replaced = true;
bool all_outs_replaced = true;
{
for (auto inp_id : ir_utils::filterByType<IterDomain>(old_expr->inputs())) {
if (replacement_map.find(inp_id) == replacement_map.end()) {
all_inps_replaced = false;
replacement_map[inp_id] = inp_id->cloneWithoutRFactor();
}
}

for (auto out_id :
ir_utils::filterByType<IterDomain>(old_expr->outputs())) {
if (replacement_map.find(out_id) == replacement_map.end()) {
all_outs_replaced = false;
replacement_map[out_id] = out_id->cloneWithoutRFactor();
}
}

NVF_ERROR(
(all_inps_replaced || all_outs_replaced),
"Either all the inputs or all the outputs need to be replaced when using this function.");

for (auto mode : initialized_modes) {
for (auto inp_or_out_id : all_inps_replaced
? ir_utils::filterByType<IterDomain>(old_expr->inputs())
: ir_utils::filterByType<IterDomain>(old_expr->outputs())) {
NVF_ERROR(
idGraph(mode).hasGroup(inp_or_out_id),
"Expected ",
inp_or_out_id->toString(),
" to be initialized in graph mode: ",
mode);
}
}
}

// Create the new expression with provided outputs
auto replay = ReplacementTransformCloner::clone(replacement_map, old_expr);

// Add new output iter domains to id_definitions_/id_uses_ of IdModel
for (auto out_id : ir_utils::filterByType<IterDomain>(replay->outputs())) {
id_definitions_[out_id].pushBack(replay);
id_uses_[out_id];
}

// Add new input iter domains to id_definitions_/id_uses_ of IdModel
for (auto inp_id : ir_utils::filterByType<IterDomain>(replay->inputs())) {
id_definitions_[inp_id];
id_uses_[inp_id].pushBack(replay);
}

// Update all the initialized graph mappings
for (auto mode : initialized_modes) {
auto& graph = idGraph(mode);

graph.registerExpr(replay);
auto replay_group = graph.toGroup(replay);

// Initialize any non-existent input ids, update existing ones
for (auto inp_id : ir_utils::filterByType<IterDomain>(replay->inputs())) {
if (!graph.disjointValSets().mappingExists(inp_id)) {
// inp_id is not initialized in the map, initialize it
graph.initializeVal(inp_id, {}, {replay});
} else {
// Update unique uses of existing input ids
auto inp_group = graph.toGroup(inp_id);
graph.addUniqueUses(inp_group, replay_group);
}
}

// Initialize any non-existent output ids, update existing ones
for (auto out_id : ir_utils::filterByType<IterDomain>(replay->outputs())) {
if (!graph.disjointValSets().mappingExists(out_id)) {
// out_id is not initialized in the map, initialize it
graph.initializeVal(out_id, {replay}, {});
} else {
// out_id is already initialized, add the replay as a unique definition
// of its group
auto out_group = graph.toGroup(out_id);
graph.addUniqueDefinitions(out_group, replay_group);
}
}

// If the inputs were replaced we want to map through forward the newly
// added expression. If the outputs were replaced we want to map through
// backwards the newly added expression.

// Forward
VectorOfUniqueEntries<Expr*> representative_uses;
for (auto in : ir_utils::filterByType<IterDomain>(replay->inputs())) {
for (const ExprGroup& use_group : graph.getUses(graph.toGroup(in))) {
if (use_group == replay_group) {
continue;
}
representative_uses.pushBack(use_group->front());
}
}

for (auto rep_use : representative_uses) {
graph.maybeMapThroughExprs(rep_use, replay, true);
}

// Backwards
VectorOfUniqueEntries<Expr*> representative_defs;
for (auto out : ir_utils::filterByType<IterDomain>(replay->outputs())) {
for (const ExprGroup& def_group :
graph.getDefinitions(graph.toGroup(out))) {
if (def_group == replay_group) {
continue;
}
representative_defs.pushBack(def_group->front());
}
}

for (auto rep_def : representative_defs) {
graph.maybeMapThroughExprs(rep_def, replay, false);
}
}
return replay;
}

// Clone provided iter domain and return the new copy. Map that copy in relevant
// maps.
IterDomain* IdModel::cloneIterDomain(IterDomain* id) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

// Figure out which graphs are already initialized to make sure we add the new
// expression to them.
std::vector<IdMappingMode> initialized_modes;
for (auto mode : kIdMappingModes) {
auto graph_it = id_graphs_.find(mode);
if (graph_it == id_graphs_.end()) {
continue;
}

auto& graph = graph_it->second;
if (graph.disjointValSets().disjointSetMap().empty()) {
continue;
}

initialized_modes.push_back(mode);
}

auto id_copy = id->cloneWithoutRFactor();

id_uses_[id_copy] = {};
id_definitions_[id_copy] = {};

for (auto mode : initialized_modes) {
idGraph(mode).initializeVal(id_copy, {}, {});
idGraph(mode).mapVals(id, id_copy);
}

return id_copy;
}

ValGraph IdModel::initializeIdGraph(bool propagate_through_exprs) const {
ValGraph id_graph(propagate_through_exprs);

Expand Down Expand Up @@ -603,13 +801,25 @@ void IdModel::buildLoopGraph() {
maybeBuildGraph(IdMappingMode::EXACT);
maybeBuildGraph(IdMappingMode::PERMISSIVE);

if (!tv_exprs_.empty()) {
std::stringstream ss;
tv_exprs_.at(0)->fusion()->print(ss);
VERBOSE() << ss.str();
}

const StatefulInliningInfo inlining_info =
buildStatefulInliningInfo(tv_exprs_, idGraph(IdMappingMode::PERMISSIVE));

initializeLoopGraph(inlining_info);

validateLoopGraphHasNoSelfMappedLeafDomains();

VERBOSE() << "Initial loop graph:\n";
for (const auto& group :
idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) {
VERBOSE() << nvfuser::toString(group) << std::endl;
}

loop_promotion_map_ = LoopPromotionMapBuilder::get(
*this, inlining_info, loop_promotion_map_builder_callback_);

Expand All @@ -620,7 +830,55 @@ void IdModel::buildLoopGraph() {
idGraph(IdMappingMode::LOOP).validateConsistency();
}

// TODO: Reenable after reenabling parallel propagation.
// propagateLoopPTypes
void IdModel::validatePTypes(const std::vector<TensorView*>& all_tvs) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused

// VectorOfUniqueEntries<IterDomain*> leaf_ids;
// for (auto tv : all_tvs) {
// leaf_ids.pushBack(tv->domain()->leaf());
// }

// for (const auto& disjoint_set :
// idGraph(IdMappingMode::EXACT).disjointValSets().disjointSets()) {
// for (auto id : disjoint_set->vector()) {
// auto id_ptype = id->getParallelType();

// NVF_ERROR(
// leaf_ids.has(id) || id_ptype == ParallelType::Serial,
// "Invalid parallelization of non leaf iter domain: ",
// id->toString());
// }
// }
}

void IdModel::propagateLoopPTypes() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

for (const auto& loop_disjoint_set :
idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) {
ParallelType common_ptype = ParallelType::Serial;
for (auto id : loop_disjoint_set->vector()) {
auto id_ptype = id->as<IterDomain>()->getParallelType();

NVF_ERROR(
id_ptype == common_ptype || id_ptype == ParallelType::Serial ||
common_ptype == ParallelType::Serial,
"Issue validating parallel type disjoint ptype is, ",
common_ptype,
" but found in the set the id: ",
id->toString());

common_ptype =
common_ptype == ParallelType::Serial ? id_ptype : common_ptype;
}

for (auto id : loop_disjoint_set->vector()) {
id->as<IterDomain>()->parallelize(common_ptype);
}
}
}

void IdModel::buildAllGraphs() {
VERBOSE() << "*** Building all graphs ***";

if (tvs_.empty()) {
return;
}
Expand Down Expand Up @@ -663,6 +921,11 @@ void IdModel::buildAllGraphs() {
idGraph(IdMappingMode::PERMISSIVE));
}

// Permissive graph needs the trivial exprs from the almost exact graph to
// build correctly. Once built though we can remove the trivial expressions
// from the almost exact graph.
idGraph(IdMappingMode::ALMOSTEXACT).removeTrivialExprs();

buildLoopGraph();
}

Expand Down Expand Up @@ -870,4 +1133,12 @@ std::unordered_map<ValGroup, IterDomain*> updateValGroupIdMap(
return new_map;
}

std::unordered_map<IterDomain*, IterDomain*> IdModel::buildIndexGraph(
const std::vector<Expr*>& exprs,
const std::vector<TensorView*>& all_tvs,
StatefulInliningInfo& info,
std::unordered_map<ValGroup, IterDomain*> stale_promotion_map) {
NVF_ERROR(false, "Not implemented yet.");
}

} // namespace nvfuser
Loading