Skip to content

Commit cc33312

Browse files
authored
fix(python/adbc_driver_manager): don't consume result for description (#3554)
Closes #3543.
1 parent 7076bee commit cc33312

File tree

5 files changed

+114
-12
lines changed

5 files changed

+114
-12
lines changed

python/adbc_driver_manager/adbc_driver_manager/_dbapi_backend.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
from . import _lib
3030

31+
if typing.TYPE_CHECKING:
32+
from typing_extensions import CapsuleType
33+
3134

3235
class DbapiBackend(abc.ABC):
3336
"""
@@ -87,6 +90,27 @@ def convert_executemany_parameters(
8790
"""
8891
...
8992

93+
@abc.abstractmethod
94+
def convert_description(self, schema: "CapsuleType") -> typing.List[typing.Tuple]:
95+
"""Convert a schema capsule into a DB-API description.
96+
97+
Parameters
98+
----------
99+
schema
100+
A PyCapsule of type "arrow_schema".
101+
102+
Returns
103+
-------
104+
description : list[tuple]
105+
A DB-API description, as a list of 7-item tuples.
106+
107+
See Also
108+
--------
109+
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
110+
111+
"""
112+
...
113+
90114
@abc.abstractmethod
91115
def import_array_stream(self, handle: _lib.ArrowArrayStreamHandle) -> typing.Any:
92116
"""Import an Arrow stream."""
@@ -120,6 +144,12 @@ def convert_executemany_parameters(
120144
status_code=_lib.AdbcStatusCode.INVALID_STATE,
121145
)
122146

