diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 76c4ea0aa666..7b773fda5db8 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -35,11 +35,9 @@ # pytype: skip-file import logging +from collections.abc import Iterable from datetime import timedelta from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional from typing import Union @@ -57,6 +55,7 @@ from apache_beam.runners.interactive.display.pcoll_visualization import visualize from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll from apache_beam.runners.interactive.options import interactive_options +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult from apache_beam.runners.interactive.utils import deferred_df_to_pcollection from apache_beam.runners.interactive.utils import elements_to_df from apache_beam.runners.interactive.utils import find_pcoll_name @@ -275,7 +274,7 @@ class Recordings(): """ def describe( self, - pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821 + pipeline: Optional[beam.Pipeline] = None) -> dict[str, Any]: # noqa: F821 """Returns a description of all the recordings for the given pipeline. If no pipeline is given then this returns a dictionary of descriptions for @@ -417,10 +416,10 @@ class Clusters: # DATAPROC_IMAGE_VERSION = '2.0.XX-debian10' def __init__(self) -> None: - self.dataproc_cluster_managers: Dict[ClusterMetadata, + self.dataproc_cluster_managers: dict[ClusterMetadata, DataprocClusterManager] = {} - self.master_urls: Dict[str, ClusterMetadata] = {} - self.pipelines: Dict[beam.Pipeline, DataprocClusterManager] = {} + self.master_urls: dict[str, ClusterMetadata] = {} + self.pipelines: dict[beam.Pipeline, DataprocClusterManager] = {} self.default_cluster_metadata: Optional[ClusterMetadata] = None def create( @@ -511,7 +510,7 @@ def cleanup( def describe( self, cluster_identifier: Optional[ClusterIdentifier] = None - ) -> Union[ClusterMetadata, List[ClusterMetadata]]: + ) -> Union[ClusterMetadata, list[ClusterMetadata]]: """Describes the ClusterMetadata by a ClusterIdentifier. If no cluster_identifier is given or if the cluster_identifier is unknown, @@ -679,7 +678,7 @@ def run_pipeline(self): @progress_indicated def show( - *pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection], + *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection], include_window_info: bool = False, visualize_data: bool = False, n: Union[int, str] = 'inf', @@ -1012,6 +1011,88 @@ def as_pcollection(pcoll_or_df): return result_tuple +@progress_indicated +def compute( + *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection], + wait_for_inputs: bool = True, + blocking: bool = False, + runner=None, + options=None, + force_compute=False, +) -> Optional[AsyncComputationResult]: + """Computes the given PCollections, potentially asynchronously. + + Args: + *pcolls: PCollections to compute. Can be a single PCollection, an iterable + of PCollections, or a dictionary with PCollections as values. + wait_for_inputs: Whether to wait until the asynchronous dependencies are + computed. Setting this to False allows to immediately schedule the + computation, but also potentially results in running the same pipeline + stages multiple times. + blocking: If False, the computation will run in non-blocking fashion. In + Colab/IPython environment this mode will also provide the controls for the + running pipeline. If True, the computation will block until the pipeline + is done. + runner: (optional) the runner with which to compute the results. + options: (optional) any additional pipeline options to use to compute the + results. + force_compute: (optional) if True, forces recomputation rather than using + cached PCollections. + + Returns: + An AsyncComputationResult object if blocking is False, otherwise None. + """ + flatten_pcolls = [] + for pcoll_container in pcolls: + if isinstance(pcoll_container, dict): + flatten_pcolls.extend(pcoll_container.values()) + elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)): + flatten_pcolls.append(pcoll_container) + else: + try: + flatten_pcolls.extend(iter(pcoll_container)) + except TypeError: + raise ValueError( + f'The given pcoll {pcoll_container} is not a dict, an iterable or ' + 'a PCollection.') + + pcolls_set = set() + for pcoll in flatten_pcolls: + if isinstance(pcoll, DeferredBase): + pcoll, _ = deferred_df_to_pcollection(pcoll) + watch({f'anonymous_pcollection_{id(pcoll)}': pcoll}) + assert isinstance( + pcoll, beam.pvalue.PCollection + ), f'{pcoll} is not an apache_beam.pvalue.PCollection.' + pcolls_set.add(pcoll) + + if not pcolls_set: + _LOGGER.info('No PCollections to compute.') + return None + + pcoll_pipeline = next(iter(pcolls_set)).pipeline + user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline) + if not user_pipeline: + watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline}) + user_pipeline = pcoll_pipeline + + for pcoll in pcolls_set: + if pcoll.pipeline is not user_pipeline: + raise ValueError('All PCollections must belong to the same pipeline.') + + recording_manager = ie.current_env().get_recording_manager( + user_pipeline, create_if_absent=True) + + return recording_manager.compute_async( + pcolls_set, + wait_for_inputs=wait_for_inputs, + blocking=blocking, + runner=runner, + options=options, + force_compute=force_compute, + ) + + @progress_indicated def show_graph(pipeline): """Shows the current pipeline shape of a given Beam pipeline as a DAG. diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 37cd63842b1e..21163fc121c5 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -23,11 +23,16 @@ import sys import time import unittest +from concurrent.futures import TimeoutError from typing import NamedTuple +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import call from unittest.mock import patch import apache_beam as beam from apache_beam import dataframe as frames +from apache_beam.dataframe.frame_base import DeferredBase from apache_beam.options.pipeline_options import FlinkRunnerOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners.interactive import interactive_beam as ib @@ -36,6 +41,7 @@ from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import DataprocClusterManager from apache_beam.runners.interactive.dataproc.types import ClusterMetadata from apache_beam.runners.interactive.options.capture_limiters import Limiter +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult from apache_beam.runners.interactive.testing.mock_env import isolated_env from apache_beam.runners.runner import PipelineState from apache_beam.testing.test_stream import TestStream @@ -65,6 +71,9 @@ def _get_watched_pcollections_with_variable_names(): return watched_pcollections +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') @isolated_env class InteractiveBeamTest(unittest.TestCase): def setUp(self): @@ -671,5 +680,387 @@ def test_default_value_for_invalid_worker_number(self): self.assertEqual(meta.num_workers, 2) +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') +@isolated_env +class InteractiveBeamComputeTest(unittest.TestCase): + def setUp(self): + self.env = ie.current_env() + self.env._is_in_ipython = False # Default to non-IPython + + def test_compute_blocking(self): + p = beam.Pipeline(ir.InteractiveRunner()) + data = list(range(10)) + pcoll = p | 'Create' >> beam.Create(data) + ib.watch(locals()) + self.env.track_user_pipelines() + + result = ib.compute(pcoll, blocking=True) + self.assertIsNone(result) # Blocking returns None + self.assertTrue(pcoll in self.env.computed_pcollections) + collected = ib.collect(pcoll, raw_records=True) + self.assertEqual(collected, data) + + def test_compute_non_blocking(self): + p = beam.Pipeline(ir.InteractiveRunner()) + data = list(range(5)) + pcoll = p | 'Create' >> beam.Create(data) + ib.watch(locals()) + self.env.track_user_pipelines() + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsInstance(async_result, AsyncComputationResult) + + pipeline_result = async_result.result(timeout=60) + self.assertTrue(async_result.done()) + self.assertIsNone(async_result.exception()) + self.assertEqual(pipeline_result.state, PipelineState.DONE) + self.assertTrue(pcoll in self.env.computed_pcollections) + collected = ib.collect(pcoll, raw_records=True) + self.assertEqual(collected, data) + + def test_compute_with_list_input(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) + ib.watch(locals()) + self.env.track_user_pipelines() + + ib.compute([pcoll1, pcoll2], blocking=True) + self.assertTrue(pcoll1 in self.env.computed_pcollections) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + def test_compute_with_dict_input(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) + ib.watch(locals()) + self.env.track_user_pipelines() + + ib.compute({'a': pcoll1, 'b': pcoll2}, blocking=True) + self.assertTrue(pcoll1 in self.env.computed_pcollections) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + def test_compute_empty_input(self): + result = ib.compute([], blocking=True) + self.assertIsNone(result) + result_async = ib.compute([], blocking=False) + self.assertIsNone(result_async) + + def test_compute_force_recompute(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'Create' >> beam.Create([1, 2, 3]) + ib.watch(locals()) + self.env.track_user_pipelines() + + ib.compute(pcoll, blocking=True) + self.assertTrue(pcoll in self.env.computed_pcollections) + + # Mock evict_computed_pcollections to check if it's called + with patch.object(self.env, 'evict_computed_pcollections') as mock_evict: + ib.compute(pcoll, blocking=True, force_compute=True) + mock_evict.assert_called_once_with(p) + self.assertTrue(pcoll in self.env.computed_pcollections) + + def test_compute_non_blocking_exception(self): + p = beam.Pipeline(ir.InteractiveRunner()) + + def raise_error(elem): + raise ValueError('Test Error') + + pcoll = p | 'Create' >> beam.Create([1]) | 'Error' >> beam.Map(raise_error) + ib.watch(locals()) + self.env.track_user_pipelines() + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsInstance(async_result, AsyncComputationResult) + + with self.assertRaises(ValueError): + async_result.result(timeout=60) + + self.assertTrue(async_result.done()) + self.assertIsInstance(async_result.exception(), ValueError) + self.assertFalse(pcoll in self.env.computed_pcollections) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + @patch('apache_beam.runners.interactive.recording_manager.display') + @patch('apache_beam.runners.interactive.recording_manager.clear_output') + @patch('apache_beam.runners.interactive.recording_manager.HTML') + @patch('ipywidgets.Button') + @patch('ipywidgets.FloatProgress') + @patch('ipywidgets.Output') + @patch('ipywidgets.HBox') + @patch('ipywidgets.VBox') + def test_compute_non_blocking_ipython_widgets( + self, + mock_vbox, + mock_hbox, + mock_output, + mock_progress, + mock_button, + mock_html, + mock_clear_output, + mock_display, + ): + self.env._is_in_ipython = True + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'Create' >> beam.Create(range(3)) + ib.watch(locals()) + self.env.track_user_pipelines() + + mock_controls = mock_vbox.return_value + mock_html_instance = mock_html.return_value + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsNotNone(async_result) + mock_button.assert_called_once_with(description='Cancel') + mock_progress.assert_called_once() + mock_output.assert_called_once() + mock_hbox.assert_called_once() + mock_vbox.assert_called_once() + mock_html.assert_called_once_with('

