From 40cde7994f31b49318b0730bacbd6bad8c463371 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Tue, 25 Feb 2025 12:00:06 +0300 Subject: [PATCH 01/17] [UP] inspectdb/__init__ class Inspect {+ def charenum_field} --- aerich/inspectdb/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 512946ee..26093f24 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -190,3 +190,8 @@ def json_field(cls, **kwargs) -> str: @classmethod def binary_field(cls, **kwargs) -> str: return cls.get_field_string("BinaryField", **kwargs) + + @classmethod + def charenum_field(cls, **kwargs) -> str: + arguments = "{enum_type}{null}" + return cls.get_field_string("CharEnumField", arguments, **kwargs) \ No newline at end of file From b51b4ea6522eaedd853c79783b8f5dfac65b1462 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Tue, 25 Feb 2025 12:04:39 +0300 Subject: [PATCH 02/17] [UP] inspectdb/__init__ class Inspect {~ def inspect +119 "class Meta ..." } --- aerich/inspectdb/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 26093f24..ca43c1c1 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -116,6 +116,7 @@ async def inspect(self) -> str: field = self.field_map[column.data_type](**column.translate()) fields.append(" " + field) tables.append(model + "\n".join(fields)) + tables.append(" class Meta:\n table = '" + table + "'\n\n") return result + "\n\n\n".join(tables) async def get_columns(self, table: str) -> list[Column]: From 3a7cfbe70fe01252a35a509898082a00b537822c Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Tue, 25 Feb 2025 13:03:56 +0300 Subject: [PATCH 03/17] [UP] inspectdb/__init__.py +class EnumDataType --- aerich/inspectdb/__init__.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index ca43c1c1..93ca755e 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -20,6 +20,39 @@ class ColumnInfoDict(TypedDict): # TODO: use dict to replace typing.Dict when dropping support for Python3.8 FieldMapDict = Dict[str, Callable[..., str]] +class EnumDataType(BaseModel): + row_name: str + row_values: str + class_name: str | None = None + class_values: list[str] = [] + + + def get_enum_class(self) -> str: + if not self.class_name: + class_name = self.row_name + class_name = class_name.replace("_", " ") + class_name = class_name.replace("@", "") + class_name = class_name.title() + class_name = class_name.replace(" ", "") + self.class_name = class_name + "Enum" + if len(self.class_values) != len(self.row_values.split(";")): + self.row_values = self.row_values.strip("{") + self.row_values = self.row_values.strip("}") + row_values = self.row_values.split(";") + for value in row_values: + if value.isdigit(): + continue + name_value = value.strip("\"") + name_value = name_value.replace(" ", "_") + name_value = name_value.replace(",", "_") + name_value = name_value.replace("-", "_") + name_value = name_value.replace("@", "") + self.class_values.append(f' {name_value.upper()} = "{value}"') + + result = f"class {self.class_name}(str, Enum):\n" + result += "\n".join(self.class_values) + + return result class Column(BaseModel): name: str From e80e3ed531be6ddf33f38218e40745b9fd3a06f5 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Tue, 25 Feb 2025 13:22:49 +0300 Subject: [PATCH 04/17] [UP] inspectdb/postgres.py class InspectPostgres {+ def get_enums*} --- aerich/inspectdb/postgres.py | 55 ++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index 4b6fa0a6..6c7d1bbb 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -42,6 +42,61 @@ async def get_all_tables(self) -> list[str]: ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) return list(map(lambda x: x["table_name"], ret)) + async def _get_enums(self): + sql = """WITH enum_types AS ( + SELECT n.nspname AS schema_name, + t.typname AS type_name, + t.typtype AS type_category + FROM pg_type t + JOIN pg_namespace n ON n.oid = t.typnamespace + WHERE t.typtype = 'e' + AND n.nspname NOT IN ('pg_catalog', 'information_schema') +) +SELECT ct.type_name, + CASE + WHEN ct.type_category = 'e' THEN array_to_string(array_agg(e.enumlabel ORDER BY e.enumsortorder), ';') -- ENUM значения с разделителем ; + ELSE NULL + END AS type_values +FROM enum_types ct +LEFT JOIN pg_enum e ON ct.type_category = 'e' AND e.enumtypid = (SELECT oid FROM pg_type WHERE typname = ct.type_name) +WHERE ct.schema_name = $2 + AND current_database() = $1 -- Параметр для проверки базы данных +GROUP BY ct.schema_name, ct.type_name, ct.type_category +ORDER BY ct.schema_name, ct.type_name; +""" + ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) + return ret + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + enums = {} + enum_names = set() + ret = await self._get_enums() + for row in ret: + name = row.get("type_name") + category = row.get("type_category") + if not name or not category: + continue + type_values = row.get("type_values") + if name in enums: + continue + enums[name] = (EnumDataType( + row_name=name, + row_values=type_values + ) + ) + + return enums + + async def get_enums_names(self) -> set[str]: + enum_names = set() + ret = await self._get_enums() + for row in ret: + name = row.get("type_name") + if not name: + continue + enum_names.add(name) + return enum_names + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = f"""select c.column_name, From cd39edc12ef41168c8567d6ef06acf915d03692b Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Tue, 25 Feb 2025 13:56:53 +0300 Subject: [PATCH 05/17] [UP] inspectdb/postgres.py class InspectPostgres {~ def field_map} --- aerich/inspectdb/postgres.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index 6c7d1bbb..7048b00a 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -17,24 +17,46 @@ def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) -> def field_map(self) -> FieldMapDict: return { "int2": self.smallint_field, + "_int2": self.smallint_field, "int4": self.int_field, - "int8": self.bigint_field, + "_int4": self.int_field, + "int8": self.int_field, + "_int8": self.int_field, "smallint": self.smallint_field, - "bigint": self.bigint_field, + "_smallint": self.smallint_field, "varchar": self.char_field, + "_varchar": self.char_field, "text": self.text_field, + "_text": self.text_field, + "bigint": self.bigint_field, + "_bigint": self.bigint_field, "timestamptz": self.datetime_field, + "_timestamptz": self.datetime_field, "float4": self.float_field, + "_float4": self.float_field, "float8": self.float_field, + "_float8": self.float_field, "date": self.date_field, + "_date": self.date_field, "time": self.time_field, + "_time": self.time_field, + "timetz": self.time_field, + "_timetz": self.time_field, "decimal": self.decimal_field, + "_decimal": self.decimal_field, "numeric": self.decimal_field, + "_numeric": self.decimal_field, "uuid": self.uuid_field, + "_uuid": self.uuid_field, "jsonb": self.json_field, + "_jsonb": self.json_field, "bytea": self.binary_field, + "_bytea": self.binary_field, "bool": self.bool_field, + "_bool": self.bool_field, "timestamp": self.datetime_field, + "_timestamp": self.datetime_field, + "enum": self.charenum_field, } async def get_all_tables(self) -> list[str]: From 764913f4b21d0d5a77fac9ad5c7266bfe561a8a4 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 22:24:51 +0300 Subject: [PATCH 06/17] [UP] inspectdb/__init__ ~class EnumDataType {+ def get_class_name; enum_type} --- aerich/inspectdb/__init__.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 93ca755e..06983384 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -20,6 +20,7 @@ class ColumnInfoDict(TypedDict): # TODO: use dict to replace typing.Dict when dropping support for Python3.8 FieldMapDict = Dict[str, Callable[..., str]] + class EnumDataType(BaseModel): row_name: str row_values: str @@ -28,13 +29,7 @@ class EnumDataType(BaseModel): def get_enum_class(self) -> str: - if not self.class_name: - class_name = self.row_name - class_name = class_name.replace("_", " ") - class_name = class_name.replace("@", "") - class_name = class_name.title() - class_name = class_name.replace(" ", "") - self.class_name = class_name + "Enum" + self.class_name = self.get_class_name() if len(self.class_values) != len(self.row_values.split(";")): self.row_values = self.row_values.strip("{") self.row_values = self.row_values.strip("}") @@ -53,6 +48,20 @@ def get_enum_class(self) -> str: result += "\n".join(self.class_values) return result + + def get_class_name(self) -> str: + if not self.class_name: + class_name = self.row_name + class_name = class_name.replace("_", " ") + class_name = class_name.replace("@", "") + class_name = class_name.title() + class_name = class_name.replace(" ", "") + self.class_name = class_name + return self.class_name + + def enum_type(self) -> dict: + return {"enum_type": f"enum_type={self.get_class_name()}, "} + class Column(BaseModel): name: str From c1ea756b5f1b8bf622b7c8c2e18bd07e464453d4 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 22:35:18 +0300 Subject: [PATCH 07/17] [UP/FIX] inspectdb/postgres.py ~class InspectPostgres {~get_enums_data_types;get_columns} --- aerich/inspectdb/postgres.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index 7048b00a..f73708fd 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -95,8 +95,7 @@ async def get_enums_data_types(self) -> dict[str, EnumDataType]: ret = await self._get_enums() for row in ret: name = row.get("type_name") - category = row.get("type_category") - if not name or not category: + if not name: continue type_values = row.get("type_values") if name in enums: @@ -109,6 +108,7 @@ async def get_enums_data_types(self) -> dict[str, EnumDataType]: return enums + async def get_enums_names(self) -> set[str]: enum_names = set() ret = await self._get_enums() @@ -149,8 +149,9 @@ async def get_columns(self, table: str) -> list[Column]: max_digits=row["numeric_precision"], decimal_places=row["numeric_scale"], comment=row["column_comment"], - pk=row["column_key"] == "PRIMARY KEY", - unique=False, # can't get this simply + pk=row["column_key"] == "PRIMARY KEY" or (row["column_key"] == "UNIQUE" and row["column_name"] == "id"), + extra=None, + unique=row["column_key"] == "UNIQUE" and row["column_name"] != "id", # can't get this simply index=False, # can't get this simply ) ) From 0bc2e0a440e24840dd39395dfb61b47673fa295a Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:02:12 +0300 Subject: [PATCH 08/17] [FIX] inspectdb/postgres.py +5| +EnumDataType --- aerich/inspectdb/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index f73708fd..4d85da46 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from aerich.inspectdb import Column, FieldMapDict, Inspect +from aerich.inspectdb import Column, FieldMapDict, Inspect, EnumDataType if TYPE_CHECKING: from tortoise.backends.base_postgres.client import BasePostgresClient From dd40655449a1ca1dccf757d96865fcd7a4a95ead Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:04:12 +0300 Subject: [PATCH 09/17] [FIX] inspectdb/sqlite.py class InspectSQLite {~ field_map} --- aerich/inspectdb/sqlite.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/aerich/inspectdb/sqlite.py b/aerich/inspectdb/sqlite.py index b729c738..41a8e9d6 100644 --- a/aerich/inspectdb/sqlite.py +++ b/aerich/inspectdb/sqlite.py @@ -7,18 +7,37 @@ class InspectSQLite(Inspect): @property def field_map(self) -> FieldMapDict: return { + "INT": self.int_field, "INTEGER": self.int_field, - "INT": self.bool_field, + "INTEGER UNSIGNED": self.int_field, "SMALLINT": self.smallint_field, + "SMALLINT UNSIGNED": self.smallint_field, + "BIGINT": self.bigint_field, + "UNSIGNED BIG INT": self.bigint_field, + "INT2": self.smallint_field, + "INT8": self.bigint_field, + "TINYINT": self.smallint_field, "VARCHAR": self.char_field, "TEXT": self.text_field, "TIMESTAMP": self.datetime_field, "REAL": self.float_field, "BIGINT": self.bigint_field, "DATE": self.date_field, + "BOOL": self.bool_field, + "DATETIME": self.datetime_field, "TIME": self.time_field, "JSON": self.json_field, "BLOB": self.binary_field, + "CHARACTER": self.char_field, + "VARCHAR": self.char_field, + "VARYING CHARACTER": self.char_field, + "NCHAR": self.char_field, + "NATIVE CHARACTER": self.char_field, + "NVARCHAR": self.char_field, + "CLOB": self.text_field, + "NUMERIC": self.float_field, + "DECIMAL": self.float_field, + "BOOLEAN": self.bool_field, } async def get_columns(self, table: str) -> list[Column]: @@ -34,7 +53,7 @@ async def get_columns(self, table: str) -> list[Column]: columns.append( Column( name=row["name"], - data_type=row["type"].split("(")[0], + data_type=row["type"].split("(")[0].upper(), null=row["notnull"] == 0, default=row["dflt_value"], length=length, From 97b6472806c0dd926fb11a178b41938a6daa023f Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:26:19 +0300 Subject: [PATCH 10/17] [FIX] inspectdb/sqlite.py class InspectSQLite {~ field_map} --- aerich/inspectdb/sqlite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aerich/inspectdb/sqlite.py b/aerich/inspectdb/sqlite.py index 41a8e9d6..b7f07eb4 100644 --- a/aerich/inspectdb/sqlite.py +++ b/aerich/inspectdb/sqlite.py @@ -35,8 +35,8 @@ def field_map(self) -> FieldMapDict: "NATIVE CHARACTER": self.char_field, "NVARCHAR": self.char_field, "CLOB": self.text_field, - "NUMERIC": self.float_field, - "DECIMAL": self.float_field, + "NUMERIC": self.decimal_field, + "DECIMAL": self.decimal_field, "BOOLEAN": self.bool_field, } From 5cbf42e36fa4643452d29dece3f0ec574ecf9385 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:28:10 +0300 Subject: [PATCH 11/17] [UP] inspectdb/mysql.py ~class InspectMySQL {+def *enum*; ~def field_map} --- aerich/inspectdb/mysql.py | 63 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/aerich/inspectdb/mysql.py b/aerich/inspectdb/mysql.py index 1034fd92..cb80778c 100644 --- a/aerich/inspectdb/mysql.py +++ b/aerich/inspectdb/mysql.py @@ -1,6 +1,6 @@ from __future__ import annotations -from aerich.inspectdb import Column, FieldMapDict, Inspect +from aerich.inspectdb import Column, FieldMapDict, Inspect, EnumDataType class InspectMySQL(Inspect): @@ -23,8 +23,69 @@ def field_map(self) -> FieldMapDict: "decimal": self.decimal_field, "json": self.json_field, "longblob": self.binary_field, + "enum": self.charenum_field, } + async def _get_enums(self): + sql = """SELECT c.TABLE_NAME, + c.COLUMN_NAME, + REPLACE( + REPLACE( + SUBSTRING_INDEX(SUBSTRING_INDEX(c.COLUMN_TYPE, '(', -1), ')', 1), + "','", "';'" + ), + "'", "" + ) AS ENUM_VALUES +FROM information_schema.COLUMNS c +LEFT JOIN information_schema.STATISTICS s + ON c.TABLE_NAME = s.TABLE_NAME + AND c.TABLE_SCHEMA = s.TABLE_SCHEMA + AND c.COLUMN_NAME = s.COLUMN_NAME +WHERE c.TABLE_SCHEMA = %s + AND c.DATA_TYPE = 'enum'; +""" + ret = await self.conn.execute_query_dict(sql, [self.database]) + return ret + + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + enums = {} + ret = await self._get_enums() + for row in ret: + table_name = row.get("TABLE_NAME") + column_name = row.get("COLUMN_NAME") + + if not table_name or not column_name: + continue + + name = f"{table_name}_{column_name}" + type_values = row.get("ENUM_VALUES") + if name in enums: + continue + enums[name] = (EnumDataType( + row_name=name, + row_values=type_values + ) + ) + + return enums + + + async def get_enums_names(self) -> set[str]: + enum_names = set() + ret = await self._get_enums() + for row in ret: + table_name = row.get("TABLE_NAME") + column_name = row.get("COLUMN_NAME") + + if not table_name or not column_name: + continue + + name = f"{table_name}_{column_name}" + enum_names.add(name) + return enum_names + + async def get_all_tables(self) -> list[str]: sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" ret = await self.conn.execute_query_dict(sql, [self.database]) From 1ec012bf8c6c86aa5cbfe0619f0178ff984609d3 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:30:03 +0300 Subject: [PATCH 12/17] [UP] inspectdb/__init__ class Inspect {~ def inspect } --- aerich/inspectdb/__init__.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 06983384..da2a0748 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -61,7 +61,7 @@ def get_class_name(self) -> str: def enum_type(self) -> dict: return {"enum_type": f"enum_type={self.get_class_name()}, "} - + class Column(BaseModel): name: str @@ -149,17 +149,37 @@ async def inspect(self) -> str: if not self.tables: self.tables = await self.get_all_tables() result = "from tortoise import Model, fields\n\n\n" + + enums_types = {} + if getattr(self, "get_enums_data_types", False): + enums_types = await self.get_enums_data_types() + + if enums_types: + result += "from enum import Enum\n\n" + enums = [] + for key, value in enums_types.items(): + enums.append(value.get_enum_class() ) + tables = [] for table in self.tables: columns = await self.get_columns(table) fields = [] model = self._table_template.format(table=table.title().replace("_", "")) for column in columns: - field = self.field_map[column.data_type](**column.translate()) + if column.data_type in enums_types: + field = self.field_map["enum"](**enums_types[column.data_type].enum_type(), **column.translate()) + elif f"{table}_{column.name}" in enums_types: + field = self.field_map["enum"](**enums_types[f"{table}_{column.name}"].enum_type(), **column.translate()) + else: + field = self.field_map[column.data_type](**column.translate()) + fields.append(" " + field) tables.append(model + "\n".join(fields)) tables.append(" class Meta:\n table = '" + table + "'\n\n") - return result + "\n\n\n".join(tables) + + enums.extend(tables) + + return result + "\n\n\n".join(enums) async def get_columns(self, table: str) -> list[Column]: raise NotImplementedError From dea80cc81c2c41f16cc2b6206a849794a95cac5a Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:44:08 +0300 Subject: [PATCH 13/17] [UP] inspectdb/__init__ class Inspect {~ def get_field_string } --- aerich/inspectdb/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index da2a0748..c5d12eb7 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -191,7 +191,15 @@ async def get_all_tables(self) -> list[str]: def get_field_string( field_class: str, arguments: str = "{null}{default}{comment}", **kwargs ) -> str: - name = kwargs["name"] + name: str = kwargs["name"] + arguments +="{source_field}" + kwargs["source_field"] = f"source_field='{name}'" + if "-" in name: + name = name.replace("-", "_") + if name[0].isdigit(): + name = "_" + name + name = name.replace("@", "") + field_params = arguments.format(**kwargs).strip().rstrip(",") return f"{name} = fields.{field_class}({field_params})" From 70268a358faf550f54bebcf571a37d9a97b401a7 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:44:24 +0300 Subject: [PATCH 14/17] [UP] inspectdb/__init__ class EnumDataType {~ def get_class_name } --- aerich/inspectdb/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index c5d12eb7..1c1b127d 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -56,7 +56,7 @@ def get_class_name(self) -> str: class_name = class_name.replace("@", "") class_name = class_name.title() class_name = class_name.replace(" ", "") - self.class_name = class_name + self.class_name = class_name+"Enum" return self.class_name def enum_type(self) -> dict: From 52c90f4351025d6586fa61b9e682b4c19277bf09 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Fri, 28 Feb 2025 23:45:15 +0300 Subject: [PATCH 15/17] [UP] inspectdb/postgres.py -112| --- aerich/inspectdb/postgres.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index 1e065124..98303ca1 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -109,7 +109,6 @@ async def get_enums_data_types(self) -> dict[str, EnumDataType]: return enums - async def get_enums_names(self) -> set[str]: enum_names = set() ret = await self._get_enums() From 86aa1764c58307426514ea6f2a5e78cbb10505c5 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Sat, 1 Mar 2025 10:00:48 +0300 Subject: [PATCH 16/17] [FIX] Inspectdb follow to code style --- aerich/inspectdb/__init__.py | 73 +++++++++++++++++++----------------- aerich/inspectdb/mysql.py | 15 ++------ aerich/inspectdb/postgres.py | 19 ++++------ aerich/inspectdb/sqlite.py | 15 ++++++-- 4 files changed, 62 insertions(+), 60 deletions(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 1c1b127d..605fb518 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -27,7 +27,6 @@ class EnumDataType(BaseModel): class_name: str | None = None class_values: list[str] = [] - def get_enum_class(self) -> str: self.class_name = self.get_class_name() if len(self.class_values) != len(self.row_values.split(";")): @@ -37,18 +36,18 @@ def get_enum_class(self) -> str: for value in row_values: if value.isdigit(): continue - name_value = value.strip("\"") + name_value = value.strip('"') name_value = name_value.replace(" ", "_") name_value = name_value.replace(",", "_") name_value = name_value.replace("-", "_") name_value = name_value.replace("@", "") self.class_values.append(f' {name_value.upper()} = "{value}"') - + result = f"class {self.class_name}(str, Enum):\n" result += "\n".join(self.class_values) - + return result - + def get_class_name(self) -> str: if not self.class_name: class_name = self.row_name @@ -56,7 +55,7 @@ def get_class_name(self) -> str: class_name = class_name.replace("@", "") class_name = class_name.title() class_name = class_name.replace(" ", "") - self.class_name = class_name+"Enum" + self.class_name = class_name + "Enum" return self.class_name def enum_type(self) -> dict: @@ -145,41 +144,47 @@ def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> def field_map(self) -> FieldMapDict: raise NotImplementedError + def get_field(self, table: str, column, enums_types: dict[str, EnumDataType]) -> str: + enum_key = enums_types.get(column.data_type) or enums_types.get(f"{table}_{column.name}") + if enum_key: + return self.field_map["enum"](**enum_key.enum_type(), **column.translate()) + return self.field_map[column.data_type](**column.translate()) + async def inspect(self) -> str: if not self.tables: self.tables = await self.get_all_tables() - result = "from tortoise import Model, fields\n\n\n" - - enums_types = {} - if getattr(self, "get_enums_data_types", False): - enums_types = await self.get_enums_data_types() + + result = ["from tortoise import Model, fields"] + enums_types: dict[str, EnumDataType] = await self.get_enums_data_types() if enums_types: - result += "from enum import Enum\n\n" - enums = [] - for key, value in enums_types.items(): - enums.append(value.get_enum_class() ) + result.append("from enum import Enum") + + # Генерация enum-классов + enums = [value.get_enum_class() for value in enums_types.values()] + # Генерация моделей tables = [] for table in self.tables: columns = await self.get_columns(table) - fields = [] + fields = [f" {self.get_field(table, column, enums_types)}" for column in columns] + model = self._table_template.format(table=table.title().replace("_", "")) - for column in columns: - if column.data_type in enums_types: - field = self.field_map["enum"](**enums_types[column.data_type].enum_type(), **column.translate()) - elif f"{table}_{column.name}" in enums_types: - field = self.field_map["enum"](**enums_types[f"{table}_{column.name}"].enum_type(), **column.translate()) - else: - field = self.field_map[column.data_type](**column.translate()) - - fields.append(" " + field) - tables.append(model + "\n".join(fields)) - tables.append(" class Meta:\n table = '" + table + "'\n\n") - - enums.extend(tables) - - return result + "\n\n\n".join(enums) + meta = f" class Meta:\n table = '{table}'\n" + + tables.append(f"{model}\n{'\n'.join(fields)}\n\n{meta}\n") + + result.extend(enums + tables) + return "\n\n\n".join(result) + + async def _get_enums(self): + raise NotImplementedError + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + raise NotImplementedError + + async def get_enums_names(self) -> set[str]: + raise NotImplementedError async def get_columns(self, table: str) -> list[Column]: raise NotImplementedError @@ -192,14 +197,14 @@ def get_field_string( field_class: str, arguments: str = "{null}{default}{comment}", **kwargs ) -> str: name: str = kwargs["name"] - arguments +="{source_field}" + arguments += "{source_field}" kwargs["source_field"] = f"source_field='{name}'" if "-" in name: name = name.replace("-", "_") if name[0].isdigit(): name = "_" + name name = name.replace("@", "") - + field_params = arguments.format(**kwargs).strip().rstrip(",") return f"{name} = fields.{field_class}({field_params})" @@ -265,4 +270,4 @@ def binary_field(cls, **kwargs) -> str: @classmethod def charenum_field(cls, **kwargs) -> str: arguments = "{enum_type}{null}" - return cls.get_field_string("CharEnumField", arguments, **kwargs) \ No newline at end of file + return cls.get_field_string("CharEnumField", arguments, **kwargs) diff --git a/aerich/inspectdb/mysql.py b/aerich/inspectdb/mysql.py index cb80778c..86b9fe68 100644 --- a/aerich/inspectdb/mysql.py +++ b/aerich/inspectdb/mysql.py @@ -1,6 +1,6 @@ from __future__ import annotations -from aerich.inspectdb import Column, FieldMapDict, Inspect, EnumDataType +from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect class InspectMySQL(Inspect): @@ -43,10 +43,9 @@ async def _get_enums(self): AND c.COLUMN_NAME = s.COLUMN_NAME WHERE c.TABLE_SCHEMA = %s AND c.DATA_TYPE = 'enum'; -""" +""" ret = await self.conn.execute_query_dict(sql, [self.database]) return ret - async def get_enums_data_types(self) -> dict[str, EnumDataType]: enums = {} @@ -62,14 +61,9 @@ async def get_enums_data_types(self) -> dict[str, EnumDataType]: type_values = row.get("ENUM_VALUES") if name in enums: continue - enums[name] = (EnumDataType( - row_name=name, - row_values=type_values - ) - ) - - return enums + enums[name] = EnumDataType(row_name=name, row_values=type_values) + return enums async def get_enums_names(self) -> set[str]: enum_names = set() @@ -85,7 +79,6 @@ async def get_enums_names(self) -> set[str]: enum_names.add(name) return enum_names - async def get_all_tables(self) -> list[str]: sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" ret = await self.conn.execute_query_dict(sql, [self.database]) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index 98303ca1..6d03578c 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING -from aerich.inspectdb import Column, FieldMapDict, Inspect, EnumDataType +from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect if TYPE_CHECKING: from tortoise.backends.base_postgres.client import BasePostgresClient @@ -86,13 +86,12 @@ async def _get_enums(self): AND current_database() = $1 -- Параметр для проверки базы данных GROUP BY ct.schema_name, ct.type_name, ct.type_category ORDER BY ct.schema_name, ct.type_name; -""" +""" ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) return ret async def get_enums_data_types(self) -> dict[str, EnumDataType]: enums = {} - enum_names = set() ret = await self._get_enums() for row in ret: name = row.get("type_name") @@ -101,12 +100,8 @@ async def get_enums_data_types(self) -> dict[str, EnumDataType]: type_values = row.get("type_values") if name in enums: continue - enums[name] = (EnumDataType( - row_name=name, - row_values=type_values - ) - ) - + enums[name] = EnumDataType(row_name=name, row_values=type_values) + return enums async def get_enums_names(self) -> set[str]: @@ -151,9 +146,11 @@ async def get_columns(self, table: str) -> list[Column]: max_digits=row["numeric_precision"], decimal_places=row["numeric_scale"], comment=row["column_comment"], - pk=row["column_key"] == "PRIMARY KEY" or (row["column_key"] == "UNIQUE" and row["column_name"] == "id"), + pk=row["column_key"] == "PRIMARY KEY" + or (row["column_key"] == "UNIQUE" and row["column_name"] == "id"), extra=None, - unique=row["column_key"] == "UNIQUE" and row["column_name"] != "id", # can't get this simply + unique=row["column_key"] == "UNIQUE" + and row["column_name"] != "id", # can't get this simply index=False, # can't get this simply ) ) diff --git a/aerich/inspectdb/sqlite.py b/aerich/inspectdb/sqlite.py index b7f07eb4..ff0dd339 100644 --- a/aerich/inspectdb/sqlite.py +++ b/aerich/inspectdb/sqlite.py @@ -1,6 +1,6 @@ from __future__ import annotations -from aerich.inspectdb import Column, FieldMapDict, Inspect +from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect class InspectSQLite(Inspect): @@ -12,7 +12,7 @@ def field_map(self) -> FieldMapDict: "INTEGER UNSIGNED": self.int_field, "SMALLINT": self.smallint_field, "SMALLINT UNSIGNED": self.smallint_field, - "BIGINT": self.bigint_field, + "BIGINT": self.bigint_field, "UNSIGNED BIG INT": self.bigint_field, "INT2": self.smallint_field, "INT8": self.bigint_field, @@ -21,7 +21,6 @@ def field_map(self) -> FieldMapDict: "TEXT": self.text_field, "TIMESTAMP": self.datetime_field, "REAL": self.float_field, - "BIGINT": self.bigint_field, "DATE": self.date_field, "BOOL": self.bool_field, "DATETIME": self.datetime_field, @@ -29,7 +28,6 @@ def field_map(self) -> FieldMapDict: "JSON": self.json_field, "BLOB": self.binary_field, "CHARACTER": self.char_field, - "VARCHAR": self.char_field, "VARYING CHARACTER": self.char_field, "NCHAR": self.char_field, "NATIVE CHARACTER": self.char_field, @@ -40,6 +38,15 @@ def field_map(self) -> FieldMapDict: "BOOLEAN": self.bool_field, } + async def _get_enums(self): + return [] + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + return {} + + async def get_enums_names(self) -> set[str]: + return set() + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = f"PRAGMA table_info({table})" From 792956764d4091189a2b895bcb6784036a745a43 Mon Sep 17 00:00:00 2001 From: "Evgeny (Krymmy) Momotov" Date: Sat, 1 Mar 2025 10:04:00 +0300 Subject: [PATCH 17/17] [FIX]inspectdb/__init__ -163;165| --- aerich/inspectdb/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 605fb518..f5f670cb 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -160,10 +160,8 @@ async def inspect(self) -> str: if enums_types: result.append("from enum import Enum") - # Генерация enum-классов enums = [value.get_enum_class() for value in enums_types.values()] - # Генерация моделей tables = [] for table in self.tables: columns = await self.get_columns(table)