Skip to content

Commit ad2bb62

Browse files
authored
[PECOBLR-330] Support for complex params (#559)
* Basic testing * testing examples * Basic working prototype * ttypes fix * Refractored the ttypes * nit * Added inline support * Reordered boolean to be above int * Check Working e2e tests prototype * More tests * Added unit tests * refractor * nit * nit * nit * nit
1 parent 3842583 commit ad2bb62

File tree

10 files changed

+478
-67
lines changed

10 files changed

+478
-67
lines changed

src/databricks/sql/parameters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
TimestampNTZParameter,
1313
TinyIntParameter,
1414
DecimalParameter,
15+
MapParameter,
16+
ArrayParameter,
1517
)

src/databricks/sql/parameters/native.py

Lines changed: 125 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import datetime
22
import decimal
33
from enum import Enum, auto
4-
from typing import Optional, Sequence
4+
from typing import Optional, Sequence, Any
55

66
from databricks.sql.exc import NotSupportedError
77
from databricks.sql.thrift_api.TCLIService.ttypes import (
88
TSparkParameter,
99
TSparkParameterValue,
10+
TSparkParameterValueArg,
1011
)
1112

1213
import datetime
@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):
5455

5556

5657
TAllowedParameterValue = Union[
57-
str, int, float, datetime.datetime, datetime.date, bool, decimal.Decimal, None
58+
str,
59+
int,
60+
float,
61+
datetime.datetime,
62+
datetime.date,
63+
bool,
64+
decimal.Decimal,
65+
None,
66+
list,
67+
dict,
68+
tuple,
5869
]
5970

6071

@@ -82,6 +93,7 @@ class DbsqlParameterBase:
8293

8394
CAST_EXPR: str
8495
name: Optional[str]
96+
value: Any
8597

8698
def as_tspark_param(self, named: bool) -> TSparkParameter:
8799
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
98110
def _tspark_param_value(self):
99111
return TSparkParameterValue(stringValue=str(self.value))
100112

113+
def _tspark_value_arg(self):
114+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
115+
return TSparkParameterValueArg(value=str(self.value), type=self._cast_expr())
116+
101117
def _cast_expr(self):
102118
return self.CAST_EXPR
103119

@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
428444
CAST_EXPR = DatabricksSupportedType.TINYINT.name
429445

430446

