Skip to content

Commit

Permalink
Support pandas in BigQuery cache
Browse files Browse the repository at this point in the history
  • Loading branch information
jscheel committed Jan 29, 2025
1 parent cc2c533 commit dc810cb
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 41 deletions.
22 changes: 19 additions & 3 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
if TYPE_CHECKING:
from collections.abc import Iterator

from sqlalchemy.engine import Engine

from airbyte._message_iterators import AirbyteMessageIterator
from airbyte.caches._state_backend_base import StateBackendBase
from airbyte.progress import ProgressTracker
Expand Down Expand Up @@ -66,7 +68,9 @@ class CacheBase(SqlConfig, AirbyteWriterInterface):
paired_destination_config_class: ClassVar[type | None] = None

@property
def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
def paired_destination_config(
self,
) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
"""Return a dictionary of destination configuration values."""
raise NotImplementedError(
f"The type '{type(self).__name__}' does not define an equivalent destination "
Expand Down Expand Up @@ -177,6 +181,14 @@ def get_record_processor(

# Read methods:

def _read_to_pandas_dataframe(
self,
table_name: str,
con: Engine,
**kwargs,
) -> pd.DataFrame:
return pd.read_sql_table(table_name, con=con, **kwargs)

def get_records(
self,
stream_name: str,
Expand All @@ -191,7 +203,11 @@ def get_pandas_dataframe(
"""Return a Pandas data frame with the stream's data."""
table_name = self._read_processor.get_sql_table_name(stream_name)
engine = self.get_sql_engine()
return pd.read_sql_table(table_name, engine, schema=self.schema_name)
return self._read_to_pandas_dataframe(
table_name=table_name,
con=engine,
schema=self.schema_name,
)

def get_arrow_dataset(
self,
Expand All @@ -204,7 +220,7 @@ def get_arrow_dataset(
engine = self.get_sql_engine()

# Read the table in chunks to handle large tables which does not fits in memory
pandas_chunks = pd.read_sql_table(
pandas_chunks = self._read_to_pandas_dataframe(
table_name=table_name,
con=engine,
schema=self.schema_name,
Expand Down
44 changes: 31 additions & 13 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@

from typing import TYPE_CHECKING, ClassVar, NoReturn

import pandas as pd
import pandas_gbq
from airbyte_api.models import DestinationBigquery
from google.oauth2.service_account import Credentials

from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor
from airbyte.caches.base import (
CacheBase,
)
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.destinations._translate_cache_to_dest import (
bigquery_cache_to_destination_configuration,
)


if TYPE_CHECKING:
from collections.abc import Iterator

from airbyte.shared.sql_processor import SqlProcessorBase


Expand All @@ -48,21 +52,35 @@ def paired_destination_config(self) -> DestinationBigquery:
"""Return a dictionary of destination configuration values."""
return bigquery_cache_to_destination_configuration(cache=self)

def get_arrow_dataset(
def _read_to_pandas_dataframe(
self,
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> NoReturn:
"""Raises NotImplementedError; BigQuery doesn't support `pd.read_sql_table`.
See: https://github.com/airbytehq/PyAirbyte/issues/165
"""
raise NotImplementedError(
"BigQuery doesn't currently support to_arrow"
"Please consider using a different cache implementation for these functionalities."
table_name: str,
chunksize: int | None = None,
**kwargs,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
# Pop unused kwargs, maybe not the best way to do this
kwargs.pop("con", None)
kwargs.pop("schema", None)

# Read the table using pandas_gbq
credentials = Credentials.from_service_account_file(self.credentials_path)
result = pandas_gbq.read_gbq(
f"{self.project_name}.{self.dataset_name}.{table_name}",
project_id=self.project_name,
credentials=credentials,
**kwargs,
)

# Cast result to DataFrame if it's not already a DataFrame
if not isinstance(result, pd.DataFrame):
result = pd.DataFrame(result)

# Return chunks as iterator if chunksize is provided
if chunksize is not None:
return (result[i : i + chunksize] for i in range(0, len(result), chunksize))

return result


# Expose the Cache class and also the Config class.
__all__ = [
Expand Down
150 changes: 149 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jsonschema = ">=3.2.0,<5.0"
orjson = "^3.10"
overrides = "^7.4.0"
pandas = { version = ">=1.5.3,<3.0" }
pandas-gbq = ">=0.26.1"
psycopg = {extras = ["binary", "pool"], version = "^3.1.19"}
psycopg2-binary = "^2.9.9"
pyarrow = ">=16.1,<18.0"
Expand Down Expand Up @@ -359,3 +360,4 @@ DEP002 = [
"psycopg2-binary",
"sqlalchemy-bigquery",
]

17 changes: 2 additions & 15 deletions tests/integration_tests/cloud/test_cloud_sql_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,7 @@ def test_read_from_deployed_connection(

dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users")
assert dataset.stream_name == "users"
data_as_list = list(dataset)
assert len(data_as_list) == 100

# TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165
# pandas_df = dataset.to_pandas()

pandas_df = pd.DataFrame(data_as_list)

pandas_df = dataset.to_pandas()
assert pandas_df.shape[0] == 100
assert pandas_df.shape[1] in { # Column count diff depending on when it was created
20,
Expand Down Expand Up @@ -187,14 +180,8 @@ def test_read_from_previous_job(
assert "users" in sync_result.stream_names
dataset: ab.CachedDataset = sync_result.get_dataset(stream_name="users")
assert dataset.stream_name == "users"
data_as_list = list(dataset)
assert len(data_as_list) == 100

# TODO: Fails on BigQuery: https://github.com/airbytehq/PyAirbyte/issues/165
# pandas_df = dataset.to_pandas()

pandas_df = pd.DataFrame(data_as_list)

pandas_df = dataset.to_pandas()
assert pandas_df.shape[0] == 100
assert pandas_df.shape[1] in { # Column count diff depending on when it was created
20,
Expand Down
14 changes: 5 additions & 9 deletions tests/integration_tests/test_all_cache_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,11 @@ def test_faker_read(
assert "Read **0** records" not in status_msg
assert f"Read **{configured_count}** records" in status_msg

if "bigquery" not in new_generic_cache.get_sql_alchemy_url():
# BigQuery doesn't support to_arrow
# https://github.com/airbytehq/PyAirbyte/issues/165
arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10)
assert arrow_dataset.count_rows() == FAKER_SCALE_A
assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10

# TODO: Uncomment this line after resolving https://github.com/airbytehq/PyAirbyte/issues/165
# assert len(result["users"].to_pandas()) == FAKER_SCALE_A
arrow_dataset = read_result["users"].to_arrow(max_chunk_size=10)
assert arrow_dataset.count_rows() == FAKER_SCALE_A
assert sum(1 for _ in arrow_dataset.to_batches()) == FAKER_SCALE_A / 10

assert len(read_result["users"].to_pandas()) == FAKER_SCALE_A


@pytest.mark.requires_creds
Expand Down

0 comments on commit dc810cb

Please sign in to comment.