Initializing...

') + + self.assertEqual(mock_display.call_count, 2) + mock_display.assert_has_calls([ + call(mock_controls, display_id=async_result._display_id), + call(mock_html_instance) + ]) + + mock_clear_output.assert_called_once() + async_result.result(timeout=60) # Let it finish + + def test_compute_dependency_wait_true(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2) + ib.watch(locals()) + self.env.track_user_pipelines() + + rm = self.env.get_recording_manager(p) + + # Start pcoll1 computation + async_res1 = ib.compute(pcoll1, blocking=False) + self.assertTrue(self.env.is_pcollection_computing(pcoll1)) + + # Spy on _wait_for_dependencies + with patch.object(rm, + '_wait_for_dependencies', + wraps=rm._wait_for_dependencies) as spy_wait: + async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=True) + + # Check that wait_for_dependencies was called for pcoll2 + spy_wait.assert_called_with({pcoll2}, async_res2) + + # Let pcoll1 finish + async_res1.result(timeout=60) + self.assertTrue(pcoll1 in self.env.computed_pcollections) + self.assertFalse(self.env.is_pcollection_computing(pcoll1)) + + # pcoll2 should now run and complete + async_res2.result(timeout=60) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + @patch.object(ie.InteractiveEnvironment, 'is_pcollection_computing') + def test_compute_dependency_wait_false(self, mock_is_computing): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2) + ib.watch(locals()) + self.env.track_user_pipelines() + + rm = self.env.get_recording_manager(p) + + # Pretend pcoll1 is computing + mock_is_computing.side_effect = lambda pcoll: pcoll is pcoll1 + + with patch.object(rm, + '_execute_pipeline_fragment', + wraps=rm._execute_pipeline_fragment) as spy_execute: + async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=False) + async_res2.result(timeout=60) + + # Assert that execute was called for pcoll2 without waiting + spy_execute.assert_called_with({pcoll2}, async_res2, ANY, ANY) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + def test_async_computation_result_cancel(self): + p = beam.Pipeline(ir.InteractiveRunner()) + # A stream that never finishes to test cancellation + pcoll = p | beam.Create([1]) | beam.Map(lambda x: time.sleep(100)) + ib.watch(locals()) + self.env.track_user_pipelines() + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsInstance(async_result, AsyncComputationResult) + + # Give it a moment to start + time.sleep(0.1) + + # Mock the pipeline result's cancel method + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.RUNNING + async_result.set_pipeline_result(mock_pipeline_result) + + self.assertTrue(async_result.cancel()) + mock_pipeline_result.cancel.assert_called_once() + + # The future should be cancelled eventually by the runner + # This part is hard to test without deeper runner integration + with self.assertRaises(TimeoutError): + async_result.result(timeout=1) # It should not complete successfully + + @patch( + 'apache_beam.runners.interactive.recording_manager.RecordingManager.' + '_execute_pipeline_fragment') + def test_compute_multiple_async(self, mock_execute_fragment): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) + pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + ib.watch(locals()) + self.env.track_user_pipelines() + + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.DONE + mock_execute_fragment.return_value = mock_pipeline_result + + res1 = ib.compute(pcoll1, blocking=False) + res2 = ib.compute(pcoll2, blocking=False) + res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1 + + self.assertIsNotNone(res1) + self.assertIsNotNone(res2) + self.assertIsNotNone(res3) + + res1.result(timeout=60) + res2.result(timeout=60) + res3.result(timeout=60) + + time.sleep(0.1) + + self.assertTrue( + pcoll1 in self.env.computed_pcollections, "pcoll1 not marked computed") + self.assertTrue( + pcoll2 in self.env.computed_pcollections, "pcoll2 not marked computed") + self.assertTrue( + pcoll3 in self.env.computed_pcollections, "pcoll3 not marked computed") + + self.assertEqual(mock_execute_fragment.call_count, 3) + + @patch( + 'apache_beam.runners.interactive.interactive_beam.' + 'deferred_df_to_pcollection') + def test_compute_input_flattening(self, mock_deferred_to_pcoll): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'C1' >> beam.Create([1]) + pcoll2 = p | 'C2' >> beam.Create([2]) + pcoll3 = p | 'C3' >> beam.Create([3]) + pcoll4 = p | 'C4' >> beam.Create([4]) + + class MockDeferred(DeferredBase): + def __init__(self, pcoll): + mock_expr = MagicMock() + super().__init__(mock_expr) + self._pcoll = pcoll + + def _get_underlying_pcollection(self): + return self._pcoll + + deferred_pcoll = MockDeferred(pcoll4) + + mock_deferred_to_pcoll.return_value = (pcoll4, p) + + ib.watch(locals()) + self.env.track_user_pipelines() + + with patch.object(self.env, 'get_recording_manager') as mock_get_rm: + mock_rm = MagicMock() + mock_get_rm.return_value = mock_rm + ib.compute(pcoll1, [pcoll2], {'a': pcoll3}, deferred_pcoll) + + expected_pcolls = {pcoll1, pcoll2, pcoll3, pcoll4} + mock_rm.compute_async.assert_called_once_with( + expected_pcolls, + wait_for_inputs=True, + blocking=False, + runner=None, + options=None, + force_compute=False) + + def test_compute_invalid_input_type(self): + with self.assertRaisesRegex(ValueError, + "not a dict, an iterable or a PCollection"): + ib.compute(123) + + def test_compute_mixed_pipelines(self): + p1 = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p1 | 'C1' >> beam.Create([1]) + p2 = beam.Pipeline(ir.InteractiveRunner()) + pcoll2 = p2 | 'C2' >> beam.Create([2]) + ib.watch(locals()) + self.env.track_user_pipelines() + + with self.assertRaisesRegex( + ValueError, "All PCollections must belong to the same pipeline"): + ib.compute(pcoll1, pcoll2) + + @patch( + 'apache_beam.runners.interactive.interactive_beam.' + 'deferred_df_to_pcollection') + @patch.object(ib, 'watch') + def test_compute_with_deferred_base(self, mock_watch, mock_deferred_to_pcoll): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'C1' >> beam.Create([1]) + + class MockDeferred(DeferredBase): + def __init__(self, pcoll): + # Provide a dummy expression to satisfy DeferredBase.__init__ + mock_expr = MagicMock() + super().__init__(mock_expr) + self._pcoll = pcoll + + def _get_underlying_pcollection(self): + return self._pcoll + + deferred = MockDeferred(pcoll) + + mock_deferred_to_pcoll.return_value = (pcoll, p) + + with patch.object(self.env, 'get_recording_manager') as mock_get_rm: + mock_rm = MagicMock() + mock_get_rm.return_value = mock_rm + ib.compute(deferred) + + mock_deferred_to_pcoll.assert_called_once_with(deferred) + self.assertEqual(mock_watch.call_count, 2) + mock_watch.assert_has_calls([ + call({f'anonymous_pcollection_{id(pcoll)}': pcoll}), + call({f'anonymous_pipeline_{id(p)}': p}) + ], + any_order=False) + mock_rm.compute_async.assert_called_once_with({pcoll}, + wait_for_inputs=True, + blocking=False, + runner=None, + options=None, + force_compute=False) + + def test_compute_new_pipeline(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'Create' >> beam.Create([1]) + # NOT calling ib.watch() or track_user_pipelines() + + with patch.object(self.env, 'get_recording_manager') as mock_get_rm, \ + patch.object(ib, 'watch') as mock_watch: + mock_rm = MagicMock() + mock_get_rm.return_value = mock_rm + ib.compute(pcoll) + + mock_watch.assert_called_with({f'anonymous_pipeline_{id(p)}': p}) + mock_get_rm.assert_called_once_with(p, create_if_absent=True) + mock_rm.compute_async.assert_called_once() + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py index e9ff86c6276f..2a8fc23088a6 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py @@ -175,6 +175,9 @@ def __init__(self): # Tracks the computation completeness of PCollections. PCollections tracked # here don't need to be re-computed when data introspection is needed. self._computed_pcolls = set() + + self._computing_pcolls = set() + # Always watch __main__ module. self.watch('__main__') # Check if [interactive] dependencies are installed. @@ -720,3 +723,19 @@ def _get_gcs_cache_dir(self, pipeline, cache_dir): bucket_name = cache_dir_path.parts[1] assert_bucket_exists(bucket_name) return 'gs://{}/{}'.format('/'.join(cache_dir_path.parts[1:]), id(pipeline)) + + @property + def computing_pcollections(self): + return self._computing_pcolls + + def mark_pcollection_computing(self, pcolls): + """Marks the given pcolls as currently being computed.""" + self._computing_pcolls.update(pcolls) + + def unmark_pcollection_computing(self, pcolls): + """Removes the given pcolls from the computing set.""" + self._computing_pcolls.difference_update(pcolls) + + def is_pcollection_computing(self, pcoll): + """Checks if the given pcollection is currently being computed.""" + return pcoll in self._computing_pcolls diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py index 4d5f3f36ce67..eb3b4b514824 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py @@ -34,6 +34,9 @@ _module_name = 'apache_beam.runners.interactive.interactive_environment_test' +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') @isolated_env class InteractiveEnvironmentTest(unittest.TestCase): def setUp(self): @@ -341,6 +344,44 @@ def test_get_gcs_cache_dir_invalid_path(self): with self.assertRaises(ValueError): env._get_gcs_cache_dir(p, cache_root) + def test_pcollection_computing_state(self): + env = ie.InteractiveEnvironment() + p = beam.Pipeline() + pcoll1 = p | 'Create1' >> beam.Create([1]) + pcoll2 = p | 'Create2' >> beam.Create([2]) + + self.assertFalse(env.is_pcollection_computing(pcoll1)) + self.assertFalse(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, set()) + + env.mark_pcollection_computing({pcoll1}) + self.assertTrue(env.is_pcollection_computing(pcoll1)) + self.assertFalse(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, {pcoll1}) + + env.mark_pcollection_computing({pcoll2}) + self.assertTrue(env.is_pcollection_computing(pcoll1)) + self.assertTrue(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, {pcoll1, pcoll2}) + + env.unmark_pcollection_computing({pcoll1}) + self.assertFalse(env.is_pcollection_computing(pcoll1)) + self.assertTrue(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, {pcoll2}) + + env.unmark_pcollection_computing({pcoll2}) + self.assertFalse(env.is_pcollection_computing(pcoll1)) + self.assertFalse(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, set()) + + def test_mark_unmark_empty(self): + env = ie.InteractiveEnvironment() + # Ensure no errors with empty sets + env.mark_pcollection_computing(set()) + self.assertEqual(env.computing_pcollections, set()) + env.unmark_pcollection_computing(set()) + self.assertEqual(env.computing_pcollections, set()) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index f72ec2fe8e17..c19b60b64fd2 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -15,13 +15,17 @@ # limitations under the License. # +import collections import logging +import os import threading import time +import uuid import warnings +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor from typing import Any -from typing import Dict -from typing import List +from typing import Optional from typing import Union import pandas as pd @@ -37,11 +41,176 @@ from apache_beam.runners.interactive import pipeline_fragment as pf from apache_beam.runners.interactive import utils from apache_beam.runners.interactive.caching.cacheable import CacheKey +from apache_beam.runners.interactive.display.pipeline_graph import PipelineGraph from apache_beam.runners.interactive.options import capture_control from apache_beam.runners.runner import PipelineState _LOGGER = logging.getLogger(__name__) +try: + import ipywidgets as widgets + from IPython.display import HTML + from IPython.display import clear_output + from IPython.display import display + + IS_IPYTHON = True +except ImportError: + IS_IPYTHON = False + + +class AsyncComputationResult: + """Represents the result of an asynchronous computation.""" + def __init__( + self, + future: Future, + pcolls: set[beam.pvalue.PCollection], + user_pipeline: beam.Pipeline, + recording_manager: 'RecordingManager', + ): + self._future = future + self._pcolls = pcolls + self._user_pipeline = user_pipeline + self._env = ie.current_env() + self._recording_manager = recording_manager + self._pipeline_result: Optional[beam.runners.runner.PipelineResult] = None + self._display_id = str(uuid.uuid4()) + self._output_widget = widgets.Output() if IS_IPYTHON else None + self._cancel_button = ( + widgets.Button(description='Cancel') if IS_IPYTHON else None) + self._progress_bar = ( + widgets.FloatProgress( + value=0.0, + min=0.0, + max=1.0, + description='Running:', + bar_style='info', + ) if IS_IPYTHON else None) + self._cancel_requested = False + + if IS_IPYTHON: + self._cancel_button.on_click(self._cancel_clicked) + controls = widgets.VBox([ + widgets.HBox([self._cancel_button, self._progress_bar]), + self._output_widget, + ]) + display(controls, display_id=self._display_id) + self.update_display('Initializing...') + + self._future.add_done_callback(self._on_done) + + def _cancel_clicked(self, b): + self._cancel_requested = True + self._cancel_button.disabled = True + self.update_display('Cancel requested...') + self.cancel() + + def update_display(self, msg: str, progress: Optional[float] = None): + if not IS_IPYTHON: + print(f'AsyncCompute: {msg}') + return + + with self._output_widget: + clear_output(wait=True) + display(HTML(f'

