1
+ #pragma once
2
+
3
+ #include " recorders.h"
4
+ #include < mutex>
5
+ #include < thread>
6
+ #include < vector>
7
+
8
+ namespace celerity ::detail {
9
+ // in c++23 replace this with mdspan
10
+ template <typename T>
11
+ struct mpi_2d_send_wrapper {
12
+ public:
13
+ const T& operator [](std::pair<int , int > ij) const {
14
+ assert (ij.first * m_width + ij.second < m_data.size ());
15
+ return m_data[ij.first * m_width + ij.second ];
16
+ }
17
+
18
+ T* data () { return m_data.data (); }
19
+
20
+ mpi_2d_send_wrapper (size_t width, size_t height) : m_data(width * height), m_width(width){};
21
+
22
+ private:
23
+ std::vector<T> m_data;
24
+ const size_t m_width;
25
+ };
26
+
27
+ // Probably replace this in c++20 with span
28
+ template <typename T>
29
+ struct window {
30
+ public:
31
+ window (const std::vector<T>& value) : m_value(value) {}
32
+
33
+ const T& operator [](size_t i) const {
34
+ assert (i >= 0 && i < m_width);
35
+ return m_value[m_offset + i];
36
+ }
37
+
38
+ size_t size () {
39
+ m_width = m_value.size () - m_offset;
40
+ return m_width;
41
+ }
42
+
43
+ void slide (size_t i) {
44
+ assert (i == 0 || (i >= 0 && i <= m_width));
45
+ m_offset += i;
46
+ m_width -= i;
47
+ }
48
+
49
+ private:
50
+ const std::vector<T>& m_value;
51
+ size_t m_offset = 0 ;
52
+ size_t m_width = 0 ;
53
+ };
54
+
55
+ using task_hash = size_t ;
56
+ using task_hash_data = mpi_2d_send_wrapper<task_hash>;
57
+ using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
58
+
59
+ class abstract_block_chain {
60
+ friend struct abstract_block_chain_testspy ;
61
+
62
+ public:
63
+ virtual void stop () { m_is_running = false ; };
64
+
65
+ abstract_block_chain (const abstract_block_chain&) = delete ;
66
+ abstract_block_chain& operator =(const abstract_block_chain&) = delete ;
67
+ abstract_block_chain& operator =(abstract_block_chain&&) = delete ;
68
+
69
+ abstract_block_chain (abstract_block_chain&&) = default ;
70
+
71
+ abstract_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm)
72
+ : m_local_nid(local_nid), m_num_nodes(num_nodes), m_sizes(num_nodes), m_task_recorder_window(task_recorder), m_comm(comm) {}
73
+
74
+ virtual ~abstract_block_chain () = default ;
75
+
76
+ protected:
77
+ void start () { m_is_running = true ; };
78
+
79
+ virtual void run () = 0;
80
+
81
+ virtual void divergence_out (const divergence_map& check_map, const int task_num) = 0;
82
+
83
+ void add_new_hashes ();
84
+ void clear (const int min_progress);
85
+ virtual void allgather_sizes ();
86
+ virtual void allgather_hashes (const int max_size, task_hash_data& data);
87
+ std::pair<int , int > collect_sizes ();
88
+ task_hash_data collect_hashes (const int max_size);
89
+ divergence_map create_check_map (const task_hash_data& task_graphs, const int task_num) const ;
90
+
91
+ void check_for_deadlock () const ;
92
+
93
+ static void print_node_divergences (const divergence_map& check_map, const int task_num);
94
+
95
+ static void print_task_record (const divergence_map& check_map, const task_record& task, const task_hash hash);
96
+
97
+ virtual void dedub_print_task_record (const divergence_map& check_map, const int task_num) const ;
98
+
99
+ bool check_for_divergence ();
100
+
101
+ protected:
102
+ node_id m_local_nid;
103
+ size_t m_num_nodes;
104
+
105
+ std::vector<task_hash> m_hashes;
106
+ std::vector<int > m_sizes;
107
+
108
+ bool m_is_running = true ;
109
+
110
+ window<task_record> m_task_recorder_window;
111
+
112
+ std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
113
+
114
+ MPI_Comm m_comm;
115
+ };
116
+
117
+ class single_node_test_divergence_block_chain : public abstract_block_chain {
118
+ public:
119
+ single_node_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm,
120
+ const std::vector<std::reference_wrapper<const std::vector<task_record>>>& other_task_records)
121
+ : abstract_block_chain(num_nodes, local_nid, task_recorder, comm), m_other_hashes(other_task_records.size()) {
122
+ for (auto & tsk_rcd : other_task_records) {
123
+ m_other_task_records.push_back (window<task_record>(tsk_rcd));
124
+ }
125
+ }
126
+
127
+ private:
128
+ void run () override {}
129
+
130
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
131
+ void allgather_sizes () override ;
132
+ void allgather_hashes (const int max_size, task_hash_data& data) override ;
133
+
134
+ void dedub_print_task_record (const divergence_map& check_map, const int task_num) const override ;
135
+
136
+ std::vector<std::vector<task_hash>> m_other_hashes;
137
+ std::vector<window<task_record>> m_other_task_records;
138
+
139
+ int m_injected_delete_size = 0 ;
140
+ };
141
+
142
+ class distributed_test_divergence_block_chain : public abstract_block_chain {
143
+ public:
144
+ distributed_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
145
+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {}
146
+
147
+ private:
148
+ void run () override {}
149
+
150
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
151
+ };
152
+
153
+ class divergence_block_chain : public abstract_block_chain {
154
+ public:
155
+ void start ();
156
+ void stop () override ;
157
+
158
+ divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
159
+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {
160
+ divergence_block_chain::start ();
161
+ }
162
+
163
+ divergence_block_chain (const divergence_block_chain&) = delete ;
164
+ divergence_block_chain& operator =(const divergence_block_chain&) = delete ;
165
+ divergence_block_chain& operator =(divergence_block_chain&&) = delete ;
166
+
167
+ divergence_block_chain (divergence_block_chain&&) = default ;
168
+
169
+ ~divergence_block_chain () override { divergence_block_chain::stop (); }
170
+
171
+ private:
172
+ void run () override ;
173
+
174
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
175
+
176
+ private:
177
+ std::thread m_thread;
178
+ };
179
+ } // namespace celerity::detail
0 commit comments