diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py index 7df853760e..ed27bb122b 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/translation.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -189,33 +189,31 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: has_side_effects = True else: - sync_code = "" + sync_code = "/* The SDFG execution should already be synchronized */" has_side_effects = False #### 2. Timestamp the SDFG entry point. + start_block = sdfg.start_block entry_if_region, begin_state = _make_if_region_for_metrics_collection( - "program_entry", metrics_level, sdfg + "metrics_entry", metrics_level, sdfg ) - - for source_state in sdfg.source_nodes(): - if source_state is entry_if_region: - continue - sdfg.add_edge(entry_if_region, source_state, dace.InterstateEdge()) - source_state.is_start_block = False - assert sdfg.out_degree(entry_if_region) > 0 - entry_if_region.is_start_block = True + sdfg.add_edge(entry_if_region, start_block, dace.InterstateEdge()) + sdfg.start_block = sdfg.node_id(entry_if_region) + assert sdfg.start_block is entry_if_region tlet_start_timer = begin_state.add_tasklet( "gt_start_timer", inputs={}, outputs={"time"}, - code="""\ + code=sync_code + + """ auto now = std::chrono::high_resolution_clock::now(); time = std::chrono::duration_cast( now.time_since_epoch() ).count(); """, language=dace.dtypes.Language.CPP, + side_effects=has_side_effects, ) begin_state.add_edge( tlet_start_timer, @@ -227,13 +225,12 @@ def add_instrumentation(sdfg: dace.SDFG, gpu: bool) -> None: #### 3. Collect the SDFG end timestamp and produce the compute metric. exit_if_region, end_state = _make_if_region_for_metrics_collection( - "program_exit", metrics_level, sdfg + "metrics_exit", metrics_level, sdfg ) - - for sink_state in sdfg.sink_nodes(): - if sink_state is exit_if_region: + for sink_node in sdfg.sink_nodes(): + if sink_node is exit_if_region: continue - sdfg.add_edge(sink_state, exit_if_region, dace.InterstateEdge()) + sdfg.add_edge(sink_node, exit_if_region, dace.InterstateEdge()) assert sdfg.in_degree(exit_if_region) > 0 # Populate the branch that computes the stencil time metric @@ -292,8 +289,6 @@ def make_sdfg_call_sync(sdfg: dace.SDFG, gpu: bool) -> None: This means that `CompiledSDFG.fast_call()` will return only after all computations have _finished_ and the results are available. This function only has an effect for work that runs on the GPU. Furthermore, all work is scheduled on the default stream. - - Todo: Revisit this function once DaCe changes its behaviour in this regard. """ if not gpu: @@ -319,10 +314,10 @@ def make_sdfg_call_sync(sdfg: dace.SDFG, gpu: bool) -> None: # because that code is only run at the `exit()` stage, not after a call. Thus we # will generate an SDFGState that contains a Tasklet with the sync call. sync_state = sdfg.add_state("sync_state") - for sink_state in sdfg.sink_nodes(): - if sink_state is sync_state: + for sink_node in sdfg.sink_nodes(): + if sink_node is sync_state: continue - sdfg.add_edge(sink_state, sync_state, dace.InterstateEdge()) + sdfg.add_edge(sink_node, sync_state, dace.InterstateEdge()) assert sdfg.in_degree(sync_state) > 0 # NOTE: Since the synchronization is done through the Tasklet explicitly,