{msg}

')) + + if progress is not None: + self._progress_bar.value = progress + + if self.done(): + self._cancel_button.disabled = True + if self.exception(): + self._progress_bar.bar_style = 'danger' + self._progress_bar.description = 'Failed' + elif self._future.cancelled(): + self._progress_bar.bar_style = 'warning' + self._progress_bar.description = 'Cancelled' + else: + self._progress_bar.bar_style = 'success' + self._progress_bar.description = 'Done' + elif self._cancel_requested: + self._cancel_button.disabled = True + self._progress_bar.description = 'Cancelling...' + else: + self._cancel_button.disabled = False + + def set_pipeline_result( + self, pipeline_result: beam.runners.runner.PipelineResult): + self._pipeline_result = pipeline_result + if self._cancel_requested: + self.cancel() + + def result(self, timeout=None): + return self._future.result(timeout=timeout) + + def done(self): + return self._future.done() + + def exception(self, timeout=None): + try: + return self._future.exception(timeout=timeout) + except TimeoutError: + return None + + def _on_done(self, future: Future): + self._env.unmark_pcollection_computing(self._pcolls) + self._recording_manager._async_computations.pop(self._display_id, None) + + if future.cancelled(): + self.update_display('Computation Cancelled.', 1.0) + return + + exc = future.exception() + if exc: + self.update_display(f'Error: {exc}', 1.0) + _LOGGER.error('Asynchronous computation failed: %s', exc, exc_info=exc) + else: + self.update_display('Computation Finished Successfully.', 1.0) + res = future.result() + if res and res.state == PipelineState.DONE: + self._env.mark_pcollection_computed(self._pcolls) + else: + _LOGGER.warning( + 'Async computation finished but state is not DONE: %s', + res.state if res else 'Unknown') + + def cancel(self): + if self._future.done(): + self.update_display('Cannot cancel: Computation already finished.') + return False + + self._cancel_requested = True + self._cancel_button.disabled = True + self.update_display('Attempting to cancel...') + + if self._pipeline_result: + try: + # Check pipeline state before cancelling + current_state = self._pipeline_result.state + if PipelineState.is_terminal(current_state): + self.update_display( + 'Cannot cancel: Pipeline already in terminal state' + f' {current_state}.') + return False + + self._pipeline_result.cancel() + self.update_display('Cancel signal sent to pipeline.') + # The future will be cancelled by the runner if successful + return True + except Exception as e: + self.update_display('Error sending cancel signal: %s', e) + _LOGGER.warning('Error during pipeline cancel(): %s', e, exc_info=e) + # Still try to cancel the future as a fallback + return self._future.cancel() + else: + self.update_display('Pipeline not yet fully started, cancelling future.') + return self._future.cancel() + + def __repr__(self): + return ( + f'") + class ElementStream: """A stream of elements from a given PCollection.""" @@ -151,7 +320,7 @@ class Recording: def __init__( self, user_pipeline: beam.Pipeline, - pcolls: List[beam.pvalue.PCollection], # noqa: F821 + pcolls: list[beam.pvalue.PCollection], # noqa: F821 result: 'beam.runner.PipelineResult', max_n: int, max_duration_secs: float, @@ -244,7 +413,7 @@ def wait_until_finish(self) -> None: self._mark_computed.join() return self._result.state - def describe(self) -> Dict[str, int]: + def describe(self) -> dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self._user_pipeline) @@ -259,15 +428,97 @@ def __init__( self, user_pipeline: beam.Pipeline, pipeline_var: str = None, - test_limiters: List['Limiter'] = None) -> None: # noqa: F821 + test_limiters: list['Limiter'] = None) -> None: # noqa: F821 self.user_pipeline: beam.Pipeline = user_pipeline self.pipeline_var: str = pipeline_var if pipeline_var else '' self._recordings: set[Recording] = set() self._start_time_sec: float = 0 self._test_limiters = test_limiters if test_limiters else [] + self._executor = ThreadPoolExecutor(max_workers=os.cpu_count()) + self._env = ie.current_env() + self._async_computations: dict[str, AsyncComputationResult] = {} + self._pipeline_graph = None + + def _execute_pipeline_fragment( + self, + pcolls_to_compute: set[beam.pvalue.PCollection], + async_result: Optional['AsyncComputationResult'] = None, + runner: runner.PipelineRunner = None, + options: pipeline_options.PipelineOptions = None, + ) -> beam.runners.runner.PipelineResult: + """Synchronously executes a pipeline fragment for the given PCollections.""" + merged_options = pipeline_options.PipelineOptions(**{ + **self.user_pipeline.options.get_all_options( + drop_default=True, retain_unknown_options=True + ), + **( + options.get_all_options( + drop_default=True, retain_unknown_options=True + ) + if options + else {} + ), + }) + + fragment = pf.PipelineFragment( + list(pcolls_to_compute), merged_options, runner=runner) + + if async_result: + async_result.update_display('Building pipeline fragment...', 0.1) + + pipeline_to_run = fragment.deduce_fragment() + if async_result: + async_result.update_display('"Pipeline running, awaiting finish..."', 0.2) + + pipeline_result = pipeline_to_run.run() + if async_result: + async_result.set_pipeline_result(pipeline_result) + + pipeline_result.wait_until_finish() + return pipeline_result + + def _run_async_computation( + self, + pcolls_to_compute: set[beam.pvalue.PCollection], + async_result: 'AsyncComputationResult', + wait_for_inputs: bool, + runner: runner.PipelineRunner = None, + options: pipeline_options.PipelineOptions = None, + ): + """The function to be run in the thread pool for async computation.""" + try: + if wait_for_inputs: + if not self._wait_for_dependencies(pcolls_to_compute, async_result): + raise RuntimeError('Dependency computation failed or was cancelled.') + + _LOGGER.info( + 'Starting asynchronous computation for %d PCollections.', + len(pcolls_to_compute)) + + pipeline_result = self._execute_pipeline_fragment( + pcolls_to_compute, async_result, runner, options) + + # if pipeline_result.state == PipelineState.DONE: + # self._env.mark_pcollection_computed(pcolls_to_compute) + # _LOGGER.info( + # 'Asynchronous computation finished successfully for' + # f' {len(pcolls_to_compute)} PCollections.' + # ) + # else: + # _LOGGER.error( + # 'Asynchronous computation failed for' + # f' {len(pcolls_to_compute)} PCollections. State:' + # f' {pipeline_result.state}' + # ) + return pipeline_result + except Exception as e: + _LOGGER.exception('Exception during asynchronous computation: %s', e) + raise + # finally: + # self._env.unmark_pcollection_computing(pcolls_to_compute) - def _watch(self, pcolls: List[beam.pvalue.PCollection]) -> None: + def _watch(self, pcolls: list[beam.pvalue.PCollection]) -> None: """Watch any pcollections not being watched. This allows for the underlying caching layer to identify the PCollection as @@ -337,7 +588,7 @@ def cancel(self: None) -> None: # evict the BCJ after they complete. ie.current_env().evict_background_caching_job(self.user_pipeline) - def describe(self) -> Dict[str, int]: + def describe(self) -> dict[str, int]: """Returns a dictionary describing the cache and recording.""" cache_manager = ie.current_env().get_cache_manager(self.user_pipeline) @@ -386,9 +637,213 @@ def record_pipeline(self) -> bool: return True return False + def compute_async( + self, + pcolls: set[beam.pvalue.PCollection], + wait_for_inputs: bool = True, + blocking: bool = False, + runner: runner.PipelineRunner = None, + options: pipeline_options.PipelineOptions = None, + force_compute: bool = False, + ) -> Optional[AsyncComputationResult]: + """Computes the given PCollections, potentially asynchronously.""" + + if force_compute: + self._env.evict_computed_pcollections(self.user_pipeline) + + computed_pcolls = { + pcoll + for pcoll in pcolls if pcoll in self._env.computed_pcollections + } + computing_pcolls = { + pcoll + for pcoll in pcolls if self._env.is_pcollection_computing(pcoll) + } + pcolls_to_compute = pcolls - computed_pcolls - computing_pcolls + + if not pcolls_to_compute: + _LOGGER.info( + 'All requested PCollections are already computed or are being' + ' computed.') + return None + + self._watch(list(pcolls_to_compute)) + self.record_pipeline() + + if blocking: + self._env.mark_pcollection_computing(pcolls_to_compute) + try: + if wait_for_inputs: + if not self._wait_for_dependencies(pcolls_to_compute): + raise RuntimeError( + 'Dependency computation failed or was cancelled.') + pipeline_result = self._execute_pipeline_fragment( + pcolls_to_compute, None, runner, options) + if pipeline_result.state == PipelineState.DONE: + self._env.mark_pcollection_computed(pcolls_to_compute) + else: + _LOGGER.error( + 'Blocking computation failed. State: %s', pipeline_result.state) + raise RuntimeError( + 'Blocking computation failed. State: %s', pipeline_result.state) + finally: + self._env.unmark_pcollection_computing(pcolls_to_compute) + return None + + else: # Asynchronous + future = Future() + async_result = AsyncComputationResult( + future, pcolls_to_compute, self.user_pipeline, self) + self._async_computations[async_result._display_id] = async_result + self._env.mark_pcollection_computing(pcolls_to_compute) + + def task(): + try: + result = self._run_async_computation( + pcolls_to_compute, async_result, wait_for_inputs, runner, options) + future.set_result(result) + except Exception as e: + if not future.cancelled(): + future.set_exception(e) + + self._executor.submit(task) + return async_result + + def _get_pipeline_graph(self): + """Lazily initializes and returns the PipelineGraph.""" + if self._pipeline_graph is None: + try: + # Try to create the graph. + self._pipeline_graph = PipelineGraph(self.user_pipeline) + except (ImportError, NameError, AttributeError): + # If pydot is missing, PipelineGraph() might crash. + _LOGGER.warning( + "Could not create PipelineGraph (pydot missing?). " \ + "Async features disabled." + ) + self._pipeline_graph = None + return self._pipeline_graph + + def _get_pcoll_id_map(self): + """Creates a map from PCollection object to its ID in the proto.""" + pcoll_to_id = {} + graph = self._get_pipeline_graph() + if graph and graph._pipeline_instrument: + pcoll_to_id = graph._pipeline_instrument._pcoll_to_pcoll_id + return {v: k for k, v in pcoll_to_id.items()} + + def _get_all_dependencies( + self, + pcolls: set[beam.pvalue.PCollection]) -> set[beam.pvalue.PCollection]: + """Gets all upstream PCollection dependencies + for the given set of PCollections.""" + graph = self._get_pipeline_graph() + if not graph: + return set() + + analyzer = graph._pipeline_instrument + if not analyzer: + return set() + + pcoll_to_id = analyzer._pcoll_to_pcoll_id + + target_pcoll_ids = { + pcoll_to_id.get(str(pcoll)) + for pcoll in pcolls if str(pcoll) in pcoll_to_id + } + + if not target_pcoll_ids: + return set() + + # Build a map from PCollection ID to the actual PCollection object + id_to_pcoll_obj = {} + for _, inspectable in self._env.inspector.inspectables.items(): + value = inspectable['value'] + if isinstance(value, beam.pvalue.PCollection): + pcoll_id = pcoll_to_id.get(str(value)) + if pcoll_id: + id_to_pcoll_obj[pcoll_id] = value + + dependencies = set() + queue = collections.deque(target_pcoll_ids) + visited_pcoll_ids = set(target_pcoll_ids) + + producers = graph._producers + transforms = graph._pipeline_proto.components.transforms + + while queue: + pcoll_id = queue.popleft() + if pcoll_id not in producers: + continue + + producer_id = producers[pcoll_id] + transform_proto = transforms.get(producer_id) + if not transform_proto: + continue + + for input_pcoll_id in transform_proto.inputs.values(): + if input_pcoll_id not in visited_pcoll_ids: + visited_pcoll_ids.add(input_pcoll_id) + queue.append(input_pcoll_id) + + dep_obj = id_to_pcoll_obj.get(input_pcoll_id) + if dep_obj and dep_obj not in pcolls: + dependencies.add(dep_obj) + + return dependencies + + def _wait_for_dependencies( + self, + pcolls: set[beam.pvalue.PCollection], + async_result: Optional[AsyncComputationResult] = None, + ) -> bool: + """Waits for any dependencies of the given + PCollections that are currently being computed.""" + dependencies = self._get_all_dependencies(pcolls) + computing_deps: dict[beam.pvalue.PCollection, AsyncComputationResult] = {} + + for dep in dependencies: + if self._env.is_pcollection_computing(dep): + for comp in self._async_computations.values(): + if dep in comp._pcolls: + computing_deps[dep] = comp + break + + if not computing_deps: + return True + + if async_result: + async_result.update_display( + 'Waiting for %d dependencies to finish...', len(computing_deps)) + _LOGGER.info( + 'Waiting for %d dependencies: %s', + len(computing_deps), + computing_deps.keys()) + + futures_to_wait = list( + set(comp._future for comp in computing_deps.values())) + + try: + for i, future in enumerate(futures_to_wait): + if async_result: + async_result.update_display( + f'Waiting for dependency {i + 1}/{len(futures_to_wait)}...', + progress=0.05 + 0.05 * (i / len(futures_to_wait)), + ) + future.result() + if async_result: + async_result.update_display('Dependencies finished.', progress=0.1) + _LOGGER.info('Dependencies finished successfully.') + return True + except Exception as e: + if async_result: + async_result.update_display(f'Dependency failed: {e}') + _LOGGER.error('Dependency computation failed: %s', e, exc_info=e) + return False + def record( self, - pcolls: List[beam.pvalue.PCollection], + pcolls: list[beam.pvalue.PCollection], *, max_n: int, max_duration: Union[int, str], @@ -431,8 +886,11 @@ def record( # Start a pipeline fragment to start computing the PCollections. uncomputed_pcolls = set(pcolls).difference(computed_pcolls) if uncomputed_pcolls: - # Clear the cache of the given uncomputed PCollections because they are - # incomplete. + if not self._wait_for_dependencies(uncomputed_pcolls): + raise RuntimeError( + 'Cannot record because a dependency failed to compute' + ' asynchronously.') + self._clear() merged_options = pipeline_options.PipelineOptions( diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py index 698a464ae739..d2038719f67a 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager_test.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager_test.py @@ -17,7 +17,9 @@ import time import unittest +from concurrent.futures import Future from unittest.mock import MagicMock +from unittest.mock import call from unittest.mock import patch import apache_beam as beam @@ -30,6 +32,8 @@ from apache_beam.runners.interactive.caching.cacheable import CacheKey from apache_beam.runners.interactive.interactive_runner import InteractiveRunner from apache_beam.runners.interactive.options.capture_limiters import Limiter +from apache_beam.runners.interactive.recording_manager import _LOGGER +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult from apache_beam.runners.interactive.recording_manager import ElementStream from apache_beam.runners.interactive.recording_manager import Recording from apache_beam.runners.interactive.recording_manager import RecordingManager @@ -43,6 +47,386 @@ from apache_beam.utils.windowed_value import WindowedValue +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') +class AsyncComputationResultTest(unittest.TestCase): + def setUp(self): + self.mock_future = MagicMock(spec=Future) + self.pcolls = {MagicMock(spec=beam.pvalue.PCollection)} + self.user_pipeline = MagicMock(spec=beam.Pipeline) + self.recording_manager = MagicMock(spec=RecordingManager) + self.recording_manager._async_computations = {} + self.env = ie.InteractiveEnvironment() + patch.object(ie, 'current_env', return_value=self.env).start() + + self.mock_button = patch('ipywidgets.Button', autospec=True).start() + self.mock_float_progress = patch( + 'ipywidgets.FloatProgress', autospec=True).start() + self.mock_output = patch('ipywidgets.Output', autospec=True).start() + self.mock_hbox = patch('ipywidgets.HBox', autospec=True).start() + self.mock_vbox = patch('ipywidgets.VBox', autospec=True).start() + self.mock_display = patch( + 'apache_beam.runners.interactive.recording_manager.display', + autospec=True).start() + self.mock_clear_output = patch( + 'apache_beam.runners.interactive.recording_manager.clear_output', + autospec=True).start() + self.mock_html = patch( + 'apache_beam.runners.interactive.recording_manager.HTML', + autospec=True).start() + + self.addCleanup(patch.stopall) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_async_result_init_non_ipython(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.assertIsNotNone(async_res) + self.mock_future.add_done_callback.assert_called_once() + self.assertIsNone(async_res._cancel_button) + + def test_on_done_success(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.DONE + self.mock_future.result.return_value = mock_pipeline_result + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = False + async_res._display_id = 'test_id' + self.recording_manager._async_computations['test_id'] = async_res + + with patch.object( + self.env, 'unmark_pcollection_computing' + ) as mock_unmark, patch.object( + self.env, 'mark_pcollection_computed' + ) as mock_mark_computed, patch.object( + async_res, 'update_display' + ) as mock_update: + async_res._on_done(self.mock_future) + mock_unmark.assert_called_once_with(self.pcolls) + mock_mark_computed.assert_called_once_with(self.pcolls) + self.assertNotIn('test_id', self.recording_manager._async_computations) + mock_update.assert_called_with('Computation Finished Successfully.', 1.0) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_on_done_failure(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + test_exception = ValueError('Test') + self.mock_future.exception.return_value = test_exception + self.mock_future.cancelled.return_value = False + + with patch.object( + self.env, 'unmark_pcollection_computing' + ) as mock_unmark, patch.object( + self.env, 'mark_pcollection_computed' + ) as mock_mark_computed: + async_res._on_done(self.mock_future) + mock_unmark.assert_called_once_with(self.pcolls) + mock_mark_computed.assert_not_called() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_on_done_cancelled(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.cancelled.return_value = True + + with patch.object(self.env, 'unmark_pcollection_computing') as mock_unmark: + async_res._on_done(self.mock_future) + mock_unmark.assert_called_once_with(self.pcolls) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + def test_cancel(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.RUNNING + async_res.set_pipeline_result(mock_pipeline_result) + self.mock_future.done.return_value = False + + self.assertTrue(async_res.cancel()) + mock_pipeline_result.cancel.assert_called_once() + self.assertTrue(async_res._cancel_requested) + self.assertTrue(async_res._cancel_button.disabled) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_cancel_already_done(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = True + self.assertFalse(async_res.cancel()) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + @patch('apache_beam.runners.interactive.recording_manager.display') + @patch('ipywidgets.Button') + @patch('ipywidgets.FloatProgress') + @patch('ipywidgets.Output') + @patch('ipywidgets.HBox') + @patch('ipywidgets.VBox') + def test_async_result_init_ipython( + self, + mock_vbox, + mock_hbox, + mock_output, + mock_progress, + mock_button, + mock_display, + ): + mock_btn_instance = mock_button.return_value + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.assertIsNotNone(async_res) + mock_button.assert_called_once_with(description='Cancel') + mock_progress.assert_called_once() + mock_output.assert_called_once() + mock_hbox.assert_called_once() + mock_vbox.assert_called_once() + mock_display.assert_called() + mock_btn_instance.on_click.assert_called_once_with( + async_res._cancel_clicked) + self.mock_future.add_done_callback.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + @patch( + 'apache_beam.runners.interactive.recording_manager.display', MagicMock()) + @patch('ipywidgets.Button', MagicMock()) + @patch('ipywidgets.FloatProgress', MagicMock()) + @patch('ipywidgets.Output', MagicMock()) + def test_cancel_clicked(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + with patch.object(async_res, 'cancel') as mock_cancel, patch.object( + async_res, 'update_display' + ) as mock_update: + async_res._cancel_clicked(None) + self.assertTrue(async_res._cancel_requested) + self.assertTrue(async_res._cancel_button.disabled) + mock_update.assert_called_once_with('Cancel requested...') + mock_cancel.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_update_display_non_ipython(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + with patch('builtins.print') as mock_print: + async_res.update_display('Test Message') + mock_print.assert_called_once_with('AsyncCompute: Test Message') + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + def test_update_display_ipython(self): + mock_prog_instance = self.mock_float_progress.return_value + mock_btn_instance = self.mock_button.return_value + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + + update_call_count = 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + + # State: Running + self.mock_future.done.return_value = False + async_res._cancel_requested = False + async_res.update_display('Running Test', 0.5) + update_call_count += 1 + self.mock_display.assert_called() + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertEqual(mock_prog_instance.value, 0.5) + self.assertFalse(mock_btn_instance.disabled) + self.mock_html.assert_called_with('

