From 9a91c04eb8aef22127931b73b04fc195994ed8e2 Mon Sep 17 00:00:00 2001 From: Robert Sznyida Date: Tue, 15 Jul 2025 08:52:36 +0200 Subject: [PATCH] include table name in DEFAULT-constraint names in MSSQL --- pysqlsync/dialect/mssql/object_types.py | 2 +- pysqlsync/dialect/mysql/discovery.py | 1 + pysqlsync/dialect/oracle/discovery.py | 1 + pysqlsync/dialect/postgresql/discovery.py | 1 + pysqlsync/formation/discovery.py | 2 + pysqlsync/formation/object_types.py | 2 + pysqlsync/formation/py_to_sql.py | 29 ++++-- tests/test_converter.py | 111 +++++++++++++++++----- tests/test_generator.py | 20 ++-- 9 files changed, 124 insertions(+), 45 deletions(-) diff --git a/pysqlsync/dialect/mssql/object_types.py b/pysqlsync/dialect/mssql/object_types.py index 45f3aff..bb6e14c 100644 --- a/pysqlsync/dialect/mssql/object_types.py +++ b/pysqlsync/dialect/mssql/object_types.py @@ -17,7 +17,7 @@ class MSSQLColumn(Column): def default_constraint_name(self) -> LocalId: "The name of the constraint for DEFAULT." - return LocalId(f"df_{self.name.local_id}") + return LocalId(f"df_{self.table_name.local_id}_{self.name.local_id}") @property def data_spec(self) -> str: diff --git a/pysqlsync/dialect/mysql/discovery.py b/pysqlsync/dialect/mysql/discovery.py index 44dca53..f1befba 100644 --- a/pysqlsync/dialect/mysql/discovery.py +++ b/pysqlsync/dialect/mysql/discovery.py @@ -90,6 +90,7 @@ async def get_columns(self, table_id: SupportsQualifiedId) -> list[Column]: for col in column_meta: columns.append( self.factory.column_class( + table_id, LocalId(col.column_name), self.discovery.sql_data_type_from_spec( type_name=col.data_type, diff --git a/pysqlsync/dialect/oracle/discovery.py b/pysqlsync/dialect/oracle/discovery.py index 3a7b8a4..53cef45 100644 --- a/pysqlsync/dialect/oracle/discovery.py +++ b/pysqlsync/dialect/oracle/discovery.py @@ -152,6 +152,7 @@ async def get_table(self, table_id: SupportsQualifiedId) -> Table: identity = bool(col.is_identity) columns.append( self.factory.column_class( + table_id, LocalId(col.column_name), data_type, nullable, diff --git a/pysqlsync/dialect/postgresql/discovery.py b/pysqlsync/dialect/postgresql/discovery.py index 173ea12..d5bbf37 100644 --- a/pysqlsync/dialect/postgresql/discovery.py +++ b/pysqlsync/dialect/postgresql/discovery.py @@ -277,6 +277,7 @@ async def get_table(self, table_id: SupportsQualifiedId) -> Table: data_type = SqlArrayType(data_type) columns.append( self.factory.column_class( + table_id, LocalId(col.column_name), data_type, bool(col.is_nullable), diff --git a/pysqlsync/formation/discovery.py b/pysqlsync/formation/discovery.py index a4b1024..f4d1fb2 100644 --- a/pysqlsync/formation/discovery.py +++ b/pysqlsync/formation/discovery.py @@ -183,6 +183,7 @@ async def _get_columns_limited(self, table_id: SupportsQualifiedId) -> list[Colu column_name, data_type, nullable = col columns.append( self.factory.column_class( + table_id, LocalId(column_name), self.discovery.sql_data_type_from_spec(type_name=data_type), bool(nullable), @@ -211,6 +212,7 @@ async def _get_columns_full(self, table_id: SupportsQualifiedId) -> list[Column] for col in column_meta: columns.append( self.factory.column_class( + table_id, LocalId(col.column_name), self.discovery.sql_data_type_from_spec( type_name=col.data_type, diff --git a/pysqlsync/formation/object_types.py b/pysqlsync/formation/object_types.py index 429a648..eedffdd 100644 --- a/pysqlsync/formation/object_types.py +++ b/pysqlsync/formation/object_types.py @@ -184,6 +184,7 @@ class Column(DatabaseObject): """ A column in a database table. + :param table_name: The name of the table that the column belongs to. :param name: The name of the column within its host table. :param data_type: The SQL data type of the column. :param nullable: True if the column can take the value NULL. @@ -192,6 +193,7 @@ class Column(DatabaseObject): :param description: The textual description of the column. """ + table_name: SupportsQualifiedId name: LocalId data_type: SqlDataType nullable: bool diff --git a/pysqlsync/formation/py_to_sql.py b/pysqlsync/formation/py_to_sql.py index 26a4e26..a0bcc31 100644 --- a/pysqlsync/formation/py_to_sql.py +++ b/pysqlsync/formation/py_to_sql.py @@ -662,7 +662,11 @@ def member_to_sql_data_type(self, typ: TypeLike, cls: type) -> SqlDataType: raise TypeError(f"unsupported data type: {typ}") def member_to_column( - self, field: DataclassField, cls: type[DataclassInstance], doc: Docstring + self, + table_name: SupportsQualifiedId, + field: DataclassField, + cls: type[DataclassInstance], + doc: Docstring, ) -> Column: "Converts a data-class field into a SQL table column." @@ -724,6 +728,7 @@ def member_to_column( ) return self.options.factory.column_class( + table_name=table_name, name=LocalId(field.name), data_type=data_type, nullable=props.nullable, @@ -739,10 +744,11 @@ def dataclass_to_table(self, cls: type[DataclassInstance]) -> Table: raise TypeError(f"expected: dataclass type; got: {cls}") doc = parse_type(cls) + id = self.create_qualified_id(cls.__module__, cls.__name__) try: columns = [ - self.member_to_column(field, cls, doc) + self.member_to_column(id, field, cls, doc) for field in dataclass_fields(cls) if self._get_relationship(field.type) is None ] @@ -809,7 +815,7 @@ def dataclass_to_table(self, cls: type[DataclassInstance]) -> Table: ) return self.options.factory.table_class( - name=self.create_qualified_id(cls.__module__, cls.__name__), + name=id, columns=columns, primary_key=(LocalId(dataclass_primary_key_name(cls)),), constraints=constraints or None, @@ -932,6 +938,7 @@ def _enum_table(self, enum_type: type[enum.Enum]) -> Table: id = self.create_qualified_id(enum_type.__module__, enum_table_name) columns = [ self.options.factory.column_class( + id, LocalId("id"), self._enumeration_key_type(), False, @@ -953,6 +960,7 @@ def _enum_table(self, enum_type: type[enum.Enum]) -> Table: if unadorned_member_type is int or unadorned_member_type is str: columns.append( self.options.factory.column_class( + id, LocalId("value"), self.member_to_sql_data_type(ENUM_LABEL_TYPE, type(None)), False, @@ -967,7 +975,7 @@ def _enum_table(self, enum_type: type[enum.Enum]) -> Table: elif is_dataclass_type(unadorned_member_type): columns.extend( self.member_to_column( - field, enum_member_type, parse_type(unadorned_member_type) + id, field, enum_member_type, parse_type(unadorned_member_type) ) for field in dataclass_fields(unadorned_member_type) ) @@ -1148,19 +1156,23 @@ def dataclasses_to_catalog( ) table_defs = tables.setdefault(entity.__module__, []) + table_id = self.create_qualified_id( + entity.__module__, + table_name, + ) + table_defs.append( self.options.factory.table_class( - self.create_qualified_id( - entity.__module__, - table_name, - ), + table_id, [ self.options.factory.column_class( + table_id, LocalId("uuid"), self.member_to_sql_data_type(uuid.UUID, entity), False, ), self.options.factory.column_class( + table_id, LocalId(column_left_name), self.member_to_sql_data_type( dataclass_primary_key_type(entity), entity @@ -1168,6 +1180,7 @@ def dataclasses_to_catalog( False, ), self.options.factory.column_class( + table_id, LocalId(column_right_name), primary_right_type, False, diff --git a/tests/test_converter.py b/tests/test_converter.py index f0d5a5d..c4707a9 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -47,9 +47,13 @@ def test_primary_key(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), - Column(LocalId("city"), SqlVariableCharacterType(), False), - Column(LocalId("state"), SqlVariableCharacterType(), True), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), + Column( + table_def.name, LocalId("city"), SqlVariableCharacterType(), False + ), + Column( + table_def.name, LocalId("state"), SqlVariableCharacterType(), True + ), ], ) @@ -58,32 +62,37 @@ def test_default(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), Column( + table_def.name, LocalId("integer_8"), SqlIntegerType(2), False, default="127", ), Column( + table_def.name, LocalId("integer_16"), SqlIntegerType(2), False, default="32767", ), Column( + table_def.name, LocalId("integer_32"), SqlIntegerType(4), False, default="2147483647", ), Column( + table_def.name, LocalId("integer_64"), SqlIntegerType(8), False, default="0", ), Column( + table_def.name, LocalId("integer"), SqlIntegerType(8), False, @@ -97,8 +106,19 @@ def test_identity(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False, identity=True), - Column(LocalId("unique"), SqlVariableCharacterType(64), False), + Column( + table_def.name, + LocalId("id"), + SqlIntegerType(8), + False, + identity=True, + ), + Column( + table_def.name, + LocalId("unique"), + SqlVariableCharacterType(64), + False, + ), ], ) self.assertListEqual( @@ -113,14 +133,16 @@ def test_foreign_key(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), Column( + table_def.name, LocalId("name"), SqlVariableCharacterType(), False, description="The person's full name.", ), Column( + table_def.name, LocalId("address"), SqlIntegerType(8), False, @@ -134,9 +156,11 @@ def test_recursive_table(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlUuidType(), False), - Column(LocalId("name"), SqlVariableCharacterType(), False), - Column(LocalId("reports_to"), SqlUuidType(), False), + Column(table_def.name, LocalId("id"), SqlUuidType(), False), + Column( + table_def.name, LocalId("name"), SqlVariableCharacterType(), False + ), + Column(table_def.name, LocalId("reports_to"), SqlUuidType(), False), ], ) @@ -148,13 +172,15 @@ def test_enum_type(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), Column( + table_def.name, LocalId("state"), SqlUserDefinedType(QualifiedId(None, "WorkflowState")), False, ), Column( + table_def.name, LocalId("optional_state"), SqlUserDefinedType(QualifiedId(None, "WorkflowState")), True, @@ -173,9 +199,11 @@ def test_extensible_enum(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), - Column(LocalId("state"), SqlIntegerType(4), False), - Column(LocalId("optional_state"), SqlIntegerType(4), True), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("state"), SqlIntegerType(4), False), + Column( + table_def.name, LocalId("optional_state"), SqlIntegerType(4), True + ), ], ) enum_name = tables.ExtensibleEnum.__name__ @@ -183,9 +211,18 @@ def test_extensible_enum(self) -> None: self.assertListEqual( list(enum_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(4), False, identity=True), Column( - LocalId("value"), SqlVariableCharacterType(ENUM_NAME_LENGTH), False + enum_def.name, + LocalId("id"), + SqlIntegerType(4), + False, + identity=True, + ), + Column( + enum_def.name, + LocalId("value"), + SqlVariableCharacterType(ENUM_NAME_LENGTH), + False, ), ], ) @@ -222,9 +259,11 @@ def test_enum_relation(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), - Column(LocalId("state"), SqlIntegerType(4), False), - Column(LocalId("optional_state"), SqlIntegerType(4), True), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("state"), SqlIntegerType(4), False), + Column( + table_def.name, LocalId("optional_state"), SqlIntegerType(4), True + ), ], ) enum_name = tables.WorkflowState.__name__ @@ -232,9 +271,18 @@ def test_enum_relation(self) -> None: self.assertListEqual( list(enum_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(4), False, identity=True), Column( - LocalId("value"), SqlVariableCharacterType(ENUM_NAME_LENGTH), False + enum_def.name, + LocalId("id"), + SqlIntegerType(4), + False, + identity=True, + ), + Column( + enum_def.name, + LocalId("value"), + SqlVariableCharacterType(ENUM_NAME_LENGTH), + False, ), ], ) @@ -271,9 +319,11 @@ def test_dataclass_enum_relation(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), - Column(LocalId("country"), SqlIntegerType(4), False), - Column(LocalId("optional_country"), SqlIntegerType(4), True), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("country"), SqlIntegerType(4), False), + Column( + table_def.name, LocalId("optional_country"), SqlIntegerType(4), True + ), ], ) enum_name = country.CountryEnum.__name__ @@ -282,17 +332,20 @@ def test_dataclass_enum_relation(self) -> None: list(enum_def.columns.values()), [ Column( + enum_def.name, LocalId("id"), SqlIntegerType(4), False, identity=True, ), Column( + enum_def.name, LocalId("iso_code"), SqlVariableCharacterType(), False, ), Column( + enum_def.name, LocalId("name"), SqlVariableCharacterType(), False, @@ -326,23 +379,27 @@ def test_literal_type(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), Column( + table_def.name, LocalId("single"), SqlFixedCharacterType(limit=5), False, ), Column( + table_def.name, LocalId("multiple"), SqlFixedCharacterType(limit=4), False, ), Column( + table_def.name, LocalId("union"), SqlVariableCharacterType(limit=255), False, ), Column( + table_def.name, LocalId("unbounded"), SqlVariableCharacterType(), False, @@ -374,8 +431,9 @@ def test_struct_reference(self) -> None: self.assertListEqual( list(table_def.columns.values()), [ - Column(LocalId("id"), SqlIntegerType(8), False), + Column(table_def.name, LocalId("id"), SqlIntegerType(8), False), Column( + table_def.name, LocalId("coords"), SqlUserDefinedType( QualifiedId("sample", tables.Coordinates.__name__) @@ -457,6 +515,7 @@ def test_mutate(self) -> None: user_table.columns["homepage_url"].nullable = False user_table.columns.add( Column( + user_table.name, LocalId("social_url"), SqlVariableCharacterType(), False, diff --git a/tests/test_generator.py b/tests/test_generator.py index c1ea0ca..f52e14f 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -109,11 +109,11 @@ def test_create_default_boolean_table(self) -> None: tables.DefaultBooleanTable, 'CREATE TABLE "DefaultBooleanTable" (\n' '"id" bigint NOT NULL,\n' - '"boolean_false" bit NOT NULL CONSTRAINT "df_boolean_false" DEFAULT 0,\n' - '"boolean_true" bit NOT NULL CONSTRAINT "df_boolean_true" DEFAULT 1,\n' + '"boolean_false" bit NOT NULL CONSTRAINT "df_DefaultBooleanTable_boolean_false" DEFAULT 0,\n' + '"boolean_true" bit NOT NULL CONSTRAINT "df_DefaultBooleanTable_boolean_true" DEFAULT 1,\n' '"nullable_boolean_null" bit,\n' - '"nullable_boolean_false" bit CONSTRAINT "df_nullable_boolean_false" DEFAULT 0,\n' - '"nullable_boolean_true" bit CONSTRAINT "df_nullable_boolean_true" DEFAULT 1,\n' + '"nullable_boolean_false" bit CONSTRAINT "df_DefaultBooleanTable_nullable_boolean_false" DEFAULT 0,\n' + '"nullable_boolean_true" bit CONSTRAINT "df_DefaultBooleanTable_nullable_boolean_true" DEFAULT 1,\n' 'CONSTRAINT "pk_DefaultBooleanTable" PRIMARY KEY ("id")\n' ");", ) @@ -172,11 +172,11 @@ def test_create_default_numeric_table(self) -> None: tables.DefaultNumericTable, 'CREATE TABLE "DefaultNumericTable" (\n' '"id" bigint NOT NULL,\n' - '"integer_8" smallint NOT NULL CONSTRAINT "df_integer_8" DEFAULT 127,\n' - '"integer_16" smallint NOT NULL CONSTRAINT "df_integer_16" DEFAULT 32767,\n' - '"integer_32" integer NOT NULL CONSTRAINT "df_integer_32" DEFAULT 2147483647,\n' - '"integer_64" bigint NOT NULL CONSTRAINT "df_integer_64" DEFAULT 0,\n' - '"integer" bigint NOT NULL CONSTRAINT "df_integer" DEFAULT 23,\n' + '"integer_8" smallint NOT NULL CONSTRAINT "df_DefaultNumericTable_integer_8" DEFAULT 127,\n' + '"integer_16" smallint NOT NULL CONSTRAINT "df_DefaultNumericTable_integer_16" DEFAULT 32767,\n' + '"integer_32" integer NOT NULL CONSTRAINT "df_DefaultNumericTable_integer_32" DEFAULT 2147483647,\n' + '"integer_64" bigint NOT NULL CONSTRAINT "df_DefaultNumericTable_integer_64" DEFAULT 0,\n' + '"integer" bigint NOT NULL CONSTRAINT "df_DefaultNumericTable_integer" DEFAULT 23,\n' 'CONSTRAINT "pk_DefaultNumericTable" PRIMARY KEY ("id")\n' ");", ) @@ -328,7 +328,7 @@ def test_create_default_datetime_table(self) -> None: tables.DefaultDateTimeTable, 'CREATE TABLE "DefaultDateTimeTable" (\n' '"id" bigint NOT NULL,\n' - """"iso_date_time" datetime2 NOT NULL CONSTRAINT "df_iso_date_time" DEFAULT '1989-10-24 23:59:59',\n""" + """"iso_date_time" datetime2 NOT NULL CONSTRAINT "df_DefaultDateTimeTable_iso_date_time" DEFAULT '1989-10-24 23:59:59',\n""" 'CONSTRAINT "pk_DefaultDateTimeTable" PRIMARY KEY ("id")\n' ");", )