147+
def convert_description(self, schema: "CapsuleType") -> typing.List[typing.Tuple]:
148+
raise _lib.ProgrammingError(
149+
"This API requires PyArrow or another suitable backend to be installed",
150+
status_code=_lib.AdbcStatusCode.INVALID_STATE,
151+
)
152+
123153
def import_array_stream(
124154
self, handle: _lib.ArrowArrayStreamHandle
125155
) -> _lib.ArrowArrayStreamHandle:
@@ -174,6 +204,11 @@ def convert_executemany_parameters(
174204
cols, bind_by_name = param_iterable_to_dict(parameters)
175205
return polars.DataFrame(cols), bind_by_name
176206

207+
def convert_description(
208+
self, schema: "CapsuleType"
209+
) -> typing.List[typing.Tuple]:
210+
raise _lib.NotSupportedError("Polars does not support __arrow_c_schema__")
211+
177212
def import_array_stream(
178213
self, handle: _lib.ArrowArrayStreamHandle
179214
) -> typing.Any:
@@ -207,6 +242,14 @@ def convert_executemany_parameters(
207242
cols, bind_by_name = param_iterable_to_dict(parameters)
208243
return pyarrow.RecordBatch.from_pydict(cols), bind_by_name
209244

245+
def convert_description(
246+
self, schema: "CapsuleType"
247+
) -> typing.List[typing.Tuple]:
248+
s = pyarrow.Schema._import_from_c_capsule(schema)
249+
return [
250+
(field.name, field.type, None, None, None, None, None) for field in s
251+
]
252+
210253
def import_array_stream(
211254
self, handle: _lib.ArrowArrayStreamHandle
212255
) -> pyarrow.RecordBatchReader:

python/adbc_driver_manager/adbc_driver_manager/_lib.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ cdef extern from "arrow-adbc/adbc.h" nogil:
3232
cdef struct CArrowArray"ArrowArray":
3333
CArrowArrayRelease release
3434

35-
ctypedef int (*CArrowArrayStreamGetLastError)(void*)
35+
ctypedef char* (*CArrowArrayStreamGetLastError)(void*)
3636
ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*)
37-
ctypedef char* (*CArrowArrayStreamGetSchema)(void*, CArrowSchema*)
37+
ctypedef int (*CArrowArrayStreamGetSchema)(void*, CArrowSchema*)
3838
ctypedef void (*CArrowArrayStreamRelease)(void*)
3939

4040
cdef struct CArrowArrayStream"ArrowArrayStream":

python/adbc_driver_manager/adbc_driver_manager/_lib.pyx

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,28 @@ cdef class ArrowArrayStreamHandle:
479479
self.stream.release(&self.stream)
480480
self.stream.release = NULL
481481

482+
def __arrow_c_schema__(self) -> object:
483+
"""Get a PyCapsule without consuming this object."""
484+
cdef const char* err = NULL
485+
486+
if not self.is_valid:
487+
raise ValueError("ArrowArrayStreamHandle already consumed")
488+
489+
cdef CArrowSchema* allocated = <CArrowSchema*> malloc(sizeof(CArrowSchema))
490+
allocated.release = NULL
491+
capsule = PyCapsule_New(
492+
<void*>allocated, "arrow_schema", &pycapsule_schema_deleter,
493+
)
494+
rc = self.stream.get_schema(&self.stream, allocated)
495+
if rc != 0:
496+
err = self.stream.get_last_error(&self.stream)
497+
if err == NULL:
498+
raise RuntimeError(f"Failed to get schema: ({rc})")
499+
else:
500+
s = err.decode()
501+
raise RuntimeError(f"Failed to get schema: ({rc}) {s}")
502+
return capsule
503+
482504
def __arrow_c_stream__(self, requested_schema=None) -> object:
483505
"""Consume this object to get a PyCapsule."""
484506
if requested_schema is not None:

python/adbc_driver_manager/adbc_driver_manager/dbapi.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def cursor(
370370
adbc_stmt_kwargs : dict, optional
371371
ADBC-specific options to pass to the underlying ADBC statement.
372372
"""
373-
return Cursor(self, adbc_stmt_kwargs)
373+
return Cursor(self, adbc_stmt_kwargs, dbapi_backend=self._backend)
374374

375375
def rollback(self) -> None:
376376
"""Explicitly rollback."""
@@ -624,9 +624,12 @@ def __init__(
624624
self,
625625
conn: Connection,
626626
adbc_stmt_kwargs: Optional[Dict[str, Any]] = None,
627+
*,
628+
dbapi_backend: Optional[_dbapi_backend.DbapiBackend] = None,
627629
) -> None:
628630
# Must be at top in case __init__ is interrupted and then __del__ is called
629631
self._closed = True
632+
self._backend = dbapi_backend or _dbapi_backend.default_backend()
630633
self._conn = conn
631634
self._stmt = _lib.AdbcStatement(conn._conn)
632635
self._closed = False
@@ -772,7 +775,7 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> "Self":
772775
handle, self._rowcount = _blocking_call(
773776
self._stmt.execute_query, (), {}, self._stmt.cancel
774777
)
775-
self._results = _RowIterator(self._stmt, handle)
778+
self._results = _RowIterator(self._stmt, handle, self._backend)
776779
return self
777780

778781
def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
@@ -1141,7 +1144,7 @@ def adbc_read_partition(self, partition: bytes) -> None:
11411144
self._conn._conn.read_partition, (partition,), {}, self._stmt.cancel
11421145
)
11431146
self._rowcount = -1
1144-
self._results = _RowIterator(self._stmt, handle)
1147+
self._results = _RowIterator(self._stmt, handle, self._backend)
11451148

11461149
@property
11471150
def adbc_statement(self) -> _lib.AdbcStatement:
@@ -1261,8 +1264,8 @@ def fetch_arrow(self) -> _lib.ArrowArrayStreamHandle:
12611264
Fetch the result as an object implementing the Arrow PyCapsule interface.
12621265
12631266
This can only be called once. It must be called before any other
1264-
method that inspect the data (e.g. description, fetchone,
1265-
fetch_arrow_table, etc.). Once this is called, other methods that
1267+
method that consume data (e.g. fetchone, fetch_arrow_table, etc.;
1268+
description is allowed). Once this is called, other methods that
12661269
inspect the data may not be called.
12671270
12681271
Notes
@@ -1285,10 +1288,14 @@ class _RowIterator(_Closeable):
12851288
"""Track state needed to iterate over the result set."""
12861289

12871290
def __init__(
1288-
self, stmt: _lib.AdbcStatement, handle: _lib.ArrowArrayStreamHandle
1291+
self,
1292+
stmt: _lib.AdbcStatement,
1293+
handle: _lib.ArrowArrayStreamHandle,
1294+
dbapi_backend: _dbapi_backend.DbapiBackend,
12891295
) -> None:
12901296
self._stmt = stmt
12911297
self._handle: Optional[_lib.ArrowArrayStreamHandle] = handle
1298+
self._backend = dbapi_backend
12921299
self._reader: Optional["_reader.AdbcRecordBatchReader"] = None
12931300
self._current_batch = None
12941301
self._next_row = 0
@@ -1321,10 +1328,16 @@ def reader(self) -> "_reader.AdbcRecordBatchReader":
13211328

13221329
@property
13231330
def description(self) -> List[tuple]:
1324-
return [
1325-
(field.name, field.type, None, None, None, None, None)
1326-
for field in self.reader.schema
1327-
]
1331+
if self._handle is None:
1332+
# Invalid state, or already imported into the reader
1333+
# (we assume PyArrow here for now)
1334+
return [
1335+
(field.name, field.type, None, None, None, None, None)
1336+
for field in self.reader.schema
1337+
]
1338+
else:
1339+
# Not yet imported into the reader. Do not force consumption
1340+
return self._backend.convert_description(self._handle.__arrow_c_schema__())
13281341

13291342
def fetchone(self) -> Optional[tuple]:
13301343
if self._current_batch is None or self._next_row >= len(self._current_batch):

python/adbc_driver_manager/tests/test_dbapi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,30 @@ def test_query_fetch_arrow(sqlite):
243243
cur.fetch_arrow()
244244

245245

246+
@pytest.mark.sqlite
247+
def test_query_fetch_arrow_3543(sqlite):
248+
# Regression test for https://github.com/apache/arrow-adbc/issues/3543
249+
with sqlite.cursor() as cur:
250+
cur.execute("SELECT 1, 'foo' AS foo, 2.0")
251+
252+
# This should not consume the result
253+
assert cur.description == [
254+
("1", dbapi.NUMBER, None, None, None, None, None),
255+
("foo", dbapi.STRING, None, None, None, None, None),
256+
("2.0", dbapi.NUMBER, None, None, None, None, None),
257+
]
258+
259+
capsule = cur.fetch_arrow().__arrow_c_stream__()
260+
reader = pyarrow.RecordBatchReader._import_from_c_capsule(capsule)
261+
assert reader.read_all() == pyarrow.table(
262+
{
263+
"1": [1],
264+
"foo": ["foo"],
265+
"2.0": [2.0],
266+
}
267+
)
268+
269+
246270
@pytest.mark.sqlite
247271
def test_query_fetch_arrow_table(sqlite):
248272
with sqlite.cursor() as cur:

0 commit comments

Comments
 (0)