Skip to content

Commit

Permalink
new stream_query from using copy_from_query to create a dataframe in …
Browse files Browse the repository at this point in the history
…chunks.
  • Loading branch information
phenobarbital committed Jan 24, 2025
1 parent a7b44ea commit 1beb67c
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 8 deletions.
49 changes: 49 additions & 0 deletions asyncdb/drivers/outputs/arrow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,57 @@
import logging
from io import StringIO
import pyarrow as pa
import pyarrow.csv as pc
from .base import OutputFormat


def arrow_parser(csv_stream: StringIO, chunksize: int, delimiter: str = ",", columns: list = None):
"""
Read a block of CSV text from `csv_stream` into a single PyArrow Table
and yield that Table.
- `columns`: a list of column names from the first CSV header line.
We'll tell Arrow to use those as the schema field names, skipping
any separate header row in the text.
- `delimiter`: the field separator (default '|').
- `chunksize`: (unused for Arrow, but kept for signature compatibility).
"""
# Convert CSV text to bytes for Arrow
data_bytes = csv_stream.getvalue().encode("utf-8")

read_opts = pc.ReadOptions(
use_threads=True
)
parse_opts = pc.ParseOptions(
delimiter=delimiter,
quote_char='"', # If you do not want quoting at all, set to None
double_quote=True,
escape_char=None
)
convert_opts = pc.ConvertOptions()

if columns is not None:
# If we've already extracted the header line ourselves,
# then we tell Arrow to treat all lines as data (no separate header),
# and *assign* those column names to each field.
read_opts.column_names = columns
# Also skip the first line if it was the header in the actual CSV text
# But in your current code, you've already stripped the header line,
# so there's no real first line to skip.
# Hence skip_rows=0 is correct (the entire data block are data lines).
read_opts.skip_rows = 0

# Read the CSV into a single Table
table = pc.read_csv(
pa.BufferReader(data_bytes),
read_options=read_opts,
parse_options=parse_opts,
convert_options=convert_opts
)

# We yield one table per chunk
yield table


class arrowFormat(OutputFormat):
"""
Returns an Apache Arrow Table from a Resultset
Expand Down
22 changes: 22 additions & 0 deletions asyncdb/drivers/outputs/dt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
import logging
from io import StringIO
import datatable as dt
from .base import OutputFormat


def dt_parser(csv_stream: StringIO, chunksize: int, delimiter: str = ",", columns: list = None, quote: str = '"'):
"""
Reads CSV text from `csv_stream` using datatable.fread, yields a single Frame.
"""
# datatable.fread cannot read directly from a file-like object,
# so we pass the CSV text via "text" parameter.
csv_text = csv_stream.getvalue()

# Create the Frame
yield dt.fread(
text=csv_text,
sep=delimiter,
nthreads=0, # use all available threads
header=None, # no separate header row in the text
columns=columns, # the list of column names we captured
quotechar=quote, # pass None or some unusual char if you want to avoid standard " quoting
fill=True # fill shorter lines with NA if needed
)



class dtFormat(OutputFormat):
"""
Returns a Pandas Dataframe from a Resultset
Expand Down
12 changes: 12 additions & 0 deletions asyncdb/drivers/outputs/pandas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
import logging
from io import StringIO
import pandas
from .base import OutputFormat



def pandas_parser(csv_stream: StringIO, chunksize: int, delimiter: str = ",", columns: list = None):
"""
Parser function that reads the CSV text in `csv_stream`
and yields DataFrame chunks using Pandas' chunked read_csv.
"""
# `pd.read_csv(..., chunksize=...)` returns an iterator of DataFrames
# We'll just return that iterator
yield from pandas.read_csv(csv_stream, chunksize=chunksize, sep=delimiter, names=columns, header=None)


class pandasFormat(OutputFormat):
"""
Returns a Pandas Dataframe from a Resultset
Expand Down
37 changes: 35 additions & 2 deletions asyncdb/drivers/outputs/polars.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
import logging
import pandas
import polars as polar
import io
import polars as pl
from .base import OutputFormat


def polars_parser(
csv_stream: io.StringIO,
chunksize: int,
delimiter: str = "|",
columns: list = None
):
"""
Parser for Polars. Reads entire CSV text from `csv_stream` into one Polars DataFrame
and yields that DataFrame.
- If `columns` is provided, we assume the CSV text has *no* header row (because
your code extracted it already). Then we manually assign those column names.
- Otherwise, if columns=None, Polars will interpret the first line in `csv_stream`
as the header (has_header=True).
"""

csv_text = csv_stream.getvalue()
has_header = columns is None

# Polars read_csv can either infer the header or we can specify new_columns=...
# We'll pass 'has_header=False' if we already stripped the header line,
# and assign 'new_columns=columns' to rename them.

yield pl.read_csv(
io.StringIO(csv_text),
separator=delimiter,
has_header=has_header,
new_columns=None if has_header else columns,
ignore_errors=True # optional, in case of mismatch
)


class polarsFormat(OutputFormat):
"""
Returns a PyPolars Dataframe from a Resultset
Expand All @@ -14,7 +47,7 @@ async def serialize(self, result, error, *args, **kwargs):
try:
result = [dict(row) for row in result]
a = pandas.DataFrame(data=result, **kwargs)
df = polar.from_pandas(a, **kwargs)
df = pl.from_pandas(a, **kwargs)
self._result = df
except ValueError as err:
print(err)
Expand Down
109 changes: 104 additions & 5 deletions asyncdb/drivers/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

