Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 90 additions & 9 deletions sdks/python/apache_beam/runners/interactive/interactive_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@
# pytype: skip-file

import logging
from collections.abc import Iterable
from datetime import timedelta
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union

Expand All @@ -57,6 +55,7 @@
from apache_beam.runners.interactive.display.pcoll_visualization import visualize
from apache_beam.runners.interactive.display.pcoll_visualization import visualize_computed_pcoll
from apache_beam.runners.interactive.options import interactive_options
from apache_beam.runners.interactive.recording_manager import AsyncComputationResult
from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
from apache_beam.runners.interactive.utils import elements_to_df
from apache_beam.runners.interactive.utils import find_pcoll_name
Expand Down Expand Up @@ -275,7 +274,7 @@ class Recordings():
"""
def describe(
self,
pipeline: Optional[beam.Pipeline] = None) -> Dict[str, Any]: # noqa: F821
pipeline: Optional[beam.Pipeline] = None) -> dict[str, Any]: # noqa: F821
"""Returns a description of all the recordings for the given pipeline.

If no pipeline is given then this returns a dictionary of descriptions for
Expand Down Expand Up @@ -417,10 +416,10 @@ class Clusters:
# DATAPROC_IMAGE_VERSION = '2.0.XX-debian10'

def __init__(self) -> None:
self.dataproc_cluster_managers: Dict[ClusterMetadata,
self.dataproc_cluster_managers: dict[ClusterMetadata,
DataprocClusterManager] = {}
self.master_urls: Dict[str, ClusterMetadata] = {}
self.pipelines: Dict[beam.Pipeline, DataprocClusterManager] = {}
self.master_urls: dict[str, ClusterMetadata] = {}
self.pipelines: dict[beam.Pipeline, DataprocClusterManager] = {}
self.default_cluster_metadata: Optional[ClusterMetadata] = None

def create(
Expand Down Expand Up @@ -511,7 +510,7 @@ def cleanup(
def describe(
self,
cluster_identifier: Optional[ClusterIdentifier] = None
) -> Union[ClusterMetadata, List[ClusterMetadata]]:
) -> Union[ClusterMetadata, list[ClusterMetadata]]:
"""Describes the ClusterMetadata by a ClusterIdentifier.

If no cluster_identifier is given or if the cluster_identifier is unknown,
Expand Down Expand Up @@ -679,7 +678,7 @@ def run_pipeline(self):

@progress_indicated
def show(
*pcolls: Union[Dict[Any, PCollection], Iterable[PCollection], PCollection],
*pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection],
include_window_info: bool = False,
visualize_data: bool = False,
n: Union[int, str] = 'inf',
Expand Down Expand Up @@ -1012,6 +1011,88 @@ def as_pcollection(pcoll_or_df):
return result_tuple


@progress_indicated
def compute(
*pcolls: Union[dict[Any, PCollection], Iterable[PCollection], PCollection],
wait_for_inputs: bool = True,
blocking: bool = False,
runner=None,
options=None,
force_compute=False,
) -> Optional[AsyncComputationResult]:
"""Computes the given PCollections, potentially asynchronously.

Args:
*pcolls: PCollections to compute. Can be a single PCollection, an iterable
of PCollections, or a dictionary with PCollections as values.
wait_for_inputs: Whether to wait until the asynchronous dependencies are
computed. Setting this to False allows to immediately schedule the
computation, but also potentially results in running the same pipeline
stages multiple times.
blocking: If False, the computation will run in non-blocking fashion. In
Colab/IPython environment this mode will also provide the controls for the
running pipeline. If True, the computation will block until the pipeline
is done.
runner: (optional) the runner with which to compute the results.
options: (optional) any additional pipeline options to use to compute the
results.
force_compute: (optional) if True, forces recomputation rather than using
cached PCollections.

Returns:
An AsyncComputationResult object if blocking is False, otherwise None.
"""
flatten_pcolls = []
for pcoll_container in pcolls:
if isinstance(pcoll_container, dict):
flatten_pcolls.extend(pcoll_container.values())
elif isinstance(pcoll_container, (beam.pvalue.PCollection, DeferredBase)):
flatten_pcolls.append(pcoll_container)
else:
try:
flatten_pcolls.extend(iter(pcoll_container))
except TypeError:
raise ValueError(
f'The given pcoll {pcoll_container} is not a dict, an iterable or '
'a PCollection.')

pcolls_set = set()
for pcoll in flatten_pcolls:
if isinstance(pcoll, DeferredBase):
pcoll, _ = deferred_df_to_pcollection(pcoll)
watch({f'anonymous_pcollection_{id(pcoll)}': pcoll})
assert isinstance(
pcoll, beam.pvalue.PCollection
), f'{pcoll} is not an apache_beam.pvalue.PCollection.'
pcolls_set.add(pcoll)

if not pcolls_set:
_LOGGER.info('No PCollections to compute.')
return None

pcoll_pipeline = next(iter(pcolls_set)).pipeline
user_pipeline = ie.current_env().user_pipeline(pcoll_pipeline)
if not user_pipeline:
watch({f'anonymous_pipeline_{id(pcoll_pipeline)}': pcoll_pipeline})
user_pipeline = pcoll_pipeline

for pcoll in pcolls_set:
if pcoll.pipeline is not user_pipeline:
raise ValueError('All PCollections must belong to the same pipeline.')

recording_manager = ie.current_env().get_recording_manager(
user_pipeline, create_if_absent=True)

return recording_manager.compute_async(
pcolls_set,
wait_for_inputs=wait_for_inputs,
blocking=blocking,
runner=runner,
options=options,
force_compute=force_compute,
)


@progress_indicated
def show_graph(pipeline):
"""Shows the current pipeline shape of a given Beam pipeline as a DAG.
Expand Down
Loading
Loading