From 1beb67c2b19940f661ba2d57e0dc9a01881d3f25 Mon Sep 17 00:00:00 2001 From: Jesus Lara Date: Fri, 24 Jan 2025 02:14:18 +0100 Subject: [PATCH] new stream_query from using copy_from_query to create a dataframe in chunks. --- asyncdb/drivers/outputs/arrow.py | 49 ++++++++++++++ asyncdb/drivers/outputs/dt.py | 22 ++++++ asyncdb/drivers/outputs/pandas.py | 12 ++++ asyncdb/drivers/outputs/polars.py | 37 +++++++++- asyncdb/drivers/pg.py | 109 ++++++++++++++++++++++++++++-- asyncdb/version.py | 2 +- examples/test_pg_stream.py | 83 +++++++++++++++++++++++ 7 files changed, 306 insertions(+), 8 deletions(-) create mode 100644 examples/test_pg_stream.py diff --git a/asyncdb/drivers/outputs/arrow.py b/asyncdb/drivers/outputs/arrow.py index 271b947f..d5791f3f 100644 --- a/asyncdb/drivers/outputs/arrow.py +++ b/asyncdb/drivers/outputs/arrow.py @@ -1,8 +1,57 @@ import logging +from io import StringIO import pyarrow as pa +import pyarrow.csv as pc from .base import OutputFormat +def arrow_parser(csv_stream: StringIO, chunksize: int, delimiter: str = ",", columns: list = None): + """ + Read a block of CSV text from `csv_stream` into a single PyArrow Table + and yield that Table. + - `columns`: a list of column names from the first CSV header line. + We'll tell Arrow to use those as the schema field names, skipping + any separate header row in the text. + - `delimiter`: the field separator (default '|'). + - `chunksize`: (unused for Arrow, but kept for signature compatibility). + """ + # Convert CSV text to bytes for Arrow + data_bytes = csv_stream.getvalue().encode("utf-8") + + read_opts = pc.ReadOptions( + use_threads=True + ) + parse_opts = pc.ParseOptions( + delimiter=delimiter, + quote_char='"', # If you do not want quoting at all, set to None + double_quote=True, + escape_char=None + ) + convert_opts = pc.ConvertOptions() + + if columns is not None: + # If we've already extracted the header line ourselves, + # then we tell Arrow to treat all lines as data (no separate header), + # and *assign* those column names to each field. + read_opts.column_names = columns + # Also skip the first line if it was the header in the actual CSV text + # But in your current code, you've already stripped the header line, + # so there's no real first line to skip. + # Hence skip_rows=0 is correct (the entire data block are data lines). + read_opts.skip_rows = 0 + + # Read the CSV into a single Table + table = pc.read_csv( + pa.BufferReader(data_bytes), + read_options=read_opts, + parse_options=parse_opts, + convert_options=convert_opts + ) + + # We yield one table per chunk + yield table + + class arrowFormat(OutputFormat): """ Returns an Apache Arrow Table from a Resultset diff --git a/asyncdb/drivers/outputs/dt.py b/asyncdb/drivers/outputs/dt.py index bb3c6943..cfe64b23 100644 --- a/asyncdb/drivers/outputs/dt.py +++ b/asyncdb/drivers/outputs/dt.py @@ -1,8 +1,30 @@ import logging +from io import StringIO import datatable as dt from .base import OutputFormat +def dt_parser(csv_stream: StringIO, chunksize: int, delimiter: str = ",", columns: list = None, quote: str = '"'): + """ + Reads CSV text from `csv_stream` using datatable.fread, yields a single Frame. + """ + # datatable.fread cannot read directly from a file-like object, + # so we pass the CSV text via "text" parameter. + csv_text = csv_stream.getvalue() + + # Create the Frame + yield dt.fread( + text=csv_text, + sep=delimiter, + nthreads=0, # use all available threads + header=None, # no separate header row in the text + columns=columns, # the list of column names we captured + quotechar=quote, # pass None or some unusual char if you want to avoid standard " quoting + fill=True # fill shorter lines with NA if needed + ) + + + class dtFormat(OutputFormat): """ Returns a Pandas Dataframe from a Resultset diff --git a/asyncdb/drivers/outputs/pandas.py b/asyncdb/drivers/outputs/pandas.py index 8756d904..418d9ee4 100644 --- a/asyncdb/drivers/outputs/pandas.py +++ b/asyncdb/drivers/outputs/pandas.py @@ -1,8 +1,20 @@ import logging +from io import StringIO import pandas from .base import OutputFormat + +def pandas_parser(csv_stream: StringIO, chunksize: int, delimiter: str = ",", columns: list = None): + """ + Parser function that reads the CSV text in `csv_stream` + and yields DataFrame chunks using Pandas' chunked read_csv. + """ + # `pd.read_csv(..., chunksize=...)` returns an iterator of DataFrames + # We'll just return that iterator + yield from pandas.read_csv(csv_stream, chunksize=chunksize, sep=delimiter, names=columns, header=None) + + class pandasFormat(OutputFormat): """ Returns a Pandas Dataframe from a Resultset diff --git a/asyncdb/drivers/outputs/polars.py b/asyncdb/drivers/outputs/polars.py index 1ac337dc..12147bc1 100644 --- a/asyncdb/drivers/outputs/polars.py +++ b/asyncdb/drivers/outputs/polars.py @@ -1,9 +1,42 @@ import logging import pandas -import polars as polar +import io +import polars as pl from .base import OutputFormat +def polars_parser( + csv_stream: io.StringIO, + chunksize: int, + delimiter: str = "|", + columns: list = None +): + """ + Parser for Polars. Reads entire CSV text from `csv_stream` into one Polars DataFrame + and yields that DataFrame. + + - If `columns` is provided, we assume the CSV text has *no* header row (because + your code extracted it already). Then we manually assign those column names. + - Otherwise, if columns=None, Polars will interpret the first line in `csv_stream` + as the header (has_header=True). + """ + + csv_text = csv_stream.getvalue() + has_header = columns is None + + # Polars read_csv can either infer the header or we can specify new_columns=... + # We'll pass 'has_header=False' if we already stripped the header line, + # and assign 'new_columns=columns' to rename them. + + yield pl.read_csv( + io.StringIO(csv_text), + separator=delimiter, + has_header=has_header, + new_columns=None if has_header else columns, + ignore_errors=True # optional, in case of mismatch + ) + + class polarsFormat(OutputFormat): """ Returns a PyPolars Dataframe from a Resultset @@ -14,7 +47,7 @@ async def serialize(self, result, error, *args, **kwargs): try: result = [dict(row) for row in result] a = pandas.DataFrame(data=result, **kwargs) - df = polar.from_pandas(a, **kwargs) + df = pl.from_pandas(a, **kwargs) self._result = df except ValueError as err: print(err) diff --git a/asyncdb/drivers/pg.py b/asyncdb/drivers/pg.py index 391a0975..033529a8 100644 --- a/asyncdb/drivers/pg.py +++ b/asyncdb/drivers/pg.py @@ -7,12 +7,13 @@ import asyncio from enum import Enum +import io import os import ssl import time import uuid from collections.abc import Callable, Iterable -from typing import Any, Optional, Union +from typing import Any, AsyncGenerator, Generator, Optional, Union from dataclasses import is_dataclass import contextlib from datamodel import BaseModel @@ -81,6 +82,12 @@ class pgRecord(asyncpg.Record): def __getattr__(self, name: str): return self[name] + def to_dict(self): + """Return a dict representation of this record.""" + # `asyncpg.Record` supports an iterator over column indexes, + # but we can also do something like: + return {key: self[key] for key in self.keys()} + class pgPool(BasePool): """ @@ -751,9 +758,7 @@ async def query(self, sentence: Union[str, Any], *args, **kwargs): error = f"Error on Query: {err}" finally: self.generated_at() - if error: - return [None, error] - return await self._serializer(self._result, error) # pylint: disable=W0150 + return [None, error] if error else await self._serializer(self._result, error) # pylint: disable=W0150 async def queryrow(self, sentence: str, *args): self._result = None @@ -1560,7 +1565,7 @@ async def _updating_(self, *args, _filter: dict = None, **kwargs): source.append(value) if name in _filter: new_cond[name] = value - cols.append("{} = {}".format(name, "${}".format(n))) # pylint: disable=C0209 + cols.append(f"{name} = ${n}") # pylint: disable=C0209 n += 1 try: set_fields = ", ".join(cols) @@ -1634,3 +1639,97 @@ async def _deleting_(self, *args, _filter: dict = None, **kwargs): return [model(**dict(r)) for r in result] except Exception as err: raise DriverError(message=f"Error on DELETE over table {model.Meta.name}: {err!s}") from err + + async def stream_query( + self, + query: str, + parser: Callable[[io.StringIO, int], Generator], + chunksize: int = 1000, + delimiter='|', + ): + """ + Stream a query's results via COPY TO STDOUT in CSV format, + parse them in memory in chunks, yield parsed DataFrames (or anything else). + + :param conn: an active asyncpg.Connection + :param query: the SQL SELECT you want to stream + :param parser: a function like pandas_parser(csv_stream, chunksize) + that yields parsed objects (DataFrame, etc.) + :param chunksize: how many rows each parser chunk should contain + """ + if not self._connection: + await self.connection() + # We'll accumulate CSV text in a buffer + buffer = io.StringIO() + columns = None # Store the header row (column names) + accumulated_lines = [] # store lines for the current batch + + # A queue to collect raw bytes from the producer + queue = asyncio.Queue() + + # Define the coroutine that asyncpg will call with each chunk of bytes + async def producer_coroutine(chunk: bytes): + # This function is called by asyncpg whenever new bytes arrive. + # We just stuff them into the queue. + await queue.put(chunk) + + # We wrap copy_from_query in a task so it runs in parallel with consumption. + # Once copy_from_query returns, we know no more data is coming. + copy_task = asyncio.create_task( + self._connection.copy_from_query( + query, + output=producer_coroutine, # call our coroutine for every chunk + delimiter=delimiter, + format='csv', + quote='"', + escape='"', + # force_quote=True, + header=True, + ) + ) + + # Consumer loop + while True: + try: + chunk = await queue.get() + if chunk is None: + break # No more data + except asyncio.TimeoutError: + continue + + text_chunk = chunk.decode('utf-8', errors='replace') + if columns is None: + # this is the first row: the header: + end_header = text_chunk.find('\n', 0) + header = text_chunk[:end_header] + columns = header.strip().split(delimiter) + # remove the header from the current chunk: + text_chunk = text_chunk[end_header + 1:] + accumulated_lines = text_chunk.count('\n') + # If we reached our threshold, parse & yield + if accumulated_lines >= chunksize: + buffer.write(text_chunk) + buffer.seek(0) + for parsed_obj in parser(buffer, chunksize, delimiter=delimiter, columns=columns): + yield parsed_obj + accumulated_lines = 0 + buffer.seek(0) + buffer.truncate(0) + else: + buffer.write(text_chunk) + + queue.task_done() + + if copy_task.done() and queue.empty(): + break # No more data + + # If there's leftover unparsed data + if accumulated_lines: + buffer.seek(0) + for parsed_obj in parser(buffer, chunksize, delimiter=delimiter, columns=columns): + yield parsed_obj + + # Make sure the COPY command is fully done + # (awaiting copy_task will raise if there's any error in the copy). + await copy_task + await queue.join() diff --git a/asyncdb/version.py b/asyncdb/version.py index 18f8f81d..f10ab3c6 100644 --- a/asyncdb/version.py +++ b/asyncdb/version.py @@ -3,7 +3,7 @@ __title__ = "asyncdb" __description__ = "Library for Asynchronous data source connections \ Collection of asyncio drivers." -__version__ = "2.10.2" +__version__ = "2.10.3" __copyright__ = "Copyright (c) 2020-2024 Jesus Lara" __author__ = "Jesus Lara" __author_email__ = "jesuslarag@gmail.com" diff --git a/examples/test_pg_stream.py b/examples/test_pg_stream.py new file mode 100644 index 00000000..f2bae428 --- /dev/null +++ b/examples/test_pg_stream.py @@ -0,0 +1,83 @@ +import asyncio +import pandas as pd +import pyarrow as pa +from asyncdb.drivers.pg import pg +from asyncdb.drivers.outputs.pandas import pandas_parser +from asyncdb.drivers.outputs.arrow import arrow_parser +from asyncdb.drivers.outputs.dt import dt_parser, dt +from asyncdb.drivers.outputs.polars import polars_parser, pl + + +params = { + "user": "troc_pgdata", + "password": "12345678", + "host": "127.0.0.1", + "port": "5432", + "database": "navigator", + "DEBUG": True, +} + +async def connect(): + db = pg(params=params) + # create a connection + async with await db.connection() as conn: + print('Connection: ', conn) + result, error = await conn.test_connection() + print(result, error) + + # create a Pandas dataframe from streaming + parts = [] + async for chunk in conn.stream_query( + "SELECT * FROM epson.sales LIMIT 100000", + pandas_parser, + chunksize=1000 + ): + print("Received chunk of size:", len(chunk)) + parts.append(chunk) + + df = pd.concat(parts, ignore_index=True) + print(df) + + # create a Arrow Table from streaming + parts = [] + async for chunk in conn.stream_query( + "SELECT * FROM epson.sales LIMIT 1000000", + arrow_parser, + chunksize=10000 + ): + # print("Received chunk of size:", len(chunk)) + parts.append(chunk) + + df = pa.concat_tables(parts) + print("Final table has", df.num_rows, "rows in total.") + + # create a Polars dataframe from streaming + parts = [] + async for chunk in conn.stream_query( + "SELECT * FROM epson.sales LIMIT 1000000", + polars_parser, + chunksize=10000 + ): + # print("Received chunk of size:", len(chunk)) + parts.append(chunk) + + df = pl.concat(parts, how="vertical") + print("Final table has", df.shape, "rows in total.") + + # create a Datatable from streaming + parts = [] + async for chunk in conn.stream_query( + "SELECT * FROM epson.sales LIMIT 1000000", + dt_parser, + chunksize=10000 + ): + # print("Received chunk of size:", len(chunk)) + parts.append(chunk) + + df = dt.rbind(parts) # or dt.rbind(frames, force=True) + print("Final table has", df.nrows, "rows in total.") + +if __name__ == "__main__": + loop = asyncio.get_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(connect())