447+
class ArrayParameter(DbsqlParameterBase):
448+
"""Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""
449+
450+
def __init__(self, value: Sequence[Any], name: Optional[str] = None):
451+
"""
452+
:value:
453+
The value to bind for this parameter. This will be casted to a ARRAY.
454+
:name:
455+
If None, your query must contain a `?` marker. Like:
456+
457+
```sql
458+
SELECT * FROM table WHERE field = ?
459+
```
460+
If not None, your query should contain a named parameter marker. Like:
461+
```sql
462+
SELECT * FROM table WHERE field = :my_param
463+
```
464+
465+
The `name` argument to this function would be `my_param`.
466+
"""
467+
self.name = name
468+
self.value = [dbsql_parameter_from_primitive(val) for val in value]
469+
470+
def as_tspark_param(self, named: bool = False) -> TSparkParameter:
471+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472+
473+
tsp = TSparkParameter(type=self._cast_expr())
474+
tsp.arguments = [val._tspark_value_arg() for val in self.value]
475+
476+
if named:
477+
tsp.name = self.name
478+
tsp.ordinal = False
479+
elif not named:
480+
tsp.ordinal = True
481+
return tsp
482+
483+
def _tspark_value_arg(self):
484+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
485+
tva = TSparkParameterValueArg(type=self._cast_expr())
486+
tva.arguments = [val._tspark_value_arg() for val in self.value]
487+
return tva
488+
489+
CAST_EXPR = DatabricksSupportedType.ARRAY.name
490+
491+
492+
class MapParameter(DbsqlParameterBase):
493+
"""Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""
494+
495+
def __init__(self, value: dict, name: Optional[str] = None):
496+
"""
497+
:value:
498+
The value to bind for this parameter. This will be casted to a MAP.
499+
:name:
500+
If None, your query must contain a `?` marker. Like:
501+
502+
```sql
503+
SELECT * FROM table WHERE field = ?
504+
```
505+
If not None, your query should contain a named parameter marker. Like:
506+
```sql
507+
SELECT * FROM table WHERE field = :my_param
508+
```
509+
510+
The `name` argument to this function would be `my_param`.
511+
"""
512+
self.name = name
513+
self.value = [
514+
dbsql_parameter_from_primitive(item)
515+
for key, val in value.items()
516+
for item in (key, val)
517+
]
518+
519+
def as_tspark_param(self, named: bool = False) -> TSparkParameter:
520+
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521+
522+
tsp = TSparkParameter(type=self._cast_expr())
523+
tsp.arguments = [val._tspark_value_arg() for val in self.value]
524+
if named:
525+
tsp.name = self.name
526+
tsp.ordinal = False
527+
elif not named:
528+
tsp.ordinal = True
529+
return tsp
530+
531+
def _tspark_value_arg(self):
532+
"""Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
533+
tva = TSparkParameterValueArg(type=self._cast_expr())
534+
tva.arguments = [val._tspark_value_arg() for val in self.value]
535+
return tva
536+
537+
CAST_EXPR = DatabricksSupportedType.MAP.name
538+
539+
431540
class DecimalParameter(DbsqlParameterBase):
432541
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""
433542

@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
543652
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
544653
# its logic
545654

546-
if type(value) is int:
655+
if isinstance(value, bool):
656+
return BooleanParameter(value=value, name=name)
657+
elif isinstance(value, int):
547658
return dbsql_parameter_from_int(value, name=name)
548-
elif type(value) is str:
659+
elif isinstance(value, str):
549660
return StringParameter(value=value, name=name)
550-
elif type(value) is float:
661+
elif isinstance(value, float):
551662
return FloatParameter(value=value, name=name)
552-
elif type(value) is datetime.datetime:
663+
elif isinstance(value, datetime.datetime):
553664
return TimestampParameter(value=value, name=name)
554-
elif type(value) is datetime.date:
665+
elif isinstance(value, datetime.date):
555666
return DateParameter(value=value, name=name)
556-
elif type(value) is bool:
557-
return BooleanParameter(value=value, name=name)
558-
elif type(value) is decimal.Decimal:
667+
elif isinstance(value, decimal.Decimal):
559668
return DecimalParameter(value=value, name=name)
669+
elif isinstance(value, dict):
670+
return MapParameter(value=value, name=name)
671+
elif isinstance(value, Sequence) and not isinstance(value, str):
672+
return ArrayParameter(value=value, name=name)
560673
elif value is None:
561674
return VoidParameter(value=value, name=name)
562-
563675
else:
564676
raise NotSupportedError(
565677
f"Could not infer parameter type from value: {value} - {type(value)} \n"
@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
581693
TimestampNTZParameter,
582694
TinyIntParameter,
583695
DecimalParameter,
696+
ArrayParameter,
697+
MapParameter,
584698
]
585699

586700

src/databricks/sql/utils.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import decimal
66
from abc import ABC, abstractmethod
77
from collections import OrderedDict, namedtuple
8-
from collections.abc import Iterable
8+
from collections.abc import Mapping
99
from decimal import Decimal
1010
from enum import Enum
11-
from typing import Any, Dict, List, Optional, Union
11+
from typing import Any, Dict, List, Optional, Union, Sequence
1212
import re
1313

1414
import lz4.frame
@@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed):
429429
# Taken from PyHive
430430
class ParamEscaper:
431431
_DATE_FORMAT = "%Y-%m-%d"
432-
_TIME_FORMAT = "%H:%M:%S.%f"
432+
_TIME_FORMAT = "%H:%M:%S.%f %z"
433433
_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)
434434

435435
def escape_args(self, parameters):
@@ -458,13 +458,22 @@ def escape_string(self, item):
458458
return "'{}'".format(item.replace("\\", "\\\\").replace("'", "\\'"))
459459

460460
def escape_sequence(self, item):
461-
l = map(str, map(self.escape_item, item))
462-
return "(" + ",".join(l) + ")"
461+
l = map(self.escape_item, item)
462+
l = list(map(str, l))
463+
return "ARRAY(" + ",".join(l) + ")"
464+
465+
def escape_mapping(self, item):
466+
l = map(
467+
self.escape_item,
468+
(element for key, value in item.items() for element in (key, value)),
469+
)
470+
l = list(map(str, l))
471+
return "MAP(" + ",".join(l) + ")"
463472

464473
def escape_datetime(self, item, format, cutoff=0):
465474
dt_str = item.strftime(format)
466475
formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str
467-
return "'{}'".format(formatted)
476+
return "'{}'".format(formatted.strip())
468477

469478
def escape_decimal(self, item):
470479
return str(item)
@@ -476,14 +485,16 @@ def escape_item(self, item):
476485
return self.escape_number(item)
477486
elif isinstance(item, str):
478487
return self.escape_string(item)
479-
elif isinstance(item, Iterable):
480-
return self.escape_sequence(item)
481488
elif isinstance(item, datetime.datetime):
482489
return self.escape_datetime(item, self._DATETIME_FORMAT)
483490
elif isinstance(item, datetime.date):
484491
return self.escape_datetime(item, self._DATE_FORMAT)
485492
elif isinstance(item, decimal.Decimal):
486493
return self.escape_decimal(item)
494+
elif isinstance(item, Sequence):
495+
return self.escape_sequence(item)
496+
elif isinstance(item, Mapping):
497+
return self.escape_mapping(item)
487498
else:
488499
raise exc.ProgrammingError("Unsupported object {}".format(item))
489500

tests/e2e/test_complex_types.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from numpy import ndarray
3+
from typing import Sequence
34

45
from tests.e2e.test_driver import PySQLPytestTestCase
56

@@ -14,50 +15,73 @@ def table_fixture(self, connection_details):
1415
# Create the table
1516
cursor.execute(
1617
"""
17-
CREATE TABLE IF NOT EXISTS pysql_e2e_test_complex_types_table (
18+
CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table (
1819
array_col ARRAY<STRING>,
1920
map_col MAP<STRING, INTEGER>,
20-
struct_col STRUCT<field1: STRING, field2: INTEGER>
21-
)
21+
struct_col STRUCT<field1: STRING, field2: INTEGER>,
22+
array_array_col ARRAY<ARRAY<STRING>>,
23+
array_map_col ARRAY<MAP<STRING, INTEGER>>,
24+
map_array_col MAP<STRING, ARRAY<STRING>>
25+
) USING DELTA
2226
"""
2327
)
2428
# Insert a record
2529
cursor.execute(
2630
"""
27-
INSERT INTO pysql_e2e_test_complex_types_table
31+
INSERT INTO pysql_test_complex_types_table
2832
VALUES (
2933
ARRAY('a', 'b', 'c'),
3034
MAP('a', 1, 'b', 2, 'c', 3),
31-
NAMED_STRUCT('field1', 'a', 'field2', 1)
35+
NAMED_STRUCT('field1', 'a', 'field2', 1),
36+
ARRAY(ARRAY('a','b','c')),
37+
ARRAY(MAP('a', 1, 'b', 2, 'c', 3)),
38+
MAP('a', ARRAY('a', 'b', 'c'), 'b', ARRAY('d', 'e'))
3239
)
3340
"""
3441
)
3542
yield
3643
# Clean up the table after the test
37-
cursor.execute("DROP TABLE IF EXISTS pysql_e2e_test_complex_types_table")
44+
cursor.execute("DELETE FROM pysql_test_complex_types_table")
3845

3946
@pytest.mark.parametrize(
4047
"field,expected_type",
41-
[("array_col", ndarray), ("map_col", list), ("struct_col", dict)],
48+
[
49+
("array_col", ndarray),
50+
("map_col", list),
51+
("struct_col", dict),
52+
("array_array_col", ndarray),
53+
("array_map_col", ndarray),
54+
("map_array_col", list),
55+
],
4256
)
4357
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
4458
"""Confirms the return types of a complex type field when reading as arrow"""
4559

4660
with self.cursor() as cursor:
4761
result = cursor.execute(
48-
"SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1"
62+
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
4963
).fetchone()
5064

5165
assert isinstance(result[field], expected_type)
5266

53-
@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
67+
@pytest.mark.parametrize(
68+
"field",
69+
[
70+
("array_col"),
71+
("map_col"),
72+
("struct_col"),
73+
("array_array_col"),
74+
("array_map_col"),
75+
("map_array_col"),
76+
],
77+
)
5478
def test_read_complex_types_as_string(self, field, table_fixture):
5579
"""Confirms the return type of a complex type that is returned as a string"""
5680
with self.cursor(
5781
extra_params={"_use_arrow_native_complex_types": False}
5882
) as cursor:
5983
result = cursor.execute(
60-
"SELECT * FROM pysql_e2e_test_complex_types_table LIMIT 1"
84+
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
6185
).fetchone()
6286

6387
assert isinstance(result[field], str)

tests/e2e/test_driver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog):
856856
raise KeyboardInterrupt("Simulated interrupt")
857857
finally:
858858
if conn is not None:
859-
assert not conn.open, "Connection should be closed after KeyboardInterrupt"
859+
assert (
860+
not conn.open
861+
), "Connection should be closed after KeyboardInterrupt"
860862

861863
def test_cursor_close_properly_closes_operation(self):
862864
"""Test that Cursor.close() properly closes the active operation handle on the server."""
@@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self):
883885
raise KeyboardInterrupt("Simulated interrupt")
884886
finally:
885887
if cursor is not None:
886-
assert not cursor.open, "Cursor should be closed after KeyboardInterrupt"
888+
assert (
889+
not cursor.open
890+
), "Cursor should be closed after KeyboardInterrupt"
887891

888892
def test_nested_cursor_context_managers(self):
889893
"""Test that nested cursor context managers properly close operations on the server."""

0 commit comments

Comments
 (0)