diff --git a/.gitignore b/.gitignore index 9b752740..0f6dcc35 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# perf-test +*test-artifact* + # temp files temp .temp diff --git a/airbyte/_airbyte_message_overrides.py b/airbyte/_airbyte_message_overrides.py new file mode 100644 index 00000000..70196f9c --- /dev/null +++ b/airbyte/_airbyte_message_overrides.py @@ -0,0 +1,61 @@ +# Copyright (c) 2024 Airbyte, Inc., all rights reserved. +"""Custom Airbyte message classes. + +These classes override the default handling, in order to ensure that the data field is always a +jsonified string, rather than a dict. + +To use these classes, import them from this module, and use them in place of the default classes. + +Example: +```python +from airbyte._airbyte_message_overrides import AirbyteMessageWithStrData + +for line in sys.stdin: + message = AirbyteMessageWithStrData.model_validate_json(line) +``` +""" + +from __future__ import annotations + +import copy +import json +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, +) + + +AirbyteRecordMessageWithStrData = copy.deepcopy(AirbyteRecordMessage) +AirbyteStateMessageWithStrData = copy.deepcopy(AirbyteStateMessage) +AirbyteMessageWithStrData = copy.deepcopy(AirbyteMessage) + +# Modify the data field in the copied class +AirbyteRecordMessageWithStrData.__annotations__["data"] = str +AirbyteStateMessageWithStrData.__annotations__["data"] = str + +AirbyteRecordMessageWithStrData.data = Field(..., description="jsonified record data as a str") +AirbyteStateMessageWithStrData.data = Field(..., description="jsonified state data as a str") + + +# Add a validator to ensure data is a JSON string +@model_validator(mode="before") +def ensure_data_is_string( + cls: BaseModel, # type: ignore # noqa: ARG001, PGH003 + values: dict[str, Any], +) -> None: + if "data" in values and not isinstance(values["data"], dict): + values["data"] = json.dumps(values["data"]) + if "data" in values and not isinstance(values["data"], str): + raise ValueError + + +AirbyteRecordMessageWithStrData.ensure_data_is_string = classmethod(ensure_data_is_string) # type: ignore [arg-type] +AirbyteStateMessageWithStrData.ensure_data_is_string = classmethod(ensure_data_is_string) # type: ignore [arg-type] + +AirbyteMessageWithStrData.__annotations__["record"] = AirbyteRecordMessageWithStrData | None +AirbyteMessageWithStrData.__annotations__["state"] = AirbyteStateMessageWithStrData | None diff --git a/airbyte/_connector_base.py b/airbyte/_connector_base.py index 6247a3e2..b4e19637 100644 --- a/airbyte/_connector_base.py +++ b/airbyte/_connector_base.py @@ -169,7 +169,7 @@ def _get_spec(self, *, force_refresh: bool = False) -> ConnectorSpecification: """ if force_refresh or self._spec is None: try: - for msg in self._execute(["spec"]): + for msg in self._execute_and_parse(["spec"]): if msg.type == Type.SPEC and msg.spec: self._spec = msg.spec break @@ -275,7 +275,7 @@ def check(self) -> None: """ with as_temp_files([self._config]) as [config_file]: try: - for msg in self._execute(["check", "--config", config_file]): + for msg in self._execute_and_parse(["check", "--config", config_file]): if msg.type == Type.CONNECTION_STATUS and msg.connectionStatus: if msg.connectionStatus.status != Status.FAILED: rich.print(f"Connection check succeeded for `{self.name}`.") @@ -349,7 +349,7 @@ def _peek_airbyte_message( ) return - def _execute( + def _execute_and_parse( self, args: list[str], stdin: IO[str] | AirbyteMessageIterator | None = None, @@ -371,7 +371,7 @@ def _execute( self.executor.ensure_installation(auto_fix=False) try: - for line in self.executor.execute(args, stdin=stdin): + for line in self._execute(args, stdin=stdin): try: message: AirbyteMessage = AirbyteMessage.model_validate_json(json_data=line) if progress_tracker and message.record: @@ -403,6 +403,33 @@ def _execute( original_exception=e, ) from None + def _execute( + self, + args: list[str], + stdin: IO[str] | AirbyteMessageIterator | None = None, + ) -> Generator[str, None, None]: + """Execute the connector with the given arguments. + + This involves the following steps: + * Locate the right venv. It is called ".venv-" + * Spawn a subprocess with .venv-/bin/ + * Read the output line by line of the subprocess and yield (unparsed) strings. + + Raises: + AirbyteConnectorFailedError: If the process returns a failure status (non-zero). + """ + # Fail early if the connector is not installed. + self.executor.ensure_installation(auto_fix=False) + + try: + yield from self.executor.execute(args, stdin=stdin) + + except Exception as e: + raise exc.AirbyteConnectorFailedError( + connector_name=self.name, + log_text=self._last_log_messages, + ) from e + __all__ = [ "ConnectorBase", diff --git a/airbyte/_message_iterators.py b/airbyte/_message_iterators.py index d50dad7a..4389b57d 100644 --- a/airbyte/_message_iterators.py +++ b/airbyte/_message_iterators.py @@ -4,12 +4,13 @@ from __future__ import annotations import datetime +import io import sys from collections.abc import Iterator from typing import IO, TYPE_CHECKING, cast import pydantic -from typing_extensions import final +from typing_extensions import Literal, final from airbyte_protocol.models import ( AirbyteMessage, @@ -57,6 +58,30 @@ def read(self) -> str: """Read the next message from the iterator.""" return next(self).model_dump_json() + def as_filelike(self) -> io.BytesIO: + """Return a file-like object that reads from the iterator.""" + + class FileLikeReader(io.RawIOBase): + def __init__(self, iterator: Iterator[AirbyteMessage]) -> None: + self.iterator = (msg.model_dump_json() for msg in iterator) + self.buffer = "" + + def readable(self) -> Literal[True]: + return True + + def readinto(self, b: Any) -> int: + try: + chunk = next(self.iterator) + except StopIteration: + return 0 # EOF + + data = chunk.encode() + n = len(data) + b[:n] = data + return n + + return cast(io.BytesIO, FileLikeReader(self._iterator)) + @classmethod def from_read_result(cls, read_result: ReadResult) -> AirbyteMessageIterator: """Create a iterator from a `ReadResult` object.""" diff --git a/airbyte/_processors/sql/duckdb.py b/airbyte/_processors/sql/duckdb.py index 35162f63..2ab36261 100644 --- a/airbyte/_processors/sql/duckdb.py +++ b/airbyte/_processors/sql/duckdb.py @@ -87,12 +87,7 @@ def get_sql_engine(self) -> Engine: class DuckDBSqlProcessor(SqlProcessorBase): - """A DuckDB implementation of the cache. - - Jsonl is used for local file storage before bulk loading. - Unlike the Snowflake implementation, we can't use the COPY command to load data - so we insert as values instead. - """ + """A DuckDB implementation of the cache.""" supports_merge_insert = False file_writer_class = JsonlWriter diff --git a/airbyte/_processors/sql/iceberg.py b/airbyte/_processors/sql/iceberg.py new file mode 100644 index 00000000..e9af6454 --- /dev/null +++ b/airbyte/_processors/sql/iceberg.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from __future__ import annotations + +from textwrap import dedent, indent +from typing import TYPE_CHECKING + +from airbyte_protocol.models import ( + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, +) + +from airbyte import exceptions as exc +from airbyte._future_cdk.sql_processor import SqlConfig, SqlProcessorBase +from airbyte._future_cdk.state_writers import StateWriterBase +from airbyte._processors.file.parquet import ParquetWriter + + +if TYPE_CHECKING: + from pathlib import Path + + +class IcebergConfig(SqlConfig): + """A Iceberg configuration.""" + + def __init__(self, db_path: str, schema_name: str) -> None: + """Initialize the Iceberg configuration.""" + self.db_path = db_path + self.schema_name = schema_name + + +class IcebergSqlProcessor(SqlProcessorBase): + """A Iceberg SQL processor.""" + + supports_merge_insert = False + file_writer_class = ParquetWriter + sql_config: IcebergConfig + + class IcebergStateWriter(StateWriterBase): + """A state writer for the Parquet cache.""" + + def __init__(self, iceberg_processor: IcebergSqlProcessor) -> None: + self._iceberg_processor = iceberg_processor + super().__init__() + + def _write_state(self, state: AirbyteRecordMessage) -> None: + """Write the state to the cache.""" + self._iceberg_processor.write_state(state) + + @property + def get_state_writer(self) -> StateWriterBase: + if self._state_writer is None: + self._state_writer = self.IcebergStateWriter(self) + + return self._state_writer + + def write_state(self, state: AirbyteStateMessage) -> None: + """Write the state to the cache. + + Args: + state (AirbyteStateMessage): The state to write. + + Implementation: + - State messages are written a separate file. + - Any pending records are written to the cache file and the cache file is closed. + - For stream state messages, the matching stream batches are flushed and closed. + - For global state, all batches are flushed and closed. + """ + stream_names: list[str] = [] + if state.type == AirbyteStateType.STREAM: + stream_names = [state.record.stream] + if state.type == AirbyteStateType.GLOBAL: + stream_names = list(self._buffered_records.keys()) + else: + msg = f"Unexpected state type: {state.type}" + raise exc.PyAirbyteInternalError(msg) + + for stream_name in stream_names: + state_file_name = self.file_writer.get_active_batch(stream_name) + self.file_writer.flush_active_batch(stream_name) + self.file_writer._write_state_to_file(state) + return + + def _write_files_to_new_table( + self, + files: list[Path], + stream_name: str, + batch_id: str, + ) -> str: + """Write file(s) to a new table. + + This involves registering the table in the Iceberg catalog, creating a manifest file, + and registering the manifest file in the catalog. + """ + temp_table_name = self._create_table_for_loading( + stream_name=stream_name, + batch_id=batch_id, + ) + columns_list = list(self._get_sql_column_definitions(stream_name=stream_name).keys()) + columns_list_str = indent( + "\n, ".join([self._quote_identifier(col) for col in columns_list]), + " ", + ) + files_list = ", ".join([f"'{f!s}'" for f in files]) + columns_type_map = indent( + "\n, ".join( + [ + self._quote_identifier(self.normalizer.normalize(prop_name)) + + ': "' + + str( + self._get_sql_column_definitions(stream_name)[ + self.normalizer.normalize(prop_name) + ] + ) + + '"' + for prop_name in columns_list + ] + ), + " ", + ) + insert_statement = dedent( + f""" + INSERT INTO {self.sql_config.schema_name}.{temp_table_name} + ( + {columns_list_str} + ) + SELECT + {columns_list_str} + FROM read_json_auto( + [{files_list}], + format = 'newline_delimited', + union_by_name = true, + columns = {{ { columns_type_map } }} + ) + """ + ) + self._execute_sql(insert_statement) + return temp_table_name diff --git a/airbyte/_util/polars.py b/airbyte/_util/polars.py new file mode 100644 index 00000000..7b750436 --- /dev/null +++ b/airbyte/_util/polars.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Polars utility functions.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import polars as pl +import pyarrow as pa + + +@dataclass +class PolarsStreamSchema: + """A Polars stream schema.""" + + expressions: list[pl.Expr] = field(default_factory=list) + arrow_schema: pa.Schema = field(default_factory=lambda: pa.schema([])) + polars_schema: pl.Schema = field(default_factory=lambda: pl.Schema([])) + + @classmethod + def from_json_schema(cls, json_schema: dict[str, Any]) -> PolarsStreamSchema: + """Create a Polars stream schema from a JSON schema.""" + expressions: list[pl.Expr] = [] + arrow_columns: list[pa.Field] = [] + polars_columns: list[pl.Col] = [] + + _json_schema_to_polars( + json_schema=json_schema, + expressions=expressions, + arrow_columns=arrow_columns, + polars_columns=polars_columns, + ) + return cls( + expressions=expressions, + arrow_schema=pa.schema(arrow_columns), + polars_schema=pl.Schema(polars_columns), + ) + + +def _json_schema_to_polars( + *, + json_schema: dict[str, Any], + expressions: list[pl.Expr], + arrow_columns: list[pa.Field], + polars_columns: list[pl.Col], + _breadcrumbs: list[str] | None = None, +) -> None: + """Get Polars transformations and PyArrow column definitions from the provided JSON schema. + + Recursive operations are tracked with a breadcrumb list. + """ + _breadcrumbs = _breadcrumbs or [] + json_schema_node = json_schema + for crumb in _breadcrumbs: + json_schema_node = json_schema_node["properties"][crumb] + + # Determine the primary type from the schema node + tidy_type: str | list[str] = json_schema_node["type"] + + # Handle multiple types, focusing on non-nullable types if present + if isinstance(tidy_type, list): + if "null" in tidy_type: + tidy_type.remove("null") + if len(tidy_type) == 1: + tidy_type = tidy_type[0] + else: + msg = f"Invalid type: {tidy_type}" + raise ValueError(msg) + + for key, value in json_schema_node.get("properties", {}).items(): + # Handle nested objects recursively + if tidy_type == "object": + # Use breadcrumbs to navigate into nested properties + _json_schema_to_polars( + json_schema=json_schema, + expressions=expressions, + arrow_columns=arrow_columns, + polars_columns=polars_columns, + _breadcrumbs=[*_breadcrumbs, key], + ) + continue + + # Map JSON schema types to Arrow and Polars types + if tidy_type == "integer": + expressions.append(pl.col(key).cast(pl.Int64)) + arrow_columns.append(pa.field(key, pa.int64())) + polars_columns.append(pl.Int64) # Add corresponding Polars column type + continue + + if tidy_type == "number": + expressions.append(pl.col(key).cast(pl.Float64)) + arrow_columns.append(pa.field(key, pa.float64())) + polars_columns.append(pl.Float64) # Add corresponding Polars column type + continue + + if tidy_type == "boolean": + expressions.append(pl.col(key).cast(pl.Boolean)) + arrow_columns.append(pa.field(key, pa.bool_())) + polars_columns.append(pl.Boolean) # Add corresponding Polars column type + continue + + if tidy_type == "string": + str_format = value.get("format") + if str_format == "date-time": + expressions.append(pl.col(key).cast(pl.Datetime)) + arrow_columns.append(pa.field(key, pa.timestamp("ms"))) + polars_columns.append(pl.Datetime) # Add corresponding Polars column type + continue + + if str_format == "date": + expressions.append(pl.col(key).cast(pl.Date)) + arrow_columns.append(pa.field(key, pa.date32())) + polars_columns.append(pl.Date) # Add corresponding Polars column type + continue + + if str_format == "time": + expressions.append(pl.col(key).cast(pl.Time)) + arrow_columns.append(pa.field(key, pa.time32("ms"))) + polars_columns.append(pl.Time) # Add corresponding Polars column type + continue + + expressions.append(pl.col(key).cast(pl.Utf8)) + arrow_columns.append(pa.field(key, pa.string())) + polars_columns.append(pl.Utf8) # Add corresponding Polars column type + continue + + msg = f"Invalid type: {tidy_type}" + raise ValueError(msg) diff --git a/airbyte/_writers/__init__.py b/airbyte/_writers/__init__.py index fd2c0072..cd102939 100644 --- a/airbyte/_writers/__init__.py +++ b/airbyte/_writers/__init__.py @@ -11,4 +11,5 @@ "BatchHandle", "FileWriterBase", "JsonlWriter", + "ParquetWriter", ] diff --git a/airbyte/_writers/file_writers.py b/airbyte/_writers/file_writers.py index b82b4a74..770007f6 100644 --- a/airbyte/_writers/file_writers.py +++ b/airbyte/_writers/file_writers.py @@ -70,7 +70,7 @@ def _open_new_file( """Open a new file for writing.""" return file_path.open("w", encoding="utf-8") - def _flush_active_batch( + def flush_active_batch( self, stream_name: str, progress_tracker: ProgressTracker, @@ -106,7 +106,7 @@ def _new_batch( This also flushes the active batch if one already exists for the given stream. """ if stream_name in self._active_batches: - self._flush_active_batch( + self.flush_active_batch( stream_name=stream_name, progress_tracker=progress_tracker, ) @@ -217,7 +217,7 @@ def flush_active_batches( """Flush active batches for all streams.""" streams = list(self._active_batches.keys()) for stream_name in streams: - self._flush_active_batch( + self.flush_active_batch( stream_name=stream_name, progress_tracker=progress_tracker, ) diff --git a/airbyte/_writers/parquet.py b/airbyte/_writers/parquet.py new file mode 100644 index 00000000..94413365 --- /dev/null +++ b/airbyte/_writers/parquet.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A Parquet cache implementation.""" + +from __future__ import annotations + +import gzip +import json +from pathlib import Path +from typing import IO, TYPE_CHECKING, Literal, cast + +import orjson +import ulid +from overrides import overrides +from pydantic import PrivateAttr + +from airbyte_protocol.models import ( + AirbyteMessage, + AirbyteRecordMessage, + AirbyteStateMessage, + AirbyteStateType, +) + +from airbyte import exceptions as exc +from airbyte._future_cdk.state_writers import StateWriterBase +from airbyte._processors.file.base import ( + FileWriterBase, +) + + +if TYPE_CHECKING: + from airbyte.records import StreamRecord + + +class ParquetWriter(FileWriterBase): + """An Parquet file writer implementation.""" + + default_cache_file_suffix = ".parquet" + prune_extra_fields = True + + _state_writer: StateWriterBase | None = PrivateAttr(default=None) + _buffered_records: dict[str, list[AirbyteMessage]] = PrivateAttr(default_factory=dict) + + def _get_records_file_path( + self, + cache_dir: Path, + stream_name: str, + batch_id: str, + ) -> Path: + """Return the records file path for the given stream and batch.""" + return cache_dir / f"{stream_name}_{batch_id}.records.parquet" + + @overrides + def _open_new_file( + self, + file_path: Path, + ) -> IO[str]: + """Open a new file for writing.""" + return cast(IO[str], gzip.open(file_path, "w")) + + @overrides + def _get_new_cache_file_path( + self, + stream_name: str, + batch_id: str | None = None, # ULID of the batch + ) -> Path: + """Return a new cache file path for the given stream.""" + batch_id = batch_id or str(ulid.ULID()) + target_dir = Path(self._cache_dir) + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir / f"{stream_name}_{batch_id}{self.default_cache_file_suffix}" + + @overrides + def _write_record_dict( + self, + record_dict: StreamRecord, + open_file_writer: IO[str], + ) -> None: + # If the record is too nested, `orjson` will fail with error `TypeError: Recursion + # limit reached`. If so, fall back to the slower `json.dumps`. + try: + open_file_writer.write(orjson.dumps(record_dict).decode("utf-8") + "\n") + except TypeError: + # Using isoformat method for datetime serialization + open_file_writer.write(json.dumps(record_dict, default=lambda _: _.isoformat()) + "\n") diff --git a/airbyte/caches/iceberg.py b/airbyte/caches/iceberg.py new file mode 100644 index 00000000..1464db2f --- /dev/null +++ b/airbyte/caches/iceberg.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A Iceberg implementation of the PyAirbyte cache. + +## Usage Example + +```python +from airbyte as ab +from airbyte.caches import IcebergCache + +cache = IcebergCache( + db_path="/path/to/my/Iceberg-file", + schema_name="myschema", +) +``` +""" + +from __future__ import annotations + +import warnings + +from Iceberg_engine import IcebergEngineWarning +from pydantic import PrivateAttr + +from airbyte._processors.sql.iceberg import IcebergConfig, IcebergSqlProcessor +from airbyte.caches.base import CacheBase + + +# Suppress warnings from Iceberg about reflection on indices. +# https://github.com/Mause/Iceberg_engine/issues/905 +warnings.filterwarnings( + "ignore", + message="Iceberg-engine doesn't yet support reflection on indices", + category=IcebergEngineWarning, +) + + +# @dataclass +class IcebergCache(IcebergConfig, CacheBase): + """A Iceberg cache.""" + + _sql_processor_class: type[IcebergSqlProcessor] = PrivateAttr(default=IcebergSqlProcessor) diff --git a/airbyte/destinations/base.py b/airbyte/destinations/base.py index 1b34d40f..01427f58 100644 --- a/airbyte/destinations/base.py +++ b/airbyte/destinations/base.py @@ -275,7 +275,7 @@ def _write_airbyte_message_stream( try: # We call the connector to write the data, tallying the inputs and outputs for destination_message in progress_tracker.tally_confirmed_writes( - messages=self._execute( + messages=self._execute_and_parse( args=[ "write", "--config", diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index fed5a973..c2a422cc 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -206,7 +206,7 @@ def _discover(self) -> AirbyteCatalog: - Make sure the subprocess is killed when the function returns. """ with as_temp_files([self._config]) as [config_file]: - for msg in self._execute(["discover", "--config", config_file]): + for msg in self._execute_and_parse(["discover", "--config", config_file]): if msg.type == Type.CATALOG and msg.catalog: return msg.catalog raise exc.AirbyteConnectorMissingCatalogError( @@ -235,7 +235,7 @@ def _get_spec(self, *, force_refresh: bool = False) -> ConnectorSpecification: * Make sure the subprocess is killed when the function returns. """ if force_refresh or self._spec is None: - for msg in self._execute(["spec"]): + for msg in self._execute_and_parse(["spec"]): if msg.type == Type.SPEC and msg.spec: self._spec = msg.spec break @@ -542,7 +542,7 @@ def _read_with_catalog( catalog_file, state_file, ]: - message_generator = self._execute( + message_generator = self._execute_and_parse( [ "read", "--config", diff --git a/examples/run_polars_poc.py b/examples/run_polars_poc.py new file mode 100644 index 00000000..1bd9283a --- /dev/null +++ b/examples/run_polars_poc.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""A test of Polars in PyAirbyte. + +Usage (from PyAirbyte root directory): +> poetry run python ./examples/run_polars_poc.py +""" + +from __future__ import annotations + +import logging +import time +from datetime import datetime +from pathlib import Path +from typing import Callable, cast + +import airbyte as ab +import boto3 +import polars as pl +from airbyte import get_source +from airbyte.secrets.google_gsm import GoogleGSMSecretManager +from typing_extensions import Literal + +logger = logging.getLogger() +for handler in logger.handlers[:]: + logger.removeHandler(handler) + +# Basic logging configuration +logger.setLevel(logging.INFO) + +# Create a console handler +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) + +# Create a formatter and set it for the handler +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +console_handler.setFormatter(formatter) + +# Add the handler to the logger +logger.addHandler(console_handler) + +AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing" +SECRET_NAME = "SECRET_SOURCE-S3_V4_JSONL_NEWLINE__CREDS" + +secret_mgr = GoogleGSMSecretManager( + project=AIRBYTE_INTERNAL_GCP_PROJECT, + credentials_json=ab.get_secret("GCP_GSM_CREDENTIALS"), +) +secret_handle = secret_mgr.get_secret(SECRET_NAME) +assert secret_handle is not None, "Secret not found." +secret_config = secret_handle.parse_json() + +# BUCKET_NAME = "performance-test-datasets" # << Can't find this bucket +BUCKET_NAME = "airbyte-internal-performance" # << From 'dataline-dev', until we find the other bucket +FILE_NAME_PREFIX = "json/no_op_source/stream1/2024_08_06_18/" +GLOB = f"{FILE_NAME_PREFIX}*.jsonl" +SAMPLE_FILE_1 = "json/no_op_source/stream1/2024_08_06_18/15836041.0.1722968833150.jsonl" +SAMPLE_FILE_2 = "json/no_op_source/stream1/2024_08_06_18/15836041.1.1722968852741.jsonl" +SAMPLE_FILES = [SAMPLE_FILE_1, SAMPLE_FILE_2] + +records_processed = 0 + + +def get_my_source() -> ab.Source: + """Take the existing S3 config and modify it to use the perf-test bucket and glob.""" + secret_config["bucket"] = BUCKET_NAME + secret_config["streams"][0] = { + "name": "stream1", + "format": {"filetype": "jsonl"}, + "globs": ["json/no_op_source/stream1/**/*.jsonl"], + } + return get_source( + "source-s3", + config=secret_config, + streams="*", + ) + + +def test_my_source() -> None: + """Test the modified S3 source.""" + source = get_my_source() + source.check() + source.select_all_streams() + source.read() + + +def get_s3_file_names( + bucket_name: str, + path_prefix: str, +) -> list[str]: + """Get the names of all files in an S3 bucket with a given prefix.""" + s3_client = boto3.client( + "s3", + aws_access_key_id=secret_config["aws_access_key_id"], + aws_secret_access_key=secret_config["aws_secret_access_key"], + region_name="us-east-2", + ) + + # List all objects in the S3 bucket with the specified prefix + response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path_prefix) + + # Filter the objects to match the file prefix + matching_files = [match["Key"] for match in response.get("Contents", [])] + + return matching_files + + +def get_polars_df( + s3_file_path: str, + *, + lazy: bool = True, + expire_cache: bool, +) -> pl.LazyFrame | pl.DataFrame: + """Set up a Polars lazy DataFrame from an S3 file. + + This action will connect to the S3 bucket but it will not read data from the file until + the DataFrame is actually used. + """ + read_fn: Callable[..., pl.LazyFrame | pl.DataFrame] + if lazy: + read_fn = pl.scan_ndjson + else: + read_fn = pl.read_ndjson + + return read_fn( + source=s3_file_path, # TODO: Try to get this working with globs + storage_options={ + "aws_access_key_id": secret_config["aws_access_key_id"], + "aws_secret_access_key": secret_config["aws_secret_access_key"], + "region": "us-east-2", + }, + include_file_paths="_ab_source_file_path", + row_index_name="_ab_record_index", + row_index_offset=0, + infer_schema_length=100, + file_cache_ttl=0 if expire_cache else None, + ) + + +def add_custom_transforms( + df: pl.LazyFrame | pl.DataFrame, +) -> pl.LazyFrame | pl.DataFrame: + """Add custom transforms to the Polars lazy DataFrame.""" + return df.with_columns( + [ + pl.lit("Hello, world!").alias("greeting"), + pl.col("_airbyte_ab_id").str.to_uppercase(), + pl.lit(datetime.now()).alias("current_timestamp"), + ], + ) + + +def write_files( + s3_urls: list[str], + file_type: Literal["jsonl", "parquet"] = "jsonl", + *, + expire_cache: bool, + lazy: bool = True, +) -> None: + global records_processed + dataframes: list[pl.DataFrame] = [] + for n, s3_url in enumerate(s3_urls, start=1): + base_name = ".".join(Path(s3_url).name.split(".")[:-1]) + output_file = f"polars-perf-test-artifact.{base_name}.{file_type}" + if Path(output_file).exists(): + Path(output_file).unlink() + + df: pl.LazyFrame | pl.DataFrame = get_polars_df( + s3_file_path=f"s3://{BUCKET_NAME}/{SAMPLE_FILE_2}", + # expected_schema=polars_stream_schema, + expire_cache=expire_cache, + lazy=lazy, + ) + + if lazy: + df = add_custom_transforms(df) + assert isinstance(df, pl.LazyFrame) + logger.info(f"Collecting file {n} of {len(s3_urls)} to '{output_file}'...") + df = df.collect() + logger.info(f"Writing file {n} of {len(s3_urls)} to '{output_file}'...") + + # Write the DataFrame to a file + if file_type == "parquet": + df.write_parquet(output_file) + elif file_type == "jsonl": + df.write_ndjson(output_file) + else: + raise ValueError(f"Invalid file type: {file_type}") + + records_processed += df.height + del df + else: + assert isinstance(df, pl.DataFrame) + dataframes.append(df) + + if not lazy: + combined_df: pl.DataFrame = pl.concat(dataframes) + combined_df = cast(pl.DataFrame, add_custom_transforms(combined_df)) + if file_type == "parquet": + combined_df.write_parquet(output_file) + elif file_type == "jsonl": + combined_df.write_ndjson(output_file) + else: + raise ValueError(f"Invalid file type: {file_type}") + + records_processed += combined_df.height + + +def run_polars_perf_test( + file_type: Literal["jsonl", "parquet"] = "jsonl", + *, + expire_cache: bool, + lazy: bool = True, +) -> None: + """Run the Polars proof of concept.""" + global records_processed + + logger.info("Finding S3 files...") + + s3_urls = get_s3_file_names(BUCKET_NAME, FILE_NAME_PREFIX) + logger.info("Creating polars dataframes...") + start_time = time.time() + + write_files( + s3_urls, + file_type=file_type, + lazy=lazy, + expire_cache=expire_cache, + ) + + logger.info("Finished write operation.") + elapsed_transfer_time = time.time() - start_time + + mb_per_record = 180 / (1024 * 1024) # << 180 bytes per record, converted to MB + logger.info( + f"Wrote {records_processed:,} records from {len(s3_urls)} files in" + f" {elapsed_transfer_time:,.2f} seconds." + f" ({records_processed / elapsed_transfer_time:,.2f} records per second," + f" {(records_processed / elapsed_transfer_time) * mb_per_record:,.1f} MB/s," + f" {records_processed * mb_per_record:,.1f} MB total)" + ) + + +def main() -> None: + """Run the Polars proof of concept.""" + # test_my_source() # We don't need to run this every time - only for debugging + run_polars_perf_test( + file_type="parquet", + lazy=False, + expire_cache=False, + ) + + +if __name__ == "__main__": + main() diff --git a/poetry.lock b/poetry.lock index 2c55f50c..4db2b0c8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -138,6 +138,47 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "boto3" +version = "1.34.155" +description = "The AWS SDK for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "boto3-1.34.155-py3-none-any.whl", hash = "sha256:445239ea2ba7f4084ddbd71f721c14d0a6d08e06f6ba51b5403a16b6544b3f1e"}, + {file = "boto3-1.34.155.tar.gz", hash = "sha256:e8d2e128c74e84199edccdc3a6b4b1c6fb36d6fdb5688eb92931827f02c6fa5b"}, +] + +[package.dependencies] +botocore = ">=1.34.155,<1.35.0" +jmespath = ">=0.7.1,<2.0.0" +s3transfer = ">=0.10.0,<0.11.0" + +[package.extras] +crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] + +[[package]] +name = "botocore" +version = "1.34.155" +description = "Low-level, data-driven core of boto 3." +optional = false +python-versions = ">=3.8" +files = [ + {file = "botocore-1.34.155-py3-none-any.whl", hash = "sha256:f2696c11bb0cad627d42512937befd2e3f966aedd15de00d90ee13cf7a16b328"}, + {file = "botocore-1.34.155.tar.gz", hash = "sha256:3aa88abfef23909f68d3e6679a3d4b4bb3c6288a6cfbf9e253aa68dac8edad64"}, +] + +[package.dependencies] +jmespath = ">=0.7.1,<2.0.0" +python-dateutil = ">=2.1,<3.0.0" +urllib3 = [ + {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, +] + +[package.extras] +crt = ["awscrt (==0.21.2)"] + [[package]] name = "bracex" version = "2.5" @@ -1367,8 +1408,8 @@ files = [ [package.dependencies] orjson = ">=3.9.14,<4.0.0" pydantic = [ - {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, ] requests = ">=2,<3" @@ -1801,8 +1842,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -1961,6 +2002,47 @@ tomli = ">=1.2.2" [package.extras] poetry-plugin = ["poetry (>=1.0,<2.0)"] +[[package]] +name = "polars" +version = "1.4.1" +description = "Blazingly fast DataFrame library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "polars-1.4.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:f02fc6a5c63dd86cfeb159caa66112e477c69fc7800a28e64609ac2780554865"}, + {file = "polars-1.4.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:bd2acd8b1977f61b9587c8d47d16f101e7e73edd8cdeb3a8a725f15f181cd120"}, + {file = "polars-1.4.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cf834a328e292c31c06eb606496becb6d8a795e927c826e26e2af27087950f1"}, + {file = "polars-1.4.1-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:64eabf0ef7ac0d17fe15361e7daaeb4425a875d2d760c17d96803e9ac8bee244"}, + {file = "polars-1.4.1-cp38-abi3-win_amd64.whl", hash = "sha256:2313d63ecfa1d9f1e740b9fcabb8ae45d9d0b5acf1ddb401951daba4c0f3f74f"}, + {file = "polars-1.4.1.tar.gz", hash = "sha256:ed8009aff8cf91f94db5a38d947185603ad5bee48a28b764cf5a52048c7c4756"}, +] + +[package.extras] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] +cloudpickle = ["cloudpickle"] +connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] +deltalake = ["deltalake (>=0.15.0)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] +fsspec = ["fsspec"] +gpu = ["cudf-polars-cu12"] +graph = ["matplotlib"] +iceberg = ["pyiceberg (>=0.5.0)"] +numpy = ["numpy (>=1.16.0)"] +openpyxl = ["openpyxl (>=3.0.0)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["hvplot (>=0.9.1)", "polars[pandas]"] +pyarrow = ["pyarrow (>=7.0.0)"] +pydantic = ["pydantic"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] +timezone = ["backports-zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "proto-plus" version = "1.24.0" @@ -2178,8 +2260,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -3073,6 +3155,23 @@ files = [ {file = "ruff-0.4.1.tar.gz", hash = "sha256:d592116cdbb65f8b1b7e2a2b48297eb865f6bdc20641879aa9d7b9c11d86db79"}, ] +[[package]] +name = "s3transfer" +version = "0.10.2" +description = "An Amazon S3 Transfer Manager" +optional = false +python-versions = ">=3.8" +files = [ + {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, + {file = "s3transfer-0.10.2.tar.gz", hash = "sha256:0711534e9356d3cc692fdde846b4a1e4b0cb6519971860796e6bc4c7aea00ef6"}, +] + +[package.dependencies] +botocore = ">=1.33.2,<2.0a.0" + +[package.extras] +crt = ["botocore[crt] (>=1.33.2,<2.0a.0)"] + [[package]] name = "setuptools" version = "72.2.0" diff --git a/pyproject.toml b/pyproject.toml index 2494dc7f..e3a78494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ airbyte-api = "^0.49.2" google-cloud-bigquery-storage = "^2.25.0" pyiceberg = "^0.6.1" uuid7 = "^0.1.0" +polars = "^1.3.0" grpcio = "<=1.65.0" structlog = "^24.4.0" @@ -70,6 +71,7 @@ poethepoet = "^0.26.1" coverage = "^7.5.1" pytest-timeout = "^2.3.1" viztracer = "^0.16.3" +boto3 = "^1.34.155" [build-system] requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]