1
1
import datetime
2
2
import decimal
3
3
from enum import Enum , auto
4
- from typing import Optional , Sequence
4
+ from typing import Optional , Sequence , Any
5
5
6
6
from databricks .sql .exc import NotSupportedError
7
7
from databricks .sql .thrift_api .TCLIService .ttypes import (
8
8
TSparkParameter ,
9
9
TSparkParameterValue ,
10
+ TSparkParameterValueArg ,
10
11
)
11
12
12
13
import datetime
@@ -54,7 +55,17 @@ class DatabricksSupportedType(Enum):
54
55
55
56
56
57
TAllowedParameterValue = Union [
57
- str , int , float , datetime .datetime , datetime .date , bool , decimal .Decimal , None
58
+ str ,
59
+ int ,
60
+ float ,
61
+ datetime .datetime ,
62
+ datetime .date ,
63
+ bool ,
64
+ decimal .Decimal ,
65
+ None ,
66
+ list ,
67
+ dict ,
68
+ tuple ,
58
69
]
59
70
60
71
@@ -82,6 +93,7 @@ class DbsqlParameterBase:
82
93
83
94
CAST_EXPR : str
84
95
name : Optional [str ]
96
+ value : Any
85
97
86
98
def as_tspark_param (self , named : bool ) -> TSparkParameter :
87
99
"""Returns a TSparkParameter object that can be passed to the DBR thrift server."""
@@ -98,6 +110,10 @@ def as_tspark_param(self, named: bool) -> TSparkParameter:
98
110
def _tspark_param_value (self ):
99
111
return TSparkParameterValue (stringValue = str (self .value ))
100
112
113
+ def _tspark_value_arg (self ):
114
+ """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
115
+ return TSparkParameterValueArg (value = str (self .value ), type = self ._cast_expr ())
116
+
101
117
def _cast_expr (self ):
102
118
return self .CAST_EXPR
103
119
@@ -428,6 +444,99 @@ def __init__(self, value: int, name: Optional[str] = None):
428
444
CAST_EXPR = DatabricksSupportedType .TINYINT .name
429
445
430
446
447
+ class ArrayParameter (DbsqlParameterBase ):
448
+ """Wrap a Python `Sequence` that will be bound to a Databricks SQL ARRAY type."""
449
+
450
+ def __init__ (self , value : Sequence [Any ], name : Optional [str ] = None ):
451
+ """
452
+ :value:
453
+ The value to bind for this parameter. This will be casted to a ARRAY.
454
+ :name:
455
+ If None, your query must contain a `?` marker. Like:
456
+
457
+ ```sql
458
+ SELECT * FROM table WHERE field = ?
459
+ ```
460
+ If not None, your query should contain a named parameter marker. Like:
461
+ ```sql
462
+ SELECT * FROM table WHERE field = :my_param
463
+ ```
464
+
465
+ The `name` argument to this function would be `my_param`.
466
+ """
467
+ self .name = name
468
+ self .value = [dbsql_parameter_from_primitive (val ) for val in value ]
469
+
470
+ def as_tspark_param (self , named : bool = False ) -> TSparkParameter :
471
+ """Returns a TSparkParameter object that can be passed to the DBR thrift server."""
472
+
473
+ tsp = TSparkParameter (type = self ._cast_expr ())
474
+ tsp .arguments = [val ._tspark_value_arg () for val in self .value ]
475
+
476
+ if named :
477
+ tsp .name = self .name
478
+ tsp .ordinal = False
479
+ elif not named :
480
+ tsp .ordinal = True
481
+ return tsp
482
+
483
+ def _tspark_value_arg (self ):
484
+ """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
485
+ tva = TSparkParameterValueArg (type = self ._cast_expr ())
486
+ tva .arguments = [val ._tspark_value_arg () for val in self .value ]
487
+ return tva
488
+
489
+ CAST_EXPR = DatabricksSupportedType .ARRAY .name
490
+
491
+
492
+ class MapParameter (DbsqlParameterBase ):
493
+ """Wrap a Python `dict` that will be bound to a Databricks SQL MAP type."""
494
+
495
+ def __init__ (self , value : dict , name : Optional [str ] = None ):
496
+ """
497
+ :value:
498
+ The value to bind for this parameter. This will be casted to a MAP.
499
+ :name:
500
+ If None, your query must contain a `?` marker. Like:
501
+
502
+ ```sql
503
+ SELECT * FROM table WHERE field = ?
504
+ ```
505
+ If not None, your query should contain a named parameter marker. Like:
506
+ ```sql
507
+ SELECT * FROM table WHERE field = :my_param
508
+ ```
509
+
510
+ The `name` argument to this function would be `my_param`.
511
+ """
512
+ self .name = name
513
+ self .value = [
514
+ dbsql_parameter_from_primitive (item )
515
+ for key , val in value .items ()
516
+ for item in (key , val )
517
+ ]
518
+
519
+ def as_tspark_param (self , named : bool = False ) -> TSparkParameter :
520
+ """Returns a TSparkParameter object that can be passed to the DBR thrift server."""
521
+
522
+ tsp = TSparkParameter (type = self ._cast_expr ())
523
+ tsp .arguments = [val ._tspark_value_arg () for val in self .value ]
524
+ if named :
525
+ tsp .name = self .name
526
+ tsp .ordinal = False
527
+ elif not named :
528
+ tsp .ordinal = True
529
+ return tsp
530
+
531
+ def _tspark_value_arg (self ):
532
+ """Returns a TSparkParameterValueArg object that can be passed to the DBR thrift server."""
533
+ tva = TSparkParameterValueArg (type = self ._cast_expr ())
534
+ tva .arguments = [val ._tspark_value_arg () for val in self .value ]
535
+ return tva
536
+
537
+ CAST_EXPR = DatabricksSupportedType .MAP .name
538
+
539
+
431
540
class DecimalParameter (DbsqlParameterBase ):
432
541
"""Wrap a Python `Decimal` that will be bound to a Databricks SQL DECIMAL type."""
433
542
@@ -543,23 +652,26 @@ def dbsql_parameter_from_primitive(
543
652
# havoc. We can't use TYPE_INFERRENCE_MAP because mypy doesn't trust
544
653
# its logic
545
654
546
- if type (value ) is int :
655
+ if isinstance (value , bool ):
656
+ return BooleanParameter (value = value , name = name )
657
+ elif isinstance (value , int ):
547
658
return dbsql_parameter_from_int (value , name = name )
548
- elif type (value ) is str :
659
+ elif isinstance (value , str ) :
549
660
return StringParameter (value = value , name = name )
550
- elif type (value ) is float :
661
+ elif isinstance (value , float ) :
551
662
return FloatParameter (value = value , name = name )
552
- elif type (value ) is datetime .datetime :
663
+ elif isinstance (value , datetime .datetime ) :
553
664
return TimestampParameter (value = value , name = name )
554
- elif type (value ) is datetime .date :
665
+ elif isinstance (value , datetime .date ) :
555
666
return DateParameter (value = value , name = name )
556
- elif type (value ) is bool :
557
- return BooleanParameter (value = value , name = name )
558
- elif type (value ) is decimal .Decimal :
667
+ elif isinstance (value , decimal .Decimal ):
559
668
return DecimalParameter (value = value , name = name )
669
+ elif isinstance (value , dict ):
670
+ return MapParameter (value = value , name = name )
671
+ elif isinstance (value , Sequence ) and not isinstance (value , str ):
672
+ return ArrayParameter (value = value , name = name )
560
673
elif value is None :
561
674
return VoidParameter (value = value , name = name )
562
-
563
675
else :
564
676
raise NotSupportedError (
565
677
f"Could not infer parameter type from value: { value } - { type (value )} \n "
@@ -581,6 +693,8 @@ def dbsql_parameter_from_primitive(
581
693
TimestampNTZParameter ,
582
694
TinyIntParameter ,
583
695
DecimalParameter ,
696
+ ArrayParameter ,
697
+ MapParameter ,
584
698
]
585
699
586
700
0 commit comments