From ecd272335a017b171fee80e33088bb1bf97724c6 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Tue, 15 Jul 2025 12:03:47 +0300 Subject: [PATCH] Support YDB native UUID type --- test/test_core.py | 25 ++++++++++++++++++++++ ydb_sqlalchemy/sqlalchemy/__init__.py | 2 ++ ydb_sqlalchemy/sqlalchemy/compiler/sa20.py | 10 +++++++++ ydb_sqlalchemy/sqlalchemy/types.py | 22 +++++++++++++++++++ 4 files changed, 59 insertions(+) diff --git a/test/test_core.py b/test/test_core.py index 3f7d808..ddad3e6 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,5 +1,6 @@ import asyncio import datetime +import uuid from decimal import Decimal from typing import NamedTuple @@ -11,6 +12,7 @@ from ydb._grpc.v4.protos import ydb_common_pb2 from ydb_sqlalchemy import IsolationLevel, dbapi + from ydb_sqlalchemy import sqlalchemy as ydb_sa from ydb_sqlalchemy.sqlalchemy import types @@ -238,6 +240,13 @@ def define_tables(cls, metadata: sa.MetaData): Column("date", sa.Date), # Column("interval", sa.Interval), ) + Table( + "test_uuid_types", + metadata, + Column("id", Integer, primary_key=True), + Column("uuid_native", sa.UUID), + Column("uuid_str", sa.Uuid), + ) def test_primitive_types(self, connection): table = self.tables.test_primitive_types @@ -310,6 +319,22 @@ def test_datetime_types_timezone(self, connection): today, ) + def test_uuid_types(self, connection): + table = self.tables.test_uuid_types + uuid_value = uuid.uuid4() + + statement = sa.insert(table).values(id=1, uuid_native=uuid_value, uuid_str=uuid_value) + connection.execute(statement) + row = connection.execute(sa.select(table).where(table.c.id == 1)).fetchone() + assert row == (1, uuid_value, uuid_value) + + uuid_value_str = str(uuid_value) + + statement = sa.insert(table).values(id=2, uuid_native=uuid_value_str, uuid_str=uuid_value) + connection.execute(statement) + row = connection.execute(sa.select(table).where(table.c.id == 2)).fetchone() + assert row == (2, uuid_value, uuid_value) + class TestWithClause(TablesTest): __backend__ = True diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 0f271f3..7f59dd3 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -66,6 +66,7 @@ def upsert(table): ydb.PrimitiveType.Interval: sa.INTEGER, ydb.PrimitiveType.Bool: sa.BOOLEAN, ydb.PrimitiveType.DyNumber: sa.TEXT, + ydb.PrimitiveType.UUID: sa.UUID, } @@ -140,6 +141,7 @@ class YqlDialect(StrCompileDialect): sa.types.DateTime: types.YqlTimestamp, # Because YDB's DateTime doesn't store microseconds sa.types.DATETIME: types.YqlDateTime, sa.types.TIMESTAMP: types.YqlTimestamp, + sa.types.UUID: types.YqlUUID, } connection_characteristics = util.immutabledict( diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py index 702d7aa..79140e3 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py @@ -11,6 +11,7 @@ BaseYqlIdentifierPreparer, BaseYqlTypeCompiler, ) +from .. import types from typing import Union @@ -18,12 +19,21 @@ class YqlTypeCompiler(BaseYqlTypeCompiler): def visit_uuid(self, type_: sa.Uuid, **kw): return "UTF8" + def visit_UUID(self, type_: Union[sa.UUID, types.YqlUUID], **kw): + return "UUID" + def get_ydb_type( self, type_: sa.types.TypeEngine, is_optional: bool ) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]: if isinstance(type_, sa.TypeDecorator): type_ = type_.impl + if isinstance(type_, sa.UUID): + ydb_type = ydb.PrimitiveType.UUID + if is_optional: + return ydb.OptionalType(ydb_type) + return ydb_type + if isinstance(type_, sa.Uuid): ydb_type = ydb.PrimitiveType.Utf8 if is_optional: diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 261eb9f..8328bd6 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -14,6 +14,28 @@ from .json import YqlJSON # noqa: F401 +class YqlUUID(types.UUID): + __visit_name__ = "UUID" + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + if isinstance(value, str): + try: + import uuid as uuid_module + + value = uuid_module.UUID(value) + except ValueError: + raise ValueError(f"Invalid UUID string: {value}") + return value + + return process + + def result_processor(self, dialect, coltype): + return None + + class UInt64(types.Integer): __visit_name__ = "uint64"