-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy path__init__.py
450 lines (371 loc) · 16 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
"""
Experimental
Work in progress, breaking changes are possible.
"""
import collections
import collections.abc
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
import sqlalchemy as sa
import ydb
from sqlalchemy import util
from sqlalchemy.engine import characteristics, reflection
from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import functions
from sqlalchemy.sql.elements import ClauseList
import ydb_dbapi
from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncConnection
from ydb_sqlalchemy.sqlalchemy.dml import Upsert
from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler
from . import types
OLD_SA = sa.__version__ < "2."
class ParametrizedFunction(functions.Function):
__visit_name__ = "parametrized_function"
def __init__(self, name, params, *args, **kwargs):
super(ParametrizedFunction, self).__init__(name, *args, **kwargs)
self._func_name = name
self._func_params = params
self.params_expr = ClauseList(operator=functions.operators.comma_op, group_contents=True, *params).self_group()
def upsert(table):
return Upsert(table)
COLUMN_TYPES = {
ydb.PrimitiveType.Int8: sa.INTEGER,
ydb.PrimitiveType.Int16: sa.INTEGER,
ydb.PrimitiveType.Int32: sa.INTEGER,
ydb.PrimitiveType.Int64: sa.INTEGER,
ydb.PrimitiveType.Uint8: sa.INTEGER,
ydb.PrimitiveType.Uint16: sa.INTEGER,
ydb.PrimitiveType.Uint32: types.UInt32,
ydb.PrimitiveType.Uint64: types.UInt64,
ydb.PrimitiveType.Float: sa.FLOAT,
ydb.PrimitiveType.Double: sa.FLOAT,
ydb.PrimitiveType.String: sa.BINARY,
ydb.PrimitiveType.Utf8: sa.TEXT,
ydb.PrimitiveType.Json: sa.JSON,
ydb.PrimitiveType.JsonDocument: sa.JSON,
ydb.DecimalType: sa.DECIMAL,
ydb.PrimitiveType.Yson: sa.TEXT,
ydb.PrimitiveType.Date: sa.DATE,
ydb.PrimitiveType.Datetime: sa.DATETIME,
ydb.PrimitiveType.Timestamp: sa.TIMESTAMP,
ydb.PrimitiveType.Interval: sa.INTEGER,
ydb.PrimitiveType.Bool: sa.BOOLEAN,
ydb.PrimitiveType.DyNumber: sa.TEXT,
}
def _get_column_info(t):
nullable = False
if isinstance(t, ydb.OptionalType):
nullable = True
t = t.item
if isinstance(t, ydb.DecimalType):
return sa.DECIMAL(precision=t.precision, scale=t.scale), nullable
return COLUMN_TYPES[t], nullable
class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic):
def reset_characteristic(self, dialect: "YDBDialect", dbapi_connection: ydb_dbapi.Connection) -> None:
dialect.reset_ydb_request_settings(dbapi_connection)
def set_characteristic(
self, dialect: "YDBDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings
) -> None:
dialect.set_ydb_request_settings(dbapi_connection, value)
def get_characteristic(
self, dialect: "YDBDialect", dbapi_connection: ydb_dbapi.Connection
) -> ydb.BaseRequestSettings:
return dialect.get_ydb_request_settings(dbapi_connection)
class YDBDialect(StrCompileDialect):
name = "ydb"
driver = "ydb_sync"
supports_alter = False
max_identifier_length = 63
supports_sane_rowcount = False
supports_statement_cache = True
supports_native_enum = False
supports_native_boolean = True
supports_native_decimal = True
supports_smallserial = False
supports_schemas = False
supports_constraint_comments = False
supports_json_type = True
insert_returning = False
update_returning = False
delete_returning = False
supports_sequences = False
sequences_optional = False
preexecute_autoincrement_sequences = True
postfetch_lastrowid = False
supports_default_values = False
supports_empty_insert = False
supports_multivalues_insert = True
default_paramstyle = "qmark"
isolation_level = None
preparer = YqlIdentifierPreparer
statement_compiler = YqlCompiler
ddl_compiler = YqlDDLCompiler
type_compiler = YqlTypeCompiler
colspecs = {
sa.types.JSON: types.YqlJSON,
sa.types.JSON.JSONPathType: types.YqlJSON.YqlJSONPathType,
sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds
sa.types.DATETIME: types.YqlDateTime,
sa.types.TIMESTAMP: types.YqlTimestamp,
}
connection_characteristics = util.immutabledict(
{
"isolation_level": characteristics.IsolationLevelCharacteristic(),
"ydb_request_settings": YdbRequestSettingsCharacteristic(),
}
)
construct_arguments = [
(
sa.schema.Table,
{
"auto_partitioning_by_size": None,
"auto_partitioning_by_load": None,
"auto_partitioning_partition_size_mb": None,
"auto_partitioning_min_partitions_count": None,
"auto_partitioning_max_partitions_count": None,
"uniform_partitions": None,
"partition_at_keys": None,
},
),
(
sa.schema.Index,
{
"async": False,
"cover": [],
},
),
]
@classmethod
def import_dbapi(cls: Any):
return ydb_dbapi
@classmethod
def dbapi(cls):
return cls.import_dbapi()
def __init__(
self,
json_serializer=None,
json_deserializer=None,
_add_declare_for_yql_stmt_vars=False,
**kwargs,
):
super().__init__(**kwargs)
self._json_deserializer = json_deserializer
self._json_serializer = json_serializer
# NOTE: _add_declare_for_yql_stmt_vars is temporary and is soon to be removed.
# no need in declare in yql statement here since ydb 24-1
self._add_declare_for_yql_stmt_vars = _add_declare_for_yql_stmt_vars
def _describe_table(self, connection, table_name, schema=None) -> ydb.TableDescription:
if schema is not None:
raise ydb_dbapi.NotSupportedError("unsupported on non empty schema")
qt = table_name if isinstance(table_name, str) else table_name.name
raw_conn = connection.connection
try:
return raw_conn.describe(qt)
except ydb_dbapi.DatabaseError as e:
raise NoSuchTableError(qt) from e
def get_view_names(self, connection, schema=None, **kw: Any):
return []
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
table = self._describe_table(connection, table_name, schema)
as_compatible = []
for column in table.columns:
col_type, nullable = _get_column_info(column.type)
as_compatible.append(
{
"name": column.name,
"type": col_type,
"nullable": nullable,
"default": None,
}
)
return as_compatible
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
if schema:
raise ydb_dbapi.NotSupportedError("unsupported on non empty schema")
raw_conn = connection.connection
return raw_conn.get_table_names()
@reflection.cache
def has_table(self, connection, table_name, schema=None, **kwargs):
try:
self._describe_table(connection, table_name, schema)
return True
except NoSuchTableError:
return False
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kwargs):
table = self._describe_table(connection, table_name, schema)
return {"constrained_columns": table.primary_key, "name": None}
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kwargs):
# foreign keys unsupported
return []
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kwargs):
table = self._describe_table(connection, table_name, schema)
indexes: list[ydb.TableIndex] = table.indexes
if OLD_SA:
sa_indexes: list[dict] = []
for index in indexes:
sa_indexes.append(
{
"name": index.name,
"column_names": index.index_columns,
"unique": False,
"dialect_options": {
"ydb_async": False, # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/351
"ydb_cover": [], # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/409
},
}
)
return sa_indexes
sa_indexes: list[sa.engine.interfaces.ReflectedIndex] = []
for index in indexes:
sa_indexes.append(
sa.engine.interfaces.ReflectedIndex(
name=index.name,
column_names=index.index_columns,
unique=False,
dialect_options={
"ydb_async": False, # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/351
"ydb_cover": [], # TODO After https://github.com/ydb-platform/ydb-python-sdk/issues/409
},
)
)
return sa_indexes
def set_isolation_level(self, dbapi_connection: ydb_dbapi.Connection, level: str) -> None:
dbapi_connection.set_isolation_level(level)
def get_default_isolation_level(self, dbapi_conn: ydb_dbapi.Connection) -> str:
return ydb_dbapi.IsolationLevel.AUTOCOMMIT
def get_isolation_level(self, dbapi_connection: ydb_dbapi.Connection) -> str:
return dbapi_connection.get_isolation_level()
def set_ydb_request_settings(
self,
dbapi_connection: ydb_dbapi.Connection,
value: ydb.BaseRequestSettings,
) -> None:
dbapi_connection.set_ydb_request_settings(value)
def reset_ydb_request_settings(self, dbapi_connection: ydb_dbapi.Connection):
self.set_ydb_request_settings(dbapi_connection, ydb.BaseRequestSettings())
def get_ydb_request_settings(self, dbapi_connection: ydb_dbapi.Connection) -> ydb.BaseRequestSettings:
return dbapi_connection.get_ydb_request_settings()
def create_connect_args(self, url):
args, kwargs = super().create_connect_args(url)
# YDB database name should start with '/'
if "database" in kwargs:
if not kwargs["database"].startswith("/"):
kwargs["database"] = "/" + kwargs["database"]
return [args, kwargs]
def connect(self, *cargs, **cparams):
return self.dbapi.connect(*cargs, **cparams)
def do_begin(self, dbapi_connection: ydb_dbapi.Connection) -> None:
dbapi_connection.begin()
def do_rollback(self, dbapi_connection: ydb_dbapi.Connection) -> None:
dbapi_connection.rollback()
def do_commit(self, dbapi_connection: ydb_dbapi.Connection) -> None:
dbapi_connection.commit()
def _handle_column_name(self, variable):
return "`" + variable + "`"
def _format_variables(
self,
statement: str,
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]],
execute_many: bool,
) -> Tuple[str, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]:
formatted_statement = statement
formatted_parameters = None
if parameters:
if execute_many:
parameters_sequence: Sequence[Mapping[str, Any]] = parameters
variable_names = set()
formatted_parameters = []
for i in range(len(parameters_sequence)):
variable_names.update(set(parameters_sequence[i].keys()))
formatted_parameters.append({f"${k}": v for k, v in parameters_sequence[i].items()})
else:
variable_names = set(parameters.keys())
formatted_parameters = {f"${k}": v for k, v in parameters.items()}
formatted_variable_names = {
variable_name: f"${self._handle_column_name(variable_name)}" for variable_name in variable_names
}
formatted_statement = formatted_statement % formatted_variable_names
formatted_statement = formatted_statement.replace("%%", "%")
return formatted_statement, formatted_parameters
def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
declarations = "\n".join(
[
f"DECLARE $`{param_name[1:] if param_name.startswith('$') else param_name}` as {str(param_type)};"
for param_name, param_type in parameters_types.items()
]
)
return f"{declarations}\n{statement}"
def __merge_parameters_values_and_types(
self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool
) -> Sequence[Mapping[str, ydb.TypedValue]]:
if isinstance(values, collections.abc.Mapping):
values = [values]
result_list = []
for value_map in values:
result = {}
for key in value_map.keys():
if key in types:
result[key] = ydb.TypedValue(value_map[key], types[key])
else:
result[key] = value_map[key]
result_list.append(result)
return result_list if execute_many else result_list[0]
def _prepare_ydb_query(
self,
statement: str,
context: Optional[DefaultExecutionContext] = None,
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None,
execute_many: bool = False,
) -> Tuple[Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]:
is_ddl = context.isddl if context is not None else False
if not is_ddl and parameters:
parameters_types = context.compiled.get_bind_types(parameters)
if parameters_types != {}:
parameters = self.__merge_parameters_values_and_types(parameters, parameters_types, execute_many)
statement, parameters = self._format_variables(statement, parameters, execute_many)
if self._add_declare_for_yql_stmt_vars:
statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types)
return statement, parameters
statement, parameters = self._format_variables(statement, parameters, execute_many)
return statement, parameters
def do_ping(self, dbapi_connection: ydb_dbapi.Connection) -> bool:
cursor = dbapi_connection.cursor()
statement, _ = self._prepare_ydb_query(self._dialect_specific_select_one)
try:
cursor.execute(statement)
finally:
cursor.close()
return True
def do_executemany(
self,
cursor: ydb_dbapi.Cursor,
statement: str,
parameters: Optional[Sequence[Mapping[str, Any]]],
context: Optional[DefaultExecutionContext] = None,
) -> None:
operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=True)
cursor.executemany(operation, parameters)
def do_execute(
self,
cursor: ydb_dbapi.Cursor,
statement: str,
parameters: Optional[Mapping[str, Any]] = None,
context: Optional[DefaultExecutionContext] = None,
) -> None:
operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=False)
is_ddl = context.isddl if context is not None else False
if is_ddl:
cursor.execute_scheme(operation, parameters)
else:
cursor.execute(operation, parameters)
class AsyncYDBDialect(YDBDialect):
driver = "ydb_async"
is_async = True
supports_statement_cache = True
def connect(self, *cargs, **cparams):
return AdaptedAsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams)))