import asyncio
from enum import Enum
import io
import os
import ssl
import time
import uuid
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union
from typing import Any, AsyncGenerator, Generator, Optional, Union
from dataclasses import is_dataclass
import contextlib
from datamodel import BaseModel
Expand Down Expand Up @@ -81,6 +82,12 @@ class pgRecord(asyncpg.Record):
def __getattr__(self, name: str):
return self[name]

def to_dict(self):
"""Return a dict representation of this record."""
# `asyncpg.Record` supports an iterator over column indexes,
# but we can also do something like:
return {key: self[key] for key in self.keys()}


class pgPool(BasePool):
"""
Expand Down Expand Up @@ -751,9 +758,7 @@ async def query(self, sentence: Union[str, Any], *args, **kwargs):
error = f"Error on Query: {err}"
finally:
self.generated_at()
if error:
return [None, error]
return await self._serializer(self._result, error) # pylint: disable=W0150
return [None, error] if error else await self._serializer(self._result, error) # pylint: disable=W0150

async def queryrow(self, sentence: str, *args):
self._result = None
Expand Down Expand Up @@ -1560,7 +1565,7 @@ async def _updating_(self, *args, _filter: dict = None, **kwargs):
source.append(value)
if name in _filter:
new_cond[name] = value
cols.append("{} = {}".format(name, "${}".format(n))) # pylint: disable=C0209
cols.append(f"{name} = ${n}") # pylint: disable=C0209
n += 1
try:
set_fields = ", ".join(cols)
Expand Down Expand Up @@ -1634,3 +1639,97 @@ async def _deleting_(self, *args, _filter: dict = None, **kwargs):
return [model(**dict(r)) for r in result]
except Exception as err:
raise DriverError(message=f"Error on DELETE over table {model.Meta.name}: {err!s}") from err

async def stream_query(
self,
query: str,
parser: Callable[[io.StringIO, int], Generator],
chunksize: int = 1000,
delimiter='|',
):
"""
Stream a query's results via COPY TO STDOUT in CSV format,
parse them in memory in chunks, yield parsed DataFrames (or anything else).
:param conn: an active asyncpg.Connection
:param query: the SQL SELECT you want to stream
:param parser: a function like pandas_parser(csv_stream, chunksize)
that yields parsed objects (DataFrame, etc.)
:param chunksize: how many rows each parser chunk should contain
"""
if not self._connection:
await self.connection()
# We'll accumulate CSV text in a buffer
buffer = io.StringIO()
columns = None # Store the header row (column names)
accumulated_lines = [] # store lines for the current batch

# A queue to collect raw bytes from the producer
queue = asyncio.Queue()

# Define the coroutine that asyncpg will call with each chunk of bytes
async def producer_coroutine(chunk: bytes):
# This function is called by asyncpg whenever new bytes arrive.
# We just stuff them into the queue.
await queue.put(chunk)

# We wrap copy_from_query in a task so it runs in parallel with consumption.
# Once copy_from_query returns, we know no more data is coming.
copy_task = asyncio.create_task(
self._connection.copy_from_query(
query,
output=producer_coroutine, # call our coroutine for every chunk
delimiter=delimiter,
format='csv',
quote='"',
escape='"',
# force_quote=True,
header=True,
)
)

# Consumer loop
while True:
try:
chunk = await queue.get()
if chunk is None:
break # No more data
except asyncio.TimeoutError:
continue

text_chunk = chunk.decode('utf-8', errors='replace')
if columns is None:
# this is the first row: the header:
end_header = text_chunk.find('\n', 0)
header = text_chunk[:end_header]
columns = header.strip().split(delimiter)
# remove the header from the current chunk:
text_chunk = text_chunk[end_header + 1:]
accumulated_lines = text_chunk.count('\n')
# If we reached our threshold, parse & yield
if accumulated_lines >= chunksize:
buffer.write(text_chunk)
buffer.seek(0)
for parsed_obj in parser(buffer, chunksize, delimiter=delimiter, columns=columns):
yield parsed_obj
accumulated_lines = 0
buffer.seek(0)
buffer.truncate(0)
else:
buffer.write(text_chunk)

queue.task_done()

if copy_task.done() and queue.empty():
break # No more data

# If there's leftover unparsed data
if accumulated_lines:
buffer.seek(0)
for parsed_obj in parser(buffer, chunksize, delimiter=delimiter, columns=columns):
yield parsed_obj

# Make sure the COPY command is fully done
# (awaiting copy_task will raise if there's any error in the copy).
await copy_task
await queue.join()
2 changes: 1 addition & 1 deletion asyncdb/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__title__ = "asyncdb"
__description__ = "Library for Asynchronous data source connections \
Collection of asyncio drivers."
__version__ = "2.10.2"
__version__ = "2.10.3"
__copyright__ = "Copyright (c) 2020-2024 Jesus Lara"
__author__ = "Jesus Lara"
__author_email__ = "[email protected]"
Expand Down
Loading

0 comments on commit 1beb67c

Please sign in to comment.