1
+ #pragma once
2
+
3
+ #include " recorders.h"
4
+ #include < mutex>
5
+ #include < thread>
6
+ #include < vector>
7
+
8
+ namespace celerity ::detail {
9
+ // in c++23 replace this with mdspan
10
+ template <typename T>
11
+ struct mpi_multidim_send_wrapper {
12
+ public:
13
+ const T& operator [](std::pair<int , int > ij) const {
14
+ assert (ij.first * m_width + ij.second < m_data.size ());
15
+ return m_data[ij.first * m_width + ij.second ];
16
+ }
17
+
18
+ T* data () { return m_data.data (); }
19
+
20
+ mpi_multidim_send_wrapper (size_t width, size_t height) : m_data(width * height), m_width(width){};
21
+
22
+ private:
23
+ std::vector<T> m_data;
24
+ const size_t m_width;
25
+ };
26
+
27
+ // Probably replace this in c++20 with span
28
+ template <typename T>
29
+ struct window {
30
+ public:
31
+ window (const std::vector<T>& value) : m_value(value) {}
32
+
33
+ const T& operator [](size_t i) const {
34
+ assert (i >= 0 && i < m_width);
35
+ return m_value[m_offset + i];
36
+ }
37
+
38
+ size_t size () {
39
+ m_width = m_value.size () - m_offset;
40
+ return m_width;
41
+ }
42
+
43
+ void slide (size_t i) {
44
+ assert (i == 0 || (i >= 0 && i <= m_width));
45
+ m_offset += i;
46
+ m_width -= i;
47
+ }
48
+
49
+ private:
50
+ const std::vector<T>& m_value;
51
+ size_t m_offset = 0 ;
52
+ size_t m_width = 0 ;
53
+ };
54
+
55
+ using task_hash = size_t ;
56
+ using task_hash_data = mpi_multidim_send_wrapper<task_hash>;
57
+ using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
58
+
59
+ class abstract_block_chain {
60
+ friend struct abstract_block_chain_testspy ;
61
+
62
+ public:
63
+ virtual void start () { m_is_running = true ; };
64
+ virtual void stop () { m_is_running = false ; };
65
+
66
+ abstract_block_chain (const abstract_block_chain&) = delete ;
67
+ abstract_block_chain& operator =(const abstract_block_chain&) = delete ;
68
+ abstract_block_chain& operator =(abstract_block_chain&&) = delete ;
69
+
70
+ abstract_block_chain (abstract_block_chain&&) = default ;
71
+ virtual ~abstract_block_chain () { stop (); }
72
+
73
+ abstract_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm)
74
+ : m_local_nid(local_nid), m_num_nodes(num_nodes), m_sizes(num_nodes), m_task_recorder_window(task_recorder), m_comm(comm) {}
75
+
76
+ protected:
77
+ virtual void run () = 0;
78
+
79
+ virtual void divergence_out (const divergence_map& check_map, const int task_num) = 0;
80
+
81
+ void add_new_hashes ();
82
+ void clear (const int min_progress);
83
+ virtual void allgather_sizes ();
84
+ virtual void allgather_hashes (const int max_size, task_hash_data& data);
85
+ std::pair<int , int > collect_sizes ();
86
+ task_hash_data collect_hashes (const int max_size);
87
+ divergence_map create_check_map (const task_hash_data& task_graphs, const int task_num) const ;
88
+
89
+ void check_for_deadlock () const ;
90
+
91
+ static void print_node_divergences (const divergence_map& check_map, const int task_num);
92
+
93
+ static void print_task_record (const divergence_map& check_map, const task_record& task, const task_hash hash);
94
+
95
+ virtual void dedub_print_task_record (const divergence_map& check_map, const int task_num) const ;
96
+
97
+ bool check_for_divergence ();
98
+
99
+ protected:
100
+ node_id m_local_nid;
101
+ size_t m_num_nodes;
102
+
103
+ std::vector<task_hash> m_hashes;
104
+ std::vector<int > m_sizes;
105
+
106
+ bool m_is_running = true ;
107
+
108
+ window<task_record> m_task_recorder_window;
109
+
110
+ std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
111
+
112
+ MPI_Comm m_comm;
113
+ };
114
+
115
+ class single_node_test_divergence_block_chain : public abstract_block_chain {
116
+ public:
117
+ single_node_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm,
118
+ const std::vector<std::reference_wrapper<const std::vector<task_record>>>& other_task_records)
119
+ : abstract_block_chain(num_nodes, local_nid, task_recorder, comm), m_other_hashes(other_task_records.size()) {
120
+ for (auto & tsk_rcd : other_task_records) {
121
+ m_other_task_records.push_back (window<task_record>(tsk_rcd));
122
+ }
123
+ }
124
+
125
+ private:
126
+ void run () override {}
127
+
128
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
129
+ void allgather_sizes () override ;
130
+ void allgather_hashes (const int max_size, task_hash_data& data) override ;
131
+
132
+ void dedub_print_task_record (const divergence_map& check_map, const int task_num) const override ;
133
+
134
+ std::vector<std::vector<task_hash>> m_other_hashes;
135
+ std::vector<window<task_record>> m_other_task_records;
136
+
137
+ int m_injected_delete_size = 0 ;
138
+ };
139
+
140
+ class distributed_test_divergence_block_chain : public abstract_block_chain {
141
+ public:
142
+ distributed_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
143
+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {}
144
+
145
+ private:
146
+ void run () override {}
147
+
148
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
149
+ };
150
+
151
+ class divergence_block_chain : public abstract_block_chain {
152
+ public:
153
+ void start () override ;
154
+ void stop () override ;
155
+
156
+ divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
157
+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {
158
+ start ();
159
+ }
160
+
161
+ private:
162
+ void run () override ;
163
+
164
+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
165
+
166
+ private:
167
+ std::thread m_thread;
168
+ };
169
+ } // namespace celerity::detail
0 commit comments