Skip to content

Commit dd386ea

Browse files
committed
Fix more linting errors, improve test coverage, make all unit tests pass
1 parent 741823e commit dd386ea

File tree

3 files changed

+471
-35
lines changed

3 files changed

+471
-35
lines changed

sdks/python/apache_beam/runners/interactive/interactive_beam_test.py

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
import time
2525
import unittest
2626
from typing import NamedTuple
27-
from unittest.mock import patch
27+
from unittest.mock import patch, MagicMock, ANY, call
2828
from concurrent.futures import TimeoutError
2929

3030
import apache_beam as beam
3131
from apache_beam import dataframe as frames
32+
from apache_beam.dataframe.frame_base import DeferredBase
3233
from apache_beam.options.pipeline_options import FlinkRunnerOptions
3334
from apache_beam.options.pipeline_options import PipelineOptions
3435
from apache_beam.runners.interactive import interactive_beam as ib
@@ -775,6 +776,8 @@ def raise_error(elem):
775776

776777
@patch('apache_beam.runners.interactive.recording_manager.IS_IPYTHON', True)
777778
@patch('apache_beam.runners.interactive.recording_manager.display')
779+
@patch('apache_beam.runners.interactive.recording_manager.clear_output')
780+
@patch('apache_beam.runners.interactive.recording_manager.HTML')
778781
@patch('ipywidgets.Button')
779782
@patch('ipywidgets.FloatProgress')
780783
@patch('ipywidgets.Output')
@@ -787,6 +790,8 @@ def test_compute_non_blocking_ipython_widgets(
787790
mock_output,
788791
mock_progress,
789792
mock_button,
793+
mock_html,
794+
mock_clear_output,
790795
mock_display,
791796
):
792797
self.env._is_in_ipython = True
@@ -795,14 +800,25 @@ def test_compute_non_blocking_ipython_widgets(
795800
ib.watch(locals())
796801
self.env.track_user_pipelines()
797802

803+
mock_controls = mock_vbox.return_value
804+
mock_html_instance = mock_html.return_value
805+
798806
async_result = ib.compute(pcoll, blocking=False)
799807
self.assertIsNotNone(async_result)
800808
mock_button.assert_called_once_with(description='Cancel')
801809
mock_progress.assert_called_once()
802810
mock_output.assert_called_once()
803811
mock_hbox.assert_called_once()
804812
mock_vbox.assert_called_once()
805-
mock_display.assert_called_once()
813+
mock_html.assert_called_once_with('<p>Initializing...</p>')
814+
815+
self.assertEqual(mock_display.call_count, 2)
816+
mock_display.assert_has_calls([
817+
call(mock_controls, display_id=async_result._display_id),
818+
call(mock_html_instance)
819+
])
820+
821+
mock_clear_output.assert_called_once()
806822
async_result.result(timeout=60) # Let it finish
807823

808824
def test_compute_dependency_wait_true(self):
@@ -885,14 +901,21 @@ def test_async_computation_result_cancel(self):
885901
with self.assertRaises(TimeoutError):
886902
async_result.result(timeout=1) # It should not complete successfully
887903

888-
def test_compute_multiple_async(self):
904+
@patch(
905+
'apache_beam.runners.interactive.recording_manager.RecordingManager.'
906+
'_execute_pipeline_fragment')
907+
def test_compute_multiple_async(self, mock_execute_fragment):
889908
p = beam.Pipeline(ir.InteractiveRunner())
890909
pcoll1 = p | 'Create1' >> beam.Create([1, 2, 3])
891910
pcoll2 = p | 'Create2' >> beam.Create([4, 5, 6])
892911
pcoll3 = pcoll1 | 'Map1' >> beam.Map(lambda x: x * 2)
893912
ib.watch(locals())
894913
self.env.track_user_pipelines()
895914

915+
mock_pipeline_result = MagicMock()
916+
mock_pipeline_result.state = PipelineState.DONE
917+
mock_execute_fragment.return_value = mock_pipeline_result
918+
896919
res1 = ib.compute(pcoll1, blocking=False)
897920
res2 = ib.compute(pcoll2, blocking=False)
898921
res3 = ib.compute(pcoll3, blocking=False) # Depends on pcoll1
@@ -905,9 +928,129 @@ def test_compute_multiple_async(self):
905928
res2.result(timeout=60)
906929
res3.result(timeout=60)
907930

908-
self.assertTrue(pcoll1 in self.env.computed_pcollections)
909-
self.assertTrue(pcoll2 in self.env.computed_pcollections)
910-
self.assertTrue(pcoll3 in self.env.computed_pcollections)
931+
time.sleep(0.1)
932+
933+
self.assertTrue(
934+
pcoll1 in self.env.computed_pcollections, "pcoll1 not marked computed")
935+
self.assertTrue(
936+
pcoll2 in self.env.computed_pcollections, "pcoll2 not marked computed")
937+
self.assertTrue(
938+
pcoll3 in self.env.computed_pcollections, "pcoll3 not marked computed")
939+
940+
self.assertEqual(mock_execute_fragment.call_count, 3)
941+
942+
@patch(
943+
'apache_beam.runners.interactive.interactive_beam.'
944+
'deferred_df_to_pcollection')
945+
def test_compute_input_flattening(self, mock_deferred_to_pcoll):
946+
p = beam.Pipeline(ir.InteractiveRunner())
947+
pcoll1 = p | 'C1' >> beam.Create([1])
948+
pcoll2 = p | 'C2' >> beam.Create([2])
949+
pcoll3 = p | 'C3' >> beam.Create([3])
950+
pcoll4 = p | 'C4' >> beam.Create([4])
951+
952+
class MockDeferred(DeferredBase):
953+
def __init__(self, pcoll):
954+
mock_expr = MagicMock()
955+
super().__init__(mock_expr)
956+
self._pcoll = pcoll
957+
958+
def _get_underlying_pcollection(self):
959+
return self._pcoll
960+
961+
deferred_pcoll = MockDeferred(pcoll4)
962+
963+
mock_deferred_to_pcoll.return_value = (pcoll4, p)
964+
965+
ib.watch(locals())
966+
self.env.track_user_pipelines()
967+
968+
with patch.object(self.env, 'get_recording_manager') as mock_get_rm:
969+
mock_rm = MagicMock()
970+
mock_get_rm.return_value = mock_rm
971+
ib.compute(pcoll1, [pcoll2], {'a': pcoll3}, deferred_pcoll)
972+
973+
expected_pcolls = {pcoll1, pcoll2, pcoll3, pcoll4}
974+
mock_rm.compute_async.assert_called_once_with(
975+
expected_pcolls,
976+
wait_for_inputs=True,
977+
blocking=False,
978+
runner=None,
979+
options=None,
980+
force_compute=False)
981+
982+
def test_compute_invalid_input_type(self):
983+
with self.assertRaisesRegex(ValueError,
984+
"not a dict, an iterable or a PCollection"):
985+
ib.compute(123)
986+
987+
def test_compute_mixed_pipelines(self):
988+
p1 = beam.Pipeline(ir.InteractiveRunner())
989+
pcoll1 = p1 | 'C1' >> beam.Create([1])
990+
p2 = beam.Pipeline(ir.InteractiveRunner())
991+
pcoll2 = p2 | 'C2' >> beam.Create([2])
992+
ib.watch(locals())
993+
self.env.track_user_pipelines()
994+
995+
with self.assertRaisesRegex(
996+
ValueError, "All PCollections must belong to the same pipeline"):
997+
ib.compute(pcoll1, pcoll2)
998+
999+
@patch(
1000+
'apache_beam.runners.interactive.interactive_beam.'
1001+
'deferred_df_to_pcollection')
1002+
@patch.object(ib, 'watch')
1003+
def test_compute_with_deferred_base(self, mock_watch, mock_deferred_to_pcoll):
1004+
p = beam.Pipeline(ir.InteractiveRunner())
1005+
pcoll = p | 'C1' >> beam.Create([1])
1006+
1007+
class MockDeferred(DeferredBase):
1008+
def __init__(self, pcoll):
1009+
# Provide a dummy expression to satisfy DeferredBase.__init__
1010+
mock_expr = MagicMock()
1011+
super().__init__(mock_expr)
1012+
self._pcoll = pcoll
1013+
1014+
def _get_underlying_pcollection(self):
1015+
return self._pcoll
1016+
1017+
deferred = MockDeferred(pcoll)
1018+
1019+
mock_deferred_to_pcoll.return_value = (pcoll, p)
1020+
1021+
with patch.object(self.env, 'get_recording_manager') as mock_get_rm:
1022+
mock_rm = MagicMock()
1023+
mock_get_rm.return_value = mock_rm
1024+
ib.compute(deferred)
1025+
1026+
mock_deferred_to_pcoll.assert_called_once_with(deferred)
1027+
self.assertEqual(mock_watch.call_count, 2)
1028+
mock_watch.assert_has_calls([
1029+
call({f'anonymous_pcollection_{id(pcoll)}': pcoll}),
1030+
call({f'anonymous_pipeline_{id(p)}': p})
1031+
],
1032+
any_order=False)
1033+
mock_rm.compute_async.assert_called_once_with({pcoll},
1034+
wait_for_inputs=True,
1035+
blocking=False,
1036+
runner=None,
1037+
options=None,
1038+
force_compute=False)
1039+
1040+
def test_compute_new_pipeline(self):
1041+
p = beam.Pipeline(ir.InteractiveRunner())
1042+
pcoll = p | 'Create' >> beam.Create([1])
1043+
# NOT calling ib.watch() or track_user_pipelines()
1044+
1045+
with patch.object(self.env, 'get_recording_manager') as mock_get_rm, \
1046+
patch.object(ib, 'watch') as mock_watch:
1047+
mock_rm = MagicMock()
1048+
mock_get_rm.return_value = mock_rm
1049+
ib.compute(pcoll)
1050+
1051+
mock_watch.assert_called_with({f'anonymous_pipeline_{id(p)}': p})
1052+
mock_get_rm.assert_called_once_with(p, create_if_absent=True)
1053+
mock_rm.compute_async.assert_called_once()
9111054

9121055

9131056
if __name__ == '__main__':

sdks/python/apache_beam/runners/interactive/recording_manager.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -163,16 +163,16 @@ def _on_done(self, future: Future):
163163
exc = future.exception()
164164
if exc:
165165
self.update_display(f'Error: {exc}', 1.0)
166-
_LOGGER.error(f'Asynchronous computation failed: {exc}', exc_info=exc)
166+
_LOGGER.error('Asynchronous computation failed: %s', exc, exc_info=exc)
167167
else:
168168
self.update_display('Computation Finished Successfully.', 1.0)
169169
res = future.result()
170170
if res and res.state == PipelineState.DONE:
171171
self._env.mark_pcollection_computed(self._pcolls)
172172
else:
173173
_LOGGER.warning(
174-
'Async computation finished but state is not DONE:'
175-
f" {res.state if res else 'Unknown'}")
174+
'Async computation finished but state is not DONE: %s',
175+
res.state if res else 'Unknown')
176176

