Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
40cde79
[UP] inspectdb/__init__ class Inspect {+ def charenum_field}
KrymmyEM Feb 25, 2025
b51b4ea
[UP] inspectdb/__init__ class Inspect {~ def inspect +119 "class Meta…
KrymmyEM Feb 25, 2025
3a7cfbe
[UP] inspectdb/__init__.py +class EnumDataType
KrymmyEM Feb 25, 2025
e80e3ed
[UP] inspectdb/postgres.py class InspectPostgres {+ def get_enums*}
KrymmyEM Feb 25, 2025
cd39edc
[UP] inspectdb/postgres.py class InspectPostgres {~ def field_map}
KrymmyEM Feb 25, 2025
764913f
[UP] inspectdb/__init__ ~class EnumDataType {+ def get_class_name; en…
KrymmyEM Feb 28, 2025
c1ea756
[UP/FIX] inspectdb/postgres.py ~class InspectPostgres {~get_enums_dat…
KrymmyEM Feb 28, 2025
0bc2e0a
[FIX] inspectdb/postgres.py +5| +EnumDataType
KrymmyEM Feb 28, 2025
dd40655
[FIX] inspectdb/sqlite.py class InspectSQLite {~ field_map}
KrymmyEM Feb 28, 2025
97b6472
[FIX] inspectdb/sqlite.py class InspectSQLite {~ field_map}
KrymmyEM Feb 28, 2025
5cbf42e
[UP] inspectdb/mysql.py ~class InspectMySQL {+def *enum*; ~def field_…
KrymmyEM Feb 28, 2025
1ec012b
[UP] inspectdb/__init__ class Inspect {~ def inspect }
KrymmyEM Feb 28, 2025
645f0e2
Merge branch 'tortoise:dev' into fix/inspectdb
KrymmyEM Feb 28, 2025
dea80cc
[UP] inspectdb/__init__ class Inspect {~ def get_field_string }
KrymmyEM Feb 28, 2025
70268a3
[UP] inspectdb/__init__ class EnumDataType {~ def get_class_name }
KrymmyEM Feb 28, 2025
52c90f4
[UP] inspectdb/postgres.py -112|
KrymmyEM Feb 28, 2025
86aa176
[FIX] Inspectdb follow to code style
KrymmyEM Mar 1, 2025
7929567
[FIX]inspectdb/__init__ -163;165|
KrymmyEM Mar 1, 2025
3624962
Merge branch 'tortoise:dev' into fix/inspectdb
KrymmyEM Mar 20, 2025
18bfa22
Merge branch 'tortoise:dev' into fix/inspectdb
KrymmyEM Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions aerich/inspectdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,47 @@ class ColumnInfoDict(TypedDict):
FieldMapDict = dict[str, Callable[..., str]]


class EnumDataType(BaseModel):
row_name: str
row_values: str
class_name: str | None = None
class_values: list[str] = []

def get_enum_class(self) -> str:
self.class_name = self.get_class_name()
if len(self.class_values) != len(self.row_values.split(";")):
self.row_values = self.row_values.strip("{")
self.row_values = self.row_values.strip("}")
row_values = self.row_values.split(";")
for value in row_values:
if value.isdigit():
continue
name_value = value.strip('"')
name_value = name_value.replace(" ", "_")
name_value = name_value.replace(",", "_")
name_value = name_value.replace("-", "_")
name_value = name_value.replace("@", "")
self.class_values.append(f' {name_value.upper()} = "{value}"')

result = f"class {self.class_name}(str, Enum):\n"
result += "\n".join(self.class_values)

return result

def get_class_name(self) -> str:
if not self.class_name:
class_name = self.row_name
class_name = class_name.replace("_", " ")
class_name = class_name.replace("@", "")
class_name = class_name.title()
class_name = class_name.replace(" ", "")
self.class_name = class_name + "Enum"
return self.class_name

def enum_type(self) -> dict:
return {"enum_type": f"enum_type={self.get_class_name()}, "}


class Column(BaseModel):
name: str
data_type: str
Expand Down Expand Up @@ -102,20 +143,45 @@ def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) ->
def field_map(self) -> FieldMapDict:
raise NotImplementedError

def get_field(self, table: str, column, enums_types: dict[str, EnumDataType]) -> str:
enum_key = enums_types.get(column.data_type) or enums_types.get(f"{table}_{column.name}")
if enum_key:
return self.field_map["enum"](**enum_key.enum_type(), **column.translate())
return self.field_map[column.data_type](**column.translate())

async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"

result = ["from tortoise import Model, fields"]
enums_types: dict[str, EnumDataType] = await self.get_enums_data_types()

if enums_types:
result.append("from enum import Enum")

enums = [value.get_enum_class() for value in enums_types.values()]

tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
fields = [f" {self.get_field(table, column, enums_types)}" for column in columns]

model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
meta = f" class Meta:\n table = '{table}'\n"

tables.append(f"{model}\n{'\n'.join(fields)}\n\n{meta}\n")

result.extend(enums + tables)
return "\n\n\n".join(result)

async def _get_enums(self):
raise NotImplementedError

async def get_enums_data_types(self) -> dict[str, EnumDataType]:
raise NotImplementedError

async def get_enums_names(self) -> set[str]:
raise NotImplementedError

async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError
Expand All @@ -127,7 +193,15 @@ async def get_all_tables(self) -> list[str]:
def get_field_string(
field_class: str, arguments: str = "{null}{default}{comment}", **kwargs
) -> str:
name = kwargs["name"]
name: str = kwargs["name"]
arguments += "{source_field}"
kwargs["source_field"] = f"source_field='{name}'"
if "-" in name:
name = name.replace("-", "_")
if name[0].isdigit():
name = "_" + name
name = name.replace("@", "")

field_params = arguments.format(**kwargs).strip().rstrip(",")
return f"{name} = fields.{field_class}({field_params})"

Expand Down Expand Up @@ -189,3 +263,8 @@ def json_field(cls, **kwargs) -> str:
@classmethod
def binary_field(cls, **kwargs) -> str:
return cls.get_field_string("BinaryField", **kwargs)

@classmethod
def charenum_field(cls, **kwargs) -> str:
arguments = "{enum_type}{null}"
return cls.get_field_string("CharEnumField", arguments, **kwargs)
56 changes: 55 additions & 1 deletion aerich/inspectdb/mysql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from aerich.inspectdb import Column, FieldMapDict, Inspect
from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect


class InspectMySQL(Inspect):
Expand All @@ -23,8 +23,62 @@ def field_map(self) -> FieldMapDict:
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
"enum": self.charenum_field,
}

async def _get_enums(self):
sql = """SELECT c.TABLE_NAME,
c.COLUMN_NAME,
REPLACE(
REPLACE(
SUBSTRING_INDEX(SUBSTRING_INDEX(c.COLUMN_TYPE, '(', -1), ')', 1),
"','", "';'"
),
"'", ""
) AS ENUM_VALUES
FROM information_schema.COLUMNS c
LEFT JOIN information_schema.STATISTICS s
ON c.TABLE_NAME = s.TABLE_NAME
AND c.TABLE_SCHEMA = s.TABLE_SCHEMA
AND c.COLUMN_NAME = s.COLUMN_NAME
WHERE c.TABLE_SCHEMA = %s
AND c.DATA_TYPE = 'enum';
"""
ret = await self.conn.execute_query_dict(sql, [self.database])
return ret

async def get_enums_data_types(self) -> dict[str, EnumDataType]:
enums = {}
ret = await self._get_enums()
for row in ret:
table_name = row.get("TABLE_NAME")
column_name = row.get("COLUMN_NAME")

if not table_name or not column_name:
continue

name = f"{table_name}_{column_name}"
type_values = row.get("ENUM_VALUES")
if name in enums:
continue
enums[name] = EnumDataType(row_name=name, row_values=type_values)

return enums

async def get_enums_names(self) -> set[str]:
enum_names = set()
ret = await self._get_enums()
for row in ret:
table_name = row.get("TABLE_NAME")
column_name = row.get("COLUMN_NAME")

if not table_name or not column_name:
continue

name = f"{table_name}_{column_name}"
enum_names.add(name)
return enum_names

async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
Expand Down
84 changes: 79 additions & 5 deletions aerich/inspectdb/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import TYPE_CHECKING

from aerich.inspectdb import Column, FieldMapDict, Inspect
from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect

if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
Expand All @@ -18,31 +18,102 @@ def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) ->
def field_map(self) -> FieldMapDict:
return {
"int2": self.smallint_field,
"_int2": self.smallint_field,
"int4": self.int_field,
"int8": self.bigint_field,
"_int4": self.int_field,
"int8": self.int_field,
"_int8": self.int_field,
"smallint": self.smallint_field,
"bigint": self.bigint_field,
"_smallint": self.smallint_field,
"varchar": self.char_field,
"_varchar": self.char_field,
"text": self.text_field,
"_text": self.text_field,
"bigint": self.bigint_field,
"_bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"_timestamptz": self.datetime_field,
"float4": self.float_field,
"_float4": self.float_field,
"float8": self.float_field,
"_float8": self.float_field,
"date": self.date_field,
"_date": self.date_field,
"time": self.time_field,
"_time": self.time_field,
"timetz": self.time_field,
"_timetz": self.time_field,
"decimal": self.decimal_field,
"_decimal": self.decimal_field,
"numeric": self.decimal_field,
"_numeric": self.decimal_field,
"uuid": self.uuid_field,
"_uuid": self.uuid_field,
"jsonb": self.json_field,
"_jsonb": self.json_field,
"bytea": self.binary_field,
"_bytea": self.binary_field,
"bool": self.bool_field,
"_bool": self.bool_field,
"timestamp": self.datetime_field,
"_timestamp": self.datetime_field,
"enum": self.charenum_field,
}

async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))

