Skip to content

Commit fd81c5a

Browse files
authored
Concat tables to be backward compatible (#647)
* fixed * Minor fix * more types
1 parent 701f7f6 commit fd81c5a

File tree

3 files changed

+77
-34
lines changed

3 files changed

+77
-34
lines changed

src/databricks/sql/result_set.py

Lines changed: 15 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from databricks.sql.utils import (
2121
ColumnTable,
2222
ColumnQueue,
23+
concat_table_chunks,
2324
)
2425
from databricks.sql.backend.types import CommandId, CommandState, ExecuteResponse
2526
from databricks.sql.telemetry.models.event import StatementType
@@ -296,23 +297,6 @@ def _convert_columnar_table(self, table):
296297

297298
return result
298299

299-
def merge_columnar(self, result1, result2) -> "ColumnTable":
300-
"""
301-
Function to merge / combining the columnar results into a single result
302-
:param result1:
303-
:param result2:
304-
:return:
305-
"""
306-
307-
if result1.column_names != result2.column_names:
308-
raise ValueError("The columns in the results don't match")
309-
310-
merged_result = [
311-
result1.column_table[i] + result2.column_table[i]
312-
for i in range(result1.num_columns)
313-
]
314-
return ColumnTable(merged_result, result1.column_names)
315-
316300
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
317301
"""
318302
Fetch the next set of rows of a query result, returning a PyArrow table.
@@ -337,7 +321,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
337321
n_remaining_rows -= partial_results.num_rows
338322
self._next_row_index += partial_results.num_rows
339323

340-
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
324+
return concat_table_chunks(partial_result_chunks)
341325

342326
def fetchmany_columnar(self, size: int):
343327
"""
@@ -350,19 +334,19 @@ def fetchmany_columnar(self, size: int):
350334
results = self.results.next_n_rows(size)
351335
n_remaining_rows = size - results.num_rows
352336
self._next_row_index += results.num_rows
353-
337+
partial_result_chunks = [results]
354338
while (
355339
n_remaining_rows > 0
356340
and not self.has_been_closed_server_side
357341
and self.has_more_rows
358342
):
359343
self._fill_results_buffer()
360344
partial_results = self.results.next_n_rows(n_remaining_rows)
361-
results = self.merge_columnar(results, partial_results)
345+
partial_result_chunks.append(partial_results)
362346
n_remaining_rows -= partial_results.num_rows
363347
self._next_row_index += partial_results.num_rows
364348

365-
return results
349+
return concat_table_chunks(partial_result_chunks)
366350

367351
def fetchall_arrow(self) -> "pyarrow.Table":
368352
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
@@ -372,36 +356,34 @@ def fetchall_arrow(self) -> "pyarrow.Table":
372356
while not self.has_been_closed_server_side and self.has_more_rows:
373357
self._fill_results_buffer()
374358
partial_results = self.results.remaining_rows()
375-
if isinstance(results, ColumnTable) and isinstance(
376-
partial_results, ColumnTable
377-
):
378-
results = self.merge_columnar(results, partial_results)
379-
else:
380-
partial_result_chunks.append(partial_results)
359+
partial_result_chunks.append(partial_results)
381360
self._next_row_index += partial_results.num_rows
382361

362+
result_table = concat_table_chunks(partial_result_chunks)
383363
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
384364
# Valid only for metadata commands result set
385-
if isinstance(results, ColumnTable) and pyarrow:
365+
if isinstance(result_table, ColumnTable) and pyarrow:
386366
data = {
387367
name: col
388-
for name, col in zip(results.column_names, results.column_table)
368+
for name, col in zip(
369+
result_table.column_names, result_table.column_table
370+
)
389371
}
390372
return pyarrow.Table.from_pydict(data)
391-
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
373+
return result_table
392374

393375
def fetchall_columnar(self):
394376
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
395377
results = self.results.remaining_rows()
396378
self._next_row_index += results.num_rows
397-
379+
partial_result_chunks = [results]
398380
while not self.has_been_closed_server_side and self.has_more_rows:
399381
self._fill_results_buffer()
400382
partial_results = self.results.remaining_rows()
401-
results = self.merge_columnar(results, partial_results)
383+
partial_result_chunks.append(partial_results)
402384
self._next_row_index += partial_results.num_rows
403385

404-
return results
386+
return concat_table_chunks(partial_result_chunks)
405387

406388
def fetchone(self) -> Optional[Row]:
407389
"""

src/databricks/sql/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,3 +853,25 @@ def _create_python_tuple(t_col_value_wrapper):
853853
result[i] = None
854854

855855
return tuple(result)
856+
857+
858+
def concat_table_chunks(
859+
table_chunks: List[Union["pyarrow.Table", ColumnTable]]
860+
) -> Union["pyarrow.Table", ColumnTable]:
861+
if len(table_chunks) == 0:
862+
return table_chunks
863+
864+
if isinstance(table_chunks[0], ColumnTable):
865+
## Check if all have the same column names
866+
if not all(
867+
table.column_names == table_chunks[0].column_names for table in table_chunks
868+
):
869+
raise ValueError("The columns in the results don't match")
870+
871+
result_table: List[List[Any]] = [[] for _ in range(table_chunks[0].num_columns)]
872+
for i in range(0, len(table_chunks)):
873+
for j in range(table_chunks[i].num_columns):
874+
result_table[j].extend(table_chunks[i].column_table[j])
875+
return ColumnTable(result_table, table_chunks[0].column_names)
876+
else:
877+
return pyarrow.concat_tables(table_chunks, use_threads=True)

tests/unit/test_util.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
import decimal
22
import datetime
33
from datetime import timezone, timedelta
4+
import pytest
5+
from databricks.sql.utils import (
6+
convert_to_assigned_datatypes_in_column_table,
7+
ColumnTable,
8+
concat_table_chunks,
9+
)
410

5-
from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table
11+
try:
12+
import pyarrow
13+
except ImportError:
14+
pyarrow = None
615

716

817
class TestUtils:
@@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self):
122131
for index, entry in enumerate(converted_column_table):
123132
assert entry[0] == expected_convertion[index][0]
124133
assert isinstance(entry[0], expected_convertion[index][1])
134+
135+
def test_concat_table_chunks_column_table(self):
136+
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
137+
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"])
138+
139+
result_table = concat_table_chunks([column_table1, column_table2])
140+
141+
assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]]
142+
assert result_table.column_names == ["col1", "col2"]
143+
144+
@pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed")
145+
def test_concat_table_chunks_arrow_table(self):
146+
arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]})
147+
arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]})
148+
149+
result_table = concat_table_chunks([arrow_table1, arrow_table2])
150+
assert result_table.column_names == ["col1", "col2"]
151+
assert result_table.column("col1").to_pylist() == [1, 2, 3, 4]
152+
assert result_table.column("col2").to_pylist() == [5, 6, 7, 8]
153+
154+
def test_concat_table_chunks_empty(self):
155+
result_table = concat_table_chunks([])
156+
assert result_table == []
157+
158+
def test_concat_table_chunks__incorrect_column_names_error(self):
159+
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
160+
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"])
161+
162+
with pytest.raises(ValueError):
163+
concat_table_chunks([column_table1, column_table2])

0 commit comments

Comments
 (0)