177177
def cancel(self):
178178
if self._future.done():
@@ -198,8 +198,8 @@ def cancel(self):
198198
# The future will be cancelled by the runner if successful
199199
return True
200200
except Exception as e:
201-
self.update_display(f'Error sending cancel signal: {e}')
202-
_LOGGER.warning(f'Error during pipeline cancel(): {e}', exc_info=e)
201+
self.update_display('Error sending cancel signal: %s', e)
202+
_LOGGER.warning('Error during pipeline cancel(): %s', e, exc_info=e)
203203
# Still try to cancel the future as a fallback
204204
return self._future.cancel()
205205
else:
@@ -495,8 +495,8 @@ def _run_async_computation(
495495

496496
self._env.mark_pcollection_computing(pcolls_to_compute)
497497
_LOGGER.info(
498-
'Starting asynchronous computation for'
499-
f' {len(pcolls_to_compute)} PCollections.')
498+
'Starting asynchronous computation for %d PCollections.',
499+
len(pcolls_to_compute))
500500

501501
pipeline_result = self._execute_pipeline_fragment(
502502
pcolls_to_compute, async_result, runner, options)
@@ -515,7 +515,7 @@ def _run_async_computation(
515515
# )
516516
return pipeline_result
517517
except Exception as e:
518-
_LOGGER.exception(f'Exception during asynchronous computation: {e}')
518+
_LOGGER.exception('Exception during asynchronous computation: %s', e)
519519
raise
520520
# finally:
521521
# self._env.unmark_pcollection_computing(pcolls_to_compute)
@@ -685,9 +685,9 @@ def compute_async(
685685
self._env.mark_pcollection_computed(pcolls_to_compute)
686686
else:
687687
_LOGGER.error(
688-
f'Blocking computation failed. State: {pipeline_result.state}')
688+
'Blocking computation failed. State: %s', pipeline_result.state)
689689
raise RuntimeError(
690-
f'Blocking computation failed. State: {pipeline_result.state}')
690+
'Blocking computation failed. State: %s', pipeline_result.state)
691691
finally:
692692
self._env.unmark_pcollection_computing(pcolls_to_compute)
693693
return None
@@ -715,15 +715,13 @@ def _get_pcoll_id_map(self):
715715
pcoll_to_id = {}
716716
if self._pipeline_graph._pipeline_instrument:
717717
pcoll_to_id = self._pipeline_graph._pipeline_instrument._pcoll_to_pcoll_id
718-
else:
719-
# Fallback for proto-based PipelineGraph, though less likely in this context
720-
proto = self._pipeline_graph._pipeline_proto
721718
return {v: k for k, v in pcoll_to_id.items()}
722719

