diff --git a/aerich/inspectdb/__init__.py b/aerich/inspectdb/__init__.py index 2f7bf553..da6368ed 100644 --- a/aerich/inspectdb/__init__.py +++ b/aerich/inspectdb/__init__.py @@ -20,6 +20,47 @@ class ColumnInfoDict(TypedDict): FieldMapDict = dict[str, Callable[..., str]] +class EnumDataType(BaseModel): + row_name: str + row_values: str + class_name: str | None = None + class_values: list[str] = [] + + def get_enum_class(self) -> str: + self.class_name = self.get_class_name() + if len(self.class_values) != len(self.row_values.split(";")): + self.row_values = self.row_values.strip("{") + self.row_values = self.row_values.strip("}") + row_values = self.row_values.split(";") + for value in row_values: + if value.isdigit(): + continue + name_value = value.strip('"') + name_value = name_value.replace(" ", "_") + name_value = name_value.replace(",", "_") + name_value = name_value.replace("-", "_") + name_value = name_value.replace("@", "") + self.class_values.append(f' {name_value.upper()} = "{value}"') + + result = f"class {self.class_name}(str, Enum):\n" + result += "\n".join(self.class_values) + + return result + + def get_class_name(self) -> str: + if not self.class_name: + class_name = self.row_name + class_name = class_name.replace("_", " ") + class_name = class_name.replace("@", "") + class_name = class_name.title() + class_name = class_name.replace(" ", "") + self.class_name = class_name + "Enum" + return self.class_name + + def enum_type(self) -> dict: + return {"enum_type": f"enum_type={self.get_class_name()}, "} + + class Column(BaseModel): name: str data_type: str @@ -102,20 +143,45 @@ def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> def field_map(self) -> FieldMapDict: raise NotImplementedError + def get_field(self, table: str, column, enums_types: dict[str, EnumDataType]) -> str: + enum_key = enums_types.get(column.data_type) or enums_types.get(f"{table}_{column.name}") + if enum_key: + return self.field_map["enum"](**enum_key.enum_type(), **column.translate()) + return self.field_map[column.data_type](**column.translate()) + async def inspect(self) -> str: if not self.tables: self.tables = await self.get_all_tables() - result = "from tortoise import Model, fields\n\n\n" + + result = ["from tortoise import Model, fields"] + enums_types: dict[str, EnumDataType] = await self.get_enums_data_types() + + if enums_types: + result.append("from enum import Enum") + + enums = [value.get_enum_class() for value in enums_types.values()] + tables = [] for table in self.tables: columns = await self.get_columns(table) - fields = [] + fields = [f" {self.get_field(table, column, enums_types)}" for column in columns] + model = self._table_template.format(table=table.title().replace("_", "")) - for column in columns: - field = self.field_map[column.data_type](**column.translate()) - fields.append(" " + field) - tables.append(model + "\n".join(fields)) - return result + "\n\n\n".join(tables) + meta = f" class Meta:\n table = '{table}'\n" + + tables.append(f"{model}\n{'\n'.join(fields)}\n\n{meta}\n") + + result.extend(enums + tables) + return "\n\n\n".join(result) + + async def _get_enums(self): + raise NotImplementedError + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + raise NotImplementedError + + async def get_enums_names(self) -> set[str]: + raise NotImplementedError async def get_columns(self, table: str) -> list[Column]: raise NotImplementedError @@ -127,7 +193,15 @@ async def get_all_tables(self) -> list[str]: def get_field_string( field_class: str, arguments: str = "{null}{default}{comment}", **kwargs ) -> str: - name = kwargs["name"] + name: str = kwargs["name"] + arguments += "{source_field}" + kwargs["source_field"] = f"source_field='{name}'" + if "-" in name: + name = name.replace("-", "_") + if name[0].isdigit(): + name = "_" + name + name = name.replace("@", "") + field_params = arguments.format(**kwargs).strip().rstrip(",") return f"{name} = fields.{field_class}({field_params})" @@ -189,3 +263,8 @@ def json_field(cls, **kwargs) -> str: @classmethod def binary_field(cls, **kwargs) -> str: return cls.get_field_string("BinaryField", **kwargs) + + @classmethod + def charenum_field(cls, **kwargs) -> str: + arguments = "{enum_type}{null}" + return cls.get_field_string("CharEnumField", arguments, **kwargs) diff --git a/aerich/inspectdb/mysql.py b/aerich/inspectdb/mysql.py index 1034fd92..86b9fe68 100644 --- a/aerich/inspectdb/mysql.py +++ b/aerich/inspectdb/mysql.py @@ -1,6 +1,6 @@ from __future__ import annotations -from aerich.inspectdb import Column, FieldMapDict, Inspect +from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect class InspectMySQL(Inspect): @@ -23,8 +23,62 @@ def field_map(self) -> FieldMapDict: "decimal": self.decimal_field, "json": self.json_field, "longblob": self.binary_field, + "enum": self.charenum_field, } + async def _get_enums(self): + sql = """SELECT c.TABLE_NAME, + c.COLUMN_NAME, + REPLACE( + REPLACE( + SUBSTRING_INDEX(SUBSTRING_INDEX(c.COLUMN_TYPE, '(', -1), ')', 1), + "','", "';'" + ), + "'", "" + ) AS ENUM_VALUES +FROM information_schema.COLUMNS c +LEFT JOIN information_schema.STATISTICS s + ON c.TABLE_NAME = s.TABLE_NAME + AND c.TABLE_SCHEMA = s.TABLE_SCHEMA + AND c.COLUMN_NAME = s.COLUMN_NAME +WHERE c.TABLE_SCHEMA = %s + AND c.DATA_TYPE = 'enum'; +""" + ret = await self.conn.execute_query_dict(sql, [self.database]) + return ret + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + enums = {} + ret = await self._get_enums() + for row in ret: + table_name = row.get("TABLE_NAME") + column_name = row.get("COLUMN_NAME") + + if not table_name or not column_name: + continue + + name = f"{table_name}_{column_name}" + type_values = row.get("ENUM_VALUES") + if name in enums: + continue + enums[name] = EnumDataType(row_name=name, row_values=type_values) + + return enums + + async def get_enums_names(self) -> set[str]: + enum_names = set() + ret = await self._get_enums() + for row in ret: + table_name = row.get("TABLE_NAME") + column_name = row.get("COLUMN_NAME") + + if not table_name or not column_name: + continue + + name = f"{table_name}_{column_name}" + enum_names.add(name) + return enum_names + async def get_all_tables(self) -> list[str]: sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" ret = await self.conn.execute_query_dict(sql, [self.database]) diff --git a/aerich/inspectdb/postgres.py b/aerich/inspectdb/postgres.py index c9bf133f..6d03578c 100644 --- a/aerich/inspectdb/postgres.py +++ b/aerich/inspectdb/postgres.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING -from aerich.inspectdb import Column, FieldMapDict, Inspect +from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect if TYPE_CHECKING: from tortoise.backends.base_postgres.client import BasePostgresClient @@ -18,24 +18,46 @@ def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) -> def field_map(self) -> FieldMapDict: return { "int2": self.smallint_field, + "_int2": self.smallint_field, "int4": self.int_field, - "int8": self.bigint_field, + "_int4": self.int_field, + "int8": self.int_field, + "_int8": self.int_field, "smallint": self.smallint_field, - "bigint": self.bigint_field, + "_smallint": self.smallint_field, "varchar": self.char_field, + "_varchar": self.char_field, "text": self.text_field, + "_text": self.text_field, + "bigint": self.bigint_field, + "_bigint": self.bigint_field, "timestamptz": self.datetime_field, + "_timestamptz": self.datetime_field, "float4": self.float_field, + "_float4": self.float_field, "float8": self.float_field, + "_float8": self.float_field, "date": self.date_field, + "_date": self.date_field, "time": self.time_field, + "_time": self.time_field, + "timetz": self.time_field, + "_timetz": self.time_field, "decimal": self.decimal_field, + "_decimal": self.decimal_field, "numeric": self.decimal_field, + "_numeric": self.decimal_field, "uuid": self.uuid_field, + "_uuid": self.uuid_field, "jsonb": self.json_field, + "_jsonb": self.json_field, "bytea": self.binary_field, + "_bytea": self.binary_field, "bool": self.bool_field, + "_bool": self.bool_field, "timestamp": self.datetime_field, + "_timestamp": self.datetime_field, + "enum": self.charenum_field, } async def get_all_tables(self) -> list[str]: @@ -43,6 +65,55 @@ async def get_all_tables(self) -> list[str]: ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) return list(map(lambda x: x["table_name"], ret)) + async def _get_enums(self): + sql = """WITH enum_types AS ( + SELECT n.nspname AS schema_name, + t.typname AS type_name, + t.typtype AS type_category + FROM pg_type t + JOIN pg_namespace n ON n.oid = t.typnamespace + WHERE t.typtype = 'e' + AND n.nspname NOT IN ('pg_catalog', 'information_schema') +) +SELECT ct.type_name, + CASE + WHEN ct.type_category = 'e' THEN array_to_string(array_agg(e.enumlabel ORDER BY e.enumsortorder), ';') -- ENUM значения с разделителем ; + ELSE NULL + END AS type_values +FROM enum_types ct +LEFT JOIN pg_enum e ON ct.type_category = 'e' AND e.enumtypid = (SELECT oid FROM pg_type WHERE typname = ct.type_name) +WHERE ct.schema_name = $2 + AND current_database() = $1 -- Параметр для проверки базы данных +GROUP BY ct.schema_name, ct.type_name, ct.type_category +ORDER BY ct.schema_name, ct.type_name; +""" + ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) + return ret + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + enums = {} + ret = await self._get_enums() + for row in ret: + name = row.get("type_name") + if not name: + continue + type_values = row.get("type_values") + if name in enums: + continue + enums[name] = EnumDataType(row_name=name, row_values=type_values) + + return enums + + async def get_enums_names(self) -> set[str]: + enum_names = set() + ret = await self._get_enums() + for row in ret: + name = row.get("type_name") + if not name: + continue + enum_names.add(name) + return enum_names + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = f"""select c.column_name, @@ -75,8 +146,11 @@ async def get_columns(self, table: str) -> list[Column]: max_digits=row["numeric_precision"], decimal_places=row["numeric_scale"], comment=row["column_comment"], - pk=row["column_key"] == "PRIMARY KEY", - unique=False, # can't get this simply + pk=row["column_key"] == "PRIMARY KEY" + or (row["column_key"] == "UNIQUE" and row["column_name"] == "id"), + extra=None, + unique=row["column_key"] == "UNIQUE" + and row["column_name"] != "id", # can't get this simply index=False, # can't get this simply ) ) diff --git a/aerich/inspectdb/sqlite.py b/aerich/inspectdb/sqlite.py index b729c738..ff0dd339 100644 --- a/aerich/inspectdb/sqlite.py +++ b/aerich/inspectdb/sqlite.py @@ -1,26 +1,52 @@ from __future__ import annotations -from aerich.inspectdb import Column, FieldMapDict, Inspect +from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect class InspectSQLite(Inspect): @property def field_map(self) -> FieldMapDict: return { + "INT": self.int_field, "INTEGER": self.int_field, - "INT": self.bool_field, + "INTEGER UNSIGNED": self.int_field, "SMALLINT": self.smallint_field, + "SMALLINT UNSIGNED": self.smallint_field, + "BIGINT": self.bigint_field, + "UNSIGNED BIG INT": self.bigint_field, + "INT2": self.smallint_field, + "INT8": self.bigint_field, + "TINYINT": self.smallint_field, "VARCHAR": self.char_field, "TEXT": self.text_field, "TIMESTAMP": self.datetime_field, "REAL": self.float_field, - "BIGINT": self.bigint_field, "DATE": self.date_field, + "BOOL": self.bool_field, + "DATETIME": self.datetime_field, "TIME": self.time_field, "JSON": self.json_field, "BLOB": self.binary_field, + "CHARACTER": self.char_field, + "VARYING CHARACTER": self.char_field, + "NCHAR": self.char_field, + "NATIVE CHARACTER": self.char_field, + "NVARCHAR": self.char_field, + "CLOB": self.text_field, + "NUMERIC": self.decimal_field, + "DECIMAL": self.decimal_field, + "BOOLEAN": self.bool_field, } + async def _get_enums(self): + return [] + + async def get_enums_data_types(self) -> dict[str, EnumDataType]: + return {} + + async def get_enums_names(self) -> set[str]: + return set() + async def get_columns(self, table: str) -> list[Column]: columns = [] sql = f"PRAGMA table_info({table})" @@ -34,7 +60,7 @@ async def get_columns(self, table: str) -> list[Column]: columns.append( Column( name=row["name"], - data_type=row["type"].split("(")[0], + data_type=row["type"].split("(")[0].upper(), null=row["notnull"] == 0, default=row["dflt_value"], length=length,