Skip to content

Commit ecd2723

Browse files
committed
Support YDB native UUID type
1 parent 8729c2e commit ecd2723

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed

test/test_core.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import datetime
3+
import uuid
34
from decimal import Decimal
45
from typing import NamedTuple
56

@@ -11,6 +12,7 @@
1112
from ydb._grpc.v4.protos import ydb_common_pb2
1213

1314
from ydb_sqlalchemy import IsolationLevel, dbapi
15+
1416
from ydb_sqlalchemy import sqlalchemy as ydb_sa
1517
from ydb_sqlalchemy.sqlalchemy import types
1618

@@ -238,6 +240,13 @@ def define_tables(cls, metadata: sa.MetaData):
238240
Column("date", sa.Date),
239241
# Column("interval", sa.Interval),
240242
)
243+
Table(
244+
"test_uuid_types",
245+
metadata,
246+
Column("id", Integer, primary_key=True),
247+
Column("uuid_native", sa.UUID),
248+
Column("uuid_str", sa.Uuid),
249+
)
241250

242251
def test_primitive_types(self, connection):
243252
table = self.tables.test_primitive_types
@@ -310,6 +319,22 @@ def test_datetime_types_timezone(self, connection):
310319
today,
311320
)
312321

322+
def test_uuid_types(self, connection):
323+
table = self.tables.test_uuid_types
324+
uuid_value = uuid.uuid4()
325+
326+
statement = sa.insert(table).values(id=1, uuid_native=uuid_value, uuid_str=uuid_value)
327+
connection.execute(statement)
328+
row = connection.execute(sa.select(table).where(table.c.id == 1)).fetchone()
329+
assert row == (1, uuid_value, uuid_value)
330+
331+
uuid_value_str = str(uuid_value)
332+
333+
statement = sa.insert(table).values(id=2, uuid_native=uuid_value_str, uuid_str=uuid_value)
334+
connection.execute(statement)
335+
row = connection.execute(sa.select(table).where(table.c.id == 2)).fetchone()
336+
assert row == (2, uuid_value, uuid_value)
337+
313338

314339
class TestWithClause(TablesTest):
315340
__backend__ = True

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def upsert(table):
6666
ydb.PrimitiveType.Interval: sa.INTEGER,
6767
ydb.PrimitiveType.Bool: sa.BOOLEAN,
6868
ydb.PrimitiveType.DyNumber: sa.TEXT,
69+
ydb.PrimitiveType.UUID: sa.UUID,
6970
}
7071

7172

@@ -140,6 +141,7 @@ class YqlDialect(StrCompileDialect):
140141
sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds
141142
sa.types.DATETIME: types.YqlDateTime,
142143
sa.types.TIMESTAMP: types.YqlTimestamp,
144+
sa.types.UUID: types.YqlUUID,
143145
}
144146

145147
connection_characteristics = util.immutabledict(

ydb_sqlalchemy/sqlalchemy/compiler/sa20.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,29 @@
1111
BaseYqlIdentifierPreparer,
1212
BaseYqlTypeCompiler,
1313
)
14+
from .. import types
1415
from typing import Union
1516

1617

1718
class YqlTypeCompiler(BaseYqlTypeCompiler):
1819
def visit_uuid(self, type_: sa.Uuid, **kw):
1920
return "UTF8"
2021

22+
def visit_UUID(self, type_: Union[sa.UUID, types.YqlUUID], **kw):
23+
return "UUID"
24+
2125
def get_ydb_type(
2226
self, type_: sa.types.TypeEngine, is_optional: bool
2327
) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]:
2428
if isinstance(type_, sa.TypeDecorator):
2529
type_ = type_.impl
2630

31+
if isinstance(type_, sa.UUID):
32+
ydb_type = ydb.PrimitiveType.UUID
33+
if is_optional:
34+
return ydb.OptionalType(ydb_type)
35+
return ydb_type
36+
2737
if isinstance(type_, sa.Uuid):
2838
ydb_type = ydb.PrimitiveType.Utf8
2939
if is_optional:

ydb_sqlalchemy/sqlalchemy/types.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@
1414
from .json import YqlJSON # noqa: F401
1515

1616

17+
class YqlUUID(types.UUID):
18+
__visit_name__ = "UUID"
19+
20+
def bind_processor(self, dialect):
21+
def process(value):
22+
if value is None:
23+
return None
24+
if isinstance(value, str):
25+
try:
26+
import uuid as uuid_module
27+
28+
value = uuid_module.UUID(value)
29+
except ValueError:
30+
raise ValueError(f"Invalid UUID string: {value}")
31+
return value
32+
33+
return process
34+
35+
def result_processor(self, dialect, coltype):
36+
return None
37+
38+
1739
class UInt64(types.Integer):
1840
__visit_name__ = "uint64"
1941

0 commit comments

Comments
 (0)