2424import time
2525import unittest
2626from typing import NamedTuple
27- from unittest .mock import patch
27+ from unittest .mock import patch , MagicMock , ANY , call
2828from concurrent .futures import TimeoutError
2929
3030import apache_beam as beam
3131from apache_beam import dataframe as frames
32+ from apache_beam .dataframe .frame_base import DeferredBase
3233from apache_beam .options .pipeline_options import FlinkRunnerOptions
3334from apache_beam .options .pipeline_options import PipelineOptions
3435from apache_beam .runners .interactive import interactive_beam as ib
@@ -775,6 +776,8 @@ def raise_error(elem):
775776
776777 @patch ('apache_beam.runners.interactive.recording_manager.IS_IPYTHON' , True )
777778 @patch ('apache_beam.runners.interactive.recording_manager.display' )
779+ @patch ('apache_beam.runners.interactive.recording_manager.clear_output' )
780+ @patch ('apache_beam.runners.interactive.recording_manager.HTML' )
778781 @patch ('ipywidgets.Button' )
779782 @patch ('ipywidgets.FloatProgress' )
780783 @patch ('ipywidgets.Output' )
@@ -787,6 +790,8 @@ def test_compute_non_blocking_ipython_widgets(
787790 mock_output ,
788791 mock_progress ,
789792 mock_button ,
793+ mock_html ,
794+ mock_clear_output ,
790795 mock_display ,
791796 ):
792797 self .env ._is_in_ipython = True
@@ -795,14 +800,25 @@ def test_compute_non_blocking_ipython_widgets(
795800 ib .watch (locals ())
796801 self .env .track_user_pipelines ()
797802
803+ mock_controls = mock_vbox .return_value
804+ mock_html_instance = mock_html .return_value
805+
798806 async_result = ib .compute (pcoll , blocking = False )
799807 self .assertIsNotNone (async_result )
800808 mock_button .assert_called_once_with (description = 'Cancel' )
801809 mock_progress .assert_called_once ()
802810 mock_output .assert_called_once ()
803811 mock_hbox .assert_called_once ()
804812 mock_vbox .assert_called_once ()
805- mock_display .assert_called_once ()
813+ mock_html .assert_called_once_with ('<p>Initializing...</p>' )
814+
815+ self .assertEqual (mock_display .call_count , 2 )
816+ mock_display .assert_has_calls ([
817+ call (mock_controls , display_id = async_result ._display_id ),
818+ call (mock_html_instance )
819+ ])
820+
821+ mock_clear_output .assert_called_once ()
806822 async_result .result (timeout = 60 ) # Let it finish
807823
808824 def test_compute_dependency_wait_true (self ):
@@ -885,14 +901,21 @@ def test_async_computation_result_cancel(self):
885901 with self .assertRaises (TimeoutError ):
886902 async_result .result (timeout = 1 ) # It should not complete successfully
887903
888- def test_compute_multiple_async (self ):
904+ @patch (
905+ 'apache_beam.runners.interactive.recording_manager.RecordingManager.'
906+ '_execute_pipeline_fragment' )
907+ def test_compute_multiple_async (self , mock_execute_fragment ):
889908 p = beam .Pipeline (ir .InteractiveRunner ())
890909 pcoll1 = p | 'Create1' >> beam .Create ([1 , 2 , 3 ])
891910 pcoll2 = p | 'Create2' >> beam .Create ([4 , 5 , 6 ])
892911 pcoll3 = pcoll1 | 'Map1' >> beam .Map (lambda x : x * 2 )
893912 ib .watch (locals ())
894913 self .env .track_user_pipelines ()
895914
915+ mock_pipeline_result = MagicMock ()
916+ mock_pipeline_result .state = PipelineState .DONE
917+ mock_execute_fragment .return_value = mock_pipeline_result
918+
896919 res1 = ib .compute (pcoll1 , blocking = False )
897920 res2 = ib .compute (pcoll2 , blocking = False )
898921 res3 = ib .compute (pcoll3 , blocking = False ) # Depends on pcoll1
@@ -905,9 +928,129 @@ def test_compute_multiple_async(self):
905928 res2 .result (timeout = 60 )
906929 res3 .result (timeout = 60 )
907930
908- self .assertTrue (pcoll1 in self .env .computed_pcollections )
909- self .assertTrue (pcoll2 in self .env .computed_pcollections )
910- self .assertTrue (pcoll3 in self .env .computed_pcollections )
931+ time .sleep (0.1 )
932+
933+ self .assertTrue (
934+ pcoll1 in self .env .computed_pcollections , "pcoll1 not marked computed" )
935+ self .assertTrue (
936+ pcoll2 in self .env .computed_pcollections , "pcoll2 not marked computed" )
937+ self .assertTrue (
938+ pcoll3 in self .env .computed_pcollections , "pcoll3 not marked computed" )
939+
940+ self .assertEqual (mock_execute_fragment .call_count , 3 )
941+
942+ @patch (
943+ 'apache_beam.runners.interactive.interactive_beam.'
944+ 'deferred_df_to_pcollection' )
945+ def test_compute_input_flattening (self , mock_deferred_to_pcoll ):
946+ p = beam .Pipeline (ir .InteractiveRunner ())
947+ pcoll1 = p | 'C1' >> beam .Create ([1 ])
948+ pcoll2 = p | 'C2' >> beam .Create ([2 ])
949+ pcoll3 = p | 'C3' >> beam .Create ([3 ])
950+ pcoll4 = p | 'C4' >> beam .Create ([4 ])
951+
952+ class MockDeferred (DeferredBase ):
953+ def __init__ (self , pcoll ):
954+ mock_expr = MagicMock ()
955+ super ().__init__ (mock_expr )
956+ self ._pcoll = pcoll
957+
958+ def _get_underlying_pcollection (self ):
959+ return self ._pcoll
960+
961+ deferred_pcoll = MockDeferred (pcoll4 )
962+
963+ mock_deferred_to_pcoll .return_value = (pcoll4 , p )
964+
965+ ib .watch (locals ())
966+ self .env .track_user_pipelines ()
967+
968+ with patch .object (self .env , 'get_recording_manager' ) as mock_get_rm :
969+ mock_rm = MagicMock ()
970+ mock_get_rm .return_value = mock_rm
971+ ib .compute (pcoll1 , [pcoll2 ], {'a' : pcoll3 }, deferred_pcoll )
972+
973+ expected_pcolls = {pcoll1 , pcoll2 , pcoll3 , pcoll4 }
974+ mock_rm .compute_async .assert_called_once_with (
975+ expected_pcolls ,
976+ wait_for_inputs = True ,
977+ blocking = False ,
978+ runner = None ,
979+ options = None ,
980+ force_compute = False )
981+
982+ def test_compute_invalid_input_type (self ):
983+ with self .assertRaisesRegex (ValueError ,
984+ "not a dict, an iterable or a PCollection" ):
985+ ib .compute (123 )
986+
987+ def test_compute_mixed_pipelines (self ):
988+ p1 = beam .Pipeline (ir .InteractiveRunner ())
989+ pcoll1 = p1 | 'C1' >> beam .Create ([1 ])
990+ p2 = beam .Pipeline (ir .InteractiveRunner ())
991+ pcoll2 = p2 | 'C2' >> beam .Create ([2 ])
992+ ib .watch (locals ())
993+ self .env .track_user_pipelines ()
994+
995+ with self .assertRaisesRegex (
996+ ValueError , "All PCollections must belong to the same pipeline" ):
997+ ib .compute (pcoll1 , pcoll2 )
998+
999+ @patch (
1000+ 'apache_beam.runners.interactive.interactive_beam.'
1001+ 'deferred_df_to_pcollection' )
1002+ @patch .object (ib , 'watch' )
1003+ def test_compute_with_deferred_base (self , mock_watch , mock_deferred_to_pcoll ):
1004+ p = beam .Pipeline (ir .InteractiveRunner ())
1005+ pcoll = p | 'C1' >> beam .Create ([1 ])
1006+
1007+ class MockDeferred (DeferredBase ):
1008+ def __init__ (self , pcoll ):
1009+ # Provide a dummy expression to satisfy DeferredBase.__init__
1010+ mock_expr = MagicMock ()
1011+ super ().__init__ (mock_expr )
1012+ self ._pcoll = pcoll
1013+
1014+ def _get_underlying_pcollection (self ):
1015+ return self ._pcoll
1016+
1017+ deferred = MockDeferred (pcoll )
1018+
1019+ mock_deferred_to_pcoll .return_value = (pcoll , p )
1020+
1021+ with patch .object (self .env , 'get_recording_manager' ) as mock_get_rm :
1022+ mock_rm = MagicMock ()
1023+ mock_get_rm .return_value = mock_rm
1024+ ib .compute (deferred )
1025+
1026+ mock_deferred_to_pcoll .assert_called_once_with (deferred )
1027+ self .assertEqual (mock_watch .call_count , 2 )
1028+ mock_watch .assert_has_calls ([
1029+ call ({f'anonymous_pcollection_{ id (pcoll )} ' : pcoll }),
1030+ call ({f'anonymous_pipeline_{ id (p )} ' : p })
1031+ ],
1032+ any_order = False )
1033+ mock_rm .compute_async .assert_called_once_with ({pcoll },
1034+ wait_for_inputs = True ,
1035+ blocking = False ,
1036+ runner = None ,
1037+ options = None ,
1038+ force_compute = False )
1039+
1040+ def test_compute_new_pipeline (self ):
1041+ p = beam .Pipeline (ir .InteractiveRunner ())
1042+ pcoll = p | 'Create' >> beam .Create ([1 ])
1043+ # NOT calling ib.watch() or track_user_pipelines()
1044+
1045+ with patch .object (self .env , 'get_recording_manager' ) as mock_get_rm , \
1046+ patch .object (ib , 'watch' ) as mock_watch :
1047+ mock_rm = MagicMock ()
1048+ mock_get_rm .return_value = mock_rm
1049+ ib .compute (pcoll )
1050+
1051+ mock_watch .assert_called_with ({f'anonymous_pipeline_{ id (p )} ' : p })
1052+ mock_get_rm .assert_called_once_with (p , create_if_absent = True )
1053+ mock_rm .compute_async .assert_called_once ()
9111054
9121055
9131056if __name__ == '__main__' :
0 commit comments