Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support pandas in BigQuery cache #597

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
if TYPE_CHECKING:
from collections.abc import Iterator

from sqlalchemy.engine import Engine

from airbyte._message_iterators import AirbyteMessageIterator
from airbyte.caches._state_backend_base import StateBackendBase
from airbyte.progress import ProgressTracker
Expand Down Expand Up @@ -66,7 +68,9 @@ class CacheBase(SqlConfig, AirbyteWriterInterface):
paired_destination_config_class: ClassVar[type | None] = None

@property
def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
def paired_destination_config(
self,
) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
"""Return a dictionary of destination configuration values."""
raise NotImplementedError(
f"The type '{type(self).__name__}' does not define an equivalent destination "
Expand Down Expand Up @@ -177,6 +181,14 @@ def get_record_processor(

# Read methods:

def _read_to_pandas_dataframe(
self,
table_name: str,
con: Engine,
**kwargs,
) -> pd.DataFrame:
return pd.read_sql_table(table_name, con=con, **kwargs)

def get_records(
self,
stream_name: str,
Expand All @@ -191,7 +203,11 @@ def get_pandas_dataframe(
"""Return a Pandas data frame with the stream's data."""
table_name = self._read_processor.get_sql_table_name(stream_name)
engine = self.get_sql_engine()
return pd.read_sql_table(table_name, engine, schema=self.schema_name)
return self._read_to_pandas_dataframe(
table_name=table_name,
con=engine,
schema=self.schema_name,
)

def get_arrow_dataset(
self,
Expand All @@ -204,7 +220,7 @@ def get_arrow_dataset(
engine = self.get_sql_engine()

# Read the table in chunks to handle large tables which does not fits in memory
pandas_chunks = pd.read_sql_table(
pandas_chunks = self._read_to_pandas_dataframe(
table_name=table_name,
con=engine,
schema=self.schema_name,
Expand Down
44 changes: 31 additions & 13 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@

from typing import TYPE_CHECKING, ClassVar, NoReturn

import pandas as pd
import pandas_gbq
from airbyte_api.models import DestinationBigquery
from google.oauth2.service_account import Credentials

from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor
from airbyte.caches.base import (
CacheBase,
)
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.destinations._translate_cache_to_dest import (
bigquery_cache_to_destination_configuration,
)


if TYPE_CHECKING:
from collections.abc import Iterator

from airbyte.shared.sql_processor import SqlProcessorBase


Expand All @@ -48,21 +52,35 @@ def paired_destination_config(self) -> DestinationBigquery:
"""Return a dictionary of destination configuration values."""
return bigquery_cache_to_destination_configuration(cache=self)

def get_arrow_dataset(
def _read_to_pandas_dataframe(
self,
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> NoReturn:
"""Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`.

See: https://github.com/airbytehq/PyAirbyte/issues/165
"""
raise NotImplementedError(
"BigQuery doesn't currently support to_arrow"
"Please consider using a different cache implementation for these functionalities."
table_name: str,
chunksize: int | None = None,
**kwargs,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
# Pop unused kwargs, maybe not the best way to do this
kwargs.pop("con", None)
kwargs.pop("schema", None)

# Read the table using pandas_gbq
credentials = Credentials.from_service_account_file(self.credentials_path)
result = pandas_gbq.read_gbq(
f"{self.project_name}.{self.dataset_name}.{table_name}",
project_id=self.project_name,
credentials=credentials,
**kwargs,
Comment on lines +65 to +71
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for credentials loading.

The credentials loading could fail for various reasons (file not found, invalid credentials, etc.). Should we add some error handling here? Maybe something like:

-        credentials = Credentials.from_service_account_file(self.credentials_path)
+        try:
+            credentials = Credentials.from_service_account_file(self.credentials_path)
+        except FileNotFoundError as e:
+            raise ValueError(f"Credentials file not found at {self.credentials_path}") from e
+        except Exception as e:
+            raise ValueError(f"Failed to load credentials: {str(e)}") from e

Wdyt?

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Read the table using pandas_gbq
credentials = Credentials.from_service_account_file(self.credentials_path)
result = pandas_gbq.read_gbq(
f"{self.project_name}.{self.dataset_name}.{table_name}",
project_id=self.project_name,
credentials=credentials,
**kwargs,
# Read the table using pandas_gbq
try:
credentials = Credentials.from_service_account_file(self.credentials_path)
except FileNotFoundError as e:
raise ValueError(f"Credentials file not found at {self.credentials_path}") from e
except Exception as e:
raise ValueError(f"Failed to load credentials: {str(e)}") from e
result = pandas_gbq.read_gbq(
f"{self.project_name}.{self.dataset_name}.{table_name}",
project_id=self.project_name,
credentials=credentials,
**kwargs,

)

# Cast result to DataFrame if it's not already a DataFrame
if not isinstance(result, pd.DataFrame):
result = pd.DataFrame(result)

# Return chunks as iterator if chunksize is provided
if chunksize is not None:
return (result[i : i + chunksize] for i in range(0, len(result), chunksize))

Comment on lines +78 to +81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Revisit chunking performance.

For very large tables, returning chunked slices of the DataFrame might still be memory-intensive as the entire DataFrame is loaded first. Would you consider a chunked read directly from pandas_gbq instead, if available? Wdyt?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas_gbq doesn't support it :(

return result


# Expose the Cache class and also the Config class.
__all__ = [
Expand Down
150 changes: 149 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jsonschema = ">=3.2.0,<5.0"
orjson = "^3.10"
overrides = "^7.4.0"
pandas = { version = ">=1.5.3,<3.0" }
pandas-gbq = ">=0.26.1"
psycopg = {extras = ["binary", "pool"], version = "^3.1.19"}
psycopg2-binary = "^2.9.9"
pyarrow = ">=16.1,<18.0"
Expand Down Expand Up @@ -359,3 +360,4 @@ DEP002 = [
"psycopg2-binary",
"sqlalchemy-bigquery",
]

17 changes: 2 additions & 15 deletions tests/integration_tests/cloud/test_cloud_sql_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,7 @@ def test_read_from_deployed_connection(

dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users")
assert dataset.stream_name == "users"
data_as_list = list(dataset)
assert len(data_as_list) == 100

# TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165
# pandas_df = dataset.to_pandas()

pandas_df = pd.DataFrame(data_as_list)

pandas_df = dataset.to_pandas()
assert pandas_df.shape[0] == 100
assert pandas_df.shape[1] in { # Column count diff depending on when it was created
20,
Expand Down Expand Up @@ -187,14 +180,8 @@ def test_read_from_previous_job(
assert "users" in sync_result.stream_names
dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users")
assert dataset.stream_name == "users"
data_as_list = list(dataset)
assert len(data_as_list) == 100

# TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165
# pandas_df = dataset.to_pandas()

pandas_df = pd.DataFrame(data_as_list)

pandas_df = dataset.to_pandas()
assert pandas_df.shape[0] == 100
assert pandas_df.shape[1] in { # Column count diff depending on when it was created
20,
Expand Down
14 changes: 5 additions & 9 deletions tests/integration_tests/test_all_cache_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,11 @@ def test_faker_read(
assert "Read **0** records" not in status_msg
assert f"Read **{configured_count}** records" in status_msg

if "bigquery" not in new_generic_cache.get_sql_alchemy_url():
# BigQuery doesn't support to_arrow
# https://github.com/airbytehq/PyAirbyte/issues/165
arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10)
assert arrow_dataset.count_rows() == FAKER_SCALE_A
assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10

# TODO: Uncomment this line after resolving https://github.com/airbytehq/PyAirbyte/issues/165
# assert len(result["users"].to_pandas()) == FAKER_SCALE_A
arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10)
assert arrow_dataset.count_rows() == FAKER_SCALE_A
assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10

assert len(read_result["users"].to_pandas()) == FAKER_SCALE_A


@pytest.mark.requires_creds
Expand Down
Loading