@@ -33,6 +33,12 @@ struct kernel_node_params {
33
33
unsigned int shared_mem_bytes{};
34
34
35
35
std::vector<dpct::experimental::node_ptr> dependencies{};
36
+ kernel_node_params () = default ;
37
+ kernel_node_params (const kernel_node_params &other)
38
+ : block_dim(other.block_dim), grid_dim(other.grid_dim),
39
+ kernel_params (other.kernel_params), func(other.func),
40
+ shared_mem_bytes(other.shared_mem_bytes),
41
+ dependencies(other.dependencies) {}
36
42
37
43
public:
38
44
void set_block_dim (const dpct::dim3 &block_dim) {
@@ -142,6 +148,7 @@ class graph_mgr {
142
148
dpct::experimental::node_ptr *dependencies,
143
149
std::size_t numberOfDependencies,
144
150
dpct::experimental::kernel_node_params *params) {
151
+ node_graph_params_map[*node] = std::make_pair (graph, params);
145
152
for (std::size_t i = 0 ; i < numberOfDependencies; i++) {
146
153
params->add_dependency (dependencies[i]);
147
154
}
@@ -156,7 +163,6 @@ class graph_mgr {
156
163
for (std::size_t i = 0 ; i < kernel_params_vector.size (); i++) {
157
164
auto &node_kernel_params_pair = kernel_params_vector[i];
158
165
auto node_params = node_kernel_params_pair.second ;
159
-
160
166
const auto &dependency_ptrs = node_params->get_dependencies ();
161
167
std::vector<sycl::ext::oneapi::experimental::node> dependencies;
162
168
dependencies.reserve (dependency_ptrs.size ());
@@ -184,9 +190,12 @@ class graph_mgr {
184
190
}
185
191
node_kernel_params_pair.first = new_node;
186
192
}
187
- auto final_graph = graph->finalize ();
193
+ execGraph = new sycl::ext::oneapi::experimental::command_graph<
194
+ sycl::ext::oneapi::experimental::graph_state::executable>(
195
+ graph->finalize (
196
+ sycl::ext::oneapi::experimental::property::graph::updatable{}));
188
197
queue->submit (
189
- [&](sycl::handler &cgh) { cgh.ext_oneapi_graph (final_graph ); });
198
+ [&](sycl::handler &cgh) { cgh.ext_oneapi_graph (*execGraph ); });
190
199
}
191
200
192
201
void instantiate (dpct::experimental::command_graph_exec_ptr *execGraph,
@@ -195,7 +204,31 @@ class graph_mgr {
195
204
}
196
205
197
206
void kernel_node_get_params (dpct::experimental::node_ptr node,
198
- dpct::experimental::kernel_node_params *params) {}
207
+ dpct::experimental::kernel_node_params *params) {
208
+ auto it = node_graph_params_map.find (node);
209
+ if (it == node_graph_params_map.end ()) {
210
+ return ;
211
+ }
212
+ *params = *(it->second .second );
213
+ }
214
+
215
+ void kernel_node_set_params (dpct::experimental::node_ptr node,
216
+ dpct::experimental::kernel_node_params *params) {
217
+ node_graph_params_map[node].second = params;
218
+ }
219
+
220
+ void get_node_type (dpct::experimental::node_ptr node,
221
+ sycl::ext::oneapi::experimental::node_type *nodeType) {
222
+ if (node_graph_params_map.find (node) != node_graph_params_map.end ()) {
223
+ *nodeType = sycl::ext::oneapi::experimental::node_type::kernel;
224
+ } else {
225
+ if (node) {
226
+ *nodeType = node->get_type ();
227
+ } else {
228
+ *nodeType = sycl::ext::oneapi::experimental::node_type::empty;
229
+ }
230
+ }
231
+ }
199
232
200
233
private:
201
234
std::unordered_map<sycl::queue *, command_graph_ptr> queue_graph_map;
@@ -214,8 +247,9 @@ class graph_mgr {
214
247
dpct::experimental::kernel_node_params *>>>
215
248
graph_kernel_node_params_map;
216
249
std::unordered_map<dpct::experimental::node_ptr,
217
- dpct::experimental::kernel_node_params>
218
- node_params_map;
250
+ std::pair<dpct::experimental::command_graph_ptr,
251
+ dpct::experimental::kernel_node_params *>>
252
+ node_graph_params_map;
219
253
};
220
254
} // namespace detail
221
255
@@ -326,11 +360,31 @@ static void launch(dpct::experimental::command_graph_exec_ptr execGraph,
326
360
327
361
static void
328
362
kernel_node_get_params (dpct::experimental::node_ptr node,
329
- dpct::experimental::kernel_node_params *params) {}
363
+ dpct::experimental::kernel_node_params *params) {
364
+ detail::graph_mgr::instance ().kernel_node_get_params (node, params);
365
+ }
330
366
331
367
static void
332
368
kernel_node_set_params (dpct::experimental::node_ptr node,
333
- dpct::experimental::kernel_node_params *params) {}
369
+ dpct::experimental::kernel_node_params *params) {
370
+ detail::graph_mgr::instance ().kernel_node_set_params (node, params);
371
+ }
372
+
373
+ static void
374
+ get_node_type (dpct::experimental::node_ptr node,
375
+ sycl::ext::oneapi::experimental::node_type *nodeType) {
376
+ detail::graph_mgr::instance ().get_node_type (node, nodeType);
377
+ }
378
+
379
+ static void update (dpct::experimental::command_graph_exec_ptr graphExec,
380
+ dpct::experimental::command_graph_ptr graph,
381
+ int *updateResultInfo) {
382
+ graphExec->update (*graph);
383
+ if (!graphExec) {
384
+ *updateResultInfo = 0 ;
385
+ }
386
+ *updateResultInfo = 1 ;
387
+ }
334
388
335
389
} // namespace experimental
336
390
} // namespace dpct
0 commit comments