diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam.py b/sdks/python/apache_beam/runners/interactive/interactive_beam.py index 76c4ea0aa666..7b773fda5db8 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam.py @@ -35,11 +35,9 @@ # pytype: skip-file import logging +from collections.abc import Iterable from datetime import timedelta from typing import Any -from typing import Dict -from typing import Iterable -from typing import List from typing import Optional from typing import Union @@ -57,6 +55,7 @@ from apache_beam.runners.interactive.display.pcoll_visualization import visualize from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll from apache_beam.runners.interactive.options import interactive_options +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult from apache_beam.runners.interactive.utils import deferred_df_to_pcollection from apache_beam.runners.interactive.utils import elements_to_df from apache_beam.runners.interactive.utils import find_pcoll_name @@ -275,7 +274,7 @@ class Recordings(): """ def describe( self, - pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821 + pipeline: Optional[beam.Pipeline] = None) -> dict[str, Any]: # noqa: F821 """Returns a description of all the recordings for the given pipeline. If no pipeline is given then this returns a dictionary of descriptions for @@ -417,10 +416,10 @@ class Clusters: # DATAPROC_IMAGE_VERSION = '2.0.XX-debian10' def __init__(self) -> None: - self.dataproc_cluster_managers: Dict[ClusterMetadata, + self.dataproc_cluster_managers: dict[ClusterMetadata, DataprocClusterManager] = {} - self.master_urls: Dict[str, ClusterMetadata] = {} - self.pipelines: Dict[beam.Pipeline, DataprocClusterManager] = {} + self.master_urls: dict[str, ClusterMetadata] = {} + self.pipelines: dict[beam.Pipeline, DataprocClusterManager] = {} self.default_cluster_metadata: Optional[ClusterMetadata] = None def create( @@ -511,7 +510,7 @@ def cleanup( def describe( self, cluster_identifier: Optional[ClusterIdentifier] = None - ) -> Union[ClusterMetadata, List[ClusterMetadata]]: + ) -> Union[ClusterMetadata, list[ClusterMetadata]]: """Describes the ClusterMetadata by a ClusterIdentifier. If no cluster_identifier is given or if the cluster_identifier is unknown, @@ -679,7 +678,7 @@ def run_pipeline(self): @progress_indicated def show( - *pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection], + *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection], include_window_info: bool = False, visualize_data: bool = False, n: Union[int, str] = 'inf', @@ -1012,6 +1011,88 @@ def as_pcollection(pcoll_or_df): return result_tuple +@progress_indicated +def compute( + *pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection], + wait_for_inputs: bool = True, + blocking: bool = False, + runner=None, + options=None, + force_compute=False, +) -> Optional[AsyncComputationResult]: + """Computes the given PCollections, potentially asynchronously. + + Args: + *pcolls: PCollections to compute. Can be a single PCollection, an iterable + of PCollections, or a dictionary with PCollections as values. + wait_for_inputs: Whether to wait until the asynchronous dependencies are + computed. Setting this to False allows to immediately schedule the + computation, but also potentially results in running the same pipeline + stages multiple times. + blocking: If False, the computation will run in non-blocking fashion. In + Colab/IPython environment this mode will also provide the controls for the + running pipeline. If True, the computation will block until the pipeline + is done. + runner: (optional) the runner with which to compute the results. + options: (optional) any additional pipeline options to use to compute the + results. + force_compute: (optional) if True, forces recomputation rather than using + cached PCollections. + + Returns: + An AsyncComputationResult object if blocking is False, otherwise None. + """ + flatten_pcolls = [] + for pcoll_container in pcolls: + if isinstance(pcoll_container, dict): + flatten_pcolls.extend(pcoll_container.values()) + elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)): + flatten_pcolls.append(pcoll_container) + else: + try: + flatten_pcolls.extend(iter(pcoll_container)) + except TypeError: + raise ValueError( + f'The given pcoll {pcoll_container} is not a dict, an iterable or ' + 'a PCollection.') + + pcolls_set = set() + for pcoll in flatten_pcolls: + if isinstance(pcoll, DeferredBase): + pcoll, _ = deferred_df_to_pcollection(pcoll) + watch({f'anonymous_pcollection_{id(pcoll)}': pcoll}) + assert isinstance( + pcoll, beam.pvalue.PCollection + ), f'{pcoll} is not an apache_beam.pvalue.PCollection.' + pcolls_set.add(pcoll) + + if not pcolls_set: + _LOGGER.info('No PCollections to compute.') + return None + + pcoll_pipeline = next(iter(pcolls_set)).pipeline + user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline) + if not user_pipeline: + watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline}) + user_pipeline = pcoll_pipeline + + for pcoll in pcolls_set: + if pcoll.pipeline is not user_pipeline: + raise ValueError('All PCollections must belong to the same pipeline.') + + recording_manager = ie.current_env().get_recording_manager( + user_pipeline, create_if_absent=True) + + return recording_manager.compute_async( + pcolls_set, + wait_for_inputs=wait_for_inputs, + blocking=blocking, + runner=runner, + options=options, + force_compute=force_compute, + ) + + @progress_indicated def show_graph(pipeline): """Shows the current pipeline shape of a given Beam pipeline as a DAG. diff --git a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py index 37cd63842b1e..21163fc121c5 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_beam_test.py @@ -23,11 +23,16 @@ import sys import time import unittest +from concurrent.futures import TimeoutError from typing import NamedTuple +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import call from unittest.mock import patch import apache_beam as beam from apache_beam import dataframe as frames +from apache_beam.dataframe.frame_base import DeferredBase from apache_beam.options.pipeline_options import FlinkRunnerOptions from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners.interactive import interactive_beam as ib @@ -36,6 +41,7 @@ from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import DataprocClusterManager from apache_beam.runners.interactive.dataproc.types import ClusterMetadata from apache_beam.runners.interactive.options.capture_limiters import Limiter +from apache_beam.runners.interactive.recording_manager import AsyncComputationResult from apache_beam.runners.interactive.testing.mock_env import isolated_env from apache_beam.runners.runner import PipelineState from apache_beam.testing.test_stream import TestStream @@ -65,6 +71,9 @@ def _get_watched_pcollections_with_variable_names(): return watched_pcollections +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') @isolated_env class InteractiveBeamTest(unittest.TestCase): def setUp(self): @@ -671,5 +680,387 @@ def test_default_value_for_invalid_worker_number(self): self.assertEqual(meta.num_workers, 2) +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') +@isolated_env +class InteractiveBeamComputeTest(unittest.TestCase): + def setUp(self): + self.env = ie.current_env() + self.env._is_in_ipython = False # Default to non-IPython + + def test_compute_blocking(self): + p = beam.Pipeline(ir.InteractiveRunner()) + data = list(range(10)) + pcoll = p | 'Create' >> beam.Create(data) + ib.watch(locals()) + self.env.track_user_pipelines() + + result = ib.compute(pcoll, blocking=True) + self.assertIsNone(result) # Blocking returns None + self.assertTrue(pcoll in self.env.computed_pcollections) + collected = ib.collect(pcoll, raw_records=True) + self.assertEqual(collected, data) + + def test_compute_non_blocking(self): + p = beam.Pipeline(ir.InteractiveRunner()) + data = list(range(5)) + pcoll = p | 'Create' >> beam.Create(data) + ib.watch(locals()) + self.env.track_user_pipelines() + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsInstance(async_result, AsyncComputationResult) + + pipeline_result = async_result.result(timeout=60) + self.assertTrue(async_result.done()) + self.assertIsNone(async_result.exception()) + self.assertEqual(pipeline_result.state, PipelineState.DONE) + self.assertTrue(pcoll in self.env.computed_pcollections) + collected = ib.collect(pcoll, raw_records=True) + self.assertEqual(collected, data) + + def test_compute_with_list_input(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) + ib.watch(locals()) + self.env.track_user_pipelines() + + ib.compute([pcoll1, pcoll2], blocking=True) + self.assertTrue(pcoll1 in self.env.computed_pcollections) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + def test_compute_with_dict_input(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) + ib.watch(locals()) + self.env.track_user_pipelines() + + ib.compute({'a': pcoll1, 'b': pcoll2}, blocking=True) + self.assertTrue(pcoll1 in self.env.computed_pcollections) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + def test_compute_empty_input(self): + result = ib.compute([], blocking=True) + self.assertIsNone(result) + result_async = ib.compute([], blocking=False) + self.assertIsNone(result_async) + + def test_compute_force_recompute(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'Create' >> beam.Create([1, 2, 3]) + ib.watch(locals()) + self.env.track_user_pipelines() + + ib.compute(pcoll, blocking=True) + self.assertTrue(pcoll in self.env.computed_pcollections) + + # Mock evict_computed_pcollections to check if it's called + with patch.object(self.env, 'evict_computed_pcollections') as mock_evict: + ib.compute(pcoll, blocking=True, force_compute=True) + mock_evict.assert_called_once_with(p) + self.assertTrue(pcoll in self.env.computed_pcollections) + + def test_compute_non_blocking_exception(self): + p = beam.Pipeline(ir.InteractiveRunner()) + + def raise_error(elem): + raise ValueError('Test Error') + + pcoll = p | 'Create' >> beam.Create([1]) | 'Error' >> beam.Map(raise_error) + ib.watch(locals()) + self.env.track_user_pipelines() + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsInstance(async_result, AsyncComputationResult) + + with self.assertRaises(ValueError): + async_result.result(timeout=60) + + self.assertTrue(async_result.done()) + self.assertIsInstance(async_result.exception(), ValueError) + self.assertFalse(pcoll in self.env.computed_pcollections) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + @patch('apache_beam.runners.interactive.recording_manager.display') + @patch('apache_beam.runners.interactive.recording_manager.clear_output') + @patch('apache_beam.runners.interactive.recording_manager.HTML') + @patch('ipywidgets.Button') + @patch('ipywidgets.FloatProgress') + @patch('ipywidgets.Output') + @patch('ipywidgets.HBox') + @patch('ipywidgets.VBox') + def test_compute_non_blocking_ipython_widgets( + self, + mock_vbox, + mock_hbox, + mock_output, + mock_progress, + mock_button, + mock_html, + mock_clear_output, + mock_display, + ): + self.env._is_in_ipython = True + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'Create' >> beam.Create(range(3)) + ib.watch(locals()) + self.env.track_user_pipelines() + + mock_controls = mock_vbox.return_value + mock_html_instance = mock_html.return_value + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsNotNone(async_result) + mock_button.assert_called_once_with(description='Cancel') + mock_progress.assert_called_once() + mock_output.assert_called_once() + mock_hbox.assert_called_once() + mock_vbox.assert_called_once() + mock_html.assert_called_once_with('
Initializing...
') + + self.assertEqual(mock_display.call_count, 2) + mock_display.assert_has_calls([ + call(mock_controls, display_id=async_result._display_id), + call(mock_html_instance) + ]) + + mock_clear_output.assert_called_once() + async_result.result(timeout=60) # Let it finish + + def test_compute_dependency_wait_true(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2) + ib.watch(locals()) + self.env.track_user_pipelines() + + rm = self.env.get_recording_manager(p) + + # Start pcoll1 computation + async_res1 = ib.compute(pcoll1, blocking=False) + self.assertTrue(self.env.is_pcollection_computing(pcoll1)) + + # Spy on _wait_for_dependencies + with patch.object(rm, + '_wait_for_dependencies', + wraps=rm._wait_for_dependencies) as spy_wait: + async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=True) + + # Check that wait_for_dependencies was called for pcoll2 + spy_wait.assert_called_with({pcoll2}, async_res2) + + # Let pcoll1 finish + async_res1.result(timeout=60) + self.assertTrue(pcoll1 in self.env.computed_pcollections) + self.assertFalse(self.env.is_pcollection_computing(pcoll1)) + + # pcoll2 should now run and complete + async_res2.result(timeout=60) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + @patch.object(ie.InteractiveEnvironment, 'is_pcollection_computing') + def test_compute_dependency_wait_false(self, mock_is_computing): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = pcoll1 | 'Map' >> beam.Map(lambda x: x * 2) + ib.watch(locals()) + self.env.track_user_pipelines() + + rm = self.env.get_recording_manager(p) + + # Pretend pcoll1 is computing + mock_is_computing.side_effect = lambda pcoll: pcoll is pcoll1 + + with patch.object(rm, + '_execute_pipeline_fragment', + wraps=rm._execute_pipeline_fragment) as spy_execute: + async_res2 = ib.compute(pcoll2, blocking=False, wait_for_inputs=False) + async_res2.result(timeout=60) + + # Assert that execute was called for pcoll2 without waiting + spy_execute.assert_called_with({pcoll2}, async_res2, ANY, ANY) + self.assertTrue(pcoll2 in self.env.computed_pcollections) + + def test_async_computation_result_cancel(self): + p = beam.Pipeline(ir.InteractiveRunner()) + # A stream that never finishes to test cancellation + pcoll = p | beam.Create([1]) | beam.Map(lambda x: time.sleep(100)) + ib.watch(locals()) + self.env.track_user_pipelines() + + async_result = ib.compute(pcoll, blocking=False) + self.assertIsInstance(async_result, AsyncComputationResult) + + # Give it a moment to start + time.sleep(0.1) + + # Mock the pipeline result's cancel method + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.RUNNING + async_result.set_pipeline_result(mock_pipeline_result) + + self.assertTrue(async_result.cancel()) + mock_pipeline_result.cancel.assert_called_once() + + # The future should be cancelled eventually by the runner + # This part is hard to test without deeper runner integration + with self.assertRaises(TimeoutError): + async_result.result(timeout=1) # It should not complete successfully + + @patch( + 'apache_beam.runners.interactive.recording_manager.RecordingManager.' + '_execute_pipeline_fragment') + def test_compute_multiple_async(self, mock_execute_fragment): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3]) + pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6]) + pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2) + ib.watch(locals()) + self.env.track_user_pipelines() + + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.DONE + mock_execute_fragment.return_value = mock_pipeline_result + + res1 = ib.compute(pcoll1, blocking=False) + res2 = ib.compute(pcoll2, blocking=False) + res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1 + + self.assertIsNotNone(res1) + self.assertIsNotNone(res2) + self.assertIsNotNone(res3) + + res1.result(timeout=60) + res2.result(timeout=60) + res3.result(timeout=60) + + time.sleep(0.1) + + self.assertTrue( + pcoll1 in self.env.computed_pcollections, "pcoll1 not marked computed") + self.assertTrue( + pcoll2 in self.env.computed_pcollections, "pcoll2 not marked computed") + self.assertTrue( + pcoll3 in self.env.computed_pcollections, "pcoll3 not marked computed") + + self.assertEqual(mock_execute_fragment.call_count, 3) + + @patch( + 'apache_beam.runners.interactive.interactive_beam.' + 'deferred_df_to_pcollection') + def test_compute_input_flattening(self, mock_deferred_to_pcoll): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p | 'C1' >> beam.Create([1]) + pcoll2 = p | 'C2' >> beam.Create([2]) + pcoll3 = p | 'C3' >> beam.Create([3]) + pcoll4 = p | 'C4' >> beam.Create([4]) + + class MockDeferred(DeferredBase): + def __init__(self, pcoll): + mock_expr = MagicMock() + super().__init__(mock_expr) + self._pcoll = pcoll + + def _get_underlying_pcollection(self): + return self._pcoll + + deferred_pcoll = MockDeferred(pcoll4) + + mock_deferred_to_pcoll.return_value = (pcoll4, p) + + ib.watch(locals()) + self.env.track_user_pipelines() + + with patch.object(self.env, 'get_recording_manager') as mock_get_rm: + mock_rm = MagicMock() + mock_get_rm.return_value = mock_rm + ib.compute(pcoll1, [pcoll2], {'a': pcoll3}, deferred_pcoll) + + expected_pcolls = {pcoll1, pcoll2, pcoll3, pcoll4} + mock_rm.compute_async.assert_called_once_with( + expected_pcolls, + wait_for_inputs=True, + blocking=False, + runner=None, + options=None, + force_compute=False) + + def test_compute_invalid_input_type(self): + with self.assertRaisesRegex(ValueError, + "not a dict, an iterable or a PCollection"): + ib.compute(123) + + def test_compute_mixed_pipelines(self): + p1 = beam.Pipeline(ir.InteractiveRunner()) + pcoll1 = p1 | 'C1' >> beam.Create([1]) + p2 = beam.Pipeline(ir.InteractiveRunner()) + pcoll2 = p2 | 'C2' >> beam.Create([2]) + ib.watch(locals()) + self.env.track_user_pipelines() + + with self.assertRaisesRegex( + ValueError, "All PCollections must belong to the same pipeline"): + ib.compute(pcoll1, pcoll2) + + @patch( + 'apache_beam.runners.interactive.interactive_beam.' + 'deferred_df_to_pcollection') + @patch.object(ib, 'watch') + def test_compute_with_deferred_base(self, mock_watch, mock_deferred_to_pcoll): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'C1' >> beam.Create([1]) + + class MockDeferred(DeferredBase): + def __init__(self, pcoll): + # Provide a dummy expression to satisfy DeferredBase.__init__ + mock_expr = MagicMock() + super().__init__(mock_expr) + self._pcoll = pcoll + + def _get_underlying_pcollection(self): + return self._pcoll + + deferred = MockDeferred(pcoll) + + mock_deferred_to_pcoll.return_value = (pcoll, p) + + with patch.object(self.env, 'get_recording_manager') as mock_get_rm: + mock_rm = MagicMock() + mock_get_rm.return_value = mock_rm + ib.compute(deferred) + + mock_deferred_to_pcoll.assert_called_once_with(deferred) + self.assertEqual(mock_watch.call_count, 2) + mock_watch.assert_has_calls([ + call({f'anonymous_pcollection_{id(pcoll)}': pcoll}), + call({f'anonymous_pipeline_{id(p)}': p}) + ], + any_order=False) + mock_rm.compute_async.assert_called_once_with({pcoll}, + wait_for_inputs=True, + blocking=False, + runner=None, + options=None, + force_compute=False) + + def test_compute_new_pipeline(self): + p = beam.Pipeline(ir.InteractiveRunner()) + pcoll = p | 'Create' >> beam.Create([1]) + # NOT calling ib.watch() or track_user_pipelines() + + with patch.object(self.env, 'get_recording_manager') as mock_get_rm, \ + patch.object(ib, 'watch') as mock_watch: + mock_rm = MagicMock() + mock_get_rm.return_value = mock_rm + ib.compute(pcoll) + + mock_watch.assert_called_with({f'anonymous_pipeline_{id(p)}': p}) + mock_get_rm.assert_called_once_with(p, create_if_absent=True) + mock_rm.compute_async.assert_called_once() + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment.py b/sdks/python/apache_beam/runners/interactive/interactive_environment.py index e9ff86c6276f..2a8fc23088a6 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment.py @@ -175,6 +175,9 @@ def __init__(self): # Tracks the computation completeness of PCollections. PCollections tracked # here don't need to be re-computed when data introspection is needed. self._computed_pcolls = set() + + self._computing_pcolls = set() + # Always watch __main__ module. self.watch('__main__') # Check if [interactive] dependencies are installed. @@ -720,3 +723,19 @@ def _get_gcs_cache_dir(self, pipeline, cache_dir): bucket_name = cache_dir_path.parts[1] assert_bucket_exists(bucket_name) return 'gs://{}/{}'.format('/'.join(cache_dir_path.parts[1:]), id(pipeline)) + + @property + def computing_pcollections(self): + return self._computing_pcolls + + def mark_pcollection_computing(self, pcolls): + """Marks the given pcolls as currently being computed.""" + self._computing_pcolls.update(pcolls) + + def unmark_pcollection_computing(self, pcolls): + """Removes the given pcolls from the computing set.""" + self._computing_pcolls.difference_update(pcolls) + + def is_pcollection_computing(self, pcoll): + """Checks if the given pcollection is currently being computed.""" + return pcoll in self._computing_pcolls diff --git a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py index 4d5f3f36ce67..eb3b4b514824 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_environment_test.py @@ -34,6 +34,9 @@ _module_name = 'apache_beam.runners.interactive.interactive_environment_test' +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') @isolated_env class InteractiveEnvironmentTest(unittest.TestCase): def setUp(self): @@ -341,6 +344,44 @@ def test_get_gcs_cache_dir_invalid_path(self): with self.assertRaises(ValueError): env._get_gcs_cache_dir(p, cache_root) + def test_pcollection_computing_state(self): + env = ie.InteractiveEnvironment() + p = beam.Pipeline() + pcoll1 = p | 'Create1' >> beam.Create([1]) + pcoll2 = p | 'Create2' >> beam.Create([2]) + + self.assertFalse(env.is_pcollection_computing(pcoll1)) + self.assertFalse(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, set()) + + env.mark_pcollection_computing({pcoll1}) + self.assertTrue(env.is_pcollection_computing(pcoll1)) + self.assertFalse(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, {pcoll1}) + + env.mark_pcollection_computing({pcoll2}) + self.assertTrue(env.is_pcollection_computing(pcoll1)) + self.assertTrue(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, {pcoll1, pcoll2}) + + env.unmark_pcollection_computing({pcoll1}) + self.assertFalse(env.is_pcollection_computing(pcoll1)) + self.assertTrue(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, {pcoll2}) + + env.unmark_pcollection_computing({pcoll2}) + self.assertFalse(env.is_pcollection_computing(pcoll1)) + self.assertFalse(env.is_pcollection_computing(pcoll2)) + self.assertEqual(env.computing_pcollections, set()) + + def test_mark_unmark_empty(self): + env = ie.InteractiveEnvironment() + # Ensure no errors with empty sets + env.mark_pcollection_computing(set()) + self.assertEqual(env.computing_pcollections, set()) + env.unmark_pcollection_computing(set()) + self.assertEqual(env.computing_pcollections, set()) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/recording_manager.py b/sdks/python/apache_beam/runners/interactive/recording_manager.py index f72ec2fe8e17..c19b60b64fd2 100644 --- a/sdks/python/apache_beam/runners/interactive/recording_manager.py +++ b/sdks/python/apache_beam/runners/interactive/recording_manager.py @@ -15,13 +15,17 @@ # limitations under the License. # +import collections import logging +import os import threading import time +import uuid import warnings +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor from typing import Any -from typing import Dict -from typing import List +from typing import Optional from typing import Union import pandas as pd @@ -37,11 +41,176 @@ from apache_beam.runners.interactive import pipeline_fragment as pf from apache_beam.runners.interactive import utils from apache_beam.runners.interactive.caching.cacheable import CacheKey +from apache_beam.runners.interactive.display.pipeline_graph import PipelineGraph from apache_beam.runners.interactive.options import capture_control from apache_beam.runners.runner import PipelineState _LOGGER = logging.getLogger(__name__) +try: + import ipywidgets as widgets + from IPython.display import HTML + from IPython.display import clear_output + from IPython.display import display + + IS_IPYTHON = True +except ImportError: + IS_IPYTHON = False + + +class AsyncComputationResult: + """Represents the result of an asynchronous computation.""" + def __init__( + self, + future: Future, + pcolls: set[beam.pvalue.PCollection], + user_pipeline: beam.Pipeline, + recording_manager: 'RecordingManager', + ): + self._future = future + self._pcolls = pcolls + self._user_pipeline = user_pipeline + self._env = ie.current_env() + self._recording_manager = recording_manager + self._pipeline_result: Optional[beam.runners.runner.PipelineResult] = None + self._display_id = str(uuid.uuid4()) + self._output_widget = widgets.Output() if IS_IPYTHON else None + self._cancel_button = ( + widgets.Button(description='Cancel') if IS_IPYTHON else None) + self._progress_bar = ( + widgets.FloatProgress( + value=0.0, + min=0.0, + max=1.0, + description='Running:', + bar_style='info', + ) if IS_IPYTHON else None) + self._cancel_requested = False + + if IS_IPYTHON: + self._cancel_button.on_click(self._cancel_clicked) + controls = widgets.VBox([ + widgets.HBox([self._cancel_button, self._progress_bar]), + self._output_widget, + ]) + display(controls, display_id=self._display_id) + self.update_display('Initializing...') + + self._future.add_done_callback(self._on_done) + + def _cancel_clicked(self, b): + self._cancel_requested = True + self._cancel_button.disabled = True + self.update_display('Cancel requested...') + self.cancel() + + def update_display(self, msg: str, progress: Optional[float] = None): + if not IS_IPYTHON: + print(f'AsyncCompute: {msg}') + return + + with self._output_widget: + clear_output(wait=True) + display(HTML(f'{msg}
')) + + if progress is not None: + self._progress_bar.value = progress + + if self.done(): + self._cancel_button.disabled = True + if self.exception(): + self._progress_bar.bar_style = 'danger' + self._progress_bar.description = 'Failed' + elif self._future.cancelled(): + self._progress_bar.bar_style = 'warning' + self._progress_bar.description = 'Cancelled' + else: + self._progress_bar.bar_style = 'success' + self._progress_bar.description = 'Done' + elif self._cancel_requested: + self._cancel_button.disabled = True + self._progress_bar.description = 'Cancelling...' + else: + self._cancel_button.disabled = False + + def set_pipeline_result( + self, pipeline_result: beam.runners.runner.PipelineResult): + self._pipeline_result = pipeline_result + if self._cancel_requested: + self.cancel() + + def result(self, timeout=None): + return self._future.result(timeout=timeout) + + def done(self): + return self._future.done() + + def exception(self, timeout=None): + try: + return self._future.exception(timeout=timeout) + except TimeoutError: + return None + + def _on_done(self, future: Future): + self._env.unmark_pcollection_computing(self._pcolls) + self._recording_manager._async_computations.pop(self._display_id, None) + + if future.cancelled(): + self.update_display('Computation Cancelled.', 1.0) + return + + exc = future.exception() + if exc: + self.update_display(f'Error: {exc}', 1.0) + _LOGGER.error('Asynchronous computation failed: %s', exc, exc_info=exc) + else: + self.update_display('Computation Finished Successfully.', 1.0) + res = future.result() + if res and res.state == PipelineState.DONE: + self._env.mark_pcollection_computed(self._pcolls) + else: + _LOGGER.warning( + 'Async computation finished but state is not DONE: %s', + res.state if res else 'Unknown') + + def cancel(self): + if self._future.done(): + self.update_display('Cannot cancel: Computation already finished.') + return False + + self._cancel_requested = True + self._cancel_button.disabled = True + self.update_display('Attempting to cancel...') + + if self._pipeline_result: + try: + # Check pipeline state before cancelling + current_state = self._pipeline_result.state + if PipelineState.is_terminal(current_state): + self.update_display( + 'Cannot cancel: Pipeline already in terminal state' + f' {current_state}.') + return False + + self._pipeline_result.cancel() + self.update_display('Cancel signal sent to pipeline.') + # The future will be cancelled by the runner if successful + return True + except Exception as e: + self.update_display('Error sending cancel signal: %s', e) + _LOGGER.warning('Error during pipeline cancel(): %s', e, exc_info=e) + # Still try to cancel the future as a fallback + return self._future.cancel() + else: + self.update_display('Pipeline not yet fully started, cancelling future.') + return self._future.cancel() + + def __repr__(self): + return ( + f'Running Test
') + + # State: Done Success + self.mock_future.done.return_value = True + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = False + async_res.update_display('Done') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertTrue(mock_btn_instance.disabled) + self.assertEqual(mock_prog_instance.bar_style, 'success') + self.assertEqual(mock_prog_instance.description, 'Done') + + # State: Done Failed + self.mock_future.exception.return_value = Exception() + async_res.update_display('Failed') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertEqual(mock_prog_instance.bar_style, 'danger') + self.assertEqual(mock_prog_instance.description, 'Failed') + + # State: Done Cancelled + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = True + async_res.update_display('Cancelled') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertEqual(mock_prog_instance.bar_style, 'warning') + self.assertEqual(mock_prog_instance.description, 'Cancelled') + + # State: Cancelling + self.mock_future.done.return_value = False + async_res._cancel_requested = True + async_res.update_display('Cancelling') + update_call_count += 1 + self.assertEqual(self.mock_clear_output.call_count, update_call_count) + self.assertTrue(mock_btn_instance.disabled) + self.assertEqual(mock_prog_instance.description, 'Cancelling...') + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_set_pipeline_result_cancel_requested(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + async_res._cancel_requested = True + mock_pipeline_result = MagicMock() + with patch.object(async_res, 'cancel') as mock_cancel: + async_res.set_pipeline_result(mock_pipeline_result) + self.assertIs(async_res._pipeline_result, mock_pipeline_result) + mock_cancel.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + def test_exception_timeout(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.exception.side_effect = TimeoutError + self.assertIsNone(async_res.exception(timeout=0.1)) + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', False) + @patch.object(_LOGGER, 'warning') + def test_on_done_not_done_state(self, mock_logger_warning): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.FAILED + self.mock_future.result.return_value = mock_pipeline_result + self.mock_future.exception.return_value = None + self.mock_future.cancelled.return_value = False + + with patch.object(self.env, + 'mark_pcollection_computed') as mock_mark_computed: + async_res._on_done(self.mock_future) + mock_mark_computed.assert_not_called() + mock_logger_warning.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + def test_cancel_no_pipeline_result(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = False + self.mock_future.cancel.return_value = True + with patch.object(async_res, 'update_display') as mock_update: + self.assertTrue(async_res.cancel()) + mock_update.assert_any_call( + 'Pipeline not yet fully started, cancelling future.') + self.mock_future.cancel.assert_called_once() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + def test_cancel_pipeline_terminal_state(self): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = False + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.DONE + async_res.set_pipeline_result(mock_pipeline_result) + + with patch.object(async_res, 'update_display') as mock_update: + self.assertFalse(async_res.cancel()) + mock_update.assert_any_call( + 'Cannot cancel: Pipeline already in terminal state DONE.') + mock_pipeline_result.cancel.assert_not_called() + + @patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True) + @patch.object(_LOGGER, 'warning') + @patch.object(AsyncComputationResult, 'update_display') + def test_cancel_pipeline_exception( + self, mock_update_display, mock_logger_warning): + async_res = AsyncComputationResult( + self.mock_future, + self.pcolls, + self.user_pipeline, + self.recording_manager, + ) + self.mock_future.done.return_value = False + mock_pipeline_result = MagicMock() + mock_pipeline_result.state = PipelineState.RUNNING + test_exception = RuntimeError('Cancel Failed') + mock_pipeline_result.cancel.side_effect = test_exception + async_res.set_pipeline_result(mock_pipeline_result) + self.mock_future.cancel.return_value = False + + self.assertFalse(async_res.cancel()) + + expected_calls = [ + call('Initializing...'), # From __init__ + call('Attempting to cancel...'), # From cancel() start + call('Error sending cancel signal: %s', + test_exception) # From except block + ] + mock_update_display.assert_has_calls(expected_calls, any_order=False) + + mock_logger_warning.assert_called_once() + self.mock_future.cancel.assert_called_once() + + class MockPipelineResult(beam.runners.runner.PipelineResult): """Mock class for controlling a PipelineResult.""" def __init__(self): @@ -283,6 +667,9 @@ def test_describe(self): cache_manager.size('full', letters_stream.cache_key)) +@unittest.skipIf( + not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') class RecordingManagerTest(unittest.TestCase): def test_basic_execution(self): """A basic pipeline to be used as a smoke test.""" @@ -565,6 +952,119 @@ def test_record_detects_remote_runner( # Reset cache_root value. ib.options.cache_root = None + def test_compute_async_blocking(self): + p = beam.Pipeline(InteractiveRunner()) + pcoll = p | beam.Create([1, 2, 3]) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + + with patch.object(rm, '_execute_pipeline_fragment') as mock_execute: + mock_result = MagicMock() + mock_result.state = PipelineState.DONE + mock_execute.return_value = mock_result + res = rm.compute_async({pcoll}, blocking=True) + self.assertIsNone(res) + mock_execute.assert_called_once() + self.assertTrue(pcoll in ie.current_env().computed_pcollections) + + @patch( + 'apache_beam.runners.interactive.recording_manager.AsyncComputationResult' + ) + @patch( + 'apache_beam.runners.interactive.recording_manager.ThreadPoolExecutor.' + 'submit') + def test_compute_async_non_blocking(self, mock_submit, mock_async_result_cls): + p = beam.Pipeline(InteractiveRunner()) + pcoll = p | beam.Create([1, 2, 3]) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + mock_async_res_instance = mock_async_result_cls.return_value + + # Capture the task + task_submitted = None + + def capture_task(task): + nonlocal task_submitted + task_submitted = task + # Return a mock future + return MagicMock() + + mock_submit.side_effect = capture_task + + with patch.object( + rm, '_wait_for_dependencies', return_value=True + ), patch.object( + rm, '_execute_pipeline_fragment' + ) as _, patch.object( + ie.current_env(), + 'mark_pcollection_computing', + wraps=ie.current_env().mark_pcollection_computing, + ) as wrapped_mark: + + res = rm.compute_async({pcoll}, blocking=False) + wrapped_mark.assert_called_once_with({pcoll}) + + # Run the task to trigger the marks + self.assertIs(res, mock_async_res_instance) + mock_submit.assert_called_once() + self.assertIsNotNone(task_submitted) + + with patch.object( + rm, '_wait_for_dependencies', return_value=True + ), patch.object( + rm, '_execute_pipeline_fragment' + ) as _: + task_submitted() + + self.assertTrue(pcoll in ie.current_env().computing_pcollections) + + def test_get_all_dependencies(self): + p = beam.Pipeline(InteractiveRunner()) + p1 = p | 'C1' >> beam.Create([1]) + p2 = p | 'C2' >> beam.Create([2]) + p3 = p1 | 'M1' >> beam.Map(lambda x: x) + p4 = (p2, p3) | 'F1' >> beam.Flatten() + p5 = p3 | 'M2' >> beam.Map(lambda x: x) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + rm.record_pipeline() # Analyze pipeline + + self.assertEqual(rm._get_all_dependencies({p1}), set()) + self.assertEqual(rm._get_all_dependencies({p3}), {p1}) + self.assertEqual(rm._get_all_dependencies({p4}), {p1, p2, p3}) + self.assertEqual(rm._get_all_dependencies({p5}), {p1, p3}) + self.assertEqual(rm._get_all_dependencies({p4, p5}), {p1, p2, p3}) + + @patch( + 'apache_beam.runners.interactive.recording_manager.AsyncComputationResult' + ) + def test_wait_for_dependencies(self, mock_async_result_cls): + p = beam.Pipeline(InteractiveRunner()) + p1 = p | 'C1' >> beam.Create([1]) + p2 = p1 | 'M1' >> beam.Map(lambda x: x) + ib.watch(locals()) + ie.current_env().track_user_pipelines() + rm = RecordingManager(p) + rm.record_pipeline() + + # Scenario 1: No dependencies computing + self.assertTrue(rm._wait_for_dependencies({p2})) + + # Scenario 2: Dependency is computing + mock_future = MagicMock(spec=Future) + mock_async_res = MagicMock(spec=AsyncComputationResult) + mock_async_res._future = mock_future + mock_async_res._pcolls = {p1} + rm._async_computations['dep_id'] = mock_async_res + ie.current_env().mark_pcollection_computing({p1}) + + self.assertTrue(rm._wait_for_dependencies({p2})) + mock_future.result.assert_called_once() + ie.current_env().unmark_pcollection_computing({p1}) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/utils_test.py b/sdks/python/apache_beam/runners/interactive/utils_test.py index 5fb41df35862..17285ac52af7 100644 --- a/sdks/python/apache_beam/runners/interactive/utils_test.py +++ b/sdks/python/apache_beam/runners/interactive/utils_test.py @@ -244,6 +244,9 @@ def test_child_module_logger_can_override_logging_level(self, mock_emit): reason='[interactive] dependency is not installed.') class ProgressIndicatorTest(unittest.TestCase): def setUp(self): + self.patcher = patch( + 'apache_beam.runners.interactive.cache_manager.CacheManager.cleanup') + self.patcher.start() ie.new_env() @patch('IPython.get_ipython', new_callable=mock_get_ipython) @@ -279,6 +282,9 @@ def test_progress_in_HTML_JS_when_in_notebook( mocked_html.assert_called() mocked_js.assert_called() + def tearDown(self): + self.patcher.stop() + @unittest.skipIf( not ie.current_env().is_interactive_ready, @@ -287,6 +293,9 @@ class MessagingUtilTest(unittest.TestCase): SAMPLE_DATA = {'a': [1, 2, 3], 'b': 4, 'c': '5', 'd': {'e': 'f'}} def setUp(self): + self.patcher = patch( + 'apache_beam.runners.interactive.cache_manager.CacheManager.cleanup') + self.patcher.start() ie.new_env() def test_as_json_decorator(self): @@ -298,6 +307,9 @@ def dummy(): # dictionaries remember the order of items inserted. self.assertEqual(json.loads(dummy()), MessagingUtilTest.SAMPLE_DATA) + def tearDown(self): + self.patcher.stop() + class GeneralUtilTest(unittest.TestCase): def test_pcoll_by_name(self):