async def _get_enums(self):
sql = """WITH enum_types AS (
SELECT n.nspname AS schema_name,
t.typname AS type_name,
t.typtype AS type_category
FROM pg_type t
JOIN pg_namespace n ON n.oid = t.typnamespace
WHERE t.typtype = 'e'
AND n.nspname NOT IN ('pg_catalog', 'information_schema')
)
SELECT ct.type_name,
CASE
WHEN ct.type_category = 'e' THEN array_to_string(array_agg(e.enumlabel ORDER BY e.enumsortorder), ';') -- ENUM значения с разделителем ;
ELSE NULL
END AS type_values
FROM enum_types ct
LEFT JOIN pg_enum e ON ct.type_category = 'e' AND e.enumtypid = (SELECT oid FROM pg_type WHERE typname = ct.type_name)
WHERE ct.schema_name = $2
AND current_database() = $1 -- Параметр для проверки базы данных
GROUP BY ct.schema_name, ct.type_name, ct.type_category
ORDER BY ct.schema_name, ct.type_name;
"""
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return ret

async def get_enums_data_types(self) -> dict[str, EnumDataType]:
enums = {}
ret = await self._get_enums()
for row in ret:
name = row.get("type_name")
if not name:
continue
type_values = row.get("type_values")
if name in enums:
continue
enums[name] = EnumDataType(row_name=name, row_values=type_values)

