-
Notifications
You must be signed in to change notification settings - Fork 69
WIP: IterDomain Graphs #32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
16f064e
45959ef
eee9bf7
8ce0963
f745410
6edbba3
1a2b261
8b3fe64
4d5c604
0e47e5c
91fb637
cc17fef
1175da2
7891741
ee4e311
4249e94
853cd44
dc88239
8d571a6
8879459
a3a86fd
6eaeb06
32cbc5b
db3ba36
adbad3d
68dec08
6f682d7
0588276
6991b90
81ba299
4db481b
ad7012b
a9d192e
e923a0a
d3793f0
2896a28
0cc1356
ddc858e
8fd5bce
84ed670
a466c5c
3934fe6
0729dee
1f10bd3
2a96ccb
cccb079
899d5e9
23b2e78
cd593fd
69a0b0f
f8c1812
93bb70a
04172d7
b75e197
8cf25af
a98883d
51243b5
a0d8c43
397edcc
fd97525
177d40c
535ce1d
b80e871
bb4968b
d6504f8
6a6ee7a
edddb91
60407b2
49967fb
24dc758
f9045f9
c51379e
7d4acab
d77436c
7f34b17
6603e0a
aacc529
a761409
d3eb4c1
3bb9692
b778788
4c50dcf
9fa2d1c
e20a76b
27c619a
4523cb2
86c574a
bff13b8
283a3c9
aadcb9d
5f0e5c4
811a4ad
2577dfc
f3bbd8f
f2448c2
57469b3
5dde9f1
7f956c4
e01e2e6
b3e60b5
f7f4d84
f073d04
6ba29ed
9eacdf6
f8a9585
9733ab6
c07402e
dbff25c
f1b5f63
ad5debf
a17dc11
14ab237
f6e6848
5715ce9
a8070a9
6c4a5f3
ad83b72
554bd3e
4691a3a
1f36eb7
f5b39e0
636fbd7
bc8fc05
9b6d761
4d2dc50
4ad9eeb
151b0ef
27cd19c
e45b1bb
3adffd9
f9c9d37
6fbebc9
b364f28
3b94574
1904eff
133613a
c66617d
119cf0f
15bfe64
64b409e
a12514e
e6c43e9
197f227
5d5458e
de20be4
9286489
ae70e3c
2dc3262
a902c6e
7ef52af
eced85a
927aec4
cbaaf0e
45dd418
be6eaac
534ac78
2837e53
45b8be9
ec8a2f5
1960432
74c98e2
ccd7bab
6505c63
1c11182
148ef83
3e7992c
fe1517d
62346a9
04458cb
6a112a3
f055994
4781d65
7962338
1eebde4
0e72708
7561d19
48e3019
a828f6f
fa99fe2
2321f3e
ad3da55
e35fb69
9727240
4028dea
32f47db
3d8b582
44da75a
44ad7e0
d8f3ed4
cb5d1bc
c14dd72
fc861a6
e6071e5
0998319
c8db2f7
8740217
0e69513
30e3d5b
fcc3b96
b30682a
9ed0b82
4d1dcdc
1231b60
04b1479
4b391d8
270ac19
5d20b8d
bd46e88
e17fa35
5573fbc
e644665
e21e4e5
801a7b4
649563a
b25ced1
3038e0d
6ef567b
0027ee9
2bcbcb4
fa8455a
c05fd85
31d28e5
d0bdc8e
aa9b414
7aef29d
24a0c8c
c7c04b1
2b605cd
0f5ab07
423fde1
c78da86
f0aeab7
ce5f8ee
3214bc7
d8aacea
18de523
83edc44
e71cb5a
009a11a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,3 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #pragma once | ||
|
|
||
| #include <c10/util/Exception.h> | ||
|
|
@@ -36,13 +29,38 @@ std::string abstractToString(T ref) { | |
|
|
||
| // Vector like class that will prevent adding duplicate entries by also | ||
| // maintaing a set | ||
| // | ||
| // TODO: Can we support std::back_inserter with this class? | ||
| template <typename T, typename Hash = std::hash<T>> | ||
| class VectorOfUniqueEntries { | ||
| public: | ||
| VectorOfUniqueEntries() = default; | ||
|
|
||
| VectorOfUniqueEntries(const std::initializer_list<T>& x) | ||
| : vector_(x), set_(x) {} | ||
| VectorOfUniqueEntries(const std::initializer_list<T>& initializer) { | ||
| for (auto entry : initializer) { | ||
| pushBack(entry); | ||
| } | ||
| } | ||
|
|
||
| VectorOfUniqueEntries(const VectorOfUniqueEntries<T>& other) { | ||
| vector_ = other.vector(); | ||
| set_ = other.set(); | ||
| } | ||
|
|
||
| VectorOfUniqueEntries& operator=(const VectorOfUniqueEntries<T>& other) { | ||
|
||
| if (this != &other) { | ||
| vector_ = other.vector(); | ||
| set_ = other.set(); | ||
| } | ||
| return *this; | ||
| } | ||
|
|
||
| template <class InputIt> | ||
| VectorOfUniqueEntries(InputIt first, InputIt last) { | ||
| while (first != last) { | ||
| pushBack(*first++); | ||
| } | ||
| } | ||
|
|
||
| // Returns if a node was actually added | ||
| bool pushBack(T entry) { | ||
|
|
@@ -53,6 +71,15 @@ class VectorOfUniqueEntries { | |
| return false; | ||
| } | ||
|
|
||
| // Returns if a node was actually added | ||
| bool pushFront(T entry) { | ||
| if (set_.emplace(entry).second) { | ||
| vector_.insert(vector_.begin(), entry); | ||
| return true; | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| // Returns if any node was added | ||
| bool pushBack(const VectorOfUniqueEntries<T, Hash>& other) { | ||
| bool any_added = false; | ||
|
|
@@ -62,11 +89,53 @@ class VectorOfUniqueEntries { | |
| return any_added; | ||
| } | ||
|
|
||
| // Returns a new VectorOfUniqueEntries with entries that are in both this and | ||
| // other, order is preserved as this. | ||
| VectorOfUniqueEntries<T, Hash> intersect( | ||
| const VectorOfUniqueEntries<T, Hash>& other) { | ||
| VectorOfUniqueEntries<T, Hash> intersection; | ||
| for (auto entry : vector()) { | ||
| if (other.has(entry)) { | ||
| intersection.pushBack(entry); | ||
| } | ||
| } | ||
| return intersection; | ||
| } | ||
|
|
||
| // Returns a new VectorOfUniqueEntries with entries that are in this but not | ||
| // in other. | ||
| VectorOfUniqueEntries<T, Hash> subtract( | ||
| const VectorOfUniqueEntries<T, Hash>& other) const { | ||
| VectorOfUniqueEntries<T, Hash> subtraction; | ||
| for (auto entry : vector()) { | ||
| if (!other.has(entry)) { | ||
| subtraction.pushBack(entry); | ||
| } | ||
| } | ||
| return subtraction; | ||
| } | ||
|
|
||
| // Returns a new VectorOfUniqueEntries with entries that are either in this or | ||
| // other. | ||
| VectorOfUniqueEntries<T, Hash> computeUnion( | ||
| const VectorOfUniqueEntries<T, Hash>& other) const { | ||
| const VectorOfUniqueEntries<T, Hash>& this_ref = *this; | ||
| VectorOfUniqueEntries<T, Hash> union_(this_ref); | ||
| for (auto entry : other.vector()) { | ||
| union_.pushBack(entry); | ||
| } | ||
| return union_; | ||
| } | ||
|
|
||
| // Returns a const vector useful for iterating on | ||
| const std::vector<T>& vector() const { | ||
| return vector_; | ||
| } | ||
|
|
||
| const std::unordered_set<T>& set() const { | ||
| return set_; | ||
| } | ||
|
|
||
| // Returns first element in vector | ||
| T front() const { | ||
| return vector_.front(); | ||
|
|
@@ -85,6 +154,14 @@ class VectorOfUniqueEntries { | |
| return v; | ||
| } | ||
|
|
||
| // Remove and returns the last element in vector | ||
| T popFront() { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we need to pop front, should we make this |
||
| T v = vector_.front(); | ||
| set_.erase(v); | ||
| vector_.erase(vector_.begin()); | ||
| return v; | ||
| } | ||
|
|
||
| // Returns if this container is empty | ||
| bool empty() const { | ||
| return vector_.empty(); | ||
|
|
@@ -141,7 +218,7 @@ class VectorOfUniqueEntries { | |
| return vector_.end(); | ||
| } | ||
|
|
||
| std::string toString() { | ||
| std::string toString() const { | ||
| std::stringstream ss; | ||
| ss << "{ "; | ||
| for (auto entry : vector()) { | ||
|
|
@@ -210,64 +287,78 @@ class DisjointSets { | |
| } | ||
|
|
||
| // Initializes a new set for provided entry | ||
| // | ||
| // TODO: Return iterator | ||
| void initializeSet(T entry) { | ||
| if (disjoint_set_maps_.find(entry) != disjoint_set_maps_.end()) { | ||
| return; | ||
| std::pair< | ||
| typename std::unordered_map< | ||
| T, | ||
| std::shared_ptr<VectorOfUniqueEntries<T, Hash>>, | ||
| Hash>::iterator, | ||
| bool> | ||
| initializeSet(T entry) { | ||
| auto disjoint_set_maps_it = disjoint_set_maps_.find(entry); | ||
| if (disjoint_set_maps_it != disjoint_set_maps_.end()) { | ||
| return std::make_pair(disjoint_set_maps_it, false); | ||
| } | ||
|
|
||
| disjoint_sets_.push_back( | ||
| std::make_shared<VectorOfUniqueEntries<T, Hash>>()); | ||
| disjoint_sets_.back()->pushBack(entry); | ||
| disjoint_set_maps_.emplace(std::make_pair(entry, disjoint_sets_.back())); | ||
| return disjoint_set_maps_.emplace( | ||
| std::make_pair(entry, disjoint_sets_.back())); | ||
| } | ||
|
|
||
| // Adds all of the disjoint set belonging to entry1 to the disjoint set | ||
| // belonging to entry0, maps all entries of disjoint set belonging to entry1 | ||
| // to entry0, removes original disjoint set belonging to entry1. | ||
| void mapEntries(T entry0, T entry1) { | ||
| if (entry0 == entry1) { | ||
csarofeen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return; | ||
| } | ||
|
|
||
| auto set_it_0 = disjoint_set_maps_.find(entry0); | ||
| auto set_it_1 = disjoint_set_maps_.find(entry1); | ||
|
|
||
| // Track if we need to reset iterators, optimize for case where both entries | ||
| // exist | ||
| bool invalid_iterators = false; | ||
| if (set_it_0 == disjoint_set_maps_.end()) { | ||
| initializeSet(entry0); | ||
| invalid_iterators = true; | ||
| } | ||
| auto set_0_found = set_it_0 != disjoint_set_maps_.end(); | ||
| auto set_1_found = set_it_1 != disjoint_set_maps_.end(); | ||
|
|
||
| if (set_it_1 == disjoint_set_maps_.end()) { | ||
| initializeSet(entry1); | ||
| invalid_iterators = true; | ||
| // Sets already joined | ||
| if (set_0_found && set_1_found && set_it_0->second == set_it_1->second) { | ||
| return; | ||
| } | ||
|
|
||
| // TODO: We can avoid refinding one iterator if initialize set returns an | ||
| // iterator, though if we insert entry1 we'd have to refind entry0 as it | ||
| // could invalidate all iterators | ||
| if (invalid_iterators) { | ||
| set_it_0 = disjoint_set_maps_.find(entry0); | ||
| // Make and map new set | ||
| disjoint_sets_.push_back( | ||
| std::make_shared<VectorOfUniqueEntries<T, Hash>>()); | ||
| auto new_set = disjoint_sets_.back(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the motivation for this change? We are not unconditionally doing data copy, would this make |
||
|
|
||
| if (set_0_found) { | ||
| auto set_0 = set_it_0->second; | ||
| for (auto set_0_entry : *set_0) { | ||
| TORCH_INTERNAL_ASSERT(set_0_entry != entry1); | ||
| new_set->pushBack(set_0_entry); | ||
| disjoint_set_maps_[set_0_entry] = new_set; | ||
| } | ||
| disjoint_sets_.erase( | ||
| std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_0)); | ||
| // Erase invalidates iterators, regrab. | ||
| set_it_1 = disjoint_set_maps_.find(entry1); | ||
| set_1_found = set_it_1 != disjoint_set_maps_.end(); | ||
| } else { | ||
| new_set->pushBack(entry0); | ||
| disjoint_set_maps_[entry0] = new_set; | ||
| } | ||
|
|
||
| auto set0_shared_ptr = set_it_0->second; | ||
| auto set1_shared_ptr = set_it_1->second; | ||
|
|
||
| // If the sets are already the same, do nothing | ||
| if (set0_shared_ptr == set1_shared_ptr) { | ||
| return; | ||
| } | ||
|
|
||
| // Place everything in set1 into set0 and remap all entries in set1 to set0 | ||
| for (auto entry : set1_shared_ptr->vector()) { | ||
| set0_shared_ptr->pushBack(entry); | ||
| disjoint_set_maps_[entry] = set0_shared_ptr; | ||
| if (set_1_found) { | ||
| auto set_1 = set_it_1->second; | ||
| for (auto set_1_entry : *set_1) { | ||
| new_set->pushBack(set_1_entry); | ||
| disjoint_set_maps_[set_1_entry] = new_set; | ||
| } | ||
| disjoint_sets_.erase( | ||
| std::find(disjoint_sets_.begin(), disjoint_sets_.end(), set_1)); | ||
| } else { | ||
| new_set->pushBack(entry1); | ||
| disjoint_set_maps_[entry1] = new_set; | ||
| } | ||
|
|
||
| // set1 no longer needed as its entries are copied into set0 | ||
| disjoint_sets_.erase(std::find( | ||
| disjoint_sets_.begin(), disjoint_sets_.end(), set1_shared_ptr)); | ||
| } | ||
|
|
||
| // Will assert if provided entry0 is not in any disjoint set, otherwise | ||
|
|
@@ -323,11 +414,7 @@ class DisjointSets { | |
| const std::string sep(" "); | ||
| for (auto s_ptr : disjoint_sets_) { | ||
| auto& set = *s_ptr; | ||
| ss << sep << "{\n"; | ||
| for (auto entry : set.vector()) { | ||
| ss << sep << sep << abstractToString(entry) << "\n"; | ||
| } | ||
| ss << sep << "}\n"; | ||
| ss << sep << abstractToString(set) << "\n"; | ||
| } | ||
| ss << "}"; | ||
| return ss.str(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just do the following?
On the other hand, should we also have a default move ctor?
VectorOfUniqueEntries(VectorOfUniqueEntries&& other) = default;Note that there should be no
<T>in the argument.