From 73f9289500f9f2b0134d31da919c45019b47021b Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 21 Jul 2024 00:07:12 -0700 Subject: [PATCH 01/27] add back iceberg file --- airbyte/_processors/file/iceberg.py | 96 +++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 airbyte/_processors/file/iceberg.py diff --git a/airbyte/_processors/file/iceberg.py b/airbyte/_processors/file/iceberg.py new file mode 100644 index 00000000..3213eb12 --- /dev/null +++ b/airbyte/_processors/file/iceberg.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A Parquet cache implementation.""" + +from __future__ import annotations + +import gzip +import json +from pathlib import Path +from typing import IO, TYPE_CHECKING, cast + +import orjson +import ulid +from overrides import overrides + +from airbyte._future_cdk.state_writers import StateWriterBase +from airbyte._processors.file.base import ( + FileWriterBase, +) + + +if TYPE_CHECKING: + from airbyte.records import StreamRecord + + +def _get_state_file_path(cache_dir: Path, stream_name: str) -> Path: + """Return the state file path for the given stream.""" + return cache_dir / f"{stream_name}_state.parquet" + + +def _get_records_file_path(cache_dir: Path, stream_name: str, batch_id: str) -> Path: + """Return the records file path for the given stream and batch.""" + return cache_dir / f"{stream_name}_{batch_id}.parquet" + + +class IcebergStateWriter(StateWriterBase): + """An Iceberg state writer implementation.""" + + def __init__(self, cache_dir: Path) -> None: + """Initialize the Iceberg state writer.""" + self._cache_dir = cache_dir + + @overrides + def write_state(self, state_message: dict) -> None: + """Write the state for the given stream.""" + stream_name = state_message["stream"] + state_file_path = Path( + _get_state_file_path( + cache_dir=self._cache_dir, + stream_name=stream_name, + ) + ) + state_file_path.write_text(json.dumps(state_message)) + + +class IcebergWriter(FileWriterBase): + """An Iceberg file writer implementation.""" + + default_cache_file_suffix = ".parquet" + prune_extra_fields = True + + def get_state_writer(self) -> IcebergStateWriter: + return IcebergStateWriter(self._cache_dir) + + @overrides + def _open_new_file( + self, + file_path: Path, + ) -> IO[str]: + """Open a new file for writing.""" + return cast(IO[str], gzip.open(file_path, "w")) + + @overrides + def _get_new_cache_file_path( + self, + stream_name: str, + batch_id: str | None = None, # ULID of the batch + ) -> Path: + """Return a new cache file path for the given stream.""" + batch_id = batch_id or str(ulid.ULID()) + target_dir = Path(self._cache_dir) + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir / f"{stream_name}_{batch_id}{self.default_cache_file_suffix}" + + @overrides + def _write_record_dict( + self, + record_dict: StreamRecord, + open_file_writer: IO[str], + ) -> None: + # If the record is too nested, `orjson` will fail with error `TypeError: Recursion + # limit reached`. If so, fall back to the slower `json.dumps`. + try: + open_file_writer.write(orjson.dumps(record_dict).decode("utf-8") + "\n") + except TypeError: + # Using isoformat method for datetime serialization + open_file_writer.write(json.dumps(record_dict, default=lambda _: _.isoformat()) + "\n") From d26f29ecfa35c2be76a8c958a344e512fccdea36 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 21 Jul 2024 00:50:02 -0700 Subject: [PATCH 02/27] updates! --- airbyte/_processors/file/__init__.py | 3 ++ .../file/{iceberg.py => parquet.py} | 37 ++--------------- airbyte/_processors/sql/duckdb.py | 11 ++--- airbyte/_processors/sql/iceberg.py | 0 airbyte/caches/iceberg.py | 41 +++++++++++++++++++ 5 files changed, 51 insertions(+), 41 deletions(-) rename airbyte/_processors/file/{iceberg.py => parquet.py} (61%) create mode 100644 airbyte/_processors/sql/iceberg.py create mode 100644 airbyte/caches/iceberg.py diff --git a/airbyte/_processors/file/__init__.py b/airbyte/_processors/file/__init__.py index 2ef9b9a4..ecd7cb54 100644 --- a/airbyte/_processors/file/__init__.py +++ b/airbyte/_processors/file/__init__.py @@ -5,6 +5,7 @@ from airbyte._batch_handles import BatchHandle from airbyte._processors.file.base import FileWriterBase +from airbyte._processors.file.parquet import LocalIcebergStateWriter, LocalIcebergWriter from airbyte._processors.file.jsonl import JsonlWriter @@ -12,4 +13,6 @@ "BatchHandle", "FileWriterBase", "JsonlWriter", + "LocalIcebergStateWriter", + "LocalIcebergWriter", ] diff --git a/airbyte/_processors/file/iceberg.py b/airbyte/_processors/file/parquet.py similarity index 61% rename from airbyte/_processors/file/iceberg.py rename to airbyte/_processors/file/parquet.py index 3213eb12..3ff1a4ec 100644 --- a/airbyte/_processors/file/iceberg.py +++ b/airbyte/_processors/file/parquet.py @@ -22,44 +22,15 @@ from airbyte.records import StreamRecord -def _get_state_file_path(cache_dir: Path, stream_name: str) -> Path: - """Return the state file path for the given stream.""" - return cache_dir / f"{stream_name}_state.parquet" - - -def _get_records_file_path(cache_dir: Path, stream_name: str, batch_id: str) -> Path: - """Return the records file path for the given stream and batch.""" - return cache_dir / f"{stream_name}_{batch_id}.parquet" - - -class IcebergStateWriter(StateWriterBase): - """An Iceberg state writer implementation.""" - - def __init__(self, cache_dir: Path) -> None: - """Initialize the Iceberg state writer.""" - self._cache_dir = cache_dir - - @overrides - def write_state(self, state_message: dict) -> None: - """Write the state for the given stream.""" - stream_name = state_message["stream"] - state_file_path = Path( - _get_state_file_path( - cache_dir=self._cache_dir, - stream_name=stream_name, - ) - ) - state_file_path.write_text(json.dumps(state_message)) - - -class IcebergWriter(FileWriterBase): +class LocalIcebergWriter(FileWriterBase): """An Iceberg file writer implementation.""" default_cache_file_suffix = ".parquet" prune_extra_fields = True - def get_state_writer(self) -> IcebergStateWriter: - return IcebergStateWriter(self._cache_dir) + def _get_records_file_path(self, cache_dir: Path, stream_name: str, batch_id: str) -> Path: + """Return the records file path for the given stream and batch.""" + return cache_dir / f"{stream_name}_{batch_id}.parquet" @overrides def _open_new_file( diff --git a/airbyte/_processors/sql/duckdb.py b/airbyte/_processors/sql/duckdb.py index 38b3a2e9..053a411c 100644 --- a/airbyte/_processors/sql/duckdb.py +++ b/airbyte/_processors/sql/duckdb.py @@ -15,7 +15,7 @@ from airbyte._future_cdk import SqlProcessorBase from airbyte._future_cdk.sql_processor import SqlConfig -from airbyte._processors.file import JsonlWriter +from airbyte._processors.file import IcebergWriter from airbyte.secrets.base import SecretString @@ -88,15 +88,10 @@ def get_sql_engine(self) -> Engine: class DuckDBSqlProcessor(SqlProcessorBase): - """A DuckDB implementation of the cache. - - Jsonl is used for local file storage before bulk loading. - Unlike the Snowflake implementation, we can't use the COPY command to load data - so we insert as values instead. - """ + """A DuckDB implementation of the cache.""" supports_merge_insert = False - file_writer_class = JsonlWriter + file_writer_class = IcebergWriter sql_config: DuckDBConfig @overrides diff --git a/airbyte/_processors/sql/iceberg.py b/airbyte/_processors/sql/iceberg.py new file mode 100644 index 00000000..e69de29b diff --git a/airbyte/caches/iceberg.py b/airbyte/caches/iceberg.py new file mode 100644 index 00000000..1464db2f --- /dev/null +++ b/airbyte/caches/iceberg.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A Iceberg implementation of the PyAirbyte cache. + +## Usage Example + +```python +from airbyte as ab +from airbyte.caches import IcebergCache + +cache = IcebergCache( + db_path="/path/to/my/Iceberg-file", + schema_name="myschema", +) +``` +""" + +from __future__ import annotations + +import warnings + +from Iceberg_engine import IcebergEngineWarning +from pydantic import PrivateAttr + +from airbyte._processors.sql.iceberg import IcebergConfig, IcebergSqlProcessor +from airbyte.caches.base import CacheBase + + +# Suppress warnings from Iceberg about reflection on indices. +# https://github.com/Mause/Iceberg_engine/issues/905 +warnings.filterwarnings( + "ignore", + message="Iceberg-engine doesn't yet support reflection on indices", + category=IcebergEngineWarning, +) + + +# @dataclass +class IcebergCache(IcebergConfig, CacheBase): + """A Iceberg cache.""" + + _sql_processor_class: type[IcebergSqlProcessor] = PrivateAttr(default=IcebergSqlProcessor) From ca37196cf0a58dff8eab938c7d4018e74bb0f8d0 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Thu, 1 Aug 2024 20:18:28 -0700 Subject: [PATCH 03/27] add polars schema utils --- airbyte/_util/polars.py | 112 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 airbyte/_util/polars.py diff --git a/airbyte/_util/polars.py b/airbyte/_util/polars.py new file mode 100644 index 00000000..e3bd925c --- /dev/null +++ b/airbyte/_util/polars.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Polars utility functions.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import polars as pl +import pyarrow as pa + + +@dataclass +class PolarsStreamSchema: + """A Polars stream schema.""" + + expressions: list[pl.Expr] = field(default_factory=list) + arrow_schema: pa.Schema = field(default_factory=lambda: pa.schema([])) + + @classmethod + def from_json_schema(cls, json_schema: dict[str, Any]) -> PolarsStreamSchema: + """Create a Polars stream schema from a JSON schema.""" + expressions: list[pl.Expr] = [] + arrow_columns: list[pa.Field] = [] + _json_schema_to_polars( + json_schema=json_schema, + expressions=expressions, + arrow_columns=arrow_columns, + ) + return cls( + expressions=expressions, + arrow_schema=pa.schema(arrow_columns), + ) + + +def _json_schema_to_polars( + *, + json_schema: dict[str, Any], + expressions: list[pl.Expr], + arrow_columns: list[pa.Field], + _breadcrumbs: list[str] | None = None, +) -> None: + """Get Polars transformations and PyArrow column definitions from the provided JSON schema. + + Recursive operations are tracked with a breadcrumb list. + """ + _breadcrumbs = _breadcrumbs or [] + json_schema_node = json_schema + for crumb in _breadcrumbs: + json_schema_node = json_schema_node["properties"][crumb] + + tidy_type: str | list[str] = json_schema_node["type"] + + if isinstance(tidy_type, list): + if "null" in tidy_type: + tidy_type.remove("null") + if len(tidy_type) == 1: + tidy_type = tidy_type[0] + else: + msg = f"Invalid type: {tidy_type}" + raise ValueError(msg) + + for key, value in json_schema_node.get("properties", {}).items(): + # Handle nested objects recursively + if tidy_type == "object": + # Use breadcrumbs to navigate into nested properties + _json_schema_to_polars( + json_schema=json_schema, + expressions=expressions, + arrow_columns=arrow_columns, + _breadcrumbs=[*_breadcrumbs, key], + ) + continue + + if tidy_type == "integer": + expressions.append(pl.col(key).cast(pl.Int64)) + arrow_columns.append(pa.field(key, pa.int64())) + continue + + if tidy_type == "number": + expressions.append(pl.col(key).cast(pl.Float64)) + arrow_columns.append(pa.field(key, pa.float64())) + continue + + if tidy_type == "boolean": + expressions.append(pl.col(key).cast(pl.Boolean)) + arrow_columns.append(pa.field(key, pa.bool_())) + continue + + if tidy_type == "string": + str_format = value.get("format") + if str_format == "date-time": + expressions.append(pl.col(key).cast(pl.Datetime)) + arrow_columns.append(pa.field(key, pa.timestamp("ms"))) + continue + + if str_format == "date": + expressions.append(pl.col(key).cast(pl.Date)) + arrow_columns.append(pa.field(key, pa.date32())) + continue + + if str_format == "time": + expressions.append(pl.col(key).cast(pl.Time)) + arrow_columns.append(pa.field(key, pa.time32("ms"))) + continue + + expressions.append(pl.col(key).cast(pl.Utf8)) + arrow_columns.append(pa.field(key, pa.string())) + continue + + msg = f"Invalid type: {tidy_type}" + raise ValueError(msg) From 07ac6a5882b3103629bad549e834a218418e2bed Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Thu, 1 Aug 2024 20:20:14 -0700 Subject: [PATCH 04/27] `poetry add polars` --- poetry.lock | 49 +++++++++++++++++++++++++++++++++++++++++++++---- pyproject.toml | 1 + 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index f82f5d6c..fafe5132 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1334,8 +1334,8 @@ files = [ [package.dependencies] orjson = ">=3.9.14,<4.0.0" pydantic = [ - {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] requests = ">=2,<3" @@ -1747,8 +1747,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -1910,6 +1910,47 @@ tomli = ">=1.2.2" [package.extras] poetry-plugin = ["poetry (>=1.0,<2.0)"] +[[package]] +name = "polars" +version = "1.3.0" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "polars-1.3.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:85a338b8f617fdf5e5472567d32efeb46e6624a604c45622cc96669324f82961"}, + {file = "polars-1.3.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:5859a11d1c8ec14089127043d8d6bae01f015021113ed01a2e4953e6c21feee5"}, + {file = "polars-1.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a1ff7779315e6b0d17641af3eb4dd7aec2ab0bc1bee009efb12242bf6403aeb"}, + {file = "polars-1.3.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:e675269c17a83484c74165989d93572785d4298019f4f8ca65e25a49d4440236"}, + {file = "polars-1.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:75cbbccc4a55a8ae8c5ea8b9daa8747aee5d182c2bba7c712496f32a8096562d"}, + {file = "polars-1.3.0.tar.gz", hash = "sha256:c7812d6c72ffdc9e70aaa8f9aa6378db80b393e7ecbe7005ad84b150c17c71cb"}, +] + +[package.extras] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] +deltalake = ["deltalake (>=0.15.0)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] +fsspec = ["fsspec"] +gpu = ["cudf-polars-cu12"] +graph = ["matplotlib"] +iceberg = ["pyiceberg (>=0.5.0)"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["hvplot (>=0.9.1)", "polars[pandas]"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] +timezone = ["backports-zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "proto-plus" version = "1.24.0" @@ -2127,8 +2168,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -3490,4 +3531,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "ff078465575d9db425a09b760fc3768d496ea291cb050206afcac62820d55a56" +content-hash = "94681c4910dbdb28d4cc01e42dddd29711ecc877557f9719a71658d8f515e158" diff --git a/pyproject.toml b/pyproject.toml index 29c9a556..6fb774cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ airbyte-api = "^0.49.2" google-cloud-bigquery-storage = "^2.25.0" pyiceberg = "^0.6.1" uuid7 = "^0.1.0" +polars = "^1.3.0" [tool.poetry.group.dev.dependencies] docker = "^7.0.0" From 80e0fdba2bf45937b64f7830a2ca51f492b882a2 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Fri, 2 Aug 2024 12:03:26 -0700 Subject: [PATCH 05/27] add polars type mapping --- airbyte/_util/polars.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/airbyte/_util/polars.py b/airbyte/_util/polars.py index e3bd925c..7b750436 100644 --- a/airbyte/_util/polars.py +++ b/airbyte/_util/polars.py @@ -16,20 +16,25 @@ class PolarsStreamSchema: expressions: list[pl.Expr] = field(default_factory=list) arrow_schema: pa.Schema = field(default_factory=lambda: pa.schema([])) + polars_schema: pl.Schema = field(default_factory=lambda: pl.Schema([])) @classmethod def from_json_schema(cls, json_schema: dict[str, Any]) -> PolarsStreamSchema: """Create a Polars stream schema from a JSON schema.""" expressions: list[pl.Expr] = [] arrow_columns: list[pa.Field] = [] + polars_columns: list[pl.Col] = [] + _json_schema_to_polars( json_schema=json_schema, expressions=expressions, arrow_columns=arrow_columns, + polars_columns=polars_columns, ) return cls( expressions=expressions, arrow_schema=pa.schema(arrow_columns), + polars_schema=pl.Schema(polars_columns), ) @@ -38,6 +43,7 @@ def _json_schema_to_polars( json_schema: dict[str, Any], expressions: list[pl.Expr], arrow_columns: list[pa.Field], + polars_columns: list[pl.Col], _breadcrumbs: list[str] | None = None, ) -> None: """Get Polars transformations and PyArrow column definitions from the provided JSON schema. @@ -49,8 +55,10 @@ def _json_schema_to_polars( for crumb in _breadcrumbs: json_schema_node = json_schema_node["properties"][crumb] + # Determine the primary type from the schema node tidy_type: str | list[str] = json_schema_node["type"] + # Handle multiple types, focusing on non-nullable types if present if isinstance(tidy_type, list): if "null" in tidy_type: tidy_type.remove("null") @@ -68,23 +76,28 @@ def _json_schema_to_polars( json_schema=json_schema, expressions=expressions, arrow_columns=arrow_columns, + polars_columns=polars_columns, _breadcrumbs=[*_breadcrumbs, key], ) continue + # Map JSON schema types to Arrow and Polars types if tidy_type == "integer": expressions.append(pl.col(key).cast(pl.Int64)) arrow_columns.append(pa.field(key, pa.int64())) + polars_columns.append(pl.Int64) # Add corresponding Polars column type continue if tidy_type == "number": expressions.append(pl.col(key).cast(pl.Float64)) arrow_columns.append(pa.field(key, pa.float64())) + polars_columns.append(pl.Float64) # Add corresponding Polars column type continue if tidy_type == "boolean": expressions.append(pl.col(key).cast(pl.Boolean)) arrow_columns.append(pa.field(key, pa.bool_())) + polars_columns.append(pl.Boolean) # Add corresponding Polars column type continue if tidy_type == "string": @@ -92,20 +105,24 @@ def _json_schema_to_polars( if str_format == "date-time": expressions.append(pl.col(key).cast(pl.Datetime)) arrow_columns.append(pa.field(key, pa.timestamp("ms"))) + polars_columns.append(pl.Datetime) # Add corresponding Polars column type continue if str_format == "date": expressions.append(pl.col(key).cast(pl.Date)) arrow_columns.append(pa.field(key, pa.date32())) + polars_columns.append(pl.Date) # Add corresponding Polars column type continue if str_format == "time": expressions.append(pl.col(key).cast(pl.Time)) arrow_columns.append(pa.field(key, pa.time32("ms"))) + polars_columns.append(pl.Time) # Add corresponding Polars column type continue expressions.append(pl.col(key).cast(pl.Utf8)) arrow_columns.append(pa.field(key, pa.string())) + polars_columns.append(pl.Utf8) # Add corresponding Polars column type continue msg = f"Invalid type: {tidy_type}" From 8d9ef08f677dde4c2d9ef70d6311add1dec4a3c3 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Fri, 2 Aug 2024 12:06:05 -0700 Subject: [PATCH 06/27] add as_filelike() for iterator --- airbyte/_message_iterators.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/airbyte/_message_iterators.py b/airbyte/_message_iterators.py index 82ee71e9..8103b436 100644 --- a/airbyte/_message_iterators.py +++ b/airbyte/_message_iterators.py @@ -56,6 +56,30 @@ def read(self) -> str: """Read the next message from the iterator.""" return next(self).model_dump_json() + def as_filelike(self) -> io.BytesIO: + """Return a file-like object that reads from the iterator.""" + + class FileLikeReader(io.RawIOBase): + def __init__(self, iterator: Iterator[AirbyteMessage]) -> None: + self.iterator = (msg.model_dump_json() for msg in iterator) + self.buffer = "" + + def readable(self) -> Literal[True]: + return True + + def readinto(self, b: Any) -> int: + try: + chunk = next(self.iterator) + except StopIteration: + return 0 # EOF + + data = chunk.encode() + n = len(data) + b[:n] = data + return n + + return FileLikeReader(self._iterator) + @classmethod def from_read_result(cls, read_result: ReadResult) -> AirbyteMessageIterator: """Create a iterator from a `ReadResult` object.""" From 015faf0db7e2390fc04f338451a4dcd9bacdc089 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Fri, 2 Aug 2024 12:06:28 -0700 Subject: [PATCH 07/27] lint fix --- airbyte/_processors/file/parquet.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airbyte/_processors/file/parquet.py b/airbyte/_processors/file/parquet.py index 3ff1a4ec..8869bd4b 100644 --- a/airbyte/_processors/file/parquet.py +++ b/airbyte/_processors/file/parquet.py @@ -12,7 +12,6 @@ import ulid from overrides import overrides -from airbyte._future_cdk.state_writers import StateWriterBase from airbyte._processors.file.base import ( FileWriterBase, ) From 4fb9b3b011ae90a7a98a2589d554e880a3b0aff9 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Fri, 2 Aug 2024 12:09:54 -0700 Subject: [PATCH 08/27] wip: iceberg updates, w example script --- airbyte/_processors/file/__init__.py | 3 +- airbyte/_processors/sql/duckdb.py | 4 +- airbyte/_processors/sql/iceberg.py | 12 ++++++ examples/run_polars_poc.py | 63 ++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 examples/run_polars_poc.py diff --git a/airbyte/_processors/file/__init__.py b/airbyte/_processors/file/__init__.py index ecd7cb54..43c3da97 100644 --- a/airbyte/_processors/file/__init__.py +++ b/airbyte/_processors/file/__init__.py @@ -5,14 +5,13 @@ from airbyte._batch_handles import BatchHandle from airbyte._processors.file.base import FileWriterBase -from airbyte._processors.file.parquet import LocalIcebergStateWriter, LocalIcebergWriter from airbyte._processors.file.jsonl import JsonlWriter +from airbyte._processors.file.parquet import LocalIcebergWriter __all__ = [ "BatchHandle", "FileWriterBase", "JsonlWriter", - "LocalIcebergStateWriter", "LocalIcebergWriter", ] diff --git a/airbyte/_processors/sql/duckdb.py b/airbyte/_processors/sql/duckdb.py index 053a411c..fc400256 100644 --- a/airbyte/_processors/sql/duckdb.py +++ b/airbyte/_processors/sql/duckdb.py @@ -15,7 +15,7 @@ from airbyte._future_cdk import SqlProcessorBase from airbyte._future_cdk.sql_processor import SqlConfig -from airbyte._processors.file import IcebergWriter +from airbyte._processors.file.jsonl import JsonlWriter from airbyte.secrets.base import SecretString @@ -91,7 +91,7 @@ class DuckDBSqlProcessor(SqlProcessorBase): """A DuckDB implementation of the cache.""" supports_merge_insert = False - file_writer_class = IcebergWriter + file_writer_class = JsonlWriter sql_config: DuckDBConfig @overrides diff --git a/airbyte/_processors/sql/iceberg.py b/airbyte/_processors/sql/iceberg.py index e69de29b..625cf617 100644 --- a/airbyte/_processors/sql/iceberg.py +++ b/airbyte/_processors/sql/iceberg.py @@ -0,0 +1,12 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations + +from airbyte._future_cdk.sql_processor import SqlProcessorBase + + +class IcebergSqlProcessor(SqlProcessorBase): + """A Iceberg SQL processor.""" + + def __init__(self, db_path: str, schema_name: str) -> None: + """Initialize the Iceberg SQL processor.""" + super().__init__(db_path=db_path, schema_name=schema_name) diff --git a/examples/run_polars_poc.py b/examples/run_polars_poc.py new file mode 100644 index 00000000..0c745b51 --- /dev/null +++ b/examples/run_polars_poc.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A test of Polars in PyAirbyte. + +Usage (from PyAirbyte root directory): +> poetry run python ./examples/run_polars_poc.py +""" + +from __future__ import annotations + +from collections.abc import Iterator +from io import BytesIO, StringIO +from typing import TextIO + +import airbyte as ab +import polars as pl +from airbyte import get_source +from airbyte._message_iterators import AirbyteMessageIterator +from airbyte._util.polars import PolarsStreamSchema +from airbyte.progress import ProgressTracker + + +def get_my_source() -> ab.Source: + return get_source( + "source-faker", + config={}, + streams=["users"], + ) + + +def main() -> None: + """Run the Polars proof of concept.""" + source = get_my_source() + + polars_stream_schema: PolarsStreamSchema = PolarsStreamSchema.from_json_schema( + json_schema=source.configured_catalog.streams[0].stream.json_schema, + ) + progress_tracker = ProgressTracker( + source=source, + cache=None, + destination=None, + ) + msg_iterator = AirbyteMessageIterator( + msg + for msg in source._get_airbyte_message_iterator( + streams=["users"], + progress_tracker=progress_tracker, + ) + ) + # jsonl_iterator = (msg.model_dump_json() for msg in msg_iterator) + # df = pl.read_ndjson( + # StringIO("\n".join(jsonl_iterator)), + # schema=polars_stream_schema.polars_schema, + # ) + filelike = msg_iterator.as_filelike() + print(filelike.readlines()) + df = pl.read_ndjson( + filelike, + schema=polars_stream_schema.polars_schema, + ) + + +if __name__ == "__main__": + main() From bddbc59593f1d6e9aca5aa51380e6e40b55cc9a3 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sat, 3 Aug 2024 12:20:23 -0700 Subject: [PATCH 09/27] custom pydantic classes with str data --- airbyte/_airbyte_message_overrides.py | 61 +++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 airbyte/_airbyte_message_overrides.py diff --git a/airbyte/_airbyte_message_overrides.py b/airbyte/_airbyte_message_overrides.py new file mode 100644 index 00000000..70196f9c --- /dev/null +++ b/airbyte/_airbyte_message_overrides.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Custom Airbyte message classes. + +These classes override the default handling, in order to ensure that the data field is always a +jsonified string, rather than a dict. + +To use these classes, import them from this module, and use them in place of the default classes. + +Example: +```python +from airbyte._airbyte_message_overrides import AirbyteMessageWithStrData + +for line in sys.stdin: + message = AirbyteMessageWithStrData.model_validate_json(line) +``` +""" + +from __future__ import annotations + +import copy +import json +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, +) + + +AirbyteRecordMessageWithStrData = copy.deepcopy(AirbyteRecordMessage) +AirbyteStateMessageWithStrData = copy.deepcopy(AirbyteStateMessage) +AirbyteMessageWithStrData = copy.deepcopy(AirbyteMessage) + +# Modify the data field in the copied class +AirbyteRecordMessageWithStrData.__annotations__["data"] = str +AirbyteStateMessageWithStrData.__annotations__["data"] = str + +AirbyteRecordMessageWithStrData.data = Field(..., description="jsonified record data as a str") +AirbyteStateMessageWithStrData.data = Field(..., description="jsonified state data as a str") + + +# Add a validator to ensure data is a JSON string +@model_validator(mode="before") +def ensure_data_is_string( + cls: BaseModel, # type: ignore # noqa: ARG001, PGH003 + values: dict[str, Any], +) -> None: + if "data" in values and not isinstance(values["data"], dict): + values["data"] = json.dumps(values["data"]) + if "data" in values and not isinstance(values["data"], str): + raise ValueError + + +AirbyteRecordMessageWithStrData.ensure_data_is_string = classmethod(ensure_data_is_string) # type: ignore [arg-type] +AirbyteStateMessageWithStrData.ensure_data_is_string = classmethod(ensure_data_is_string) # type: ignore [arg-type] + +AirbyteMessageWithStrData.__annotations__["record"] = AirbyteRecordMessageWithStrData | None +AirbyteMessageWithStrData.__annotations__["state"] = AirbyteStateMessageWithStrData | None From 9af597f23f72e5fdcf0199e17e36efb74038f52c Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sat, 3 Aug 2024 12:33:02 -0700 Subject: [PATCH 10/27] refactor to parse stdout in separate method --- airbyte/_connector_base.py | 35 +++++++++++++++++++++++++++++++---- airbyte/destinations/base.py | 2 +- airbyte/sources/base.py | 6 +++--- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/airbyte/_connector_base.py b/airbyte/_connector_base.py index f14a3924..188028cd 100644 --- a/airbyte/_connector_base.py +++ b/airbyte/_connector_base.py @@ -148,7 +148,7 @@ def _get_spec(self, *, force_refresh: bool = False) -> ConnectorSpecification: """ if force_refresh or self._spec is None: try: - for msg in self._execute(["spec"]): + for msg in self._execute_and_parse(["spec"]): if msg.type == Type.SPEC and msg.spec: self._spec = msg.spec break @@ -254,7 +254,7 @@ def check(self) -> None: """ with as_temp_files([self._config]) as [config_file]: try: - for msg in self._execute(["check", "--config", config_file]): + for msg in self._execute_and_parse(["check", "--config", config_file]): if msg.type == Type.CONNECTION_STATUS and msg.connectionStatus: if msg.connectionStatus.status != Status.FAILED: print(f"Connection check succeeded for `{self.name}`.") @@ -363,7 +363,7 @@ def _peek_airbyte_message( ) return - def _execute( + def _execute_and_parse( self, args: list[str], stdin: IO[str] | AirbyteMessageIterator | None = None, @@ -383,7 +383,7 @@ def _execute( self.executor.ensure_installation(auto_fix=False) try: - for line in self.executor.execute(args, stdin=stdin): + for line in self._execute(args, stdin=stdin): try: message: AirbyteMessage = AirbyteMessage.model_validate_json(json_data=line) self._peek_airbyte_message(message) @@ -399,6 +399,33 @@ def _execute( log_text=self._last_log_messages, ) from e + def _execute( + self, + args: list[str], + stdin: IO[str] | AirbyteMessageIterator | None = None, + ) -> Generator[str, None, None]: + """Execute the connector with the given arguments. + + This involves the following steps: + * Locate the right venv. It is called ".venv-" + * Spawn a subprocess with .venv-/bin/ + * Read the output line by line of the subprocess and yield (unparsed) strings. + + Raises: + AirbyteConnectorFailedError: If the process returns a failure status (non-zero). + """ + # Fail early if the connector is not installed. + self.executor.ensure_installation(auto_fix=False) + + try: + yield from self.executor.execute(args, stdin=stdin) + + except Exception as e: + raise exc.AirbyteConnectorFailedError( + connector_name=self.name, + log_text=self._last_log_messages, + ) from e + __all__ = [ "ConnectorBase", diff --git a/airbyte/destinations/base.py b/airbyte/destinations/base.py index ab52c3d3..3fbd421e 100644 --- a/airbyte/destinations/base.py +++ b/airbyte/destinations/base.py @@ -275,7 +275,7 @@ def _write_airbyte_message_stream( try: # We call the connector to write the data, tallying the inputs and outputs for destination_message in progress_tracker.tally_confirmed_writes( - messages=self._execute( + messages=self._execute_and_parse( args=[ "write", "--config", diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 8b34b46c..6b464fa8 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -207,7 +207,7 @@ def _discover(self) -> AirbyteCatalog: - Make sure the subprocess is killed when the function returns. """ with as_temp_files([self._config]) as [config_file]: - for msg in self._execute(["discover", "--config", config_file]): + for msg in self._execute_and_parse(["discover", "--config", config_file]): if msg.type == Type.CATALOG and msg.catalog: return msg.catalog raise exc.AirbyteConnectorMissingCatalogError( @@ -236,7 +236,7 @@ def _get_spec(self, *, force_refresh: bool = False) -> ConnectorSpecification: * Make sure the subprocess is killed when the function returns. """ if force_refresh or self._spec is None: - for msg in self._execute(["spec"]): + for msg in self._execute_and_parse(["spec"]): if msg.type == Type.SPEC and msg.spec: self._spec = msg.spec break @@ -543,7 +543,7 @@ def _read_with_catalog( catalog_file, state_file, ]: - message_generator = self._execute( + message_generator = self._execute_and_parse( [ "read", "--config", From 46ba395f75f5316286f454e3d4041ddae26d5632 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 4 Aug 2024 14:56:52 -0700 Subject: [PATCH 11/27] make flush_active_batch public --- airbyte/_processors/file/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airbyte/_processors/file/base.py b/airbyte/_processors/file/base.py index 7951fa07..9662332d 100644 --- a/airbyte/_processors/file/base.py +++ b/airbyte/_processors/file/base.py @@ -65,7 +65,7 @@ def _open_new_file( """Open a new file for writing.""" return file_path.open("w", encoding="utf-8") - def _flush_active_batch( + def flush_active_batch( self, stream_name: str, progress_tracker: ProgressTracker, @@ -101,7 +101,7 @@ def _new_batch( This also flushes the active batch if one already exists for the given stream. """ if stream_name in self._active_batches: - self._flush_active_batch( + self.flush_active_batch( stream_name=stream_name, progress_tracker=progress_tracker, ) @@ -193,7 +193,7 @@ def flush_active_batches( """Flush active batches for all streams.""" streams = list(self._active_batches.keys()) for stream_name in streams: - self._flush_active_batch( + self.flush_active_batch( stream_name=stream_name, progress_tracker=progress_tracker, ) From e99e72f450fa5b7c6d5685392138b20b702068d3 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 4 Aug 2024 14:58:21 -0700 Subject: [PATCH 12/27] cleanup --- airbyte/_message_iterators.py | 5 +++-- examples/run_polars_poc.py | 3 --- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/airbyte/_message_iterators.py b/airbyte/_message_iterators.py index 8103b436..9c43d177 100644 --- a/airbyte/_message_iterators.py +++ b/airbyte/_message_iterators.py @@ -4,9 +4,10 @@ from __future__ import annotations import datetime +import io import sys from collections.abc import Iterator -from typing import IO, TYPE_CHECKING, Callable, cast +from typing import IO, TYPE_CHECKING, Any, Callable, Literal, cast import pydantic from typing_extensions import final @@ -78,7 +79,7 @@ def readinto(self, b: Any) -> int: b[:n] = data return n - return FileLikeReader(self._iterator) + return cast(io.BytesIO, FileLikeReader(self._iterator)) @classmethod def from_read_result(cls, read_result: ReadResult) -> AirbyteMessageIterator: diff --git a/examples/run_polars_poc.py b/examples/run_polars_poc.py index 0c745b51..329ccb8e 100644 --- a/examples/run_polars_poc.py +++ b/examples/run_polars_poc.py @@ -7,9 +7,6 @@ from __future__ import annotations -from collections.abc import Iterator -from io import BytesIO, StringIO -from typing import TextIO import airbyte as ab import polars as pl From 4d163a085ad99d750bd8b53b9f2d74dc31233a06 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 4 Aug 2024 14:59:23 -0700 Subject: [PATCH 13/27] add icerberg processor methods --- airbyte/_processors/sql/iceberg.py | 134 ++++++++++++++++++++++++++++- 1 file changed, 130 insertions(+), 4 deletions(-) diff --git a/airbyte/_processors/sql/iceberg.py b/airbyte/_processors/sql/iceberg.py index 625cf617..e9af6454 100644 --- a/airbyte/_processors/sql/iceberg.py +++ b/airbyte/_processors/sql/iceberg.py @@ -1,12 +1,138 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. from __future__ import annotations -from airbyte._future_cdk.sql_processor import SqlProcessorBase +from textwrap import dedent, indent +from typing import TYPE_CHECKING + +from airbyte_protocol.models import ( + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, +) + +from airbyte import exceptions as exc +from airbyte._future_cdk.sql_processor import SqlConfig, SqlProcessorBase +from airbyte._future_cdk.state_writers import StateWriterBase +from airbyte._processors.file.parquet import ParquetWriter + + +if TYPE_CHECKING: + from pathlib import Path + + +class IcebergConfig(SqlConfig): + """A Iceberg configuration.""" + + def __init__(self, db_path: str, schema_name: str) -> None: + """Initialize the Iceberg configuration.""" + self.db_path = db_path + self.schema_name = schema_name class IcebergSqlProcessor(SqlProcessorBase): """A Iceberg SQL processor.""" - def __init__(self, db_path: str, schema_name: str) -> None: - """Initialize the Iceberg SQL processor.""" - super().__init__(db_path=db_path, schema_name=schema_name) + supports_merge_insert = False + file_writer_class = ParquetWriter + sql_config: IcebergConfig + + class IcebergStateWriter(StateWriterBase): + """A state writer for the Parquet cache.""" + + def __init__(self, iceberg_processor: IcebergSqlProcessor) -> None: + self._iceberg_processor = iceberg_processor + super().__init__() + + def _write_state(self, state: AirbyteRecordMessage) -> None: + """Write the state to the cache.""" + self._iceberg_processor.write_state(state) + + @property + def get_state_writer(self) -> StateWriterBase: + if self._state_writer is None: + self._state_writer = self.IcebergStateWriter(self) + + return self._state_writer + + def write_state(self, state: AirbyteStateMessage) -> None: + """Write the state to the cache. + + Args: + state (AirbyteStateMessage): The state to write. + + Implementation: + - State messages are written a separate file. + - Any pending records are written to the cache file and the cache file is closed. + - For stream state messages, the matching stream batches are flushed and closed. + - For global state, all batches are flushed and closed. + """ + stream_names: list[str] = [] + if state.type == AirbyteStateType.STREAM: + stream_names = [state.record.stream] + if state.type == AirbyteStateType.GLOBAL: + stream_names = list(self._buffered_records.keys()) + else: + msg = f"Unexpected state type: {state.type}" + raise exc.PyAirbyteInternalError(msg) + + for stream_name in stream_names: + state_file_name = self.file_writer.get_active_batch(stream_name) + self.file_writer.flush_active_batch(stream_name) + self.file_writer._write_state_to_file(state) + return + + def _write_files_to_new_table( + self, + files: list[Path], + stream_name: str, + batch_id: str, + ) -> str: + """Write file(s) to a new table. + + This involves registering the table in the Iceberg catalog, creating a manifest file, + and registering the manifest file in the catalog. + """ + temp_table_name = self._create_table_for_loading( + stream_name=stream_name, + batch_id=batch_id, + ) + columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys()) + columns_list_str = indent( + "\n, ".join([self._quote_identifier(col) for col in columns_list]), + " ", + ) + files_list = ", ".join([f"'{f!s}'" for f in files]) + columns_type_map = indent( + "\n, ".join( + [ + self._quote_identifier(self.normalizer.normalize(prop_name)) + + ': "' + + str( + self._get_sql_column_definitions(stream_name)[ + self.normalizer.normalize(prop_name) + ] + ) + + '"' + for prop_name in columns_list + ] + ), + " ", + ) + insert_statement = dedent( + f""" + INSERT INTO {self.sql_config.schema_name}.{temp_table_name} + ( + {columns_list_str} + ) + SELECT + {columns_list_str} + FROM read_json_auto( + [{files_list}], + format = 'newline_delimited', + union_by_name = true, + columns = {{ { columns_type_map } }} + ) + """ + ) + self._execute_sql(insert_statement) + return temp_table_name From 6337bd24a69f6727b2adb8f000286d716af26713 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 4 Aug 2024 15:02:31 -0700 Subject: [PATCH 14/27] update parquet writer --- airbyte/_processors/file/parquet.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/airbyte/_processors/file/parquet.py b/airbyte/_processors/file/parquet.py index 8869bd4b..ec7337d0 100644 --- a/airbyte/_processors/file/parquet.py +++ b/airbyte/_processors/file/parquet.py @@ -6,12 +6,22 @@ import gzip import json from pathlib import Path -from typing import IO, TYPE_CHECKING, cast +from typing import IO, TYPE_CHECKING, Literal, cast import orjson import ulid from overrides import overrides +from pydantic import PrivateAttr +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, +) + +from airbyte import exceptions as exc +from airbyte._future_cdk.state_writers import StateWriterBase from airbyte._processors.file.base import ( FileWriterBase, ) @@ -21,12 +31,24 @@ from airbyte.records import StreamRecord -class LocalIcebergWriter(FileWriterBase): - """An Iceberg file writer implementation.""" +class ParquetWriter(FileWriterBase): + """An Parquet file writer implementation.""" default_cache_file_suffix = ".parquet" prune_extra_fields = True + _state_writer: StateWriterBase | None = PrivateAttr(default=None) + _buffered_records: dict[str, list[AirbyteMessage]] = PrivateAttr(default_factory=dict) + + def _get_records_file_path( + self, + cache_dir: Path, + stream_name: str, + batch_id: str, + ) -> Path: + """Return the records file path for the given stream and batch.""" + return cache_dir / f"{stream_name}_{batch_id}.records.parquet" + def _get_records_file_path(self, cache_dir: Path, stream_name: str, batch_id: str) -> Path: """Return the records file path for the given stream and batch.""" return cache_dir / f"{stream_name}_{batch_id}.parquet" From 852a4b82242ffe7a1c6c04c4086bc985a242ba1d Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 4 Aug 2024 15:02:54 -0700 Subject: [PATCH 15/27] remove dupe implementation --- airbyte/_processors/file/parquet.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/airbyte/_processors/file/parquet.py b/airbyte/_processors/file/parquet.py index ec7337d0..94413365 100644 --- a/airbyte/_processors/file/parquet.py +++ b/airbyte/_processors/file/parquet.py @@ -49,10 +49,6 @@ def _get_records_file_path( """Return the records file path for the given stream and batch.""" return cache_dir / f"{stream_name}_{batch_id}.records.parquet" - def _get_records_file_path(self, cache_dir: Path, stream_name: str, batch_id: str) -> Path: - """Return the records file path for the given stream and batch.""" - return cache_dir / f"{stream_name}_{batch_id}.parquet" - @overrides def _open_new_file( self, From 049b38a0c23d093b1ccf9e4ee3bef9b72a109e06 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Sun, 4 Aug 2024 16:01:06 -0700 Subject: [PATCH 16/27] refactor: remove unnecessary RecordProcessor class --- airbyte/_future_cdk/record_processor.py | 296 ------------------------ airbyte/_future_cdk/sql_processor.py | 188 +++++++++++++-- 2 files changed, 173 insertions(+), 311 deletions(-) delete mode 100644 airbyte/_future_cdk/record_processor.py diff --git a/airbyte/_future_cdk/record_processor.py b/airbyte/_future_cdk/record_processor.py deleted file mode 100644 index 89c3f0d0..00000000 --- a/airbyte/_future_cdk/record_processor.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -"""Abstract base class for Processors, including SQL processors. - -Processors accept Airbyte messages as input from STDIN or from another input stream. -""" - -from __future__ import annotations - -import abc -import io -import sys -from collections import defaultdict -from typing import IO, TYPE_CHECKING, cast, final - -from airbyte_cdk import AirbyteMessage -from airbyte_protocol.models import ( - AirbyteRecordMessage, - AirbyteStateMessage, - AirbyteStateType, - AirbyteStreamState, - AirbyteTraceMessage, - Type, -) - -from airbyte import exceptions as exc -from airbyte._future_cdk.state_writers import StdOutStateWriter -from airbyte._message_iterators import AirbyteMessageIterator -from airbyte.records import StreamRecordHandler -from airbyte.strategies import WriteStrategy - - -if TYPE_CHECKING: - from collections.abc import Iterable, Iterator - - from airbyte._batch_handles import BatchHandle - from airbyte._future_cdk.catalog_providers import CatalogProvider - from airbyte._future_cdk.state_writers import StateWriterBase - from airbyte.progress import ProgressTracker - - -class AirbyteMessageParsingError(Exception): - """Raised when an Airbyte message is invalid or cannot be parsed.""" - - -class RecordProcessorBase(abc.ABC): - """Abstract base class for classes which can process Airbyte messages from a source. - - This class is responsible for all aspects of handling Airbyte protocol. - - The class should be passed a catalog manager and stream manager class to handle the - catalog and state aspects of the protocol. - """ - - def __init__( - self, - *, - catalog_provider: CatalogProvider, - state_writer: StateWriterBase | None = None, - ) -> None: - """Initialize the processor. - - If a state writer is not provided, the processor will use the default (STDOUT) state writer. - """ - self._catalog_provider: CatalogProvider | None = catalog_provider - self._state_writer: StateWriterBase | None = state_writer or StdOutStateWriter() - - self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {}) - self._finalized_state_messages: dict[ - str, - list[AirbyteStateMessage], - ] = defaultdict(list, {}) - - self._setup() - - @property - def expected_streams(self) -> set[str]: - """Return the expected stream names.""" - return set(self.catalog_provider.stream_names) - - @property - def catalog_provider( - self, - ) -> CatalogProvider: - """Return the catalog manager. - - Subclasses should set this property to a valid catalog manager instance if one - is not explicitly passed to the constructor. - - Raises: - PyAirbyteInternalError: If the catalog manager is not set. - """ - if not self._catalog_provider: - raise exc.PyAirbyteInternalError( - message="Catalog manager should exist but does not.", - ) - - return self._catalog_provider - - @property - def state_writer( - self, - ) -> StateWriterBase: - """Return the state writer instance. - - Subclasses should set this property to a valid state manager instance if one - is not explicitly passed to the constructor. - - Raises: - PyAirbyteInternalError: If the state manager is not set. - """ - if not self._state_writer: - raise exc.PyAirbyteInternalError( - message="State manager should exist but does not.", - ) - - return self._state_writer - - @final - def process_stdin( - self, - *, - write_strategy: WriteStrategy = WriteStrategy.AUTO, - progress_tracker: ProgressTracker, - ) -> None: - """Process the input stream from stdin. - - Return a list of summaries for testing. - """ - input_stream = io.TextIOWrapper(sys.stdin.buffer, encoding="utf-8") - self.process_input_stream( - input_stream, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - @final - def _airbyte_messages_from_buffer( - self, - buffer: io.TextIOBase, - ) -> Iterator[AirbyteMessage]: - """Yield messages from a buffer.""" - yield from (AirbyteMessage.model_validate_json(line) for line in buffer) - - @final - def process_input_stream( - self, - input_stream: IO[str], - *, - write_strategy: WriteStrategy = WriteStrategy.AUTO, - progress_tracker: ProgressTracker, - ) -> None: - """Parse the input stream and process data in batches. - - Return a list of summaries for testing. - """ - messages = AirbyteMessageIterator.from_str_buffer(input_stream) - self.process_airbyte_messages( - messages, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - @abc.abstractmethod - def process_record_message( - self, - record_msg: AirbyteRecordMessage, - stream_record_handler: StreamRecordHandler, - progress_tracker: ProgressTracker, - ) -> None: - """Write a record. - - This method is called for each record message. - - In most cases, the SQL processor will not perform any action, but will pass this along to to - the file processor. - """ - - @final - def process_airbyte_messages( - self, - messages: Iterable[AirbyteMessage], - *, - write_strategy: WriteStrategy, - progress_tracker: ProgressTracker, - ) -> None: - """Process a stream of Airbyte messages.""" - if not isinstance(write_strategy, WriteStrategy): - raise exc.AirbyteInternalError( - message="Invalid `write_strategy` argument. Expected instance of WriteStrategy.", - context={"write_strategy": write_strategy}, - ) - - stream_record_handlers: dict[str, StreamRecordHandler] = {} - - # Process messages, writing to batches as we go - for message in messages: - if message.type is Type.RECORD: - record_msg = cast(AirbyteRecordMessage, message.record) - stream_name = record_msg.stream - - if stream_name not in stream_record_handlers: - stream_record_handlers[stream_name] = StreamRecordHandler( - json_schema=self.catalog_provider.get_stream_json_schema( - stream_name=stream_name, - ), - normalize_keys=True, - prune_extra_fields=True, - ) - - self.process_record_message( - record_msg, - stream_record_handler=stream_record_handlers[stream_name], - progress_tracker=progress_tracker, - ) - - elif message.type is Type.STATE: - state_msg = cast(AirbyteStateMessage, message.state) - if state_msg.type in {AirbyteStateType.GLOBAL, AirbyteStateType.LEGACY}: - self._pending_state_messages[f"_{state_msg.type}"].append(state_msg) - else: - stream_state = cast(AirbyteStreamState, state_msg.stream) - stream_name = stream_state.stream_descriptor.name - self._pending_state_messages[stream_name].append(state_msg) - - elif message.type is Type.TRACE: - trace_msg: AirbyteTraceMessage = cast(AirbyteTraceMessage, message.trace) - if trace_msg.stream_status and trace_msg.stream_status.status == "SUCCEEDED": - # This stream has completed successfully, so go ahead and write the data. - # This will also finalize any pending state messages. - self.write_stream_data( - stream_name=trace_msg.stream_status.stream_descriptor.name, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - else: - # Ignore unexpected or unhandled message types: - # Type.LOG, Type.CONTROL, etc. - pass - - # We've finished processing input data. - # Finalize all received records and state messages: - self.write_all_stream_data( - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - self.cleanup_all() - - def write_all_stream_data( - self, - write_strategy: WriteStrategy, - progress_tracker: ProgressTracker, - ) -> None: - """Finalize any pending writes.""" - for stream_name in self.catalog_provider.stream_names: - self.write_stream_data( - stream_name, - write_strategy=write_strategy, - progress_tracker=progress_tracker, - ) - - @abc.abstractmethod - def write_stream_data( - self, - stream_name: str, - write_strategy: WriteStrategy, - progress_tracker: ProgressTracker, - ) -> list[BatchHandle]: - """Write pending stream data to the cache.""" - ... - - def _finalize_state_messages( - self, - state_messages: list[AirbyteStateMessage], - ) -> None: - """Handle state messages by passing them to the catalog manager.""" - if state_messages: - self.state_writer.write_state( - state_message=state_messages[-1], - ) - - def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract - """Create the database. - - By default this is a no-op but subclasses can override this method to prepare - any necessary resources. - """ - pass - - def cleanup_all(self) -> None: # noqa: B027 # Intentionally empty, not abstract - """Clean up all resources. - - The default implementation is a no-op. - """ - pass diff --git a/airbyte/_future_cdk/sql_processor.py b/airbyte/_future_cdk/sql_processor.py index 925ab5e3..3cb7784b 100644 --- a/airbyte/_future_cdk/sql_processor.py +++ b/airbyte/_future_cdk/sql_processor.py @@ -6,6 +6,7 @@ import abc import contextlib import enum +from collections import defaultdict from contextlib import contextmanager from functools import cached_property from pathlib import Path @@ -29,8 +30,17 @@ ) from sqlalchemy.sql.elements import TextClause +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, + AirbyteStreamState, + AirbyteTraceMessage, + Type, +) + from airbyte import exceptions as exc -from airbyte._future_cdk.record_processor import RecordProcessorBase from airbyte._future_cdk.state_writers import StdOutStateWriter from airbyte._util.name_normalizers import LowerCaseNormalizer from airbyte.constants import ( @@ -39,12 +49,13 @@ AB_RAW_ID_COLUMN, DEBUG_MODE, ) +from airbyte.records import StreamRecordHandler from airbyte.strategies import WriteStrategy from airbyte.types import SQLTypeConverter if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Iterable from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.cursor import CursorResult @@ -52,17 +63,11 @@ from sqlalchemy.sql.base import Executable from sqlalchemy.sql.type_api import TypeEngine - from airbyte_protocol.models import ( - AirbyteRecordMessage, - AirbyteStateMessage, - ) - from airbyte._batch_handles import BatchHandle from airbyte._future_cdk.catalog_providers import CatalogProvider from airbyte._future_cdk.state_writers import StateWriterBase from airbyte._processors.file.base import FileWriterBase from airbyte.progress import ProgressTracker - from airbyte.records import StreamRecordHandler from airbyte.secrets.base import SecretString @@ -116,7 +121,7 @@ def get_vendor_client(self) -> object: ) -class SqlProcessorBase(RecordProcessorBase): +class SqlProcessorBase(abc.ABC): """A base class to be used for SQL Caches.""" type_converter_class: type[SQLTypeConverter] = SQLTypeConverter @@ -131,8 +136,6 @@ class SqlProcessorBase(RecordProcessorBase): supports_merge_insert = False """True if the database supports the MERGE INTO syntax.""" - # Constructor: - def __init__( self, *, @@ -152,10 +155,16 @@ def __init__( self._sql_config: SqlConfig = sql_config - super().__init__( - state_writer=state_writer, - catalog_provider=catalog_provider, - ) + self._catalog_provider: CatalogProvider | None = catalog_provider + self._state_writer: StateWriterBase | None = state_writer or StdOutStateWriter() + + self._pending_state_messages: dict[str, list[AirbyteStateMessage]] = defaultdict(list, {}) + self._finalized_state_messages: dict[ + str, + list[AirbyteStateMessage], + ] = defaultdict(list, {}) + + self._setup() self.file_writer = file_writer or self.file_writer_class( cache_dir=cast(Path, temp_dir), cleanup=temp_file_cleanup, @@ -166,6 +175,155 @@ def __init__( self._known_schemas_list: list[str] = [] self._ensure_schema_exists() + # Inherited methods + + # @property + # def expected_streams(self) -> set[str]: + # """Return the expected stream names.""" + # return set(self.catalog_provider.stream_names) + + @property + def catalog_provider( + self, + ) -> CatalogProvider: + """Return the catalog manager. + + Subclasses should set this property to a valid catalog manager instance if one + is not explicitly passed to the constructor. + + Raises: + PyAirbyteInternalError: If the catalog manager is not set. + """ + if not self._catalog_provider: + raise exc.PyAirbyteInternalError( + message="Catalog manager should exist but does not.", + ) + + return self._catalog_provider + + @property + def state_writer( + self, + ) -> StateWriterBase: + """Return the state writer instance. + + Subclasses should set this property to a valid state manager instance if one + is not explicitly passed to the constructor. + + Raises: + PyAirbyteInternalError: If the state manager is not set. + """ + if not self._state_writer: + raise exc.PyAirbyteInternalError( + message="State manager should exist but does not.", + ) + + return self._state_writer + + @final + def process_airbyte_messages( + self, + messages: Iterable[AirbyteMessage], + *, + catalog_provider: CatalogProvider | None = None, + write_strategy: WriteStrategy, + progress_tracker: ProgressTracker, + ) -> None: + """Process a stream of Airbyte messages.""" + if not isinstance(write_strategy, WriteStrategy): + raise exc.AirbyteInternalError( + message="Invalid `write_strategy` argument. Expected instance of WriteStrategy.", + context={"write_strategy": write_strategy}, + ) + + stream_record_handlers: dict[str, StreamRecordHandler] = {} + + # Process messages, writing to batches as we go + for message in messages: + if message.type is Type.RECORD: + record_msg = cast(AirbyteRecordMessage, message.record) + stream_name = record_msg.stream + + if stream_name not in stream_record_handlers: + stream_record_handlers[stream_name] = StreamRecordHandler( + json_schema=self.catalog_provider.get_stream_json_schema( + stream_name=stream_name, + ), + normalize_keys=True, + prune_extra_fields=True, + ) + + self.process_record_message( + record_msg, + stream_record_handler=stream_record_handlers[stream_name], + progress_tracker=progress_tracker, + ) + + elif message.type is Type.STATE: + state_msg = cast(AirbyteStateMessage, message.state) + if state_msg.type in {AirbyteStateType.GLOBAL, AirbyteStateType.LEGACY}: + self._pending_state_messages[f"_{state_msg.type}"].append(state_msg) + else: + stream_state = cast(AirbyteStreamState, state_msg.stream) + stream_name = stream_state.stream_descriptor.name + self._pending_state_messages[stream_name].append(state_msg) + + elif message.type is Type.TRACE: + trace_msg: AirbyteTraceMessage = cast(AirbyteTraceMessage, message.trace) + if trace_msg.stream_status and trace_msg.stream_status.status == "SUCCEEDED": + # This stream has completed successfully, so go ahead and write the data. + # This will also finalize any pending state messages. + self.write_stream_data( + stream_name=trace_msg.stream_status.stream_descriptor.name, + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + + else: + # Ignore unexpected or unhandled message types: + # Type.LOG, Type.CONTROL, etc. + pass + + # We've finished processing input data. + # Finalize all received records and state messages: + self.write_all_stream_data( + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + + self.cleanup_all() + + def write_all_stream_data( + self, + write_strategy: WriteStrategy, + progress_tracker: ProgressTracker, + ) -> None: + """Finalize any pending writes.""" + for stream_name in self.catalog_provider.stream_names: + self.write_stream_data( + stream_name, + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + + def _finalize_state_messages( + self, + state_messages: list[AirbyteStateMessage], + ) -> None: + """Handle state messages by passing them to the catalog manager.""" + if state_messages: + self.state_writer.write_state( + state_message=state_messages[-1], + ) + + def _setup(self) -> None: # noqa: B027 # Intentionally empty, not abstract + """Create the database. + + By default this is a no-op but subclasses can override this method to prepare + any necessary resources. + """ + pass + # Public interface: @property From a671c97e0b19f427ffc47c624528058dd8a81f65 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 5 Aug 2024 13:13:40 -0700 Subject: [PATCH 17/27] refactor: add AirbyteWritersInterface and WriteMethod, plus related refactoring --- airbyte/_future_cdk/catalog_providers.py | 69 +++++++++++++++ airbyte/_future_cdk/sql_processor.py | 88 +++++++------------ airbyte/_util/telemetry.py | 34 +++++-- airbyte/caches/base.py | 30 ++++++- airbyte/destinations/base.py | 22 ++--- airbyte/progress.py | 3 +- airbyte/results.py | 5 +- airbyte/sources/base.py | 21 ++--- airbyte/strategies.py | 66 ++++++++++++-- airbyte/writers.py | 58 ++++++++++++ .../test_source_to_destination.py | 1 + .../test_docker_executable.py | 8 +- .../test_source_faker_integration.py | 8 +- 13 files changed, 314 insertions(+), 99 deletions(-) create mode 100644 airbyte/writers.py diff --git a/airbyte/_future_cdk/catalog_providers.py b/airbyte/_future_cdk/catalog_providers.py index 2b1c0a93..470d7fa7 100644 --- a/airbyte/_future_cdk/catalog_providers.py +++ b/airbyte/_future_cdk/catalog_providers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import copy from typing import TYPE_CHECKING, Any, final from airbyte_protocol.models import ( @@ -15,6 +16,7 @@ ) from airbyte import exceptions as exc +from airbyte.strategies import WriteMethod, WriteStrategy if TYPE_CHECKING: @@ -118,3 +120,70 @@ def from_read_result( ] ) ) + + def get_primary_keys( + self, + stream_name: str, + ) -> list[str]: + pks = self.get_configured_stream_info(stream_name).primary_key + if not pks: + return [] + + joined_pks = [".".join(pk) for pk in pks] + for pk in joined_pks: + if "." in pk: + msg = f"Nested primary keys are not yet supported. Found: {pk}" + raise NotImplementedError(msg) + + return joined_pks + + def get_cursor_key( + self, + stream_name: str, + ) -> str | None: + return self.get_configured_stream_info(stream_name).cursor_field + + def resolve_write_method( + self, + stream_name: str, + write_strategy: WriteStrategy, + ) -> WriteMethod: + """Return the write method for the given stream.""" + has_pks: bool = bool(self.get_primary_keys(stream_name)) + has_incremental_key: bool = bool(self.get_cursor_key(stream_name)) + if write_strategy == WriteStrategy.MERGE and not has_pks: + raise exc.PyAirbyteInputError( + message="Cannot use merge strategy on a stream with no primary keys.", + context={ + "stream_name": stream_name, + }, + ) + + if write_strategy != WriteStrategy.AUTO: + return WriteMethod(write_strategy) + + if has_pks: + return WriteMethod.MERGE + + if has_incremental_key: + return WriteMethod.APPEND + + return WriteMethod.REPLACE + + def with_write_strategy( + self, + write_strategy: WriteStrategy, + ) -> CatalogProvider: + """Return a new catalog provider with the specified write strategy applied. + + The original catalog provider is not modified. + """ + new_catalog: ConfiguredAirbyteCatalog = copy.deepcopy(self.configured_catalog) + for stream in new_catalog.streams: + write_method = self.resolve_write_method( + stream_name=stream.stream.name, + write_strategy=write_strategy, + ) + stream.destination_sync_mode = write_method.destination_sync_mode + + return CatalogProvider(new_catalog) diff --git a/airbyte/_future_cdk/sql_processor.py b/airbyte/_future_cdk/sql_processor.py index 3cb7784b..b265e01a 100644 --- a/airbyte/_future_cdk/sql_processor.py +++ b/airbyte/_future_cdk/sql_processor.py @@ -50,7 +50,7 @@ DEBUG_MODE, ) from airbyte.records import StreamRecordHandler -from airbyte.strategies import WriteStrategy +from airbyte.strategies import WriteMethod, WriteStrategy from airbyte.types import SQLTypeConverter @@ -225,11 +225,13 @@ def process_airbyte_messages( self, messages: Iterable[AirbyteMessage], *, - catalog_provider: CatalogProvider | None = None, - write_strategy: WriteStrategy, + write_strategy: WriteStrategy = WriteStrategy.AUTO, progress_tracker: ProgressTracker, ) -> None: - """Process a stream of Airbyte messages.""" + """Process a stream of Airbyte messages. + + This method assumes that the catalog is already registered with the processor. + """ if not isinstance(write_strategy, WriteStrategy): raise exc.AirbyteInternalError( message="Invalid `write_strategy` argument. Expected instance of WriteStrategy.", @@ -401,7 +403,7 @@ def process_record_message( # Protected members (non-public interface): - def _init_connection_settings(self, connection: Connection) -> None: + def _init_connection_settings(self, connection: Connection) -> None: # noqa: B027 # Intentionally empty, not abstract """This is called automatically whenever a new connection is created. By default this is a no-op. Subclasses can use this to set connection settings, such as @@ -636,7 +638,9 @@ def _get_sql_column_definitions( def write_stream_data( self, stream_name: str, - write_strategy: WriteStrategy, + *, + write_method: WriteMethod | None = None, + write_strategy: WriteStrategy | None = None, progress_tracker: ProgressTracker, ) -> list[BatchHandle]: """Finalize all uncommitted batches. @@ -649,6 +653,18 @@ def write_stream_data( Some sources will send us duplicate records within the same stream, although this is a fairly rare edge case we can ignore in V1. """ + if write_method and write_strategy and write_strategy != WriteStrategy.AUTO: + raise exc.PyAirbyteInternalError( + message=( + "Both `write_method` and `write_strategy` were provided. " + "Only one should be set." + ), + ) + if not write_method: + write_method = self.catalog_provider.resolve_write_method( + stream_name=stream_name, + write_strategy=write_strategy or WriteStrategy.AUTO, + ) # Flush any pending writes self.file_writer.flush_active_batches( progress_tracker=progress_tracker, @@ -686,7 +702,7 @@ def write_stream_data( stream_name=stream_name, temp_table_name=temp_table_name, final_table_name=final_table_name, - write_strategy=write_strategy, + write_method=write_method, ) finally: self._drop_temp_table(temp_table_name, if_exists=True) @@ -863,28 +879,10 @@ def _write_temp_table_to_final_table( stream_name: str, temp_table_name: str, final_table_name: str, - write_strategy: WriteStrategy, + write_method: WriteMethod, ) -> None: """Write the temp table into the final table using the provided write strategy.""" - has_pks: bool = bool(self._get_primary_keys(stream_name)) - has_incremental_key: bool = bool(self._get_incremental_key(stream_name)) - if write_strategy == WriteStrategy.MERGE and not has_pks: - raise exc.PyAirbyteInputError( - message="Cannot use merge strategy on a stream with no primary keys.", - context={ - "stream_name": stream_name, - }, - ) - - if write_strategy == WriteStrategy.AUTO: - if has_pks: - write_strategy = WriteStrategy.MERGE - elif has_incremental_key: - write_strategy = WriteStrategy.APPEND - else: - write_strategy = WriteStrategy.REPLACE - - if write_strategy == WriteStrategy.REPLACE: + if write_method == WriteMethod.REPLACE: # Note: No need to check for schema compatibility # here, because we are fully replacing the table. self._swap_temp_table_with_final_table( @@ -894,7 +892,7 @@ def _write_temp_table_to_final_table( ) return - if write_strategy == WriteStrategy.APPEND: + if write_method == WriteMethod.APPEND: self._ensure_compatible_table_schema( stream_name=stream_name, table_name=final_table_name, @@ -906,7 +904,7 @@ def _write_temp_table_to_final_table( ) return - if write_strategy == WriteStrategy.MERGE: + if write_method == WriteMethod.MERGE: self._ensure_compatible_table_schema( stream_name=stream_name, table_name=final_table_name, @@ -928,9 +926,9 @@ def _write_temp_table_to_final_table( return raise exc.PyAirbyteInternalError( - message="Write strategy is not supported.", + message="Write method is not supported.", context={ - "write_strategy": write_strategy, + "write_method": write_method, }, ) @@ -953,28 +951,6 @@ def _append_temp_table_to_final_table( """, ) - def _get_primary_keys( - self, - stream_name: str, - ) -> list[str]: - pks = self.catalog_provider.get_configured_stream_info(stream_name).primary_key - if not pks: - return [] - - joined_pks = [".".join(pk) for pk in pks] - for pk in joined_pks: - if "." in pk: - msg = f"Nested primary keys are not yet supported. Found: {pk}" - raise NotImplementedError(msg) - - return joined_pks - - def _get_incremental_key( - self, - stream_name: str, - ) -> str | None: - return self.catalog_provider.get_configured_stream_info(stream_name).cursor_field - def _swap_temp_table_with_final_table( self, stream_name: str, @@ -1017,7 +993,9 @@ def _merge_temp_table_to_final_table( """ nl = "\n" columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)} - pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)} + pk_columns = { + self._quote_identifier(c) for c in self.catalog_provider.get_primary_keys(stream_name) + } non_pk_columns = columns - pk_columns join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns) set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns) @@ -1073,7 +1051,7 @@ def _emulated_merge_temp_table_to_final_table( """ final_table = self._get_table_by_name(final_table_name) temp_table = self._get_table_by_name(temp_table_name) - pk_columns = self._get_primary_keys(stream_name) + pk_columns = self.catalog_provider.get_primary_keys(stream_name) columns_to_update: set[str] = self._get_sql_column_definitions( stream_name=stream_name diff --git a/airbyte/_util/telemetry.py b/airbyte/_util/telemetry.py index 2927742e..faa276b8 100644 --- a/airbyte/_util/telemetry.py +++ b/airbyte/_util/telemetry.py @@ -47,12 +47,13 @@ from airbyte import exceptions as exc from airbyte._util import meta +from airbyte.destinations.base import Destination from airbyte.version import get_version +from airbyte.writers import AirbyteWriterInterface if TYPE_CHECKING: from airbyte.caches.base import CacheBase - from airbyte.destinations.base import Destination from airbyte.sources.base import Source @@ -226,18 +227,35 @@ class DestinationTelemetryInfo: version: str | None @classmethod - def from_destination(cls, destination: Destination | str | None) -> DestinationTelemetryInfo: + def from_destination( + cls, + destination: Destination | AirbyteWriterInterface | str | None, + ) -> DestinationTelemetryInfo: if not destination: return cls(name=UNKNOWN, executor_type=UNKNOWN, version=UNKNOWN) if isinstance(destination, str): return cls(name=destination, executor_type=UNKNOWN, version=UNKNOWN) - # Else, `destination` should be a `Destination` at this point - return cls( - name=destination.name, - executor_type=type(destination.executor).__name__, - version=destination.executor.reported_version, + if isinstance(destination, Destination): + return cls( + name=destination.name, + executor_type=type(destination.executor).__name__, + version=destination.executor.reported_version, + ) + + # Else, `destination` should be a `AirbyteWriterInterface` at this point + if isinstance(destination, AirbyteWriterInterface): + return cls( + name=destination.name, + executor_type=UNKNOWN, + version=UNKNOWN, + ) + + return cls( # type: ignore [unreachable] + name=repr(destination), + executor_type=UNKNOWN, + version=UNKNOWN, ) @@ -274,7 +292,7 @@ def get_env_flags() -> dict[str, Any]: def send_telemetry( *, source: Source | str | None, - destination: Destination | str | None, + destination: Destination | AirbyteWriterInterface | str | None, cache: CacheBase | None, state: EventState, event_type: EventType, diff --git a/airbyte/caches/base.py b/airbyte/caches/base.py index 33f7cb2d..32cc5d58 100644 --- a/airbyte/caches/base.py +++ b/airbyte/caches/base.py @@ -4,7 +4,7 @@ from __future__ import annotations from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, final +from typing import IO, TYPE_CHECKING, Any, Optional, final import pandas as pd import pyarrow as pa @@ -23,6 +23,7 @@ from airbyte.caches._state_backend import SqlStateBackend from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE from airbyte.datasets._sql import CachedDataset +from airbyte.writers import AirbyteWriterInterface if TYPE_CHECKING: @@ -31,11 +32,14 @@ from airbyte._future_cdk.sql_processor import SqlProcessorBase from airbyte._future_cdk.state_providers import StateProviderBase from airbyte._future_cdk.state_writers import StateWriterBase + from airbyte._message_iterators import AirbyteMessageIterator from airbyte.caches._state_backend_base import StateBackendBase from airbyte.datasets._base import DatasetBase + from airbyte.progress import ProgressTracker + from airbyte.strategies import WriteStrategy -class CacheBase(SqlConfig): +class CacheBase(SqlConfig, AirbyteWriterInterface): """Base configuration for a cache. Caches inherit from the matching `SqlConfig` class, which provides the SQL config settings @@ -258,3 +262,25 @@ def __iter__( # type: ignore [override] # Overriding Pydantic model method ) -> Iterator[tuple[str, Any]]: """Iterate over the streams in the cache.""" return ((name, dataset) for name, dataset in self.streams.items()) + + def _write_airbyte_message_stream( + self, + stdin: IO[str] | AirbyteMessageIterator, + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Read from the connector and write to the cache.""" + cache_processor = self.get_record_processor( + source_name=self.name, + catalog_provider=catalog_provider, + state_writer=state_writer, + ) + cache_processor.process_airbyte_messages( + messages=stdin, + write_strategy=write_strategy, + progress_tracker=progress_tracker, + ) + progress_tracker.log_cache_processing_complete() diff --git a/airbyte/destinations/base.py b/airbyte/destinations/base.py index ab52c3d3..31765bbc 100644 --- a/airbyte/destinations/base.py +++ b/airbyte/destinations/base.py @@ -23,7 +23,7 @@ StateProviderBase, StaticInputState, ) -from airbyte._future_cdk.state_writers import NoOpStateWriter, StateWriterBase, StdOutStateWriter +from airbyte._future_cdk.state_writers import NoOpStateWriter, StdOutStateWriter from airbyte._message_iterators import AirbyteMessageIterator from airbyte._util.temp_files import as_temp_files from airbyte.caches.util import get_default_cache @@ -31,6 +31,7 @@ from airbyte.results import ReadResult, WriteResult from airbyte.sources.base import Source from airbyte.strategies import WriteStrategy +from airbyte.writers import AirbyteWriterInterface if TYPE_CHECKING: @@ -39,7 +40,7 @@ from airbyte.caches.base import CacheBase -class Destination(ConnectorBase): +class Destination(ConnectorBase, AirbyteWriterInterface): """A class representing a destination that can be called.""" connector_type: Literal["destination"] = "destination" @@ -73,11 +74,12 @@ def write( # noqa: PLR0912, PLR0915 # Too many arguments/statements write_strategy: WriteStrategy = WriteStrategy.AUTO, force_full_refresh: bool = False, ) -> WriteResult: - """Write data to the destination. + """Write data from source connector or already cached source data. + + Caching is enabled by default, unless explicitly disabled. Args: - source_data: The source data to write to the destination. Can be a `Source`, a `Cache`, - or a `ReadResult` object. + source_data: The source data to write. Can be a `Source` or a `ReadResult` object. streams: The streams to write to the destination. If omitted or if "*" is provided, all streams will be written. If `source_data` is a source, then streams must be selected here or on the source. If both are specified, this setting will override @@ -227,8 +229,8 @@ def write( # noqa: PLR0912, PLR0915 # Too many arguments/statements self._write_airbyte_message_stream( stdin=message_iterator, catalog_provider=catalog_provider, + write_strategy=write_strategy, state_writer=destination_state_writer, - skip_validation=False, progress_tracker=progress_tracker, ) except Exception as ex: @@ -251,18 +253,18 @@ def _write_airbyte_message_stream( stdin: IO[str] | AirbyteMessageIterator, *, catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, state_writer: StateWriterBase | None = None, - skip_validation: bool = False, progress_tracker: ProgressTracker, ) -> None: """Read from the connector and write to the cache.""" # Run optional validation step - if not skip_validation: - self.validate_config() - if state_writer is None: state_writer = StdOutStateWriter() + # Apply the write strategy to the catalog provider before sending to the destination + catalog_provider = catalog_provider.with_write_strategy(write_strategy) + with as_temp_files( files_contents=[ self._config, diff --git a/airbyte/progress.py b/airbyte/progress.py index 7660c328..f9cf7395 100644 --- a/airbyte/progress.py +++ b/airbyte/progress.py @@ -46,6 +46,7 @@ from airbyte.caches.base import CacheBase from airbyte.destinations.base import Destination from airbyte.sources.base import Source + from airbyte.writers import AirbyteWriterInterface IS_REPL = hasattr(sys, "ps1") # True if we're in a Python REPL, in which case we can use Rich. HORIZONTAL_LINE = "------------------------------------------------\n" @@ -148,7 +149,7 @@ def __init__( *, source: Source | None, cache: CacheBase | None, - destination: Destination | None, + destination: AirbyteWriterInterface | Destination | None, expected_streams: list[str] | None = None, ) -> None: """Initialize the progress tracker.""" diff --git a/airbyte/results.py b/airbyte/results.py index 7035abab..d3de4dcf 100644 --- a/airbyte/results.py +++ b/airbyte/results.py @@ -26,6 +26,7 @@ from airbyte.destinations.base import Destination from airbyte.progress import ProgressTracker from airbyte.sources.base import Source + from airbyte.writers import AirbyteWriterInterface class ReadResult(Mapping[str, CachedDataset]): @@ -110,7 +111,7 @@ class WriteResult: def __init__( self, *, - destination: Destination, + destination: AirbyteWriterInterface | Destination, source_data: Source | ReadResult, catalog_provider: CatalogProvider, state_writer: StateWriterBase, @@ -121,7 +122,7 @@ def __init__( This class should not be created directly. Instead, it should be returned by the `write` method of the `Destination` class. """ - self._destination: Destination = destination + self._destination: AirbyteWriterInterface | Destination = destination self._source_data: Source | ReadResult = source_data self._catalog_provider: CatalogProvider = catalog_provider self._state_writer: StateWriterBase = state_writer diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 8b34b46c..83df28ee 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -718,23 +718,20 @@ def _read_to_cache( # noqa: PLR0913 # Too many arguments if incremental_streams: self._log_incremental_streams(incremental_streams=incremental_streams) - airbyte_message_iterator: Iterator[AirbyteMessage] = self._read_with_catalog( - catalog=catalog_provider.configured_catalog, - state=state_provider, - progress_tracker=progress_tracker, + airbyte_message_iterator = AirbyteMessageIterator( + self._read_with_catalog( + catalog=catalog_provider.configured_catalog, + state=state_provider, + progress_tracker=progress_tracker, + ) ) - cache_processor = cache.get_record_processor( - source_name=self.name, + cache._write_airbyte_message_stream( # noqa: SLF001 # Non-public API + stdin=airbyte_message_iterator, catalog_provider=catalog_provider, - state_writer=state_writer, - ) - cache_processor.process_airbyte_messages( - messages=airbyte_message_iterator, write_strategy=write_strategy, + state_writer=state_writer, progress_tracker=progress_tracker, ) - progress_tracker.log_cache_processing_complete() - return ReadResult( source_name=self.name, progress_tracker=progress_tracker, diff --git a/airbyte/strategies.py b/airbyte/strategies.py index 05ab3ba9..e55b4d9a 100644 --- a/airbyte/strategies.py +++ b/airbyte/strategies.py @@ -6,11 +6,26 @@ from enum import Enum +from airbyte_protocol.models import DestinationSyncMode + + +_MERGE = "merge" +_REPLACE = "replace" +_APPEND = "append" +_AUTO = "auto" + class WriteStrategy(str, Enum): - """Read strategies for PyAirbyte.""" + """Read strategies for PyAirbyte. + + Read strategies set a preferred method for writing data to a destination. The actual method used + may differ based on the capabilities of the destination. - MERGE = "merge" + If a destination does not support the preferred method, it will fall back to the next best + method. + """ + + MERGE = _MERGE """Merge new records with existing records. This requires a primary key to be set on the stream. @@ -20,13 +35,13 @@ class WriteStrategy(str, Enum): please use the `auto` strategy instead. """ - APPEND = "append" + APPEND = _APPEND """Append new records to existing records.""" - REPLACE = "replace" + REPLACE = _REPLACE """Replace existing records with new records.""" - AUTO = "auto" + AUTO = _AUTO """Automatically determine the best strategy to use. This will use the following logic: @@ -34,3 +49,44 @@ class WriteStrategy(str, Enum): - Else, if there's an incremental key, use append. - Else, use full replace (table swap). """ + + +class WriteMethod(str, Enum): + """Write methods for PyAirbyte. + + Unlike write strategies, write methods are expected to be fully resolved and do not require any + additional logic to determine the best method to use. + + If a destination does not support the declared method, it will raise an exception. + """ + + MERGE = _MERGE + """Merge new records with existing records. + + This requires a primary key to be set on the stream. + If no primary key is set, this will raise an exception. + + To apply this strategy in cases where some destination streams don't have a primary key, + please use the `auto` strategy instead. + """ + + APPEND = _APPEND + """Append new records to existing records.""" + + REPLACE = _REPLACE + """Replace existing records with new records.""" + + @property + def destination_sync_mode(self) -> DestinationSyncMode: + """Convert the write method to a destination sync mode.""" + if self == WriteMethod.MERGE: + return DestinationSyncMode.append_dedup + + if self == WriteMethod.APPEND: + return DestinationSyncMode.append + + if self == WriteMethod.REPLACE: + return DestinationSyncMode.overwrite + + msg = f"Unknown write method: {self}" # type: ignore [unreachable] + raise ValueError(msg) diff --git a/airbyte/writers.py b/airbyte/writers.py new file mode 100644 index 00000000..7f216168 --- /dev/null +++ b/airbyte/writers.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Write interfaces for PyAirbyte.""" + +from __future__ import annotations + +import abc +from typing import IO, TYPE_CHECKING + + +if TYPE_CHECKING: + from airbyte._future_cdk.catalog_providers import CatalogProvider + from airbyte._future_cdk.state_writers import StateWriterBase + from airbyte._message_iterators import AirbyteMessageIterator + from airbyte.progress import ProgressTracker + from airbyte.strategies import WriteStrategy + + +class AirbyteWriterInterface(abc.ABC): + """An interface for writing Airbyte messages.""" + + @property + def name(self) -> str: + """Return the name of the writer. + + This is used for logging and state tracking. + """ + return self.__class__.__name__ + + def _write_airbyte_io_stream( + self, + stdin: IO[str], + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Read from the connector and write to the cache.""" + self._write_airbyte_message_stream( + stdin, + catalog_provider=catalog_provider, + write_strategy=write_strategy, + state_writer=state_writer, + progress_tracker=progress_tracker, + ) + + @abc.abstractmethod + def _write_airbyte_message_stream( + self, + stdin: IO[str] | AirbyteMessageIterator, + *, + catalog_provider: CatalogProvider, + write_strategy: WriteStrategy, + state_writer: StateWriterBase | None = None, + progress_tracker: ProgressTracker, + ) -> None: + """Write the incoming data.""" + ... diff --git a/tests/integration_tests/destinations/test_source_to_destination.py b/tests/integration_tests/destinations/test_source_to_destination.py index 09fe6c56..d071f446 100644 --- a/tests/integration_tests/destinations/test_source_to_destination.py +++ b/tests/integration_tests/destinations/test_source_to_destination.py @@ -90,6 +90,7 @@ def test_duckdb_destination_write_components( catalog_provider=CatalogProvider( configured_catalog=new_source_faker.configured_catalog ), + write_strategy=WriteStrategy.AUTO, progress_tracker=ProgressTracker( source=None, cache=None, diff --git a/tests/integration_tests/test_docker_executable.py b/tests/integration_tests/test_docker_executable.py index 53869791..1869e25c 100644 --- a/tests/integration_tests/test_docker_executable.py +++ b/tests/integration_tests/test_docker_executable.py @@ -84,8 +84,12 @@ def test_faker_pks( read_result = source_docker_faker_seed_a.read( new_duckdb_cache, write_strategy="append" ) - assert read_result.cache.processor._get_primary_keys("products") == ["id"] - assert read_result.cache.processor._get_primary_keys("purchases") == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "products" + ) == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "purchases" + ) == ["id"] @pytest.mark.slow diff --git a/tests/integration_tests/test_source_faker_integration.py b/tests/integration_tests/test_source_faker_integration.py index 6058b06b..3b9d5e4f 100644 --- a/tests/integration_tests/test_source_faker_integration.py +++ b/tests/integration_tests/test_source_faker_integration.py @@ -133,8 +133,12 @@ def test_faker_pks( assert catalog.streams[1].primary_key read_result = source_faker_seed_a.read(duckdb_cache, write_strategy="append") - assert read_result.cache.processor._get_primary_keys("products") == ["id"] - assert read_result.cache.processor._get_primary_keys("purchases") == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "products" + ) == ["id"] + assert read_result.cache.processor.catalog_provider.get_primary_keys( + "purchases" + ) == ["id"] @pytest.mark.slow From 34acdda406a2987f4c1619d07bd891dfb498b99d Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 5 Aug 2024 13:58:05 -0700 Subject: [PATCH 18/27] fix circular import loop --- airbyte/_util/telemetry.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/airbyte/_util/telemetry.py b/airbyte/_util/telemetry.py index faa276b8..7f64bf2d 100644 --- a/airbyte/_util/telemetry.py +++ b/airbyte/_util/telemetry.py @@ -47,14 +47,14 @@ from airbyte import exceptions as exc from airbyte._util import meta -from airbyte.destinations.base import Destination from airbyte.version import get_version -from airbyte.writers import AirbyteWriterInterface if TYPE_CHECKING: from airbyte.caches.base import CacheBase + from airbyte.destinations.base import Destination from airbyte.sources.base import Source + from airbyte.writers import AirbyteWriterInterface DEBUG = True @@ -237,21 +237,13 @@ def from_destination( if isinstance(destination, str): return cls(name=destination, executor_type=UNKNOWN, version=UNKNOWN) - if isinstance(destination, Destination): + if hasattr(destination, "executor"): return cls( name=destination.name, executor_type=type(destination.executor).__name__, version=destination.executor.reported_version, ) - # Else, `destination` should be a `AirbyteWriterInterface` at this point - if isinstance(destination, AirbyteWriterInterface): - return cls( - name=destination.name, - executor_type=UNKNOWN, - version=UNKNOWN, - ) - return cls( # type: ignore [unreachable] name=repr(destination), executor_type=UNKNOWN, From ed8edec3de32744d500ee8aa3847609a767b41f3 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 5 Aug 2024 14:01:46 -0700 Subject: [PATCH 19/27] fix parquet writer import --- airbyte/_processors/file/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airbyte/_processors/file/__init__.py b/airbyte/_processors/file/__init__.py index 43c3da97..0f83652b 100644 --- a/airbyte/_processors/file/__init__.py +++ b/airbyte/_processors/file/__init__.py @@ -6,12 +6,12 @@ from airbyte._batch_handles import BatchHandle from airbyte._processors.file.base import FileWriterBase from airbyte._processors.file.jsonl import JsonlWriter -from airbyte._processors.file.parquet import LocalIcebergWriter +from airbyte._processors.file.parquet import ParquetWriter __all__ = [ "BatchHandle", "FileWriterBase", "JsonlWriter", - "LocalIcebergWriter", + "ParquetWriter", ] From a5efbd145c68d186e7a1954ed68e0f66d09079be Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 7 Aug 2024 07:24:54 -0700 Subject: [PATCH 20/27] updated polars test script --- examples/run_polars_poc.py | 253 ++++++++++++++++++++++++++++++++----- 1 file changed, 224 insertions(+), 29 deletions(-) diff --git a/examples/run_polars_poc.py b/examples/run_polars_poc.py index 329ccb8e..b5054f35 100644 --- a/examples/run_polars_poc.py +++ b/examples/run_polars_poc.py @@ -7,52 +7,247 @@ from __future__ import annotations +import logging +import time +from pathlib import Path +from typing import Callable, cast import airbyte as ab +import boto3 import polars as pl from airbyte import get_source -from airbyte._message_iterators import AirbyteMessageIterator -from airbyte._util.polars import PolarsStreamSchema -from airbyte.progress import ProgressTracker +from airbyte.secrets.google_gsm import GoogleGSMSecretManager +from typing_extensions import Literal + +logger = logging.getLogger() +for handler in logger.handlers[:]: + logger.removeHandler(handler) + +# Basic logging configuration +logger.setLevel(logging.INFO) + +# Create a console handler +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) + +# Create a formatter and set it for the handler +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +console_handler.setFormatter(formatter) + +# Add the handler to the logger +logger.addHandler(console_handler) + +AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing" +SECRET_NAME = "SECRET_SOURCE-S3_V4_JSONL_NEWLINE__CREDS" + +secret_mgr = GoogleGSMSecretManager( + project=AIRBYTE_INTERNAL_GCP_PROJECT, + credentials_json=ab.get_secret("GCP_GSM_CREDENTIALS"), +) +secret_handle = secret_mgr.get_secret(SECRET_NAME) +assert secret_handle is not None, "Secret not found." +secret_config = secret_handle.parse_json() + +# BUCKET_NAME = "performance-test-datasets" # << Can't find this bucket +BUCKET_NAME = "airbyte-internal-performance" # << From 'dataline-dev', until we find the other bucket +FILE_NAME_PREFIX = "json/no_op_source/stream1/2024_08_06_18/" +GLOB = f"{FILE_NAME_PREFIX}*.jsonl" +SAMPLE_FILE_1 = "json/no_op_source/stream1/2024_08_06_18/15836041.0.1722968833150.jsonl" +SAMPLE_FILE_2 = "json/no_op_source/stream1/2024_08_06_18/15836041.1.1722968852741.jsonl" +SAMPLE_FILES = [SAMPLE_FILE_1, SAMPLE_FILE_2] + +records_processed = 0 def get_my_source() -> ab.Source: + """Take the existing S3 config and modify it to use the perf-test bucket and glob.""" + secret_config["bucket"] = BUCKET_NAME + secret_config["streams"][0] = { + "name": "stream1", + "format": {"filetype": "jsonl"}, + "globs": ["json/no_op_source/stream1/**/*.jsonl"], + } return get_source( - "source-faker", - config={}, - streams=["users"], + "source-s3", + config=secret_config, + streams="*", ) -def main() -> None: - """Run the Polars proof of concept.""" +def test_my_source() -> None: + """Test the modified S3 source.""" source = get_my_source() + source.check() + source.select_all_streams() + source.read() - polars_stream_schema: PolarsStreamSchema = PolarsStreamSchema.from_json_schema( - json_schema=source.configured_catalog.streams[0].stream.json_schema, + +def get_s3_file_names( + bucket_name: str, + path_prefix: str, +) -> list[str]: + """Get the names of all files in an S3 bucket with a given prefix.""" + s3_client = boto3.client( + "s3", + aws_access_key_id=secret_config["aws_access_key_id"], + aws_secret_access_key=secret_config["aws_secret_access_key"], + region_name="us-east-2", ) - progress_tracker = ProgressTracker( - source=source, - cache=None, - destination=None, + + # List all objects in the S3 bucket with the specified prefix + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path_prefix) + + # Filter the objects to match the file prefix + matching_files = [match["Key"] for match in response.get("Contents", [])] + + return matching_files + + +def get_polars_df( + s3_file_path: str, + *, + lazy: bool = True, + expire_cache: bool, +) -> pl.LazyFrame | pl.DataFrame: + """Set up a Polars lazy DataFrame from an S3 file. + + This action will connect to the S3 bucket but it will not read data from the file until + the DataFrame is actually used. + """ + read_fn: Callable[..., pl.LazyFrame | pl.DataFrame] + if lazy: + read_fn = pl.scan_ndjson + else: + read_fn = pl.read_ndjson + + return read_fn( + source=s3_file_path, # TODO: Try to get this working with globs + storage_options={ + "aws_access_key_id": secret_config["aws_access_key_id"], + "aws_secret_access_key": secret_config["aws_secret_access_key"], + "region": "us-east-2", + }, + include_file_paths="_ab_source_file_path", + row_index_name="_ab_record_index", + row_index_offset=0, + infer_schema_length=100, + file_cache_ttl=0 if expire_cache else None, + ) + + +def add_custom_transforms( + df: pl.LazyFrame | pl.DataFrame, +) -> pl.LazyFrame | pl.DataFrame: + """Add custom transforms to the Polars lazy DataFrame.""" + # Add a sample custom column to the DataFrame with text 'Hello, world!' + return df.with_columns( + [ + pl.lit("Hello, world!").alias("greeting"), + pl.col("_airbyte_ab_id").str.to_uppercase(), + # pl.lit(pl.now()).alias("current_timestamp"), + ], + # timestamp=datetime.now(), ) - msg_iterator = AirbyteMessageIterator( - msg - for msg in source._get_airbyte_message_iterator( - streams=["users"], - progress_tracker=progress_tracker, + + +def write_files( + s3_urls: list[str], + file_type: Literal["jsonl", "parquet"] = "jsonl", + *, + expire_cache: bool, + lazy: bool = True, +) -> None: + global records_processed + dataframes: list[pl.DataFrame] = [] + for n, s3_url in enumerate(s3_urls, start=1): + base_name = ".".join(Path(s3_url).name.split(".")[:-1]) + output_file = f"polars-perf-test-artifact.{base_name}.{file_type}" + if Path(output_file).exists(): + Path(output_file).unlink() + + df: pl.LazyFrame | pl.DataFrame = get_polars_df( + s3_file_path=f"s3://{BUCKET_NAME}/{SAMPLE_FILE_2}", + # expected_schema=polars_stream_schema, + expire_cache=expire_cache, + lazy=lazy, ) + + if lazy: + df = add_custom_transforms(df) + assert isinstance(df, pl.LazyFrame) + logger.info(f"Collecting file {n} of {len(s3_urls)} to '{output_file}'...") + df = df.collect() + logger.info(f"Writing file {n} of {len(s3_urls)} to '{output_file}'...") + + # Write the DataFrame to a file + if file_type == "parquet": + df.write_parquet(output_file) + elif file_type == "jsonl": + df.write_ndjson(output_file) + else: + raise ValueError(f"Invalid file type: {file_type}") + + records_processed += df.height + del df + else: + assert isinstance(df, pl.DataFrame) + dataframes.append(df) + + if not lazy: + combined_df: pl.DataFrame = pl.concat(dataframes) + combined_df = cast(pl.DataFrame, add_custom_transforms(combined_df)) + if file_type == "parquet": + combined_df.write_parquet(output_file) + elif file_type == "jsonl": + combined_df.write_ndjson(output_file) + else: + raise ValueError(f"Invalid file type: {file_type}") + + records_processed += combined_df.height + + +def run_polars_perf_test( + file_type: Literal["jsonl", "parquet"] = "jsonl", + *, + expire_cache: bool, + lazy: bool = True, +) -> None: + """Run the Polars proof of concept.""" + global records_processed + + logger.info("Finding S3 files...") + + s3_urls = get_s3_file_names(BUCKET_NAME, FILE_NAME_PREFIX) + logger.info("Creating polars dataframes...") + start_time = time.time() + + write_files( + s3_urls, + file_type=file_type, + lazy=lazy, + expire_cache=expire_cache, + ) + + logger.info("Finished write operation.") + elapsed_transfer_time = time.time() - start_time + + mb_per_record = 180 / (1024 * 1024) # << 180 bytes per record, converted to MB + logger.info( + f"Wrote {records_processed:,} records from {len(s3_urls)} files in" + f" {elapsed_transfer_time:,.2f} seconds." + f" ({records_processed / elapsed_transfer_time:,.2f} records per second," + f" {(records_processed / elapsed_transfer_time) * mb_per_record:,.1f} MB/s," + f" {records_processed * mb_per_record:,.1f} MB total)" ) - # jsonl_iterator = (msg.model_dump_json() for msg in msg_iterator) - # df = pl.read_ndjson( - # StringIO("\n".join(jsonl_iterator)), - # schema=polars_stream_schema.polars_schema, - # ) - filelike = msg_iterator.as_filelike() - print(filelike.readlines()) - df = pl.read_ndjson( - filelike, - schema=polars_stream_schema.polars_schema, + + +def main() -> None: + """Run the Polars proof of concept.""" + # test_my_source() # We don't need to run this every time - only for debugging + run_polars_perf_test( + file_type="parquet", + lazy=False, + expire_cache=False, ) From c3d4e9e961d5456c3b1f95e113cc6c51cccb1aca Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 7 Aug 2024 07:25:23 -0700 Subject: [PATCH 21/27] poetry add --dev boto3 --- poetry.lock | 85 +++++++++++++++++++++++++++++++++++++++++++++----- pyproject.toml | 1 + 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/poetry.lock b/poetry.lock index fafe5132..ab59de84 100644 --- a/poetry.lock +++ b/poetry.lock @@ -151,6 +151,47 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "boto3" +version = "1.34.155" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.155-py3-none-any.whl", hash = "sha256:445239ea2ba7f4084ddbd71f721c14d0a6d08e06f6ba51b5403a16b6544b3f1e"}, + {file = "boto3-1.34.155.tar.gz", hash = "sha256:e8d2e128c74e84199edccdc3a6b4b1c6fb36d6fdb5688eb92931827f02c6fa5b"}, +] + +[package.dependencies] +botocore = ">=1.34.155,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.155" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.155-py3-none-any.whl", hash = "sha256:f2696c11bb0cad627d42512937befd2e3f966aedd15de00d90ee13cf7a16b328"}, + {file = "botocore-1.34.155.tar.gz", hash = "sha256:3aa88abfef23909f68d3e6679a3d4b4bb3c6288a6cfbf9e253aa68dac8edad64"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.21.2)"] + [[package]] name = "bracex" version = "2.4" @@ -1230,6 +1271,17 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jmespath" +version = "1.0.1" +description = "JSON Matching Expressions" +optional = false +python-versions = ">=3.7" +files = [ + {file = "jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980"}, + {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, +] + [[package]] name = "jsonpatch" version = "1.33" @@ -1912,17 +1964,17 @@ poetry-plugin = ["poetry (>=1.0,<2.0)"] [[package]] name = "polars" -version = "1.3.0" +version = "1.4.1" description = "Blazingly fast DataFrame library" optional = false python-versions = ">=3.8" files = [ - {file = "polars-1.3.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:85a338b8f617fdf5e5472567d32efeb46e6624a604c45622cc96669324f82961"}, - {file = "polars-1.3.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:5859a11d1c8ec14089127043d8d6bae01f015021113ed01a2e4953e6c21feee5"}, - {file = "polars-1.3.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a1ff7779315e6b0d17641af3eb4dd7aec2ab0bc1bee009efb12242bf6403aeb"}, - {file = "polars-1.3.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:e675269c17a83484c74165989d93572785d4298019f4f8ca65e25a49d4440236"}, - {file = "polars-1.3.0-cp38-abi3-win_amd64.whl", hash = "sha256:75cbbccc4a55a8ae8c5ea8b9daa8747aee5d182c2bba7c712496f32a8096562d"}, - {file = "polars-1.3.0.tar.gz", hash = "sha256:c7812d6c72ffdc9e70aaa8f9aa6378db80b393e7ecbe7005ad84b150c17c71cb"}, + {file = "polars-1.4.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f02fc6a5c63dd86cfeb159caa66112e477c69fc7800a28e64609ac2780554865"}, + {file = "polars-1.4.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:bd2acd8b1977f61b9587c8d47d16f101e7e73edd8cdeb3a8a725f15f181cd120"}, + {file = "polars-1.4.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cf834a328e292c31c06eb606496becb6d8a795e927c826e26e2af27087950f1"}, + {file = "polars-1.4.1-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:64eabf0ef7ac0d17fe15361e7daaeb4425a875d2d760c17d96803e9ac8bee244"}, + {file = "polars-1.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:2313d63ecfa1d9f1e740b9fcabb8ae45d9d0b5acf1ddb401951daba4c0f3f74f"}, + {file = "polars-1.4.1.tar.gz", hash = "sha256:ed8009aff8cf91f94db5a38d947185603ad5bee48a28b764cf5a52048c7c4756"}, ] [package.extras] @@ -2976,6 +3028,23 @@ files = [ {file = "ruff-0.4.1.tar.gz", hash = "sha256:d592116cdbb65f8b1b7e2a2b48297eb865f6bdc20641879aa9d7b9c11d86db79"}, ] +[[package]] +name = "s3transfer" +version = "0.10.2" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, + {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "setuptools" version = "71.1.0" @@ -3531,4 +3600,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "94681c4910dbdb28d4cc01e42dddd29711ecc877557f9719a71658d8f515e158" +content-hash = "69d59cce7a1417192f306b54c080dff66dc15a7ec9113543a06384e115b8dec8" diff --git a/pyproject.toml b/pyproject.toml index 6fb774cf..0b3a0c44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ poethepoet = "^0.26.1" coverage = "^7.5.1" pytest-timeout = "^2.3.1" viztracer = "^0.16.3" +boto3 = "^1.34.155" [build-system] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] From c6969dec2bb853b9b4e85f71b2ca185f2c952ce3 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 7 Aug 2024 07:25:38 -0700 Subject: [PATCH 22/27] gitignore test-artifact files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 9b752740..0f6dcc35 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# perf-test +*test-artifact* + # temp files temp .temp From 63b18ef45ffff0f78831deaf8ef0157bc207bb6a Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Wed, 7 Aug 2024 09:51:19 -0700 Subject: [PATCH 23/27] update transforms --- examples/run_polars_poc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/run_polars_poc.py b/examples/run_polars_poc.py index b5054f35..1bd9283a 100644 --- a/examples/run_polars_poc.py +++ b/examples/run_polars_poc.py @@ -9,6 +9,7 @@ import logging import time +from datetime import datetime from pathlib import Path from typing import Callable, cast @@ -139,14 +140,12 @@ def add_custom_transforms( df: pl.LazyFrame | pl.DataFrame, ) -> pl.LazyFrame | pl.DataFrame: """Add custom transforms to the Polars lazy DataFrame.""" - # Add a sample custom column to the DataFrame with text 'Hello, world!' return df.with_columns( [ pl.lit("Hello, world!").alias("greeting"), pl.col("_airbyte_ab_id").str.to_uppercase(), - # pl.lit(pl.now()).alias("current_timestamp"), + pl.lit(datetime.now()).alias("current_timestamp"), ], - # timestamp=datetime.now(), ) From 9abcea42b3b606a7285175ef7c2827a24bfd4e5a Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 2 Sep 2024 13:00:09 -0700 Subject: [PATCH 24/27] remove dupe file --- airbyte/writers.py | 58 ---------------------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 airbyte/writers.py diff --git a/airbyte/writers.py b/airbyte/writers.py deleted file mode 100644 index 7f216168..00000000 --- a/airbyte/writers.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2024 Airbyte, Inc., all rights reserved. -"""Write interfaces for PyAirbyte.""" - -from __future__ import annotations - -import abc -from typing import IO, TYPE_CHECKING - - -if TYPE_CHECKING: - from airbyte._future_cdk.catalog_providers import CatalogProvider - from airbyte._future_cdk.state_writers import StateWriterBase - from airbyte._message_iterators import AirbyteMessageIterator - from airbyte.progress import ProgressTracker - from airbyte.strategies import WriteStrategy - - -class AirbyteWriterInterface(abc.ABC): - """An interface for writing Airbyte messages.""" - - @property - def name(self) -> str: - """Return the name of the writer. - - This is used for logging and state tracking. - """ - return self.__class__.__name__ - - def _write_airbyte_io_stream( - self, - stdin: IO[str], - *, - catalog_provider: CatalogProvider, - write_strategy: WriteStrategy, - state_writer: StateWriterBase | None = None, - progress_tracker: ProgressTracker, - ) -> None: - """Read from the connector and write to the cache.""" - self._write_airbyte_message_stream( - stdin, - catalog_provider=catalog_provider, - write_strategy=write_strategy, - state_writer=state_writer, - progress_tracker=progress_tracker, - ) - - @abc.abstractmethod - def _write_airbyte_message_stream( - self, - stdin: IO[str] | AirbyteMessageIterator, - *, - catalog_provider: CatalogProvider, - write_strategy: WriteStrategy, - state_writer: StateWriterBase | None = None, - progress_tracker: ProgressTracker, - ) -> None: - """Write the incoming data.""" - ... From 80b7340a703e4ef48884c08447c77b2f3ed8d4c8 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 2 Sep 2024 13:08:14 -0700 Subject: [PATCH 25/27] remove dupe import --- airbyte/_util/telemetry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airbyte/_util/telemetry.py b/airbyte/_util/telemetry.py index 1f26c5fc..73376d40 100644 --- a/airbyte/_util/telemetry.py +++ b/airbyte/_util/telemetry.py @@ -55,7 +55,6 @@ from airbyte.caches.base import CacheBase from airbyte.destinations.base import Destination from airbyte.sources.base import Source - from airbyte.writers import AirbyteWriterInterface DEBUG = True From 30d1afbc58a615981dc5f5971bf5862e79c603a1 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 2 Sep 2024 13:09:12 -0700 Subject: [PATCH 26/27] remove import --- airbyte/destinations/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airbyte/destinations/base.py b/airbyte/destinations/base.py index 665a0616..01427f58 100644 --- a/airbyte/destinations/base.py +++ b/airbyte/destinations/base.py @@ -30,7 +30,6 @@ from airbyte.shared.state_writers import NoOpStateWriter, StdOutStateWriter from airbyte.sources.base import Source from airbyte.strategies import WriteStrategy -from airbyte.writers import AirbyteWriterInterface if TYPE_CHECKING: From 54c9752db1f1e53e6088c6600a0f754762c5a0b8 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 2 Sep 2024 13:10:29 -0700 Subject: [PATCH 27/27] remove extra imports --- airbyte/progress.py | 1 - airbyte/results.py | 1 - 2 files changed, 2 deletions(-) diff --git a/airbyte/progress.py b/airbyte/progress.py index 6cf944df..def54028 100644 --- a/airbyte/progress.py +++ b/airbyte/progress.py @@ -59,7 +59,6 @@ from airbyte.caches.base import CacheBase from airbyte.destinations.base import Destination from airbyte.sources.base import Source - from airbyte.writers import AirbyteWriterInterface IS_REPL = hasattr(sys, "ps1") # True if we're in a Python REPL, in which case we can use Rich. HORIZONTAL_LINE = "------------------------------------------------\n" diff --git a/airbyte/results.py b/airbyte/results.py index 700e80d0..68757304 100644 --- a/airbyte/results.py +++ b/airbyte/results.py @@ -27,7 +27,6 @@ from airbyte.shared.state_providers import StateProviderBase from airbyte.shared.state_writers import StateWriterBase from airbyte.sources.base import Source - from airbyte.writers import AirbyteWriterInterface class ReadResult(Mapping[str, CachedDataset]):