Skip to content

Commit 36d3ec4

Browse files
SEA: Allow large metadata responses (databricks#653)
* remove redundant conversion.py Signed-off-by: varun-edachali-dbx <[email protected]> * fix type issues Signed-off-by: varun-edachali-dbx <[email protected]> * ValueError not ProgrammingError Signed-off-by: varun-edachali-dbx <[email protected]> * reduce diff Signed-off-by: varun-edachali-dbx <[email protected]> * introduce SEA cloudfetch e2e tests Signed-off-by: varun-edachali-dbx <[email protected]> * allow empty cloudfetch result Signed-off-by: varun-edachali-dbx <[email protected]> * add unit tests for CloudFetchQueue and SeaResultSet Signed-off-by: varun-edachali-dbx <[email protected]> * skip pyarrow dependent tests Signed-off-by: varun-edachali-dbx <[email protected]> * simplify download process: no pre-fetching Signed-off-by: varun-edachali-dbx <[email protected]> * correct class name in logs Signed-off-by: varun-edachali-dbx <[email protected]> * align with old impl Signed-off-by: varun-edachali-dbx <[email protected]> * align next_n_rows with prev imple Signed-off-by: varun-edachali-dbx <[email protected]> * align remaining_rows with prev impl Signed-off-by: varun-edachali-dbx <[email protected]> * remove un-necessary Optional params Signed-off-by: varun-edachali-dbx <[email protected]> * remove un-necessary changes in thrift field if tests Signed-off-by: varun-edachali-dbx <[email protected]> * remove unused imports Signed-off-by: varun-edachali-dbx <[email protected]> * init hybrid * run large queries Signed-off-by: varun-edachali-dbx <[email protected]> * hybrid disposition Signed-off-by: varun-edachali-dbx <[email protected]> * remove un-ncessary log Signed-off-by: varun-edachali-dbx <[email protected]> * formatting (black) Signed-off-by: varun-edachali-dbx <[email protected]> * remove redundant tests Signed-off-by: varun-edachali-dbx <[email protected]> * multi frame decompression of lz4 Signed-off-by: varun-edachali-dbx <[email protected]> * ensure no compression (temp) Signed-off-by: varun-edachali-dbx <[email protected]> * introduce separate link fetcher Signed-off-by: varun-edachali-dbx <[email protected]> * log time to create table Signed-off-by: varun-edachali-dbx <[email protected]> * add chunk index to table creation time log Signed-off-by: varun-edachali-dbx <[email protected]> * remove custom multi-frame decompressor for lz4 Signed-off-by: varun-edachali-dbx <[email protected]> * remove excess logs * remove redundant tests (temp) Signed-off-by: varun-edachali-dbx <[email protected]> * add link to download manager before notifying consumer Signed-off-by: varun-edachali-dbx <[email protected]> * move link fetching immediately before table creation so link expiry is not an issue Signed-off-by: varun-edachali-dbx <[email protected]> * resolve merge artifacts Signed-off-by: varun-edachali-dbx <[email protected]> * remove redundant methods Signed-off-by: varun-edachali-dbx <[email protected]> * formatting (black) Signed-off-by: varun-edachali-dbx <[email protected]> * introduce callback to handle link expiry Signed-off-by: varun-edachali-dbx <[email protected]> * fix types Signed-off-by: varun-edachali-dbx <[email protected]> * fix param type in unit tests Signed-off-by: varun-edachali-dbx <[email protected]> * formatting + minor type fixes Signed-off-by: varun-edachali-dbx <[email protected]> * Revert "introduce callback to handle link expiry" This reverts commit bd51b1c. * remove unused callback (to be introduced later) Signed-off-by: varun-edachali-dbx <[email protected]> * correct param extraction Signed-off-by: varun-edachali-dbx <[email protected]> * remove common constructor for databricks client abc Signed-off-by: varun-edachali-dbx <[email protected]> * make SEA Http Client instance a private member Signed-off-by: varun-edachali-dbx <[email protected]> * make GetChunksResponse model more robust Signed-off-by: varun-edachali-dbx <[email protected]> * add link to doc of GetChunk response model Signed-off-by: varun-edachali-dbx <[email protected]> * pass result_data instead of "initial links" into SeaCloudFetchQueue Signed-off-by: varun-edachali-dbx <[email protected]> * move download_manager init into parent CloudFetchQueue Signed-off-by: varun-edachali-dbx <[email protected]> * raise ServerOperationError for no 0th chunk Signed-off-by: varun-edachali-dbx <[email protected]> * unused iports Signed-off-by: varun-edachali-dbx <[email protected]> * return None in case of empty respose Signed-off-by: varun-edachali-dbx <[email protected]> * ensure table is empty on no initial link s Signed-off-by: varun-edachali-dbx <[email protected]> * account for total chunk count Signed-off-by: varun-edachali-dbx <[email protected]> * iterate by chunk index instead of link Signed-off-by: varun-edachali-dbx <[email protected]> * make LinkFetcher convert link static Signed-off-by: varun-edachali-dbx <[email protected]> * add helper for link addition, check for edge case to prevent inf wait Signed-off-by: varun-edachali-dbx <[email protected]> * add unit tests for LinkFetcher Signed-off-by: varun-edachali-dbx <[email protected]> * remove un-necessary download manager check Signed-off-by: varun-edachali-dbx <[email protected]> * remove un-necessary string literals around param type Signed-off-by: varun-edachali-dbx <[email protected]> * remove duplicate download_manager init Signed-off-by: varun-edachali-dbx <[email protected]> * account for empty response in LinkFetcher init Signed-off-by: varun-edachali-dbx <[email protected]> * make get_chunk_link return mandatory ExternalLink Signed-off-by: varun-edachali-dbx <[email protected]> * set shutdown_event instead of breaking on completion so get_chunk_link is informed Signed-off-by: varun-edachali-dbx <[email protected]> * docstrings, logging, pydoc Signed-off-by: varun-edachali-dbx <[email protected]> * use total_chunk_cound > 0 Signed-off-by: varun-edachali-dbx <[email protected]> * clarify that link has already been submitted on getting row_offset Signed-off-by: varun-edachali-dbx <[email protected]> * return None for out of range Signed-off-by: varun-edachali-dbx <[email protected]> * default link_fetcher to None Signed-off-by: varun-edachali-dbx <[email protected]> --------- Signed-off-by: varun-edachali-dbx <[email protected]> * Chunk download latency (databricks#634) * chunk download latency Signed-off-by: Sai Shree Pradhan <[email protected]> * formatting Signed-off-by: Sai Shree Pradhan <[email protected]> * test fixes Signed-off-by: Sai Shree Pradhan <[email protected]> * sea-migration static type checking fixes Signed-off-by: Sai Shree Pradhan <[email protected]> * check types fix Signed-off-by: Sai Shree Pradhan <[email protected]> * fix type issues Signed-off-by: varun-edachali-dbx <[email protected]> * type fix revert Signed-off-by: Sai Shree Pradhan <[email protected]> * - Signed-off-by: Sai Shree Pradhan <[email protected]> * statement id in get metadata functions Signed-off-by: Sai Shree Pradhan <[email protected]> * removed result set extractor Signed-off-by: Sai Shree Pradhan <[email protected]> * databricks client type Signed-off-by: Sai Shree Pradhan <[email protected]> * formatting Signed-off-by: Sai Shree Pradhan <[email protected]> * remove defaults, fix chunk id Signed-off-by: Sai Shree Pradhan <[email protected]> * added statement type to command id Signed-off-by: Sai Shree Pradhan <[email protected]> * check types fix Signed-off-by: Sai Shree Pradhan <[email protected]> * renamed chunk_id to num_downloaded_chunks Signed-off-by: Sai Shree Pradhan <[email protected]> * set statement type to query for chunk download Signed-off-by: Sai Shree Pradhan <[email protected]> * comment fix Signed-off-by: Sai Shree Pradhan <[email protected]> * removed dup check for trowset Signed-off-by: Sai Shree Pradhan <[email protected]> --------- Signed-off-by: Sai Shree Pradhan <[email protected]> * acquire lock before notif + formatting (black) Signed-off-by: varun-edachali-dbx <[email protected]> * fix imports Signed-off-by: varun-edachali-dbx <[email protected]> * add get_chunk_link s Signed-off-by: varun-edachali-dbx <[email protected]> * simplify description extraction Signed-off-by: varun-edachali-dbx <[email protected]> * pass session_id_hex to ThriftResultSet Signed-off-by: varun-edachali-dbx <[email protected]> * revert to main's extract description Signed-off-by: varun-edachali-dbx <[email protected]> * validate row count for sync query tests as well Signed-off-by: varun-edachali-dbx <[email protected]> * guid_hex -> hex_guid Signed-off-by: varun-edachali-dbx <[email protected]> * reduce diff Signed-off-by: varun-edachali-dbx <[email protected]> * reduce diff Signed-off-by: varun-edachali-dbx <[email protected]> * reduce diff Signed-off-by: varun-edachali-dbx <[email protected]> * set .value in compression Signed-off-by: varun-edachali-dbx <[email protected]> * reduce diff Signed-off-by: varun-edachali-dbx <[email protected]> * is_direct_results -> has_more_rows Signed-off-by: varun-edachali-dbx <[email protected]> * preliminary large metadata results Signed-off-by: varun-edachali-dbx <[email protected]> * account for empty table in arrow table filter Signed-off-by: varun-edachali-dbx <[email protected]> * align flows Signed-off-by: varun-edachali-dbx <[email protected]> * align flow of json with arrow Signed-off-by: varun-edachali-dbx <[email protected]> * case sensitive support for arrow table Signed-off-by: varun-edachali-dbx <[email protected]> * remove un-necessary comment Signed-off-by: varun-edachali-dbx <[email protected]> * fix merge artifacts Signed-off-by: varun-edachali-dbx <[email protected]> * remove redundant method Signed-off-by: varun-edachali-dbx <[email protected]> * remove incorrect docstring Signed-off-by: varun-edachali-dbx <[email protected]> * remove deepcopy Signed-off-by: varun-edachali-dbx <[email protected]> --------- Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 3b0c882 commit 36d3ec4

File tree

4 files changed

+314
-89
lines changed

4 files changed

+314
-89
lines changed

src/databricks/sql/backend/sea/backend.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
)
159159

160160
self.use_hybrid_disposition = kwargs.get("use_hybrid_disposition", True)
161+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
161162

162163
# Extract warehouse ID from http_path
163164
self.warehouse_id = self._extract_warehouse_id(http_path)
@@ -694,7 +695,7 @@ def get_catalogs(
694695
max_bytes=max_bytes,
695696
lz4_compression=False,
696697
cursor=cursor,
697-
use_cloud_fetch=False,
698+
use_cloud_fetch=self.use_cloud_fetch,
698699
parameters=[],
699700
async_op=False,
700701
enforce_embedded_schema_correctness=False,
@@ -727,7 +728,7 @@ def get_schemas(
727728
max_bytes=max_bytes,
728729
lz4_compression=False,
729730
cursor=cursor,
730-
use_cloud_fetch=False,
731+
use_cloud_fetch=self.use_cloud_fetch,
731732
parameters=[],
732733
async_op=False,
733734
enforce_embedded_schema_correctness=False,
@@ -768,7 +769,7 @@ def get_tables(
768769
max_bytes=max_bytes,
769770
lz4_compression=False,
770771
cursor=cursor,
771-
use_cloud_fetch=False,
772+
use_cloud_fetch=self.use_cloud_fetch,
772773
parameters=[],
773774
async_op=False,
774775
enforce_embedded_schema_correctness=False,
@@ -815,7 +816,7 @@ def get_columns(
815816
max_bytes=max_bytes,
816817
lz4_compression=False,
817818
cursor=cursor,
818-
use_cloud_fetch=False,
819+
use_cloud_fetch=self.use_cloud_fetch,
819820
parameters=[],
820821
async_op=False,
821822
enforce_embedded_schema_correctness=False,

src/databricks/sql/backend/sea/utils/filters.py

Lines changed: 184 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
from __future__ import annotations
88

9+
import io
910
import logging
1011
from typing import (
1112
List,
1213
Optional,
1314
Any,
14-
Callable,
1515
cast,
1616
TYPE_CHECKING,
1717
)
@@ -20,6 +20,16 @@
2020
from databricks.sql.backend.sea.result_set import SeaResultSet
2121

2222
from databricks.sql.backend.types import ExecuteResponse
23+
from databricks.sql.backend.sea.models.base import ResultData
24+
from databricks.sql.backend.sea.backend import SeaDatabricksClient
25+
from databricks.sql.utils import CloudFetchQueue, ArrowQueue
26+
27+
try:
28+
import pyarrow
29+
import pyarrow.compute as pc
30+
except ImportError:
31+
pyarrow = None
32+
pc = None
2333

2434
logger = logging.getLogger(__name__)
2535

@@ -30,32 +40,18 @@ class ResultSetFilter:
3040
"""
3141

3242
@staticmethod
33-
def _filter_sea_result_set(
34-
result_set: SeaResultSet, filter_func: Callable[[List[Any]], bool]
35-
) -> SeaResultSet:
43+
def _create_execute_response(result_set: SeaResultSet) -> ExecuteResponse:
3644
"""
37-
Filter a SEA result set using the provided filter function.
45+
Create an ExecuteResponse with parameters from the original result set.
3846
3947
Args:
40-
result_set: The SEA result set to filter
41-
filter_func: Function that takes a row and returns True if the row should be included
48+
result_set: Original result set to copy parameters from
4249
4350
Returns:
44-
A filtered SEA result set
51+
ExecuteResponse: New execute response object
4552
"""
46-
47-
# Get all remaining rows
48-
all_rows = result_set.results.remaining_rows()
49-
50-
# Filter rows
51-
filtered_rows = [row for row in all_rows if filter_func(row)]
52-
53-
# Reuse the command_id from the original result set
54-
command_id = result_set.command_id
55-
56-
# Create an ExecuteResponse for the filtered data
57-
execute_response = ExecuteResponse(
58-
command_id=command_id,
53+
return ExecuteResponse(
54+
command_id=result_set.command_id,
5955
status=result_set.status,
6056
description=result_set.description,
6157
has_been_closed_server_side=result_set.has_been_closed_server_side,
@@ -64,32 +60,145 @@ def _filter_sea_result_set(
6460
is_staging_operation=False,
6561
)
6662

67-
# Create a new ResultData object with filtered data
68-
from databricks.sql.backend.sea.models.base import ResultData
63+
@staticmethod
64+
def _update_manifest(result_set: SeaResultSet, new_row_count: int):
65+
"""
66+
Create a copy of the manifest with updated row count.
67+
68+
Args:
69+
result_set: Original result set to copy manifest from
70+
new_row_count: New total row count for filtered data
6971
70-
result_data = ResultData(data=filtered_rows, external_links=None)
72+
Returns:
73+
Updated manifest copy
74+
"""
75+
filtered_manifest = result_set.manifest
76+
filtered_manifest.total_row_count = new_row_count
77+
return filtered_manifest
7178

72-
from databricks.sql.backend.sea.backend import SeaDatabricksClient
79+
@staticmethod
80+
def _create_filtered_result_set(
81+
result_set: SeaResultSet,
82+
result_data: ResultData,
83+
row_count: int,
84+
) -> "SeaResultSet":
85+
"""
86+
Create a new filtered SeaResultSet with the provided data.
87+
88+
Args:
89+
result_set: Original result set to copy parameters from
90+
result_data: New result data for the filtered set
91+
row_count: Number of rows in the filtered data
92+
93+
Returns:
94+
New filtered SeaResultSet
95+
"""
7396
from databricks.sql.backend.sea.result_set import SeaResultSet
7497

75-
# Create a new SeaResultSet with the filtered data
76-
manifest = result_set.manifest
77-
manifest.total_row_count = len(filtered_rows)
98+
execute_response = ResultSetFilter._create_execute_response(result_set)
99+
filtered_manifest = ResultSetFilter._update_manifest(result_set, row_count)
78100

79-
filtered_result_set = SeaResultSet(
101+
return SeaResultSet(
80102
connection=result_set.connection,
81103
execute_response=execute_response,
82104
sea_client=cast(SeaDatabricksClient, result_set.backend),
83105
result_data=result_data,
84-
manifest=manifest,
106+
manifest=filtered_manifest,
85107
buffer_size_bytes=result_set.buffer_size_bytes,
86108
arraysize=result_set.arraysize,
87109
)
88110

89-
return filtered_result_set
111+
@staticmethod
112+
def _filter_arrow_table(
113+
table: Any, # pyarrow.Table
114+
column_name: str,
115+
allowed_values: List[str],
116+
case_sensitive: bool = True,
117+
) -> Any: # returns pyarrow.Table
118+
"""
119+
Filter a PyArrow table by column values.
120+
121+
Args:
122+
table: The PyArrow table to filter
123+
column_name: The name of the column to filter on
124+
allowed_values: List of allowed values for the column
125+
case_sensitive: Whether to perform case-sensitive comparison
126+
127+
Returns:
128+
A filtered PyArrow table
129+
"""
130+
if not pyarrow:
131+
raise ImportError("PyArrow is required for Arrow table filtering")
132+
133+
if table.num_rows == 0:
134+
return table
135+
136+
# Handle case-insensitive filtering by normalizing both column and allowed values
137+
if not case_sensitive:
138+
# Convert allowed values to uppercase
139+
allowed_values = [v.upper() for v in allowed_values]
140+
# Get column values as uppercase
141+
column = pc.utf8_upper(table[column_name])
142+
else:
143+
# Use column as-is
144+
column = table[column_name]
145+
146+
# Convert allowed_values to PyArrow Array
147+
allowed_array = pyarrow.array(allowed_values)
148+
149+
# Construct a boolean mask: True where column is in allowed_list
150+
mask = pc.is_in(column, value_set=allowed_array)
151+
return table.filter(mask)
152+
153+
@staticmethod
154+
def _filter_arrow_result_set(
155+
result_set: SeaResultSet,
156+
column_index: int,
157+
allowed_values: List[str],
158+
case_sensitive: bool = True,
159+
) -> SeaResultSet:
160+
"""
161+
Filter a SEA result set that contains Arrow tables.
162+
163+
Args:
164+
result_set: The SEA result set to filter (containing Arrow data)
165+
column_index: The index of the column to filter on
166+
allowed_values: List of allowed values for the column
167+
case_sensitive: Whether to perform case-sensitive comparison
168+
169+
Returns:
170+
A filtered SEA result set
171+
"""
172+
# Validate column index and get column name
173+
if column_index >= len(result_set.description):
174+
raise ValueError(f"Column index {column_index} is out of bounds")
175+
column_name = result_set.description[column_index][0]
176+
177+
# Get all remaining rows as Arrow table and filter it
178+
arrow_table = result_set.results.remaining_rows()
179+
filtered_table = ResultSetFilter._filter_arrow_table(
180+
arrow_table, column_name, allowed_values, case_sensitive
181+
)
182+
183+
# Convert the filtered table to Arrow stream format for ResultData
184+
sink = io.BytesIO()
185+
with pyarrow.ipc.new_stream(sink, filtered_table.schema) as writer:
186+
writer.write_table(filtered_table)
187+
arrow_stream_bytes = sink.getvalue()
188+
189+
# Create ResultData with attachment containing the filtered data
190+
result_data = ResultData(
191+
data=None, # No JSON data
192+
external_links=None, # No external links
193+
attachment=arrow_stream_bytes, # Arrow data as attachment
194+
)
195+
196+
return ResultSetFilter._create_filtered_result_set(
197+
result_set, result_data, filtered_table.num_rows
198+
)
90199

91200
@staticmethod
92-
def filter_by_column_values(
201+
def _filter_json_result_set(
93202
result_set: SeaResultSet,
94203
column_index: int,
95204
allowed_values: List[str],
@@ -107,22 +216,35 @@ def filter_by_column_values(
107216
Returns:
108217
A filtered result set
109218
"""
219+
# Validate column index (optional - not in arrow version but good practice)
220+
if column_index >= len(result_set.description):
221+
raise ValueError(f"Column index {column_index} is out of bounds")
110222

111-
# Convert to uppercase for case-insensitive comparison if needed
223+
# Extract rows
224+
all_rows = result_set.results.remaining_rows()
225+
226+
# Convert allowed values if case-insensitive
112227
if not case_sensitive:
113228
allowed_values = [v.upper() for v in allowed_values]
229+
# Helper lambda to get column value based on case sensitivity
230+
get_column_value = (
231+
lambda row: row[column_index].upper()
232+
if not case_sensitive
233+
else row[column_index]
234+
)
235+
236+
# Filter rows based on allowed values
237+
filtered_rows = [
238+
row
239+
for row in all_rows
240+
if len(row) > column_index and get_column_value(row) in allowed_values
241+
]
242+
243+
# Create filtered result set
244+
result_data = ResultData(data=filtered_rows, external_links=None)
114245

115-
return ResultSetFilter._filter_sea_result_set(
116-
result_set,
117-
lambda row: (
118-
len(row) > column_index
119-
and (
120-
row[column_index].upper()
121-
if not case_sensitive
122-
else row[column_index]
123-
)
124-
in allowed_values
125-
),
246+
return ResultSetFilter._create_filtered_result_set(
247+
result_set, result_data, len(filtered_rows)
126248
)
127249

128250
@staticmethod
@@ -143,14 +265,25 @@ def filter_tables_by_type(
143265
Returns:
144266
A filtered result set containing only tables of the specified types
145267
"""
146-
147268
# Default table types if none specified
148269
DEFAULT_TABLE_TYPES = ["TABLE", "VIEW", "SYSTEM TABLE"]
149-
valid_types = (
150-
table_types if table_types and len(table_types) > 0 else DEFAULT_TABLE_TYPES
151-
)
270+
valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
152271

272+
# Check if we have an Arrow table (cloud fetch) or JSON data
153273
# Table type is the 6th column (index 5)
154-
return ResultSetFilter.filter_by_column_values(
155-
result_set, 5, valid_types, case_sensitive=True
156-
)
274+
if isinstance(result_set.results, (CloudFetchQueue, ArrowQueue)):
275+
# For Arrow tables, we need to handle filtering differently
276+
return ResultSetFilter._filter_arrow_result_set(
277+
result_set,
278+
column_index=5,
279+
allowed_values=valid_types,
280+
case_sensitive=True,
281+
)
282+
else:
283+
# For JSON data, use the existing filter method
284+
return ResultSetFilter._filter_json_result_set(
285+
result_set,
286+
column_index=5,
287+
allowed_values=valid_types,
288+
case_sensitive=True,
289+
)

0 commit comments

Comments
 (0)