6
6
7
7
from __future__ import annotations
8
8
9
+ import io
9
10
import logging
10
11
from typing import (
11
12
List ,
12
13
Optional ,
13
14
Any ,
14
- Callable ,
15
15
cast ,
16
16
TYPE_CHECKING ,
17
17
)
20
20
from databricks .sql .backend .sea .result_set import SeaResultSet
21
21
22
22
from databricks .sql .backend .types import ExecuteResponse
23
+ from databricks .sql .backend .sea .models .base import ResultData
24
+ from databricks .sql .backend .sea .backend import SeaDatabricksClient
25
+ from databricks .sql .utils import CloudFetchQueue , ArrowQueue
26
+
27
+ try :
28
+ import pyarrow
29
+ import pyarrow .compute as pc
30
+ except ImportError :
31
+ pyarrow = None
32
+ pc = None
23
33
24
34
logger = logging .getLogger (__name__ )
25
35
@@ -30,32 +40,18 @@ class ResultSetFilter:
30
40
"""
31
41
32
42
@staticmethod
33
- def _filter_sea_result_set (
34
- result_set : SeaResultSet , filter_func : Callable [[List [Any ]], bool ]
35
- ) -> SeaResultSet :
43
+ def _create_execute_response (result_set : SeaResultSet ) -> ExecuteResponse :
36
44
"""
37
- Filter a SEA result set using the provided filter function .
45
+ Create an ExecuteResponse with parameters from the original result set .
38
46
39
47
Args:
40
- result_set: The SEA result set to filter
41
- filter_func: Function that takes a row and returns True if the row should be included
48
+ result_set: Original result set to copy parameters from
42
49
43
50
Returns:
44
- A filtered SEA result set
51
+ ExecuteResponse: New execute response object
45
52
"""
46
-
47
- # Get all remaining rows
48
- all_rows = result_set .results .remaining_rows ()
49
-
50
- # Filter rows
51
- filtered_rows = [row for row in all_rows if filter_func (row )]
52
-
53
- # Reuse the command_id from the original result set
54
- command_id = result_set .command_id
55
-
56
- # Create an ExecuteResponse for the filtered data
57
- execute_response = ExecuteResponse (
58
- command_id = command_id ,
53
+ return ExecuteResponse (
54
+ command_id = result_set .command_id ,
59
55
status = result_set .status ,
60
56
description = result_set .description ,
61
57
has_been_closed_server_side = result_set .has_been_closed_server_side ,
@@ -64,32 +60,145 @@ def _filter_sea_result_set(
64
60
is_staging_operation = False ,
65
61
)
66
62
67
- # Create a new ResultData object with filtered data
68
- from databricks .sql .backend .sea .models .base import ResultData
63
+ @staticmethod
64
+ def _update_manifest (result_set : SeaResultSet , new_row_count : int ):
65
+ """
66
+ Create a copy of the manifest with updated row count.
67
+
68
+ Args:
69
+ result_set: Original result set to copy manifest from
70
+ new_row_count: New total row count for filtered data
69
71
70
- result_data = ResultData (data = filtered_rows , external_links = None )
72
+ Returns:
73
+ Updated manifest copy
74
+ """
75
+ filtered_manifest = result_set .manifest
76
+ filtered_manifest .total_row_count = new_row_count
77
+ return filtered_manifest
71
78
72
- from databricks .sql .backend .sea .backend import SeaDatabricksClient
79
+ @staticmethod
80
+ def _create_filtered_result_set (
81
+ result_set : SeaResultSet ,
82
+ result_data : ResultData ,
83
+ row_count : int ,
84
+ ) -> "SeaResultSet" :
85
+ """
86
+ Create a new filtered SeaResultSet with the provided data.
87
+
88
+ Args:
89
+ result_set: Original result set to copy parameters from
90
+ result_data: New result data for the filtered set
91
+ row_count: Number of rows in the filtered data
92
+
93
+ Returns:
94
+ New filtered SeaResultSet
95
+ """
73
96
from databricks .sql .backend .sea .result_set import SeaResultSet
74
97
75
- # Create a new SeaResultSet with the filtered data
76
- manifest = result_set .manifest
77
- manifest .total_row_count = len (filtered_rows )
98
+ execute_response = ResultSetFilter ._create_execute_response (result_set )
99
+ filtered_manifest = ResultSetFilter ._update_manifest (result_set , row_count )
78
100
79
- filtered_result_set = SeaResultSet (
101
+ return SeaResultSet (
80
102
connection = result_set .connection ,
81
103
execute_response = execute_response ,
82
104
sea_client = cast (SeaDatabricksClient , result_set .backend ),
83
105
result_data = result_data ,
84
- manifest = manifest ,
106
+ manifest = filtered_manifest ,
85
107
buffer_size_bytes = result_set .buffer_size_bytes ,
86
108
arraysize = result_set .arraysize ,
87
109
)
88
110
89
- return filtered_result_set
111
+ @staticmethod
112
+ def _filter_arrow_table (
113
+ table : Any , # pyarrow.Table
114
+ column_name : str ,
115
+ allowed_values : List [str ],
116
+ case_sensitive : bool = True ,
117
+ ) -> Any : # returns pyarrow.Table
118
+ """
119
+ Filter a PyArrow table by column values.
120
+
121
+ Args:
122
+ table: The PyArrow table to filter
123
+ column_name: The name of the column to filter on
124
+ allowed_values: List of allowed values for the column
125
+ case_sensitive: Whether to perform case-sensitive comparison
126
+
127
+ Returns:
128
+ A filtered PyArrow table
129
+ """
130
+ if not pyarrow :
131
+ raise ImportError ("PyArrow is required for Arrow table filtering" )
132
+
133
+ if table .num_rows == 0 :
134
+ return table
135
+
136
+ # Handle case-insensitive filtering by normalizing both column and allowed values
137
+ if not case_sensitive :
138
+ # Convert allowed values to uppercase
139
+ allowed_values = [v .upper () for v in allowed_values ]
140
+ # Get column values as uppercase
141
+ column = pc .utf8_upper (table [column_name ])
142
+ else :
143
+ # Use column as-is
144
+ column = table [column_name ]
145
+
146
+ # Convert allowed_values to PyArrow Array
147
+ allowed_array = pyarrow .array (allowed_values )
148
+
149
+ # Construct a boolean mask: True where column is in allowed_list
150
+ mask = pc .is_in (column , value_set = allowed_array )
151
+ return table .filter (mask )
152
+
153
+ @staticmethod
154
+ def _filter_arrow_result_set (
155
+ result_set : SeaResultSet ,
156
+ column_index : int ,
157
+ allowed_values : List [str ],
158
+ case_sensitive : bool = True ,
159
+ ) -> SeaResultSet :
160
+ """
161
+ Filter a SEA result set that contains Arrow tables.
162
+
163
+ Args:
164
+ result_set: The SEA result set to filter (containing Arrow data)
165
+ column_index: The index of the column to filter on
166
+ allowed_values: List of allowed values for the column
167
+ case_sensitive: Whether to perform case-sensitive comparison
168
+
169
+ Returns:
170
+ A filtered SEA result set
171
+ """
172
+ # Validate column index and get column name
173
+ if column_index >= len (result_set .description ):
174
+ raise ValueError (f"Column index { column_index } is out of bounds" )
175
+ column_name = result_set .description [column_index ][0 ]
176
+
177
+ # Get all remaining rows as Arrow table and filter it
178
+ arrow_table = result_set .results .remaining_rows ()
179
+ filtered_table = ResultSetFilter ._filter_arrow_table (
180
+ arrow_table , column_name , allowed_values , case_sensitive
181
+ )
182
+
183
+ # Convert the filtered table to Arrow stream format for ResultData
184
+ sink = io .BytesIO ()
185
+ with pyarrow .ipc .new_stream (sink , filtered_table .schema ) as writer :
186
+ writer .write_table (filtered_table )
187
+ arrow_stream_bytes = sink .getvalue ()
188
+
189
+ # Create ResultData with attachment containing the filtered data
190
+ result_data = ResultData (
191
+ data = None , # No JSON data
192
+ external_links = None , # No external links
193
+ attachment = arrow_stream_bytes , # Arrow data as attachment
194
+ )
195
+
196
+ return ResultSetFilter ._create_filtered_result_set (
197
+ result_set , result_data , filtered_table .num_rows
198
+ )
90
199
91
200
@staticmethod
92
- def filter_by_column_values (
201
+ def _filter_json_result_set (
93
202
result_set : SeaResultSet ,
94
203
column_index : int ,
95
204
allowed_values : List [str ],
@@ -107,22 +216,35 @@ def filter_by_column_values(
107
216
Returns:
108
217
A filtered result set
109
218
"""
219
+ # Validate column index (optional - not in arrow version but good practice)
220
+ if column_index >= len (result_set .description ):
221
+ raise ValueError (f"Column index { column_index } is out of bounds" )
110
222
111
- # Convert to uppercase for case-insensitive comparison if needed
223
+ # Extract rows
224
+ all_rows = result_set .results .remaining_rows ()
225
+
226
+ # Convert allowed values if case-insensitive
112
227
if not case_sensitive :
113
228
allowed_values = [v .upper () for v in allowed_values ]
229
+ # Helper lambda to get column value based on case sensitivity
230
+ get_column_value = (
231
+ lambda row : row [column_index ].upper ()
232
+ if not case_sensitive
233
+ else row [column_index ]
234
+ )
235
+
236
+ # Filter rows based on allowed values
237
+ filtered_rows = [
238
+ row
239
+ for row in all_rows
240
+ if len (row ) > column_index and get_column_value (row ) in allowed_values
241
+ ]
242
+
243
+ # Create filtered result set
244
+ result_data = ResultData (data = filtered_rows , external_links = None )
114
245
115
- return ResultSetFilter ._filter_sea_result_set (
116
- result_set ,
117
- lambda row : (
118
- len (row ) > column_index
119
- and (
120
- row [column_index ].upper ()
121
- if not case_sensitive
122
- else row [column_index ]
123
- )
124
- in allowed_values
125
- ),
246
+ return ResultSetFilter ._create_filtered_result_set (
247
+ result_set , result_data , len (filtered_rows )
126
248
)
127
249
128
250
@staticmethod
@@ -143,14 +265,25 @@ def filter_tables_by_type(
143
265
Returns:
144
266
A filtered result set containing only tables of the specified types
145
267
"""
146
-
147
268
# Default table types if none specified
148
269
DEFAULT_TABLE_TYPES = ["TABLE" , "VIEW" , "SYSTEM TABLE" ]
149
- valid_types = (
150
- table_types if table_types and len (table_types ) > 0 else DEFAULT_TABLE_TYPES
151
- )
270
+ valid_types = table_types if table_types else DEFAULT_TABLE_TYPES
152
271
272
+ # Check if we have an Arrow table (cloud fetch) or JSON data
153
273
# Table type is the 6th column (index 5)
154
- return ResultSetFilter .filter_by_column_values (
155
- result_set , 5 , valid_types , case_sensitive = True
156
- )
274
+ if isinstance (result_set .results , (CloudFetchQueue , ArrowQueue )):
275
+ # For Arrow tables, we need to handle filtering differently
276
+ return ResultSetFilter ._filter_arrow_result_set (
277
+ result_set ,
278
+ column_index = 5 ,
279
+ allowed_values = valid_types ,
280
+ case_sensitive = True ,
281
+ )
282
+ else :
283
+ # For JSON data, use the existing filter method
284
+ return ResultSetFilter ._filter_json_result_set (
285
+ result_set ,
286
+ column_index = 5 ,
287
+ allowed_values = valid_types ,
288
+ case_sensitive = True ,
289
+ )
0 commit comments