return enums

async def get_enums_names(self) -> set[str]:
enum_names = set()
ret = await self._get_enums()
for row in ret:
name = row.get("type_name")
if not name:
continue
enum_names.add(name)
return enum_names

async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"""select c.column_name,
Expand Down Expand Up @@ -75,8 +146,11 @@ async def get_columns(self, table: str) -> list[Column]:
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
pk=row["column_key"] == "PRIMARY KEY"
or (row["column_key"] == "UNIQUE" and row["column_name"] == "id"),
extra=None,
unique=row["column_key"] == "UNIQUE"
and row["column_name"] != "id", # can't get this simply
index=False, # can't get this simply
)
)
Expand Down
34 changes: 30 additions & 4 deletions aerich/inspectdb/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,52 @@
from __future__ import annotations

from aerich.inspectdb import Column, FieldMapDict, Inspect
from aerich.inspectdb import Column, EnumDataType, FieldMapDict, Inspect


class InspectSQLite(Inspect):
@property
def field_map(self) -> FieldMapDict:
return {
"INT": self.int_field,
"INTEGER": self.int_field,
"INT": self.bool_field,
"INTEGER UNSIGNED": self.int_field,
"SMALLINT": self.smallint_field,
"SMALLINT UNSIGNED": self.smallint_field,
"BIGINT": self.bigint_field,
"UNSIGNED BIG INT": self.bigint_field,
"INT2": self.smallint_field,
"INT8": self.bigint_field,
"TINYINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"BOOL": self.bool_field,
"DATETIME": self.datetime_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
"CHARACTER": self.char_field,
"VARYING CHARACTER": self.char_field,
"NCHAR": self.char_field,
"NATIVE CHARACTER": self.char_field,
"NVARCHAR": self.char_field,
"CLOB": self.text_field,
"NUMERIC": self.decimal_field,
"DECIMAL": self.decimal_field,
"BOOLEAN": self.bool_field,
}

async def _get_enums(self):
return []

async def get_enums_data_types(self) -> dict[str, EnumDataType]:
return {}

async def get_enums_names(self) -> set[str]:
return set()

async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
Expand All @@ -34,7 +60,7 @@ async def get_columns(self, table: str) -> list[Column]:
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
data_type=row["type"].split("(")[0].upper(),
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
Expand Down