Skip to content

feat: support series input in managed function #1920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bigframes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
"""Remote Function and Managed UDF with axis=1 preview."""


class FunctionPackageVersionWarning(PreviewWarning):
"""
Managed UDF package versions for Numpy, Pandas, and Pyarrow may not
precisely match users' local environment or the exact versions specified.
"""


def format_message(message: str, fill: bool = True):
"""Formats a warning message with ANSI color codes for the warning color.
Expand Down
36 changes: 9 additions & 27 deletions bigframes/functions/_function_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import os
import random
import re
import shutil
import string
import tempfile
Expand Down Expand Up @@ -247,7 +246,7 @@ def provision_bq_managed_function(
# Augment user package requirements with any internal package
# requirements.
packages = _utils._get_updated_package_requirements(
packages, is_row_processor, capture_references
packages, is_row_processor, capture_references, ignore_package_version=True
)
if packages:
managed_function_options["packages"] = packages
Expand All @@ -270,28 +269,6 @@ def provision_bq_managed_function(
)

udf_name = func.__name__
if capture_references:
# This code path ensures that if the udf body contains any
# references to variables and/or imports outside the body, they are
# captured as well.
import cloudpickle

pickled = cloudpickle.dumps(func)
udf_code = textwrap.dedent(
f"""
import cloudpickle
{udf_name} = cloudpickle.loads({pickled})
"""
)
else:
# This code path ensures that if the udf body is self contained,
# i.e. there are no references to variables or imports outside the
# body.
udf_code = textwrap.dedent(inspect.getsource(func))
match = re.search(r"^def ", udf_code, flags=re.MULTILINE)
if match is None:
raise ValueError("The UDF is not defined correctly.")
udf_code = udf_code[match.start() :]

with_connection_clause = (
(
Expand All @@ -301,6 +278,13 @@ def provision_bq_managed_function(
else ""
)

# Generate the complete Python code block for the managed Python UDF,
# including the user's function, necessary imports, and the BigQuery
# handler wrapper.
python_code_block = bff_template.generate_managed_function_code(
func, udf_name, is_row_processor, capture_references
)

create_function_ddl = (
textwrap.dedent(
f"""
Expand All @@ -311,13 +295,11 @@ def provision_bq_managed_function(
OPTIONS ({managed_function_options_str})
AS r'''
__UDF_PLACE_HOLDER__
def bigframes_handler(*args):
return {udf_name}(*args)
'''
"""
)
.strip()
.replace("__UDF_PLACE_HOLDER__", udf_code)
.replace("__UDF_PLACE_HOLDER__", python_code_block)
)

self._ensure_dataset_exists()
Expand Down
4 changes: 2 additions & 2 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,15 +847,15 @@ def wrapper(func):
if output_type:
py_sig = py_sig.replace(return_annotation=output_type)

udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)

# The function will actually be receiving a pandas Series, but allow
# both BigQuery DataFrames and pandas object types for compatibility.
is_row_processor = False
if new_sig := _convert_row_processor_sig(py_sig):
py_sig = new_sig
is_row_processor = True

udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)

managed_function_client = _function_client.FunctionClient(
dataset_ref.project,
bq_location,
Expand Down
34 changes: 26 additions & 8 deletions bigframes/functions/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys
import typing
from typing import cast, Optional, Set
import warnings

import cloudpickle
import google.api_core.exceptions
Expand All @@ -26,6 +27,7 @@
import pandas
import pyarrow

import bigframes.exceptions as bfe
import bigframes.formatting_helpers as bf_formatting
from bigframes.functions import function_typing

Expand Down Expand Up @@ -61,20 +63,36 @@ def get_remote_function_locations(bq_location):


def _get_updated_package_requirements(
package_requirements=None, is_row_processor=False, capture_references=True
package_requirements=None,
is_row_processor=False,
capture_references=True,
ignore_package_version=False,
):
requirements = []
if capture_references:
requirements.append(f"cloudpickle=={cloudpickle.__version__}")

if is_row_processor:
# bigframes function will send an entire row of data as json, which
# would be converted to a pandas series and processed Ensure numpy
# versions match to avoid unpickling problems. See internal issue
# b/347934471.
requirements.append(f"numpy=={numpy.__version__}")
requirements.append(f"pandas=={pandas.__version__}")
requirements.append(f"pyarrow=={pyarrow.__version__}")
if ignore_package_version:
# TODO(jialuo): Add back the version after b/410924784 is resolved.
# Due to current limitations on the packages version in Python UDFs,
# we use `ignore_package_version` to optionally omit the version for
# managed functions only.
msg = bfe.format_message(
"Numpy, Pandas, and Pyarrow version may not precisely match your local environment."
)
warnings.warn(msg, category=bfe.PreviewWarning)
requirements.append("pandas")
requirements.append("pyarrow")
requirements.append("numpy")
else:
# bigframes function will send an entire row of data as json, which
# would be converted to a pandas series and processed Ensure numpy
# versions match to avoid unpickling problems. See internal issue
# b/347934471.
requirements.append(f"pandas=={pandas.__version__}")
requirements.append(f"pyarrow=={pyarrow.__version__}")
requirements.append(f"numpy=={numpy.__version__}")

if package_requirements:
requirements.extend(package_requirements)
Expand Down
53 changes: 53 additions & 0 deletions bigframes/functions/function_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import logging
import os
import re
import textwrap
from typing import Tuple

Expand Down Expand Up @@ -291,3 +292,55 @@ def generate_cloud_function_main_code(
logger.debug(f"Wrote {os.path.abspath(main_py)}:\n{open(main_py).read()}")

return handler_func_name


def generate_managed_function_code(
def_,
udf_name: str,
is_row_processor: bool,
capture_references: bool,
) -> str:
"""Generates the Python code block for managed Python UDF."""

if capture_references:
# This code path ensures that if the udf body contains any
# references to variables and/or imports outside the body, they are
# captured as well.
import cloudpickle

pickled = cloudpickle.dumps(def_)
func_code = textwrap.dedent(
f"""
import cloudpickle
{udf_name} = cloudpickle.loads({pickled})
"""
)
else:
# This code path ensures that if the udf body is self contained,
# i.e. there are no references to variables or imports outside the
# body.
func_code = textwrap.dedent(inspect.getsource(def_))
match = re.search(r"^def ", func_code, flags=re.MULTILINE)
if match is None:
raise ValueError("The UDF is not defined correctly.")
func_code = func_code[match.start() :]

if is_row_processor:
udf_code = textwrap.dedent(inspect.getsource(get_pd_series))
udf_code = udf_code[udf_code.index("def") :]
bigframes_handler_code = textwrap.dedent(
f"""def bigframes_handler(str_arg):
return {udf_name}({get_pd_series.__name__}(str_arg))"""
)
else:
udf_code = ""
bigframes_handler_code = textwrap.dedent(
f"""def bigframes_handler(*args):
return {udf_name}(*args)"""
)

udf_code_block = textwrap.dedent(
f"{udf_code}\n{func_code}\n{bigframes_handler_code}"
)

return udf_code_block
Loading