Running Test

') + + # State: Done Success + self.mock_future.done.return_value = True + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = False + async_res.update_display('Done') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertTrue(mock_btn_instance.disabled) + self.assertEqual(mock_prog_instance.bar_style, 'success') + self.assertEqual(mock_prog_instance.description, 'Done') + + # State: Done Failed + self.mock_future.exception.return_value = Exception() + async_res.update_display('Failed') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertEqual(mock_prog_instance.bar_style, 'danger') + self.assertEqual(mock_prog_instance.description, 'Failed') + + # State: Done Cancelled + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = True + async_res.update_display('Cancelled') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertEqual(mock_prog_instance.bar_style, 'warning') + self.assertEqual(mock_prog_instance.description, 'Cancelled') + + # State: Cancelling + self.mock_future.done.return_value = False + async_res._cancel_requested = True + async_res.update_display('Cancelling') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertTrue(mock_btn_instance.disabled) + self.assertEqual(mock_prog_instance.description, 'Cancelling...') + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_set_pipeline_result_cancel_requested(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + async_res._cancel_requested = True + mock_pipeline_result = MagicMock() + with patch.object(async_res, 'cancel') as mock_cancel: + async_res.set_pipeline_result(mock_pipeline_result) + self.assertIs(async_res._pipeline_result, mock_pipeline_result) + mock_cancel.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_exception_timeout(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.exception.side_effect = TimeoutError + self.assertIsNone(async_res.exception(timeout=0.1)) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + @patch.object(_LOGGER, 'warning') + def test_on_done_not_done_state(self, mock_logger_warning): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.FAILED + self.mock_future.result.return_value = mock_pipeline_result + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = False + + with patch.object(self.env, + 'mark_pcollection_computed') as mock_mark_computed: + async_res._on_done(self.mock_future) + mock_mark_computed.assert_not_called() + mock_logger_warning.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + def test_cancel_no_pipeline_result(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = False + self.mock_future.cancel.return_value = True + with patch.object(async_res, 'update_display') as mock_update: + self.assertTrue(async_res.cancel()) + mock_update.assert_any_call( + 'Pipeline not yet fully started, cancelling future.') + self.mock_future.cancel.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + def test_cancel_pipeline_terminal_state(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = False + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.DONE + async_res.set_pipeline_result(mock_pipeline_result) + + with patch.object(async_res, 'update_display') as mock_update: + self.assertFalse(async_res.cancel()) + mock_update.assert_any_call( + 'Cannot cancel: Pipeline already in terminal state DONE.') + mock_pipeline_result.cancel.assert_not_called() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + @patch.object(_LOGGER, 'warning') + @patch.object(AsyncComputationResult, 'update_display') + def test_cancel_pipeline_exception( + self, mock_update_display, mock_logger_warning): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = False + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.RUNNING + test_exception = RuntimeError('Cancel Failed') + mock_pipeline_result.cancel.side_effect = test_exception + async_res.set_pipeline_result(mock_pipeline_result) + self.mock_future.cancel.return_value = False + + self.assertFalse(async_res.cancel()) + + expected_calls = [ + call('Initializing...'), # From __init__ + call('Attempting to cancel...'), # From cancel() start + call('Error sending cancel signal: %s', + test_exception) # From except block + ] + mock_update_display.assert_has_calls(expected_calls, any_order=False) + + mock_logger_warning.assert_called_once() + self.mock_future.cancel.assert_called_once() + + class MockPipelineResult(beam.runners.runner.PipelineResult): """Mock class for controlling a PipelineResult.""" def __init__(self): @@ -283,6 +667,9 @@ def test_describe(self): cache_manager.size('full', letters_stream.cache_key)) +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') class RecordingManagerTest(unittest.TestCase): def test_basic_execution(self): """A basic pipeline to be used as a smoke test.""" @@ -565,6 +952,119 @@ def test_record_detects_remote_runner( # Reset cache_root value. ib.options.cache_root = None + def test_compute_async_blocking(self): + p = beam.Pipeline(InteractiveRunner()) + pcoll = p | beam.Create([1, 2, 3]) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + + with patch.object(rm, '_execute_pipeline_fragment') as mock_execute: + mock_result = MagicMock() + mock_result.state = PipelineState.DONE + mock_execute.return_value = mock_result + res = rm.compute_async({pcoll}, blocking=True) + self.assertIsNone(res) + mock_execute.assert_called_once() + self.assertTrue(pcoll in ie.current_env().computed_pcollections) + + @patch( + 'apache_beam.runners.interactive.recording_manager.AsyncComputationResult' + ) + @patch( + 'apache_beam.runners.interactive.recording_manager.ThreadPoolExecutor.' + 'submit') + def test_compute_async_non_blocking(self, mock_submit, mock_async_result_cls): + p = beam.Pipeline(InteractiveRunner()) + pcoll = p | beam.Create([1, 2, 3]) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + mock_async_res_instance = mock_async_result_cls.return_value + + # Capture the task + task_submitted = None + + def capture_task(task): + nonlocal task_submitted + task_submitted = task + # Return a mock future + return MagicMock() + + mock_submit.side_effect = capture_task + + with patch.object( + rm, '_wait_for_dependencies', return_value=True + ), patch.object( + rm, '_execute_pipeline_fragment' + ) as _, patch.object( + ie.current_env(), + 'mark_pcollection_computing', + wraps=ie.current_env().mark_pcollection_computing, + ) as wrapped_mark: + + res = rm.compute_async({pcoll}, blocking=False) + wrapped_mark.assert_called_once_with({pcoll}) + + # Run the task to trigger the marks + self.assertIs(res, mock_async_res_instance) + mock_submit.assert_called_once() + self.assertIsNotNone(task_submitted) + + with patch.object( + rm, '_wait_for_dependencies', return_value=True + ), patch.object( + rm, '_execute_pipeline_fragment' + ) as _: + task_submitted() + + self.assertTrue(pcoll in ie.current_env().computing_pcollections) + + def test_get_all_dependencies(self): + p = beam.Pipeline(InteractiveRunner()) + p1 = p | 'C1' >> beam.Create([1]) + p2 = p | 'C2' >> beam.Create([2]) + p3 = p1 | 'M1' >> beam.Map(lambda x: x) + p4 = (p2, p3) | 'F1' >> beam.Flatten() + p5 = p3 | 'M2' >> beam.Map(lambda x: x) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + rm.record_pipeline() # Analyze pipeline + + self.assertEqual(rm._get_all_dependencies({p1}), set()) + self.assertEqual(rm._get_all_dependencies({p3}), {p1}) + self.assertEqual(rm._get_all_dependencies({p4}), {p1, p2, p3}) + self.assertEqual(rm._get_all_dependencies({p5}), {p1, p3}) + self.assertEqual(rm._get_all_dependencies({p4, p5}), {p1, p2, p3}) + + @patch( + 'apache_beam.runners.interactive.recording_manager.AsyncComputationResult' + ) + def test_wait_for_dependencies(self, mock_async_result_cls): + p = beam.Pipeline(InteractiveRunner()) + p1 = p | 'C1' >> beam.Create([1]) + p2 = p1 | 'M1' >> beam.Map(lambda x: x) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + rm.record_pipeline() + + # Scenario 1: No dependencies computing + self.assertTrue(rm._wait_for_dependencies({p2})) + + # Scenario 2: Dependency is computing + mock_future = MagicMock(spec=Future) + mock_async_res = MagicMock(spec=AsyncComputationResult) + mock_async_res._future = mock_future + mock_async_res._pcolls = {p1} + rm._async_computations['dep_id'] = mock_async_res + ie.current_env().mark_pcollection_computing({p1}) + + self.assertTrue(rm._wait_for_dependencies({p2})) + mock_future.result.assert_called_once() + ie.current_env().unmark_pcollection_computing({p1}) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/utils_test.py b/sdks/python/apache_beam/runners/interactive/utils_test.py index 5fb41df35862..17285ac52af7 100644 --- a/sdks/python/apache_beam/runners/interactive/utils_test.py +++ b/sdks/python/apache_beam/runners/interactive/utils_test.py @@ -244,6 +244,9 @@ def test_child_module_logger_can_override_logging_level(self, mock_emit): reason='[interactive] dependency is not installed.') class ProgressIndicatorTest(unittest.TestCase): def setUp(self): + self.patcher = patch( + 'apache_beam.runners.interactive.cache_manager.CacheManager.cleanup') + self.patcher.start() ie.new_env() @patch('IPython.get_ipython', new_callable=mock_get_ipython) @@ -279,6 +282,9 @@ def test_progress_in_HTML_JS_when_in_notebook( mocked_html.assert_called() mocked_js.assert_called() + def tearDown(self): + self.patcher.stop() + @unittest.skipIf( not ie.current_env().is_interactive_ready, @@ -287,6 +293,9 @@ class MessagingUtilTest(unittest.TestCase): SAMPLE_DATA = {'a': [1, 2, 3], 'b': 4, 'c': '5', 'd': {'e': 'f'}} def setUp(self): + self.patcher = patch( + 'apache_beam.runners.interactive.cache_manager.CacheManager.cleanup') + self.patcher.start() ie.new_env() def test_as_json_decorator(self): @@ -298,6 +307,9 @@ def dummy(): # dictionaries remember the order of items inserted. self.assertEqual(json.loads(dummy()), MessagingUtilTest.SAMPLE_DATA) + def tearDown(self): + self.patcher.stop() + class GeneralUtilTest(unittest.TestCase): def test_pcoll_by_name(self):