723720
def _get_all_dependencies(
724721
self,
725722
pcolls: Set[beam.pvalue.PCollection]) -> Set[beam.pvalue.PCollection]:
726-
"""Gets all upstream PCollection dependencies for the given set of PCollections."""
723+
"""Gets all upstream PCollection dependencies
724+
for the given set of PCollections."""
727725
if not self._pipeline_graph:
728726
return set()
729727

@@ -732,7 +730,6 @@ def _get_all_dependencies(
732730
return set()
733731

734732
pcoll_to_id = analyzer._pcoll_to_pcoll_id
735-
id_to_pcoll_str = {v: k for k, v in pcoll_to_id.items()}
736733

737734
target_pcoll_ids = {
738735
pcoll_to_id.get(str(pcoll))
@@ -784,7 +781,8 @@ def _wait_for_dependencies(
784781
pcolls: Set[beam.pvalue.PCollection],
785782
async_result: Optional[AsyncComputationResult] = None,
786783
) -> bool:
787-
"""Waits for any dependencies of the given PCollections that are currently being computed."""
784+
"""Waits for any dependencies of the given
785+
PCollections that are currently being computed."""
788786
dependencies = self._get_all_dependencies(pcolls)
789787
computing_deps: Dict[beam.pvalue.PCollection, AsyncComputationResult] = {}
790788

@@ -800,10 +798,11 @@ def _wait_for_dependencies(
800798

801799
if async_result:
802800
async_result.update_display(
803-
f'Waiting for {len(computing_deps)} dependencies to finish...')
801+
'Waiting for %d dependencies to finish...', len(computing_deps))
804802
_LOGGER.info(
805-
f'Waiting for {len(computing_deps)} dependencies:'
806-
f' {computing_deps.keys()}')
803+
'Waiting for %d dependencies: %s',
804+
len(computing_deps),
805+
computing_deps.keys())
807806

808807
futures_to_wait = list(
809808
set(comp._future for comp in computing_deps.values()))
@@ -823,7 +822,7 @@ def _wait_for_dependencies(
823822
except Exception as e:
824823
if async_result:
825824
async_result.update_display(f'Dependency failed: {e}')
826-
_LOGGER.error(f'Dependency computation failed: {e}', exc_info=e)
825+
_LOGGER.error('Dependency computation failed: %s', e, exc_info=e)
827826
return False
828827

829828
def record(

0 commit comments

Comments
 (0)