diff --git a/.gitignore b/.gitignore index d210b77b..779ee46e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ .idea/ .vscode/ +.cursor/ +.DS_Store venv/ .venv/ .python-version diff --git a/backend/app/admin/conf.py b/backend/app/admin/conf.py index b653a45a..ea222dc7 100644 --- a/backend/app/admin/conf.py +++ b/backend/app/admin/conf.py @@ -4,34 +4,31 @@ from pydantic_settings import BaseSettings, SettingsConfigDict -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH class AdminSettings(BaseSettings): - """Admin Settings""" + """Admin 配置""" - model_config = SettingsConfigDict(env_file=f'{BasePath}/.env', env_file_encoding='utf-8', extra='ignore') + model_config = SettingsConfigDict(env_file=f'{BASE_PATH}/.env', env_file_encoding='utf-8', extra='ignore') - # OAuth2:https://github.com/fastapi-practices/fastapi_oauth20 - # GitHub + # .env OAuth2 配置 OAUTH2_GITHUB_CLIENT_ID: str OAUTH2_GITHUB_CLIENT_SECRET: str - OAUTH2_GITHUB_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/github/callback' - - # Linux Do OAUTH2_LINUX_DO_CLIENT_ID: str OAUTH2_LINUX_DO_CLIENT_SECRET: str - OAUTH2_LINUX_DO_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/linux-do/callback' - # Front-end redirect address + # OAuth2 配置 + OAUTH2_GITHUB_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/github/callback' + OAUTH2_LINUX_DO_REDIRECT_URI: str = 'http://127.0.0.1:8000/api/v1/oauth2/linux-do/callback' OAUTH2_FRONTEND_REDIRECT_URI: str = 'http://localhost:5173/oauth2/callback' - # Captcha + # 验证码配置 CAPTCHA_LOGIN_REDIS_PREFIX: str = 'fba:login:captcha' - CAPTCHA_LOGIN_EXPIRE_SECONDS: int = 60 * 5 # 过期时间,单位:秒 + CAPTCHA_LOGIN_EXPIRE_SECONDS: int = 60 * 5 # 3 分钟 - # Config - CONFIG_BUILT_IN_TYPES: list = ['website', 'protocol', 'policy'] + # 系统配置 + CONFIG_BUILT_IN_TYPES: list[str] = ['website', 'protocol', 'policy'] @lru_cache diff --git a/backend/app/admin/crud/crud_user.py b/backend/app/admin/crud/crud_user.py index 1c0a06f1..1a5c6609 100644 --- a/backend/app/admin/crud/crud_user.py +++ b/backend/app/admin/crud/crud_user.py @@ -176,10 +176,10 @@ async def get_list(self, dept: int = None, username: str = None, phone: str = No """ 获取用户列表 - :param dept: - :param username: - :param phone: - :param status: + :param dept: 部门 ID(可选) + :param username: 用户名(可选) + :param phone: 电话号码(可选) + :param status: 用户状态(可选) :return: """ stmt = ( @@ -191,17 +191,22 @@ async def get_list(self, dept: int = None, username: str = None, phone: str = No ) .order_by(desc(self.model.join_time)) ) - where_list = [] + + # 构建过滤条件 + filters = [] if dept: - where_list.append(self.model.dept_id == dept) + filters.append(self.model.dept_id == dept) if username: - where_list.append(self.model.username.like(f'%{username}%')) + filters.append(self.model.username.like(f'%{username}%')) if phone: - where_list.append(self.model.phone.like(f'%{phone}%')) + filters.append(self.model.phone.like(f'%{phone}%')) if status is not None: - where_list.append(self.model.status == status) - if where_list: - stmt = stmt.where(and_(*where_list)) + filters.append(self.model.status == status) + + # 应用过滤条件 + if filters: + stmt = stmt.where(and_(*filters)) + return stmt async def get_super(self, db: AsyncSession, user_id: int) -> bool: diff --git a/backend/app/admin/schema/captcha.py b/backend/app/admin/schema/captcha.py index 71ea24ca..0c1bee48 100644 --- a/backend/app/admin/schema/captcha.py +++ b/backend/app/admin/schema/captcha.py @@ -6,5 +6,7 @@ class GetCaptchaDetail(SchemaBase): + """验证码详情""" + image_type: str = Field(description='图片类型') image: str = Field(description='图片内容') diff --git a/backend/app/admin/schema/config.py b/backend/app/admin/schema/config.py index ff9a85d9..69b97379 100644 --- a/backend/app/admin/schema/config.py +++ b/backend/app/admin/schema/config.py @@ -2,37 +2,43 @@ # -*- coding: utf-8 -*- from datetime import datetime -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from backend.common.schema import SchemaBase class SaveBuiltInConfigParam(SchemaBase): - name: str - key: str - value: str + """保存内置配置参数""" + + name: str = Field(description='配置名称') + key: str = Field(description='配置键名') + value: str = Field(description='配置值') class ConfigSchemaBase(SchemaBase): - name: str - type: str | None - key: str - value: str - is_frontend: bool - remark: str | None + """配置基础模型""" + + name: str = Field(description='配置名称') + type: str | None = Field(default=None, description='配置类型') + key: str = Field(description='配置键名') + value: str = Field(description='配置值') + is_frontend: bool = Field(description='是否前端配置') + remark: str | None = Field(default=None, description='备注') class CreateConfigParam(ConfigSchemaBase): - pass + """创建配置参数""" class UpdateConfigParam(ConfigSchemaBase): - pass + """更新配置参数""" class GetConfigDetail(ConfigSchemaBase): + """配置详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='配置 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') diff --git a/backend/app/admin/schema/data_rule.py b/backend/app/admin/schema/data_rule.py index 80f8211b..ce3a1202 100644 --- a/backend/app/admin/schema/data_rule.py +++ b/backend/app/admin/schema/data_rule.py @@ -9,28 +9,33 @@ class DataRuleSchemaBase(SchemaBase): - name: str - model: str - column: str - operator: RoleDataRuleOperatorType = Field(RoleDataRuleOperatorType.OR) - expression: RoleDataRuleExpressionType = Field(RoleDataRuleExpressionType.eq) - value: str + """数据规则基础模型""" + + name: str = Field(description='规则名称') + model: str = Field(description='模型名称') + column: str = Field(description='字段名称') + operator: RoleDataRuleOperatorType = Field(default=RoleDataRuleOperatorType.OR, description='操作符(AND/OR)') + expression: RoleDataRuleExpressionType = Field(default=RoleDataRuleExpressionType.eq, description='表达式类型') + value: str = Field(description='规则值') class CreateDataRuleParam(DataRuleSchemaBase): - pass + """创建数据规则参数""" class UpdateDataRuleParam(DataRuleSchemaBase): - pass + """更新数据规则参数""" class GetDataRuleDetail(DataRuleSchemaBase): + """数据规则详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='规则 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') - def __hash__(self): + def __hash__(self) -> int: + """计算哈希值""" return hash(self.name) diff --git a/backend/app/admin/schema/dept.py b/backend/app/admin/schema/dept.py index c8186076..19a90e82 100644 --- a/backend/app/admin/schema/dept.py +++ b/backend/app/admin/schema/dept.py @@ -9,27 +9,31 @@ class DeptSchemaBase(SchemaBase): - name: str - parent_id: int | None = Field(default=None, description='部门父级ID') + """部门基础模型""" + + name: str = Field(description='部门名称') + parent_id: int | None = Field(default=None, description='部门父级 ID') sort: int = Field(default=0, ge=0, description='排序') - leader: str | None = None - phone: CustomPhoneNumber | None = None - email: CustomEmailStr | None = None - status: StatusType = Field(default=StatusType.enable) + leader: str | None = Field(default=None, description='负责人') + phone: CustomPhoneNumber | None = Field(default=None, description='联系电话') + email: CustomEmailStr | None = Field(default=None, description='邮箱') + status: StatusType = Field(default=StatusType.enable, description='状态') class CreateDeptParam(DeptSchemaBase): - pass + """创建部门参数""" class UpdateDeptParam(DeptSchemaBase): - pass + """更新部门参数""" class GetDeptDetail(DeptSchemaBase): + """部门详情""" + model_config = ConfigDict(from_attributes=True) - id: int - del_flag: bool - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='部门 ID') + del_flag: bool = Field(description='是否删除') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') diff --git a/backend/app/admin/schema/dict_data.py b/backend/app/admin/schema/dict_data.py index 013223ef..79aabe04 100644 --- a/backend/app/admin/schema/dict_data.py +++ b/backend/app/admin/schema/dict_data.py @@ -10,29 +10,35 @@ class DictDataSchemaBase(SchemaBase): - type_id: int - label: str - value: str - sort: int - status: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """字典数据基础模型""" + + type_id: int = Field(description='字典类型 ID') + label: str = Field(description='字典标签') + value: str = Field(description='字典值') + sort: int = Field(description='排序') + status: StatusType = Field(default=StatusType.enable, description='状态') + remark: str | None = Field(default=None, description='备注') class CreateDictDataParam(DictDataSchemaBase): - pass + """创建字典数据参数""" class UpdateDictDataParam(DictDataSchemaBase): - pass + """更新字典数据参数""" class GetDictDataDetail(DictDataSchemaBase): + """字典数据详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='字典数据 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') class GetDictDataWithRelation(DictDataSchemaBase): - type: GetDictTypeDetail | None = None + """字典数据关联详情""" + + type: GetDictTypeDetail | None = Field(default=None, description='字典类型信息') diff --git a/backend/app/admin/schema/dict_type.py b/backend/app/admin/schema/dict_type.py index 42fe4916..7ee5d4f8 100644 --- a/backend/app/admin/schema/dict_type.py +++ b/backend/app/admin/schema/dict_type.py @@ -9,23 +9,31 @@ class DictTypeSchemaBase(SchemaBase): - name: str - code: str - status: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """字典类型基础模型""" + + name: str = Field(description='字典名称') + code: str = Field(description='字典编码') + status: StatusType = Field(default=StatusType.enable, description='状态') + remark: str | None = Field(default=None, description='备注') class CreateDictTypeParam(DictTypeSchemaBase): + """创建字典类型参数""" + pass class UpdateDictTypeParam(DictTypeSchemaBase): + """更新字典类型参数""" + pass class GetDictTypeDetail(DictTypeSchemaBase): + """字典类型详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='字典类型 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') diff --git a/backend/app/admin/schema/login_log.py b/backend/app/admin/schema/login_log.py index 50218e8c..292fd822 100644 --- a/backend/app/admin/schema/login_log.py +++ b/backend/app/admin/schema/login_log.py @@ -2,37 +2,41 @@ # -*- coding: utf-8 -*- from datetime import datetime -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from backend.common.schema import SchemaBase class LoginLogSchemaBase(SchemaBase): - user_uuid: str - username: str - status: int - ip: str - country: str | None - region: str | None - city: str | None - user_agent: str - browser: str | None - os: str | None - device: str | None - msg: str - login_time: datetime + """登录日志基础模型""" + + user_uuid: str = Field(description='用户 UUID') + username: str = Field(description='用户名') + status: int = Field(description='登录状态') + ip: str = Field(description='IP 地址') + country: str | None = Field(default=None, description='国家') + region: str | None = Field(default=None, description='地区') + city: str | None = Field(default=None, description='城市') + user_agent: str = Field(description='用户代理') + browser: str | None = Field(default=None, description='浏览器') + os: str | None = Field(default=None, description='操作系统') + device: str | None = Field(default=None, description='设备') + msg: str = Field(description='消息') + login_time: datetime = Field(description='登录时间') class CreateLoginLogParam(LoginLogSchemaBase): - pass + """创建登录日志参数""" class UpdateLoginLogParam(LoginLogSchemaBase): - pass + """更新登录日志参数""" class GetLoginLogDetail(LoginLogSchemaBase): + """登录日志详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime + id: int = Field(description='日志 ID') + created_time: datetime = Field(description='创建时间') diff --git a/backend/app/admin/schema/menu.py b/backend/app/admin/schema/menu.py index 7961d1e7..25d17c9a 100644 --- a/backend/app/admin/schema/menu.py +++ b/backend/app/admin/schema/menu.py @@ -9,32 +9,36 @@ class MenuSchemaBase(SchemaBase): - title: str - name: str - parent_id: int | None = Field(default=None, description='菜单父级ID') + """菜单基础模型""" + + title: str = Field(description='菜单标题') + name: str = Field(description='菜单名称') + parent_id: int | None = Field(default=None, description='菜单父级 ID') sort: int = Field(default=0, ge=0, description='排序') - icon: str | None = None - path: str | None = None + icon: str | None = Field(default=None, description='图标') + path: str | None = Field(default=None, description='路由路径') menu_type: MenuType = Field(default=MenuType.directory, description='菜单类型(0目录 1菜单 2按钮)') - component: str | None = None - perms: str | None = None - status: StatusType = Field(default=StatusType.enable) - display: StatusType = Field(default=StatusType.enable) - cache: StatusType = Field(default=StatusType.enable) - remark: str | None = None + component: str | None = Field(default=None, description='组件路径') + perms: str | None = Field(default=None, description='权限标识') + status: StatusType = Field(default=StatusType.enable, description='状态') + display: StatusType = Field(default=StatusType.enable, description='是否显示') + cache: StatusType = Field(default=StatusType.enable, description='是否缓存') + remark: str | None = Field(default=None, description='备注') class CreateMenuParam(MenuSchemaBase): - pass + """创建菜单参数""" class UpdateMenuParam(MenuSchemaBase): - pass + """更新菜单参数""" class GetMenuDetail(MenuSchemaBase): + """菜单详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='菜单 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') diff --git a/backend/app/admin/schema/opera_log.py b/backend/app/admin/schema/opera_log.py index 0791c488..16683dee 100644 --- a/backend/app/admin/schema/opera_log.py +++ b/backend/app/admin/schema/opera_log.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from datetime import datetime +from typing import Any from pydantic import ConfigDict, Field @@ -9,37 +10,41 @@ class OperaLogSchemaBase(SchemaBase): - trace_id: str - username: str | None = None - method: str - title: str - path: str - ip: str - country: str | None = None - region: str | None = None - city: str | None = None - user_agent: str - os: str | None = None - browser: str | None = None - device: str | None = None - args: dict | None = None - status: StatusType = Field(default=StatusType.enable) - code: str - msg: str | None = None - cost_time: float - opera_time: datetime + """操作日志基础模型""" + + trace_id: str = Field(description='追踪 ID') + username: str | None = Field(default=None, description='用户名') + method: str = Field(description='请求方法') + title: str = Field(description='操作标题') + path: str = Field(description='请求路径') + ip: str = Field(description='IP 地址') + country: str | None = Field(default=None, description='国家') + region: str | None = Field(default=None, description='地区') + city: str | None = Field(default=None, description='城市') + user_agent: str = Field(description='用户代理') + os: str | None = Field(default=None, description='操作系统') + browser: str | None = Field(default=None, description='浏览器') + device: str | None = Field(default=None, description='设备') + args: dict[str, Any] | None = Field(default=None, description='请求参数') + status: StatusType = Field(default=StatusType.enable, description='状态') + code: str = Field(description='状态码') + msg: str | None = Field(default=None, description='消息') + cost_time: float = Field(description='耗时') + opera_time: datetime = Field(description='操作时间') class CreateOperaLogParam(OperaLogSchemaBase): - pass + """创建操作日志参数""" class UpdateOperaLogParam(OperaLogSchemaBase): - pass + """更新操作日志参数""" class GetOperaLogDetail(OperaLogSchemaBase): + """操作日志详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime + id: int = Field(description='日志 ID') + created_time: datetime = Field(description='创建时间') diff --git a/backend/app/admin/schema/role.py b/backend/app/admin/schema/role.py index 32c4e21c..29c2d4e7 100644 --- a/backend/app/admin/schema/role.py +++ b/backend/app/admin/schema/role.py @@ -11,35 +11,45 @@ class RoleSchemaBase(SchemaBase): - name: str - status: StatusType = Field(default=StatusType.enable) - remark: str | None = None + """角色基础模型""" + + name: str = Field(description='角色名称') + status: StatusType = Field(default=StatusType.enable, description='状态') + remark: str | None = Field(default=None, description='备注') class CreateRoleParam(RoleSchemaBase): - pass + """创建角色参数""" class UpdateRoleParam(RoleSchemaBase): - pass + """更新角色参数""" class UpdateRoleMenuParam(SchemaBase): - menus: list[int] + """更新角色菜单参数""" + + menus: list[int] = Field(description='菜单 ID 列表') class UpdateRoleRuleParam(SchemaBase): - rules: list[int] + """更新角色规则参数""" + + rules: list[int] = Field(description='数据规则 ID 列表') class GetRoleDetail(RoleSchemaBase): + """角色详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='角色 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') class GetRoleWithRelationDetail(GetRoleDetail): - menus: list[GetMenuDetail | None] = [] - rules: list[GetDataRuleDetail | None] = [] + """角色关联详情""" + + menus: list[GetMenuDetail | None] = Field(default=[], description='菜单详情列表') + rules: list[GetDataRuleDetail | None] = Field(default=[], description='数据规则详情列表') diff --git a/backend/app/admin/schema/token.py b/backend/app/admin/schema/token.py index 9e0b6301..2b25070f 100644 --- a/backend/app/admin/schema/token.py +++ b/backend/app/admin/schema/token.py @@ -2,44 +2,56 @@ # -*- coding: utf-8 -*- from datetime import datetime +from pydantic import Field + from backend.app.admin.schema.user import GetUserInfoDetail from backend.common.enums import StatusType from backend.common.schema import SchemaBase class GetSwaggerToken(SchemaBase): - access_token: str - token_type: str = 'Bearer' - user: GetUserInfoDetail + """Swagger 认证令牌""" + + access_token: str = Field(description='访问令牌') + token_type: str = Field(default='Bearer', description='令牌类型') + user: GetUserInfoDetail = Field(description='用户信息') class AccessTokenBase(SchemaBase): - access_token: str - access_token_expire_time: datetime - session_uuid: str + """访问令牌基础模型""" + + access_token: str = Field(description='访问令牌') + access_token_expire_time: datetime = Field(description='令牌过期时间') + session_uuid: str = Field(description='会话 UUID') class GetNewToken(AccessTokenBase): - pass + """获取新令牌""" class GetLoginToken(AccessTokenBase): - user: GetUserInfoDetail + """获取登录令牌""" + + user: GetUserInfoDetail = Field(description='用户信息') class KickOutToken(SchemaBase): - session_uuid: str + """踢出令牌""" + + session_uuid: str = Field(description='会话 UUID') class GetTokenDetail(SchemaBase): - id: int - session_uuid: str - username: str - nickname: str - ip: str - os: str - browser: str - device: str - status: StatusType - last_login_time: str - expire_time: datetime + """令牌详情""" + + id: int = Field(description='用户 ID') + session_uuid: str = Field(description='会话 UUID') + username: str = Field(description='用户名') + nickname: str = Field(description='昵称') + ip: str = Field(description='IP 地址') + os: str = Field(description='操作系统') + browser: str = Field(description='浏览器') + device: str = Field(description='设备') + status: StatusType = Field(description='状态') + last_login_time: str = Field(description='最后登录时间') + expire_time: datetime = Field(description='过期时间') diff --git a/backend/app/admin/schema/user.py b/backend/app/admin/schema/user.py index 284b127b..7925b24e 100644 --- a/backend/app/admin/schema/user.py +++ b/backend/app/admin/schema/user.py @@ -13,84 +13,106 @@ class AuthSchemaBase(SchemaBase): - username: str - password: str | None + """用户认证基础模型""" + + username: str = Field(description='用户名') + password: str | None = Field(description='密码') class AuthLoginParam(AuthSchemaBase): - captcha: str + """用户登录参数""" + + captcha: str = Field(description='验证码') class RegisterUserParam(AuthSchemaBase): - nickname: str | None = None - email: EmailStr = Field(examples=['user@example.com']) + """用户注册参数""" + + nickname: str | None = Field(default=None, description='昵称') + email: EmailStr = Field(examples=['user@example.com'], description='邮箱') class AddUserParam(AuthSchemaBase): - dept_id: int - roles: list[int] - nickname: str | None = None - email: EmailStr = Field(examples=['user@example.com']) + """添加用户参数""" + + dept_id: int = Field(description='部门 ID') + roles: list[int] = Field(description='角色 ID 列表') + nickname: str | None = Field(default=None, description='昵称') + email: EmailStr = Field(examples=['user@example.com'], description='邮箱') class ResetPasswordParam(SchemaBase): - old_password: str - new_password: str - confirm_password: str + """重置密码参数""" + + old_password: str = Field(description='旧密码') + new_password: str = Field(description='新密码') + confirm_password: str = Field(description='确认密码') class UserInfoSchemaBase(SchemaBase): - dept_id: int | None = None - username: str - nickname: str - email: EmailStr = Field(examples=['user@example.com']) - phone: CustomPhoneNumber | None = None + """用户信息基础模型""" + + dept_id: int | None = Field(default=None, description='部门 ID') + username: str = Field(description='用户名') + nickname: str = Field(description='昵称') + email: EmailStr = Field(examples=['user@example.com'], description='邮箱') + phone: CustomPhoneNumber | None = Field(default=None, description='手机号') class UpdateUserParam(UserInfoSchemaBase): - pass + """更新用户参数""" class UpdateUserRoleParam(SchemaBase): - roles: list[int] + """更新用户角色参数""" + + roles: list[int] = Field(description='角色 ID 列表') class AvatarParam(SchemaBase): + """更新头像参数""" + url: HttpUrl = Field(description='头像 http 地址') class GetUserInfoDetail(UserInfoSchemaBase): + """用户信息详情""" + model_config = ConfigDict(from_attributes=True) - dept_id: int | None = None - id: int - uuid: str - avatar: str | None = None - status: StatusType = Field(default=StatusType.enable) - is_superuser: bool - is_staff: bool - is_multi_login: bool - join_time: datetime = None - last_login_time: datetime | None = None + dept_id: int | None = Field(default=None, description='部门 ID') + id: int = Field(description='用户 ID') + uuid: str = Field(description='用户 UUID') + avatar: str | None = Field(default=None, description='头像') + status: StatusType = Field(default=StatusType.enable, description='状态') + is_superuser: bool = Field(description='是否超级管理员') + is_staff: bool = Field(description='是否管理员') + is_multi_login: bool = Field(description='是否允许多端登录') + join_time: datetime = Field(description='加入时间') + last_login_time: datetime | None = Field(default=None, description='最后登录时间') class GetUserInfoWithRelationDetail(GetUserInfoDetail): + """用户信息关联详情""" + model_config = ConfigDict(from_attributes=True) - dept: GetDeptDetail | None = None - roles: list[GetRoleWithRelationDetail] + dept: GetDeptDetail | None = Field(default=None, description='部门信息') + roles: list[GetRoleWithRelationDetail] = Field(description='角色列表') class GetCurrentUserInfoWithRelationDetail(GetUserInfoWithRelationDetail): + """当前用户信息关联详情""" + model_config = ConfigDict(from_attributes=True) - dept: str | None = None - roles: list[str] + dept: str | None = Field(default=None, description='部门名称') + roles: list[str] = Field(description='角色名称列表') @model_validator(mode='before') @classmethod def handel(cls, data: Any) -> Self: - """处理部门和角色""" + """处理部门和角色数据""" dept = data['dept'] if dept: data['dept'] = dept['name'] diff --git a/backend/app/admin/schema/user_social.py b/backend/app/admin/schema/user_social.py index 464d150f..9e057719 100644 --- a/backend/app/admin/schema/user_social.py +++ b/backend/app/admin/schema/user_social.py @@ -1,21 +1,27 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from pydantic import Field + from backend.common.enums import UserSocialType from backend.common.schema import SchemaBase class UserSocialSchemaBase(SchemaBase): - source: UserSocialType - open_id: str | None = None - uid: str | None = None - union_id: str | None = None - scope: str | None = None - code: str | None = None + """用户社交基础模型""" + + source: UserSocialType = Field(description='社交平台') + open_id: str | None = Field(default=None, description='开放平台 ID') + uid: str | None = Field(default=None, description='用户 ID') + union_id: str | None = Field(default=None, description='开放平台唯一 ID') + scope: str | None = Field(default=None, description='授权范围') + code: str | None = Field(default=None, description='授权码') class CreateUserSocialParam(UserSocialSchemaBase): - user_id: int + """创建用户社交参数""" + + user_id: int = Field(description='用户 ID') class UpdateUserSocialParam(SchemaBase): - pass + """更新用户社交参数""" diff --git a/backend/app/admin/tests/utils/db.py b/backend/app/admin/tests/utils/db.py index 203d3a54..9c03c017 100644 --- a/backend/app/admin/tests/utils/db.py +++ b/backend/app/admin/tests/utils/db.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import AsyncGenerator + +from sqlalchemy.ext.asyncio.session import AsyncSession from backend.database.db import create_async_engine_and_session, create_database_url @@ -8,7 +11,7 @@ _, async_test_db_session = create_async_engine_and_session(TEST_SQLALCHEMY_DATABASE_URL) -async def override_get_db(): +async def override_get_db() -> AsyncGenerator[AsyncSession, None]: """session 生成器""" async with async_test_db_session() as session: yield session diff --git a/backend/app/generator/crud/crud_gen.py b/backend/app/generator/crud/crud_gen.py index 5fe1ae9b..d789da26 100644 --- a/backend/app/generator/crud/crud_gen.py +++ b/backend/app/generator/crud/crud_gen.py @@ -9,8 +9,17 @@ class CRUDGen: + """代码生成 CRUD 类""" + @staticmethod - async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]: + async def get_all_tables(db: AsyncSession, table_schema: str) -> list[str]: + """ + 获取所有表名 + + :param db: 数据库会话 + :param table_schema: 数据库 schema 名称 + :return: + """ if settings.DATABASE_TYPE == 'mysql': sql = """ SELECT table_name AS table_name FROM information_schema.tables @@ -30,6 +39,13 @@ async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]: @staticmethod async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]: + """ + 获取表信息 + + :param db: 数据库会话 + :param table_name: 表名 + :return: + """ if settings.DATABASE_TYPE == 'mysql': sql = """ SELECT table_name AS table_name, table_comment AS table_comment FROM information_schema.tables @@ -51,6 +67,14 @@ async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]: @staticmethod async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[Row[tuple]]: + """ + 获取所有列信息 + + :param db: 数据库会话 + :param table_schema: 数据库 schema 名称 + :param table_name: 表名 + :return: + """ if settings.DATABASE_TYPE == 'mysql': sql = """ SELECT column_name AS column_name, diff --git a/backend/app/generator/crud/crud_gen_business.py b/backend/app/generator/crud/crud_gen_business.py index a2e95c6b..e93f1e05 100644 --- a/backend/app/generator/crud/crud_gen_business.py +++ b/backend/app/generator/crud/crud_gen_business.py @@ -10,12 +10,14 @@ class CRUDGenBusiness(CRUDPlus[GenBusiness]): + """代码生成业务 CRUD 类""" + async def get(self, db: AsyncSession, pk: int) -> GenBusiness | None: """ 获取代码生成业务表 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 主键 ID :return: """ return await self.select_model(db, pk) @@ -24,8 +26,8 @@ async def get_by_name(self, db: AsyncSession, name: str) -> GenBusiness | None: """ 通过 name 获取代码生成业务表 - :param db: - :param name: + :param db: 数据库会话 + :param name: 表名 :return: """ return await self.select_model_by_column(db, table_name_en=name) @@ -34,6 +36,7 @@ async def get_all(self, db: AsyncSession) -> Sequence[GenBusiness]: """ 获取所有代码生成业务表 + :param db: 数据库会话 :return: """ return await self.select_models(db) @@ -42,8 +45,8 @@ async def create(self, db: AsyncSession, obj_in: CreateGenBusinessParam) -> None """ 创建代码生成业务表 - :param db: - :param obj_in: + :param db: 数据库会话 + :param obj_in: 创建参数 :return: """ await self.create_model(db, obj_in) @@ -52,9 +55,8 @@ async def update(self, db: AsyncSession, pk: int, obj_in: UpdateGenBusinessParam """ 更新代码生成业务表 - :param db: - :param pk: - :param obj_in: + :param db: 数据库会话 + :param obj_in: 更新参数 :return: """ return await self.update_model(db, pk, obj_in) @@ -63,8 +65,8 @@ async def delete(self, db: AsyncSession, pk: int) -> int: """ 删除代码生成业务表 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 主键 ID :return: """ return await self.delete_model(db, pk) diff --git a/backend/app/generator/crud/crud_gen_model.py b/backend/app/generator/crud/crud_gen_model.py index b681268c..e1d70324 100644 --- a/backend/app/generator/crud/crud_gen_model.py +++ b/backend/app/generator/crud/crud_gen_model.py @@ -10,10 +10,14 @@ class CRUDGenModel(CRUDPlus[GenModel]): + """代码生成模型 CRUD 类""" + async def get(self, db: AsyncSession, pk: int) -> GenModel | None: """ 获取代码生成模型列 + :param db: 数据库会话 + :param pk: 主键 ID :return: """ return await self.select_model(db, pk) @@ -22,8 +26,8 @@ async def get_all_by_business_id(self, db: AsyncSession, business_id: int) -> Se """ 获取所有代码生成模型列 - :param db: - :param business_id: + :param db: 数据库会话 + :param business_id: 业务 ID :return: """ return await self.select_models_order(db, sort_columns='sort', gen_business_id=business_id) @@ -32,21 +36,21 @@ async def create(self, db: AsyncSession, obj_in: CreateGenModelParam, pd_type: s """ 创建代码生成模型表 - :param db: - :param obj_in: - :param pd_type: + :param db: 数据库会话 + :param obj_in: 创建参数 + :param pd_type: Pydantic 类型 :return: """ await self.create_model(db, obj_in, pd_type=pd_type) async def update(self, db: AsyncSession, pk: int, obj_in: UpdateGenModelParam, pd_type: str | None = None) -> int: """ - 更细代码生成模型表 + 更新代码生成模型表 - :param db: - :param pk: - :param obj_in: - :param pd_type: + :param db: 数据库会话 + :param pk: 主键 ID + :param obj_in: 更新参数 + :param pd_type: Pydantic 类型 :return: """ return await self.update_model(db, pk, obj_in, pd_type=pd_type) @@ -55,8 +59,8 @@ async def delete(self, db: AsyncSession, pk: int) -> int: """ 删除代码生成模型表 - :param db: - :param pk: + :param db: 数据库会话 + :param pk: 主键 ID :return: """ return await self.delete_model(db, pk) diff --git a/backend/app/generator/model/gen_business.py b/backend/app/generator/model/gen_business.py index 15a9dfa5..19fb4b4d 100644 --- a/backend/app/generator/model/gen_business.py +++ b/backend/app/generator/model/gen_business.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import TYPE_CHECKING + from sqlalchemy import String from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.dialects.postgresql import TEXT @@ -7,6 +9,9 @@ from backend.common.model import Base, id_key +if TYPE_CHECKING: + from backend.app.generator.model.gen_model import GenModel + class GenBusiness(Base): """代码生成业务表""" @@ -28,4 +33,4 @@ class GenBusiness(Base): LONGTEXT().with_variant(TEXT, 'postgresql'), default=None, comment='备注' ) # 代码生成业务模型一对多 - gen_model: Mapped[list['GenModel']] = relationship(init=False, back_populates='gen_business') # noqa: F821 + gen_model: Mapped[list['GenModel']] = relationship(init=False, back_populates='gen_business') diff --git a/backend/app/generator/model/gen_model.py b/backend/app/generator/model/gen_model.py index 74905f09..2a592b58 100644 --- a/backend/app/generator/model/gen_model.py +++ b/backend/app/generator/model/gen_model.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Union +from typing import TYPE_CHECKING, Union from sqlalchemy import ForeignKey, String from sqlalchemy.dialects.mysql import LONGTEXT @@ -9,6 +9,9 @@ from backend.common.model import DataClassBase, id_key +if TYPE_CHECKING: + from backend.app.generator.model.gen_business import GenBusiness + class GenModel(DataClassBase): """代码生成模型表""" @@ -32,4 +35,4 @@ class GenModel(DataClassBase): gen_business_id: Mapped[int] = mapped_column( ForeignKey('sys_gen_business.id', ondelete='CASCADE'), default=0, comment='代码生成业务ID' ) - gen_business: Mapped[Union['GenBusiness', None]] = relationship(init=False, back_populates='gen_model') # noqa: F821 + gen_business: Mapped[Union['GenBusiness', None]] = relationship(init=False, back_populates='gen_model') diff --git a/backend/app/generator/schema/gen.py b/backend/app/generator/schema/gen.py index fce60d38..85e18cb9 100644 --- a/backend/app/generator/schema/gen.py +++ b/backend/app/generator/schema/gen.py @@ -6,6 +6,8 @@ class ImportParam(SchemaBase): + """导入参数""" + app: str = Field(description='应用名称,用于代码生成到指定 app') table_name: str = Field(description='数据库表名') table_schema: str = Field(description='数据库名') diff --git a/backend/app/generator/schema/gen_business.py b/backend/app/generator/schema/gen_business.py index 982c6cdd..5c5f894d 100644 --- a/backend/app/generator/schema/gen_business.py +++ b/backend/app/generator/schema/gen_business.py @@ -9,35 +9,40 @@ class GenBusinessSchemaBase(SchemaBase): - app_name: str - table_name_en: str - table_name_zh: str - table_simple_name_zh: str - table_comment: str | None = None - schema_name: str | None = None - default_datetime_column: bool = Field(default=True) - api_version: str = Field(default='v1') - gen_path: str | None = None - remark: str | None = None + """代码生成业务基础模型""" + + app_name: str = Field(description='应用名称(英文)') + table_name_en: str = Field(description='表名称(英文)') + table_name_zh: str = Field(description='表名称(中文)') + table_simple_name_zh: str = Field(description='表名称(中文简称)') + table_comment: str | None = Field(default=None, description='表描述') + schema_name: str | None = Field(default=None, description='Schema 名称 (默认为英文表名称)') + default_datetime_column: bool = Field(default=True, description='是否存在默认时间列') + api_version: str = Field(default='v1', description='代码生成 api 版本,默认为 v1') + gen_path: str | None = Field(default=None, description='代码生成路径(默认为 app 根路径)') + remark: str | None = Field(default=None, description='备注') @model_validator(mode='after') def check_schema_name(self) -> Self: + """检查并设置 schema 名称""" if self.schema_name is None: self.schema_name = self.table_name_en return self class CreateGenBusinessParam(GenBusinessSchemaBase): - pass + """创建代码生成业务参数""" class UpdateGenBusinessParam(GenBusinessSchemaBase): - pass + """更新代码生成业务参数""" class GetGenBusinessDetail(GenBusinessSchemaBase): + """获取代码生成业务详情""" + model_config = ConfigDict(from_attributes=True) - id: int - created_time: datetime - updated_time: datetime | None = None + id: int = Field(description='主键 ID') + created_time: datetime = Field(description='创建时间') + updated_time: datetime | None = Field(default=None, description='更新时间') diff --git a/backend/app/generator/schema/gen_model.py b/backend/app/generator/schema/gen_model.py index 524f96fa..90e489a5 100644 --- a/backend/app/generator/schema/gen_model.py +++ b/backend/app/generator/schema/gen_model.py @@ -7,32 +7,37 @@ class GenModelSchemaBase(SchemaBase): - name: str - comment: str | None = None - type: str - default: str | None = None - sort: int - length: int - is_pk: bool = Field(default=False) - is_nullable: bool = Field(default=False) - gen_business_id: int | None = Field(ge=1) + """代码生成模型基础模型""" + + name: str = Field(description='列名称') + comment: str | None = Field(default=None, description='列描述') + type: str = Field(description='SQLA 模型列类型') + default: str | None = Field(default=None, description='列默认值') + sort: int = Field(description='列排序') + length: int = Field(description='列长度') + is_pk: bool = Field(default=False, description='是否主键') + is_nullable: bool = Field(default=False, description='是否可为空') + gen_business_id: int | None = Field(ge=1, description='代码生成业务ID') @field_validator('type') @classmethod - def type_update(cls, v): + def type_update(cls, v: str) -> str: + """更新列类型""" return sql_type_to_sqlalchemy(v) class CreateGenModelParam(GenModelSchemaBase): - pass + """创建代码生成模型参数""" class UpdateGenModelParam(GenModelSchemaBase): - pass + """更新代码生成模型参数""" class GetGenModelDetail(GenModelSchemaBase): + """获取代码生成模型详情""" + model_config = ConfigDict(from_attributes=True) - id: int - pd_type: str + id: int = Field(description='主键 ID') + pd_type: str = Field(description='列类型对应的 pydantic 类型') diff --git a/backend/app/generator/service/gen_business_service.py b/backend/app/generator/service/gen_business_service.py index ec8ece11..3d04b873 100644 --- a/backend/app/generator/service/gen_business_service.py +++ b/backend/app/generator/service/gen_business_service.py @@ -10,8 +10,16 @@ class GenBusinessService: + """代码生成业务服务类""" + @staticmethod async def get(*, pk: int) -> GenBusiness: + """ + 获取指定 ID 的业务 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: @@ -20,12 +28,18 @@ async def get(*, pk: int) -> GenBusiness: @staticmethod async def get_all() -> Sequence[GenBusiness]: + """获取所有业务""" async with async_db_session() as db: - businesses = await gen_business_dao.get_all(db) - return businesses + return await gen_business_dao.get_all(db) @staticmethod async def create(*, obj: CreateGenBusinessParam) -> None: + """ + 创建业务 + + :param obj: 创建参数 + :return: + """ async with async_db_session.begin() as db: business = await gen_business_dao.get_by_name(db, obj.table_name_en) if business: @@ -34,15 +48,26 @@ async def create(*, obj: CreateGenBusinessParam) -> None: @staticmethod async def update(*, pk: int, obj: UpdateGenBusinessParam) -> int: + """ + 更新业务 + + :param pk: 业务 ID + :param obj: 更新参数 + :return: + """ async with async_db_session.begin() as db: - count = await gen_business_dao.update(db, pk, obj) - return count + return await gen_business_dao.update(db, pk, obj) @staticmethod async def delete(*, pk: int) -> int: + """ + 删除业务 + + :param pk: 业务 ID + :return: + """ async with async_db_session.begin() as db: - count = await gen_business_dao.delete(db, pk) - return count + return await gen_business_dao.delete(db, pk) gen_business_service: GenBusinessService = GenBusinessService() diff --git a/backend/app/generator/service/gen_model_service.py b/backend/app/generator/service/gen_model_service.py index 3d284f3b..d660720d 100644 --- a/backend/app/generator/service/gen_model_service.py +++ b/backend/app/generator/service/gen_model_service.py @@ -12,50 +12,85 @@ class GenModelService: + """代码生成模型服务类""" + @staticmethod async def get(*, pk: int) -> GenModel: + """ + 获取指定 ID 的模型 + + :param pk: 模型 ID + :return: + """ async with async_db_session() as db: - gen_model = await gen_model_dao.get(db, pk) - return gen_model + model = await gen_model_dao.get(db, pk) + if not model: + raise errors.NotFoundError(msg='代码生成模型不存在') + return model @staticmethod async def get_types() -> list[str]: + """获取所有 MySQL 列类型""" types = GenModelMySQLColumnType.get_member_keys() types.sort() return types @staticmethod async def get_by_business(*, business_id: int) -> Sequence[GenModel]: + """ + 获取指定业务的所有模型 + + :param business_id: 业务 ID + :return: + """ async with async_db_session() as db: - gen_models = await gen_model_dao.get_all_by_business_id(db, business_id) - return gen_models + return await gen_model_dao.get_all_by_business_id(db, business_id) @staticmethod async def create(*, obj: CreateGenModelParam) -> None: + """ + 创建模型 + + :param obj: 创建参数 + :return: + """ async with async_db_session.begin() as db: gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id) if obj.name in [gen_model.name for gen_model in gen_models]: raise errors.ForbiddenError(msg='禁止添加相同列到同一模型表') + pd_type = sql_type_to_pydantic(obj.type) await gen_model_dao.create(db, obj, pd_type=pd_type) @staticmethod async def update(*, pk: int, obj: UpdateGenModelParam) -> int: + """ + 更新模型 + + :param pk: 模型 ID + :param obj: 更新参数 + :return: + """ async with async_db_session.begin() as db: model = await gen_model_dao.get(db, pk) if obj.name != model.name: gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id) if obj.name in [gen_model.name for gen_model in gen_models]: raise errors.ForbiddenError(msg='模型列名已存在') + pd_type = sql_type_to_pydantic(obj.type) - count = await gen_model_dao.update(db, pk, obj, pd_type=pd_type) - return count + return await gen_model_dao.update(db, pk, obj, pd_type=pd_type) @staticmethod async def delete(*, pk: int) -> int: + """ + 删除模型 + + :param pk: 模型 ID + :return: + """ async with async_db_session.begin() as db: - count = await gen_model_dao.delete(db, pk) - return count + return await gen_model_dao.delete(db, pk) gen_model_service: GenModelService = GenModelService() diff --git a/backend/app/generator/service/gen_service.py b/backend/app/generator/service/gen_service.py index 2a712fd6..19a480bb 100644 --- a/backend/app/generator/service/gen_service.py +++ b/backend/app/generator/service/gen_service.py @@ -5,7 +5,6 @@ import zipfile from pathlib import Path -from typing import Sequence import aiofiles @@ -20,27 +19,43 @@ from backend.app.generator.schema.gen_model import CreateGenModelParam from backend.app.generator.service.gen_model_service import gen_model_service from backend.common.exception import errors -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH from backend.database.db import async_db_session from backend.utils.gen_template import gen_template from backend.utils.type_conversion import sql_type_to_pydantic class GenService: + """代码生成服务类""" + @staticmethod - async def get_tables(*, table_schema: str) -> Sequence[str]: + async def get_tables(*, table_schema: str) -> list[str]: + """ + 获取指定 schema 下的所有表名 + + :param table_schema: 数据库 schema 名称 + :return: + """ async with async_db_session() as db: return await gen_dao.get_all_tables(db, table_schema) @staticmethod async def import_business_and_model(*, obj: ImportParam) -> None: + """ + 导入业务和模型数据 + + :param obj: 导入参数对象 + :return: + """ async with async_db_session.begin() as db: table_info = await gen_dao.get_table(db, obj.table_name) if not table_info: raise errors.NotFoundError(msg='数据库表不存在') + business_info = await gen_business_dao.get_by_name(db, obj.table_name) if business_info: raise errors.ForbiddenError(msg='已存在相同数据库表业务') + table_name = table_info[0] business_data = { 'app_name': obj.app, @@ -52,6 +67,7 @@ async def import_business_and_model(*, obj: ImportParam) -> None: new_business = GenBusiness(**CreateGenBusinessParam(**business_data).model_dump()) db.add(new_business) await db.flush() + column_info = await gen_dao.get_all_columns(db, obj.table_schema, table_name) for column in column_info: column_type = column[-1].split('(')[0].upper() @@ -70,20 +86,34 @@ async def import_business_and_model(*, obj: ImportParam) -> None: @staticmethod async def render_tpl_code(*, business: GenBusiness) -> dict[str, str]: + """ + 渲染模板代码 + + :param business: 业务对象 + :return: + """ gen_models = await gen_model_service.get_by_business(business_id=business.id) if not gen_models: raise errors.NotFoundError(msg='代码生成模型表为空') + gen_vars = gen_template.get_vars(business, gen_models) - tpl_code_map = {} - for tpl_path in gen_template.get_template_paths(): - tpl_code_map[tpl_path] = await gen_template.get_template(tpl_path).render_async(**gen_vars) - return tpl_code_map + return { + tpl_path: await gen_template.get_template(tpl_path).render_async(**gen_vars) + for tpl_path in gen_template.get_template_paths() + } async def preview(self, *, pk: int) -> dict[str, bytes]: + """ + 预览生成的代码 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') + tpl_code_map = await self.render_tpl_code(business=business) return { tpl.replace('.jinja', '.py') if tpl.startswith('py') else ...: code.encode('utf-8') @@ -92,42 +122,50 @@ async def preview(self, *, pk: int) -> dict[str, bytes]: @staticmethod async def get_generate_path(*, pk: int) -> list[str]: + """ + 获取代码生成路径 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') - gen_path = business.gen_path - if not gen_path: - # 伪加密路径 - gen_path = 'current-backend-app-path' + + gen_path = business.gen_path or 'fba-backend-app-path' target_files = gen_template.get_code_gen_paths(business) - code_gen_paths = [] - for target_file in target_files: - code_gen_paths.append(os.path.join(gen_path, *target_file.split('/')[1:])) - return code_gen_paths + return [os.path.join(gen_path, *target_file.split('/')[1:]) for target_file in target_files] async def generate(self, *, pk: int) -> None: + """ + 生成代码文件 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') + tpl_code_map = await self.render_tpl_code(business=business) - gen_path = business.gen_path - if not gen_path: - gen_path = os.path.join(BasePath, 'app') + gen_path = business.gen_path or os.path.join(BASE_PATH, 'app') + for tpl_path, code in tpl_code_map.items(): code_filepath = os.path.join( gen_path, *gen_template.get_code_gen_path(tpl_path, business).split('/')[1:], ) code_folder = Path(str(code_filepath)).parent - if not code_folder.exists(): - code_folder.mkdir(parents=True, exist_ok=True) + code_folder.mkdir(parents=True, exist_ok=True) + # 写入 init 文件 init_filepath = code_folder.joinpath('__init__.py') if not init_filepath.exists(): async with aiofiles.open(init_filepath, 'w', encoding='utf-8') as f: await f.write(gen_template.init_content) + if 'api' in str(code_folder): # api __init__.py api_init_filepath = code_folder.parent.joinpath('__init__.py') @@ -136,12 +174,14 @@ async def generate(self, *, pk: int) -> None: await f.write(gen_template.init_content) # app __init__.py app_init_filepath = api_init_filepath.parent.joinpath('__init__.py') - if not app_init_filepath: + if not app_init_filepath.exists(): async with aiofiles.open(app_init_filepath, 'w', encoding='utf-8') as f: await f.write(gen_template.init_content) - # 写入代码文件呢 + + # 写入代码文件 async with aiofiles.open(code_filepath, 'w', encoding='utf-8') as f: await f.write(code) + # model init 文件补充 if code_folder.name == 'model': async with aiofiles.open(init_filepath, 'a', encoding='utf-8') as f: @@ -151,33 +191,42 @@ async def generate(self, *, pk: int) -> None: ) async def download(self, *, pk: int) -> io.BytesIO: + """ + 下载生成的代码 + + :param pk: 业务 ID + :return: + """ async with async_db_session() as db: business = await gen_business_dao.get(db, pk) if not business: raise errors.NotFoundError(msg='业务不存在') + bio = io.BytesIO() - zf = zipfile.ZipFile(bio, 'w') - tpl_code_map = await self.render_tpl_code(business=business) - for tpl_path, code in tpl_code_map.items(): - # 写入代码文件 - new_code_path = gen_template.get_code_gen_path(tpl_path, business) - zf.writestr(new_code_path, code) - # 写入 init 文件 - init_filepath = os.path.join(*new_code_path.split('/')[:-1], '__init__.py') - if 'model' not in new_code_path.split('/'): - zf.writestr(init_filepath, gen_template.init_content) - else: - zf.writestr( - init_filepath, - f'{gen_template.init_content}' - f'from backend.app.{business.app_name}.model.{business.table_name_en} ' - f'import {to_pascal(business.table_name_en)}\n', - ) - if 'api' in new_code_path: - # api __init__.py - api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py') - zf.writestr(api_init_filepath, gen_template.init_content) - zf.close() + with zipfile.ZipFile(bio, 'w') as zf: + tpl_code_map = await self.render_tpl_code(business=business) + for tpl_path, code in tpl_code_map.items(): + # 写入代码文件 + new_code_path = gen_template.get_code_gen_path(tpl_path, business) + zf.writestr(new_code_path, code) + + # 写入 init 文件 + init_filepath = os.path.join(*new_code_path.split('/')[:-1], '__init__.py') + if 'model' not in new_code_path.split('/'): + zf.writestr(init_filepath, gen_template.init_content) + else: + zf.writestr( + init_filepath, + f'{gen_template.init_content}' + f'from backend.app.{business.app_name}.model.{business.table_name_en} ' + f'import {to_pascal(business.table_name_en)}\n', + ) + + if 'api' in new_code_path: + # api __init__.py + api_init_filepath = os.path.join(*new_code_path.split('/')[:-2], '__init__.py') + zf.writestr(api_init_filepath, gen_template.init_content) + bio.seek(0) return bio diff --git a/backend/app/task/api/router.py b/backend/app/task/api/router.py index 18cb9ba6..1b5bbf6d 100644 --- a/backend/app/task/api/router.py +++ b/backend/app/task/api/router.py @@ -5,6 +5,6 @@ from backend.app.task.api.v1.task import router as task_router from backend.core.conf import settings -v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH) +v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH, tags=['任务']) -v1.include_router(task_router, prefix='/tasks', tags=['任务']) +v1.include_router(task_router, prefix='/tasks') diff --git a/backend/app/task/api/v1/task.py b/backend/app/task/api/v1/task.py index de27ef7d..5f66c5ec 100644 --- a/backend/app/task/api/v1/task.py +++ b/backend/app/task/api/v1/task.py @@ -27,7 +27,7 @@ async def get_all_tasks() -> ResponseSchemaModel[list[str]]: description='此接口被视为作废,建议使用 flower 查看任务详情', dependencies=[DependsJwtAuth], ) -async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> ResponseSchemaModel[TaskResult]: +async def get_task_detail(tid: Annotated[str, Path(description='任务 UUID')]) -> ResponseSchemaModel[TaskResult]: status = task_service.get_detail(tid=tid) return response_base.success(data=status) @@ -40,7 +40,7 @@ async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> DependsRBAC, ], ) -async def revoke_task(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel: +async def revoke_task(tid: Annotated[str, Path(description='任务 UUID')]) -> ResponseModel: task_service.revoke(tid=tid) return response_base.success() diff --git a/backend/app/task/celery.py b/backend/app/task/celery.py index 8750d0d4..4442a873 100644 --- a/backend/app/task/celery.py +++ b/backend/app/task/celery.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any + import celery import celery_aio_pool @@ -9,61 +11,64 @@ __all__ = ['celery_app'] -def init_celery() -> celery.Celery: - """初始化 celery 应用""" - - # TODO: Update this work if celery version >= 6.0.0 - # https://github.com/fastapi-practices/fastapi_best_architecture/issues/321 - # https://github.com/celery/celery/issues/7874 - celery.app.trace.build_tracer = celery_aio_pool.build_async_tracer - celery.app.trace.reset_worker_optimizations() - - # Celery Schedule Tasks - # https://docs.celeryq.dev/en/stable/userguide/periodic-tasks.html - beat_schedule = task_settings.CELERY_SCHEDULE - - # Celery Config - # https://docs.celeryq.dev/en/stable/userguide/configuration.html - broker_url = ( - ( +def get_broker_url() -> str: + """获取消息代理 URL""" + if task_settings.CELERY_BROKER == 'redis': + return ( f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:' f'{settings.REDIS_PORT}/{task_settings.CELERY_BROKER_REDIS_DATABASE}' ) - if task_settings.CELERY_BROKER == 'redis' - else ( - f'amqp://{task_settings.RABBITMQ_USERNAME}:{task_settings.RABBITMQ_PASSWORD}@' - f'{task_settings.RABBITMQ_HOST}:{task_settings.RABBITMQ_PORT}' - ) + return ( + f'amqp://{task_settings.RABBITMQ_USERNAME}:{task_settings.RABBITMQ_PASSWORD}@' + f'{task_settings.RABBITMQ_HOST}:{task_settings.RABBITMQ_PORT}' ) - result_backend = ( + + +def get_result_backend() -> str: + """获取结果后端 URL""" + return ( f'redis://:{settings.REDIS_PASSWORD}@{settings.REDIS_HOST}:' f'{settings.REDIS_PORT}/{task_settings.CELERY_BACKEND_REDIS_DATABASE}' ) - result_backend_transport_options = { - 'global_keyprefix': f'{task_settings.CELERY_BACKEND_REDIS_PREFIX}', + + +def get_result_backend_transport_options() -> dict[str, Any]: + """获取结果后端传输选项""" + return { + 'global_keyprefix': task_settings.CELERY_BACKEND_REDIS_PREFIX, 'retry_policy': { 'timeout': task_settings.CELERY_BACKEND_REDIS_TIMEOUT, }, } + +def init_celery() -> celery.Celery: + """初始化 Celery 应用""" + + # TODO: Update this work if celery version >= 6.0.0 + # https://github.com/fastapi-practices/fastapi_best_architecture/issues/321 + # https://github.com/celery/celery/issues/7874 + celery.app.trace.build_tracer = celery_aio_pool.build_async_tracer + celery.app.trace.reset_worker_optimizations() + app = celery.Celery( 'fba_celery', enable_utc=False, timezone=settings.DATETIME_TIMEZONE, - beat_schedule=beat_schedule, - broker_url=broker_url, + beat_schedule=task_settings.CELERY_SCHEDULE, + broker_url=get_broker_url(), broker_connection_retry_on_startup=True, - result_backend=result_backend, - result_backend_transport_options=result_backend_transport_options, + result_backend=get_result_backend(), + result_backend_transport_options=get_result_backend_transport_options(), task_cls='app.task.celery_task.base:TaskBase', task_track_started=True, ) - # Load task modules + # 自动发现任务模块 app.autodiscover_tasks(task_settings.CELERY_TASK_PACKAGES) return app -# 创建 celery 实例 +# 创建 Celery 实例 celery_app: celery.Celery = init_celery() diff --git a/backend/app/task/celery_task/base.py b/backend/app/task/celery_task/base.py index fcca1e03..4e21f27c 100644 --- a/backend/app/task/celery_task/base.py +++ b/backend/app/task/celery_task/base.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any from celery import Task from sqlalchemy.exc import SQLAlchemyError @@ -9,16 +10,37 @@ class TaskBase(Task): - """任务基类""" + """Celery 任务基类""" autoretry_for = (SQLAlchemyError,) max_retries = task_settings.CELERY_TASK_MAX_RETRIES - async def before_start(self, task_id, args, kwargs): + async def before_start(self, task_id: str, args, kwargs) -> None: + """ + 任务开始前执行钩子 + + :param task_id: 任务 ID + :return: + """ await task_notification(msg=f'任务 {task_id} 开始执行') - async def on_success(self, retval, task_id, args, kwargs): + async def on_success(self, retval: Any, task_id: str, args, kwargs) -> None: + """ + 任务成功后执行钩子 + + :param retval: 任务返回值 + :param task_id: 任务 ID + :return: + """ await task_notification(msg=f'任务 {task_id} 执行成功') - async def on_failure(self, exc, task_id, args, kwargs, einfo): + async def on_failure(self, exc: Exception, task_id: str, args, kwargs, einfo) -> None: + """ + 任务失败后执行钩子 + + :param exc: 异常对象 + :param task_id: 任务 ID + :param einfo: 异常信息 + :return: + """ await task_notification(msg=f'任务 {task_id} 执行失败') diff --git a/backend/app/task/celery_task/tasks.py b/backend/app/task/celery_task/tasks.py index 231fd0d6..08d79e29 100644 --- a/backend/app/task/celery_task/tasks.py +++ b/backend/app/task/celery_task/tasks.py @@ -7,5 +7,6 @@ @celery_app.task(name='task_demo_async') async def task_demo_async() -> str: + """异步示例任务,模拟耗时操作""" await sleep(20) return 'test async' diff --git a/backend/app/task/conf.py b/backend/app/task/conf.py index 73fb98ae..d6210b65 100644 --- a/backend/app/task/conf.py +++ b/backend/app/task/conf.py @@ -1,35 +1,32 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from functools import lru_cache -from typing import Literal +from typing import Any, Literal from celery.schedules import crontab from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH class TaskSettings(BaseSettings): - """Task Settings""" + """Celery 任务配置""" - model_config = SettingsConfigDict(env_file=f'{BasePath}/.env', env_file_encoding='utf-8', extra='ignore') + model_config = SettingsConfigDict(env_file=f'{BASE_PATH}/.env', env_file_encoding='utf-8', extra='ignore') - # Env Config - ENVIRONMENT: Literal['dev', 'pro'] - - # Env Celery - CELERY_BROKER_REDIS_DATABASE: int # 仅在 dev 模式时生效 + # .env Redis 配置 + CELERY_BROKER_REDIS_DATABASE: int CELERY_BACKEND_REDIS_DATABASE: int - # Env Rabbitmq + # .env RabbitMQ 配置 # docker run -d --hostname fba-mq --name fba-mq -p 5672:5672 -p 15672:15672 rabbitmq:latest RABBITMQ_HOST: str RABBITMQ_PORT: int RABBITMQ_USERNAME: str RABBITMQ_PASSWORD: str - # Celery + # Celery 基础配置 CELERY_BROKER: Literal['rabbitmq', 'redis'] = 'redis' CELERY_BACKEND_REDIS_PREFIX: str = 'fba:celery:' CELERY_BACKEND_REDIS_TIMEOUT: int = 5 @@ -38,7 +35,9 @@ class TaskSettings(BaseSettings): 'app.task.celery_task.db_log', ] CELERY_TASK_MAX_RETRIES: int = 5 - CELERY_SCHEDULE: dict = { + + # Celery 定时任务配置 + CELERY_SCHEDULE: dict[str, dict[str, Any]] = { 'exec-every-10-seconds': { 'task': 'task_demo_async', 'schedule': 10, @@ -55,7 +54,8 @@ class TaskSettings(BaseSettings): @model_validator(mode='before') @classmethod - def validate_celery_broker(cls, values): + def validate_celery_broker(cls, values: dict[str, Any]) -> dict[str, Any]: + """生产环境强制使用 RabbitMQ 作为消息代理""" if values['ENVIRONMENT'] == 'pro': values['CELERY_BROKER'] = 'rabbitmq' return values @@ -63,7 +63,7 @@ def validate_celery_broker(cls, values): @lru_cache def get_task_settings() -> TaskSettings: - """获取 task 配置""" + """获取 Celery 任务配置""" return TaskSettings() diff --git a/backend/app/task/schema/task.py b/backend/app/task/schema/task.py index 850e5860..af3a268d 100644 --- a/backend/app/task/schema/task.py +++ b/backend/app/task/schema/task.py @@ -1,23 +1,29 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from typing import Any + from pydantic import Field from backend.common.schema import SchemaBase class RunParam(SchemaBase): + """任务运行参数""" + name: str = Field(description='任务名称') - args: list | None = Field(default=None, description='任务函数位置参数') - kwargs: dict | None = Field(default=None, description='任务函数关键字参数') + args: list[Any] | None = Field(default=None, description='任务函数位置参数') + kwargs: dict[str, Any] | None = Field(default=None, description='任务函数关键字参数') class TaskResult(SchemaBase): - result: str - traceback: str - status: str - name: str - args: list | None - kwargs: dict | None - worker: str - retries: int | None - queue: str | None + """任务执行结果""" + + result: str = Field(description='任务执行结果') + traceback: str = Field(description='错误堆栈信息') + status: str = Field(description='任务状态') + name: str = Field(description='任务名称') + args: list[Any] | None = Field(default=None, description='任务函数位置参数') + kwargs: dict[str, Any] | None = Field(default=None, description='任务函数关键字参数') + worker: str = Field(description='执行任务的 worker') + retries: int | None = Field(default=None, description='重试次数') + queue: str | None = Field(default=None, description='任务队列') diff --git a/backend/app/task/service/task_service.py b/backend/app/task/service/task_service.py index a9f74054..5aeebfc4 100644 --- a/backend/app/task/service/task_service.py +++ b/backend/app/task/service/task_service.py @@ -13,14 +13,21 @@ class TaskService: @staticmethod async def get_list() -> list[str]: + """获取所有已注册的 Celery 任务列表""" registered_tasks = await run_in_threadpool(celery_app.control.inspect().registered) if not registered_tasks: - raise errors.ForbiddenError(msg='celery 服务未启动') + raise errors.ForbiddenError(msg='Celery 服务未启动') tasks = list(registered_tasks.values())[0] return tasks @staticmethod def get_detail(*, tid: str) -> TaskResult: + """ + 获取指定任务的详细信息 + + :param tid: 任务 UUID + :return: + """ try: result = AsyncResult(id=tid, app=celery_app) except NotRegistered: @@ -38,7 +45,13 @@ def get_detail(*, tid: str) -> TaskResult: ) @staticmethod - def revoke(*, tid: str): + def revoke(*, tid: str) -> None: + """ + 撤销指定的任务 + + :param tid: 任务 UUID + :return: + """ try: result = AsyncResult(id=tid, app=celery_app) except NotRegistered: @@ -47,6 +60,12 @@ def revoke(*, tid: str): @staticmethod def run(*, obj: RunParam) -> str: + """ + 运行指定的任务 + + :param obj: 任务运行参数 + :return: + """ task: AsyncResult = celery_app.send_task(name=obj.name, args=obj.args, kwargs=obj.kwargs) return task.task_id diff --git a/backend/core/conf.py b/backend/core/conf.py index 28b70e92..fe45d106 100644 --- a/backend/core/conf.py +++ b/backend/core/conf.py @@ -6,97 +6,100 @@ from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict -from backend.core.path_conf import BasePath +from backend.core.path_conf import BASE_PATH class Settings(BaseSettings): - """Global Settings""" + """全局配置""" - model_config = SettingsConfigDict(env_file=f'{BasePath}/.env', env_file_encoding='utf-8', extra='ignore') + model_config = SettingsConfigDict( + env_file=f'{BASE_PATH}/.env', + env_file_encoding='utf-8', + extra='ignore', + case_sensitive=True, + ) - # Env Config + # .env 环境配置 ENVIRONMENT: Literal['dev', 'pro'] - # Env Database Type + # .env 数据库配置 DATABASE_TYPE: Literal['mysql', 'postgresql'] - - # Env Database DATABASE_HOST: str DATABASE_PORT: int DATABASE_USER: str DATABASE_PASSWORD: str - # Env Redis + # .env Redis 配置 REDIS_HOST: str REDIS_PORT: int REDIS_PASSWORD: str REDIS_DATABASE: int - # Env Token + # .env Token 配置 TOKEN_SECRET_KEY: str # 密钥 secrets.token_urlsafe(32) - # Env Opera Log + # .env 操作日志加密密钥 OPERA_LOG_ENCRYPT_SECRET_KEY: str # 密钥 os.urandom(32), 需使用 bytes.hex() 方法转换为 str - # FastAPI - FASTAPI_API_V1_PATH: str = '/api/v1' - FASTAPI_TITLE: str = 'FastAPI' - FASTAPI_VERSION: str = '0.0.1' - FASTAPI_DESCRIPTION: str = 'FastAPI Best Architecture' - FASTAPI_DOCS_URL: str = '/docs' - FASTAPI_REDOC_URL: str = '/redoc' - FASTAPI_OPENAPI_URL: str | None = '/openapi' - FASTAPI_STATIC_FILES: bool = True - - # Upload - UPLOAD_READ_SIZE: int = 1024 # 上传文件时分片读取大小 - UPLOAD_IMAGE_EXT_INCLUDE: list[str] = ['jpg', 'jpeg', 'png', 'gif', 'webp'] - UPLOAD_IMAGE_SIZE_MAX: int = 1024 * 1024 * 5 - UPLOAD_VIDEO_EXT_INCLUDE: list[str] = ['mp4', 'mov', 'avi', 'flv'] - UPLOAD_VIDEO_SIZE_MAX: int = 1024 * 1024 * 20 - - # Database + # 数据库配置 DATABASE_ECHO: bool = False DATABASE_POOL_ECHO: bool = False DATABASE_SCHEMA: str = 'fba' DATABASE_CHARSET: str = 'utf8mb4' - # Redis + # Redis 配置 REDIS_TIMEOUT: int = 5 - # Socketio - WS_NO_AUTH_MARKER: str = 'internal' - - # Token - TOKEN_ALGORITHM: str = 'HS256' # 算法 - TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒 - TOKEN_REFRESH_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # refresh token 过期时间,单位:秒 + # Token 配置 + TOKEN_ALGORITHM: str = 'HS256' + TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 # 1 天 + TOKEN_REFRESH_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 7 天 TOKEN_REDIS_PREFIX: str = 'fba:token' TOKEN_EXTRA_INFO_REDIS_PREFIX: str = 'fba:token_extra_info' TOKEN_ONLINE_REDIS_PREFIX: str = 'fba:token_online' TOKEN_REFRESH_REDIS_PREFIX: str = 'fba:refresh_token' - TOKEN_REQUEST_PATH_EXCLUDE: list[str] = [ # JWT / RBAC 白名单 - f'{FASTAPI_API_V1_PATH}/auth/login', + TOKEN_REQUEST_PATH_EXCLUDE: list[str] = [ # JWT / RBAC 路由白名单 + '/api/v1/auth/login', ] - # JWT + # JWT 配置 JWT_USER_REDIS_PREFIX: str = 'fba:user' - JWT_USER_REDIS_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 + JWT_USER_REDIS_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 7 天 - # RBAC + # RBAC 配置 RBAC_ROLE_MENU_MODE: bool = False RBAC_ROLE_MENU_EXCLUDE: list[str] = [ 'sys:monitor:redis', 'sys:monitor:server', ] - # Cookies + # Cookie 配置 COOKIE_REFRESH_TOKEN_KEY: str = 'fba_refresh_token' - COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS: int = TOKEN_REFRESH_EXPIRE_SECONDS + COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # 7 天 + + # FastAPI 配置 + FASTAPI_API_V1_PATH: str = '/api/v1' + FASTAPI_TITLE: str = 'FastAPI' + FASTAPI_VERSION: str = '0.0.1' + FASTAPI_DESCRIPTION: str = 'FastAPI Best Architecture' + FASTAPI_DOCS_URL: str = '/docs' + FASTAPI_REDOC_URL: str = '/redoc' + FASTAPI_OPENAPI_URL: str | None = '/openapi' + FASTAPI_STATIC_FILES: bool = True + + # Socketio 配置 + WS_NO_AUTH_MARKER: str = 'internal' + + # 文件上传配置 + UPLOAD_READ_SIZE: int = 1024 + UPLOAD_IMAGE_EXT_INCLUDE: list[str] = ['jpg', 'jpeg', 'png', 'gif', 'webp'] + UPLOAD_IMAGE_SIZE_MAX: int = 5 * 1024 * 1024 # 5 MB + UPLOAD_VIDEO_EXT_INCLUDE: list[str] = ['mp4', 'mov', 'avi', 'flv'] + UPLOAD_VIDEO_SIZE_MAX: int = 20 * 1024 * 1024 # 20 MB - # Log + # 日志配置 LOG_CID_DEFAULT_VALUE: str = '-' - LOG_CID_UUID_LENGTH: int = 32 # must <= 32 + LOG_CID_UUID_LENGTH: int = 32 # 日志 correlation_id 长度,必须小于等于 32 LOG_STD_LEVEL: str = 'INFO' LOG_ACCESS_FILE_LEVEL: str = 'INFO' LOG_ERROR_FILE_LEVEL: str = 'ERROR' @@ -111,51 +114,51 @@ class Settings(BaseSettings): LOG_ACCESS_FILENAME: str = 'fba_access.log' LOG_ERROR_FILENAME: str = 'fba_error.log' - # Middleware + # 中间件配置 MIDDLEWARE_CORS: bool = True MIDDLEWARE_ACCESS: bool = True - # Trace ID + # 追踪 ID 配置 TRACE_ID_REQUEST_HEADER_KEY: str = 'X-Request-ID' - # CORS - CORS_ALLOWED_ORIGINS: list[str] = [ + # CORS 配置 + CORS_ALLOWED_ORIGINS: list[str] = [ # 末尾不带斜杠 'http://127.0.0.1:8000', - 'http://localhost:5173', # 前端地址,末尾不要带 '/' + 'http://localhost:5173', ] CORS_EXPOSE_HEADERS: list[str] = [ - TRACE_ID_REQUEST_HEADER_KEY, + 'X-Request-ID', ] - # DateTime + # 时间配置 DATETIME_TIMEZONE: str = 'Asia/Shanghai' DATETIME_FORMAT: str = '%Y-%m-%d %H:%M:%S' - # Request limiter + # 请求限制配置 REQUEST_LIMITER_REDIS_PREFIX: str = 'fba:limiter' - # Demo mode (Only GET, OPTIONS requests are allowed) + # 演示模式配置 DEMO_MODE: bool = False DEMO_MODE_EXCLUDE: set[tuple[str, str]] = { - ('POST', f'{FASTAPI_API_V1_PATH}/auth/login'), - ('POST', f'{FASTAPI_API_V1_PATH}/auth/logout'), - ('GET', f'{FASTAPI_API_V1_PATH}/auth/captcha'), + ('POST', '/api/v1/auth/login'), + ('POST', '/api/v1/auth/logout'), + ('GET', '/api/v1/auth/captcha'), } - # Ip location + # IP 定位配置 IP_LOCATION_PARSE: Literal['online', 'offline', 'false'] = 'offline' IP_LOCATION_REDIS_PREFIX: str = 'fba:ip:location' - IP_LOCATION_EXPIRE_SECONDS: int = 60 * 60 * 24 * 1 # 过期时间,单位:秒 + IP_LOCATION_EXPIRE_SECONDS: int = 60 * 60 * 24 # 1 天 - # Opera log + # 操作日志配置 OPERA_LOG_PATH_EXCLUDE: list[str] = [ '/favicon.ico', - FASTAPI_DOCS_URL, - FASTAPI_REDOC_URL, - FASTAPI_OPENAPI_URL, - f'{FASTAPI_API_V1_PATH}/auth/login/swagger', - f'{FASTAPI_API_V1_PATH}/oauth2/github/callback', - f'{FASTAPI_API_V1_PATH}/oauth2/linux-do/callback', + '/docs', + '/redoc', + '/openapi', + '/api/v1/auth/login/swagger', + '/api/v1/oauth2/github/callback', + '/api/v1/oauth2/linux-do/callback', ] OPERA_LOG_ENCRYPT_TYPE: int = 1 # 0: AES (性能损耗); 1: md5; 2: ItsDangerous; 3: 不加密, others: 替换为 ****** OPERA_LOG_ENCRYPT_KEY_INCLUDE: list[str] = [ # 将加密接口入参参数对应的值 @@ -165,10 +168,8 @@ class Settings(BaseSettings): 'confirm_password', ] - # Data permission - DATA_PERMISSION_MODELS: dict[ - str, str - ] = { # 允许进行数据过滤的 SQLA 模型,它必须以模块字符串的方式定义(它应该只用于前台数据,这里只是为了演示) + # 数据权限配置 + DATA_PERMISSION_MODELS: dict[str, str] = { # 允许进行数据过滤的 SQLA 模型,它必须以模块字符串的方式定义 'Api': 'backend.plugin.casbin.model.Api', } DATA_PERMISSION_COLUMN_EXCLUDE: list[str] = [ # 排除允许进行数据过滤的 SQLA 模型列 @@ -178,14 +179,15 @@ class Settings(BaseSettings): 'updated_time', ] - # Plugin + # 插件配置 PLUGIN_PIP_CHINA: bool = True PLUGIN_PIP_INDEX_URL: str = 'https://mirrors.aliyun.com/pypi/simple/' @model_validator(mode='before') @classmethod def check_env(cls, values: Any) -> Any: - if values['ENVIRONMENT'] == 'pro': + """生产环境下禁用 OpenAPI 文档和静态文件服务""" + if values.get('ENVIRONMENT') == 'pro': values['FASTAPI_OPENAPI_URL'] = None values['FASTAPI_STATIC_FILES'] = False return values @@ -193,9 +195,9 @@ def check_env(cls, values: Any) -> Any: @lru_cache def get_settings() -> Settings: - """获取全局配置""" + """获取全局配置单例""" return Settings() -# 创建配置实例 +# 创建全局配置实例 settings = get_settings() diff --git a/backend/core/path_conf.py b/backend/core/path_conf.py index edbc271a..903e1340 100644 --- a/backend/core/path_conf.py +++ b/backend/core/path_conf.py @@ -1,30 +1,27 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import os - from pathlib import Path -# 获取项目根目录 -# 或使用绝对路径,指到backend目录为止,例如windows:BasePath = D:\git_project\fastapi_mysql\backend -BasePath = Path(__file__).resolve().parent.parent +# 项目根目录 +BASE_PATH = Path(__file__).resolve().parent.parent # alembic 迁移文件存放路径 -ALEMBIC_VERSION_DIR = os.path.join(BasePath, 'alembic', 'versions') +ALEMBIC_VERSION_DIR = BASE_PATH / 'alembic' / 'versions' # 日志文件路径 -LOG_DIR = os.path.join(BasePath, 'log') - -# 离线 IP 数据库路径 -IP2REGION_XDB = os.path.join(BasePath, 'static', 'ip2region.xdb') +LOG_DIR = BASE_PATH / 'log' # 静态资源目录 -STATIC_DIR = os.path.join(BasePath, 'static') +STATIC_DIR = BASE_PATH / 'static' # 上传文件目录 -UPLOAD_DIR = os.path.join(BasePath, 'static', 'upload') +UPLOAD_DIR = STATIC_DIR / 'upload' # jinja2 模版文件路径 -JINJA2_TEMPLATE_DIR = os.path.join(BasePath, 'templates') +JINJA2_TEMPLATE_DIR = BASE_PATH / 'templates' # 插件目录 -PLUGIN_DIR = os.path.join(BasePath, 'plugin') +PLUGIN_DIR = BASE_PATH / 'plugin' + +# 离线 IP 数据库路径 +IP2REGION_XDB = STATIC_DIR / 'ip2region.xdb' diff --git a/backend/core/registrar.py b/backend/core/registrar.py index b81af850..1063db8e 100644 --- a/backend/core/registrar.py +++ b/backend/core/registrar.py @@ -3,6 +3,7 @@ import os from contextlib import asynccontextmanager +from typing import AsyncGenerator import socketio @@ -30,10 +31,11 @@ @asynccontextmanager -async def register_init(app: FastAPI): +async def register_init(app: FastAPI) -> AsyncGenerator[None, None]: """ 启动初始化 + :param app: FastAPI 应用实例 :return: """ # 创建数据库表 @@ -55,8 +57,8 @@ async def register_init(app: FastAPI): await FastAPILimiter.close() -def register_app(): - # FastAPI +def register_app() -> FastAPI: + """注册并配置 FastAPI 应用""" app = FastAPI( title=settings.FASTAPI_TITLE, version=settings.FASTAPI_VERSION, @@ -67,79 +69,69 @@ def register_app(): default_response_class=MsgSpecJSONResponse, lifespan=register_init, ) - - # socketio register_socket_app(app) - - # 日志 register_logger() - - # 静态文件 register_static_file(app) - - # 中间件 register_middleware(app) - - # 路由 register_router(app) - - # 分页 register_page(app) - - # 全局异常处理 register_exception(app) return app def register_logger() -> None: - """ - 系统日志 - - :return: - """ + """配置系统日志""" setup_logging() set_custom_logfile() -def register_static_file(app: FastAPI): +def register_static_file(app: FastAPI) -> None: """ - 静态资源服务,生产应使用 nginx 代理静态资源服务 + 注册静态资源服务,生产环境应使用 nginx 代理静态资源服务 - :param app: + :param app: FastAPI 应用实例 :return: """ # 上传静态资源 if not os.path.exists(UPLOAD_DIR): os.makedirs(UPLOAD_DIR) app.mount('/static/upload', StaticFiles(directory=UPLOAD_DIR), name='upload') + # 固有静态资源 if settings.FASTAPI_STATIC_FILES: app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static') -def register_middleware(app: FastAPI): +def register_middleware(app: FastAPI) -> None: """ - 中间件,执行顺序从下往上 + 注册中间件,执行顺序从下往上 - :param app: + :param app: FastAPI 应用实例 :return: """ # Opera log (required) app.add_middleware(OperaLogMiddleware) + # JWT auth (required) app.add_middleware( - AuthenticationMiddleware, backend=JwtAuthMiddleware(), on_error=JwtAuthMiddleware.auth_exception_handler + AuthenticationMiddleware, + backend=JwtAuthMiddleware(), + on_error=JwtAuthMiddleware.auth_exception_handler, ) + # Access log if settings.MIDDLEWARE_ACCESS: from backend.middleware.access_middleware import AccessMiddleware app.add_middleware(AccessMiddleware) + # State app.add_middleware(StateMiddleware) + # Trace ID (required) app.add_middleware(CorrelationIdMiddleware, validator=False) + # CORS: Always at the end if settings.MIDDLEWARE_CORS: from fastapi.middleware.cors import CORSMiddleware @@ -154,11 +146,11 @@ def register_middleware(app: FastAPI): ) -def register_router(app: FastAPI): +def register_router(app: FastAPI) -> None: """ - 路由 + 注册路由 - :param app: FastAPI + :param app: FastAPI 应用实例 :return: """ dependencies = [Depends(demo_site)] if settings.DEMO_MODE else None @@ -166,7 +158,8 @@ def register_router(app: FastAPI): # API plugin_router_inject() - from backend.app.router import router # 必须在插件路由注入后导入 + # 必须在插件路由注入后导入 + from backend.app.router import router app.include_router(router, dependencies=dependencies) @@ -175,21 +168,21 @@ def register_router(app: FastAPI): simplify_operation_ids(app) -def register_page(app: FastAPI): +def register_page(app: FastAPI) -> None: """ - 分页查询 + 注册分页查询功能 - :param app: + :param app: FastAPI 应用实例 :return: """ add_pagination(app) -def register_socket_app(app: FastAPI): +def register_socket_app(app: FastAPI) -> None: """ - socket 应用 + 注册 Socket.IO 应用 - :param app: + :param app: FastAPI 应用实例 :return: """ from backend.common.socketio.server import sio diff --git a/backend/database/db.py b/backend/database/db.py index ef521a99..172a23b6 100644 --- a/backend/database/db.py +++ b/backend/database/db.py @@ -1,6 +1,8 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import sys -from typing import Annotated +from typing import Annotated, AsyncGenerator from uuid import uuid4 from fastapi import Depends @@ -33,7 +35,12 @@ def create_database_url(unittest: bool = False) -> URL: def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: - """创建数据库引擎和 Session""" + """ + 创建数据库引擎和 Session + + :param url: 数据库连接 URL + :return: + """ try: # 数据库引擎 engine = create_async_engine( @@ -58,8 +65,8 @@ def create_async_engine_and_session(url: str | URL) -> tuple[AsyncEngine, async_ return engine, db_session -async def get_db(): - """session 生成器""" +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """获取数据库会话""" async with async_db_session() as session: yield session diff --git a/backend/database/redis.py b/backend/database/redis.py index 7f587d64..081171d4 100644 --- a/backend/database/redis.py +++ b/backend/database/redis.py @@ -10,7 +10,10 @@ class RedisCli(Redis): - def __init__(self): + """Redis 客户端""" + + def __init__(self) -> None: + """初始化 Redis 客户端""" super(RedisCli, self).__init__( host=settings.REDIS_HOST, port=settings.REDIS_PORT, @@ -20,12 +23,8 @@ def __init__(self): decode_responses=True, # 转码 utf-8 ) - async def open(self): - """ - 触发初始化连接 - - :return: - """ + async def open(self) -> None: + """触发初始化连接""" try: await self.ping() except TimeoutError: @@ -38,12 +37,12 @@ async def open(self): log.error('❌ 数据库 redis 连接异常 {}', e) sys.exit() - async def delete_prefix(self, prefix: str, exclude: str | list = None): + async def delete_prefix(self, prefix: str, exclude: str | list[str] | None = None) -> None: """ - 删除指定前缀的所有key + 删除指定前缀的所有 key - :param prefix: - :param exclude: + :param prefix: 前缀 + :param exclude: 排除的 key :return: """ keys = [] diff --git a/backend/middleware/access_middleware.py b/backend/middleware/access_middleware.py index ad3cbedc..61fa38a4 100644 --- a/backend/middleware/access_middleware.py +++ b/backend/middleware/access_middleware.py @@ -11,6 +11,13 @@ class AccessMiddleware(BaseHTTPMiddleware): """请求日志中间件""" async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + 处理请求并记录访问日志 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ start_time = timezone.now() response = await call_next(request) end_time = timezone.now() diff --git a/backend/middleware/jwt_auth_middleware.py b/backend/middleware/jwt_auth_middleware.py index 09356e83..5f04090d 100644 --- a/backend/middleware/jwt_auth_middleware.py +++ b/backend/middleware/jwt_auth_middleware.py @@ -18,7 +18,17 @@ class _AuthenticationError(AuthenticationError): """重写内部认证错误类""" - def __init__(self, *, code: int = None, msg: str = None, headers: dict[str, Any] | None = None): + def __init__( + self, *, code: int | None = None, msg: str | None = None, headers: dict[str, Any] | None = None + ) -> None: + """ + 初始化认证错误 + + :param code: 错误码 + :param msg: 错误信息 + :param headers: 响应头 + :return: + """ self.code = code self.msg = msg self.headers = headers @@ -29,10 +39,22 @@ class JwtAuthMiddleware(AuthenticationBackend): @staticmethod def auth_exception_handler(conn: HTTPConnection, exc: _AuthenticationError) -> Response: - """覆盖内部认证错误处理""" + """ + 覆盖内部认证错误处理 + + :param conn: HTTP 连接对象 + :param exc: 认证错误对象 + :return: + """ return MsgSpecJSONResponse(content={'code': exc.code, 'msg': exc.msg, 'data': None}, status_code=exc.code) async def authenticate(self, request: Request) -> tuple[AuthCredentials, GetUserInfoWithRelationDetail] | None: + """ + 认证请求 + + :param request: FastAPI 请求对象 + :return: + """ token = request.headers.get('Authorization') if not token: return diff --git a/backend/middleware/opera_log_middleware.py b/backend/middleware/opera_log_middleware.py index 747a9642..5cc70716 100644 --- a/backend/middleware/opera_log_middleware.py +++ b/backend/middleware/opera_log_middleware.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from asyncio import create_task +from typing import Any from asgiref.sync import sync_to_async from fastapi import Response @@ -22,7 +23,14 @@ class OperaLogMiddleware(BaseHTTPMiddleware): """操作日志中间件""" - async def dispatch(self, request: Request, call_next) -> Response: + async def dispatch(self, request: Request, call_next: Any) -> Response: + """ + 处理请求并记录操作日志 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ # 排除记录白名单 path = request.url.path if path in settings.OPERA_LOG_PATH_EXCLUDE or not path.startswith(f'{settings.FASTAPI_API_V1_PATH}'): @@ -79,8 +87,14 @@ async def dispatch(self, request: Request, call_next) -> Response: return request_next.response - async def execute_request(self, request: Request, call_next) -> RequestCallNext: - """执行请求""" + async def execute_request(self, request: Request, call_next: Any) -> RequestCallNext: + """ + 执行请求并处理异常 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ code = 200 msg = 'Success' status = StatusType.enable @@ -101,7 +115,14 @@ async def execute_request(self, request: Request, call_next) -> RequestCallNext: @staticmethod def request_exception_handler(request: Request, code: int, msg: str) -> tuple[str, str]: - """请求异常处理器""" + """ + 请求异常处理器 + + :param request: FastAPI 请求对象 + :param code: 错误码 + :param msg: 错误信息 + :return: + """ exception_states = [ '__request_http_exception__', '__request_validation_exception__', @@ -121,8 +142,13 @@ def request_exception_handler(request: Request, code: int, msg: str) -> tuple[st return code, msg @staticmethod - async def get_request_args(request: Request) -> dict: - """获取请求参数""" + async def get_request_args(request: Request) -> dict[str, Any]: + """ + 获取请求参数 + + :param request: FastAPI 请求对象 + :return: + """ args = dict(request.query_params) args.update(request.path_params) # Tip: .body() 必须在 .form() 之前获取 @@ -149,11 +175,11 @@ async def get_request_args(request: Request) -> dict: @staticmethod @sync_to_async - def desensitization(args: dict) -> dict | None: + def desensitization(args: dict[str, Any]) -> dict[str, Any] | None: """ 脱敏处理 - :param args: + :param args: 需要脱敏的参数字典 :return: """ if not args: diff --git a/backend/middleware/state_middleware.py b/backend/middleware/state_middleware.py index b2246b7e..b7307524 100644 --- a/backend/middleware/state_middleware.py +++ b/backend/middleware/state_middleware.py @@ -7,9 +7,16 @@ class StateMiddleware(BaseHTTPMiddleware): - """请求 state 中间件""" + """请求 state 中间件,用于解析和设置请求的附加信息""" async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + """ + 处理请求并设置请求状态信息 + + :param request: FastAPI 请求对象 + :param call_next: 下一个中间件或路由处理函数 + :return: + """ ip_info = await parse_ip_info(request) ua_info = parse_user_agent_info(request) diff --git a/backend/plugin/tools.py b/backend/plugin/tools.py index 842a321b..23872cd3 100644 --- a/backend/plugin/tools.py +++ b/backend/plugin/tools.py @@ -6,6 +6,8 @@ import sys import warnings +from typing import Any + import rtoml from fastapi import APIRouter @@ -17,156 +19,193 @@ class PluginInjectError(Exception): - pass + """插件注入错误""" def get_plugins() -> list[str]: - """获取插件""" + """获取插件列表""" plugin_packages = [] + # 遍历插件目录 for item in os.listdir(PLUGIN_DIR): item_path = os.path.join(PLUGIN_DIR, item) - if os.path.isdir(item_path): - if '__init__.py' in os.listdir(item_path): - plugin_packages.append(item) + # 检查是否为目录且包含 __init__.py 文件 + if os.path.isdir(item_path) and '__init__.py' in os.listdir(item_path): + plugin_packages.append(item) return plugin_packages -def get_plugin_models() -> list: +def get_plugin_models() -> list[type]: """获取插件所有模型类""" classes = [] + + # 获取所有插件 plugins = get_plugins() + + # 遍历插件列表 for plugin in plugins: + # 导入插件的模型模块 module_path = f'backend.plugin.{plugin}.model' module = import_module_cached(module_path) + + # 获取模块中的所有类 for name, obj in inspect.getmembers(module): if inspect.isclass(obj): classes.append(obj) + return classes -def plugin_router_inject() -> None: +def _load_plugin_config(plugin: str) -> dict[str, Any]: """ - 插件路由注入 + 加载插件配置 + :param plugin: 插件名称 :return: """ - plugins = get_plugins() - for plugin in plugins: - toml_path = os.path.join(PLUGIN_DIR, plugin, 'plugin.toml') - if not os.path.exists(toml_path): - raise PluginInjectError(f'插件 {plugin} 缺少 plugin.toml 配置文件,请检查插件是否合法') - - # 获取 plugin.toml 配置 - with open(toml_path, 'r', encoding='utf-8') as f: - data = rtoml.load(f) - api = data.get('api', {}) - - # 非独立 app - if api: - app_include = data.get('app', {}).get('include', '') - if not app_include: - raise PluginInjectError(f'非独立 app 插件 {plugin} 配置文件存在错误,请检查') - - # 插件中 API 路由文件的路径 - plugin_api_path = os.path.join(PLUGIN_DIR, plugin, 'api') - if not os.path.exists(plugin_api_path): - raise PluginInjectError(f'插件 {plugin} 缺少 api 目录,请检查插件文件是否完整') - - # 将插件路由注入到对应模块的路由中 - for root, _, api_files in os.walk(plugin_api_path): - for file in api_files: - if file.endswith('.py') and file != '__init__.py': - # 解析插件路由配置 - prefix = data.get('api', {}).get(f'{file[:-3]}', {}).get('prefix', '') - tags = data.get('api', {}).get(f'{file[:-3]}', {}).get('tags', []) - - # 获取插件路由模块 - file_path = os.path.join(root, file) - path_to_module_str = os.path.relpath(file_path, PLUGIN_DIR).replace(os.sep, '.')[:-3] - module_path = f'backend.plugin.{path_to_module_str}' - try: - module = import_module_cached(module_path) - except PluginInjectError as e: - raise PluginInjectError(f'导入非独立 app 插件 {plugin} 模块 {module_path} 失败:{e}') from e - plugin_router = getattr(module, 'router', None) - if not plugin_router: - warnings.warn( - f'非独立 app 插件 {plugin} 模块 {module_path} 中没有有效的 router,' - '请检查插件文件是否完整', - FutureWarning, - ) - continue - - # 获取源程序路由模块 - relative_path = os.path.relpath(root, plugin_api_path) - target_module_path = f'backend.app.{app_include}.api.{relative_path.replace(os.sep, ".")}' - try: - target_module = import_module_cached(target_module_path) - except PluginInjectError as e: - raise PluginInjectError(f'导入源程序模块 {target_module_path} 失败:{e}') from e - target_router = getattr(target_module, 'router', None) - if not target_router or not isinstance(target_router, APIRouter): - raise PluginInjectError( - f'非独立 app 插件 {plugin} 模块 {module_path} 中没有有效的 router,' - '请检查插件文件是否完整' - ) - - # 将插件路由注入到目标 router 中 - target_router.include_router( - router=plugin_router, - prefix=prefix, - tags=[tags] if tags else [], - ) - # 独立 app - else: - # 将插件中的路由直接注入到总路由中 - module_path = f'backend.plugin.{plugin}.api.router' + toml_path = os.path.join(PLUGIN_DIR, plugin, 'plugin.toml') + if not os.path.exists(toml_path): + raise PluginInjectError(f'插件 {plugin} 缺少 plugin.toml 配置文件,请检查插件是否合法') + + with open(toml_path, 'r', encoding='utf-8') as f: + return rtoml.load(f) + + +def _inject_extra_router(plugin: str, data: dict[str, Any]) -> None: + """ + 扩展级插件路由注入 + + :param plugin: 插件名称 + :param data: 插件配置数据 + :return: + """ + app_include = data.get('app', {}).get('include', '') + if not app_include: + raise PluginInjectError(f'扩展级插件 {plugin} 配置文件存在错误,请检查') + + plugin_api_path = os.path.join(PLUGIN_DIR, plugin, 'api') + if not os.path.exists(plugin_api_path): + raise PluginInjectError(f'插件 {plugin} 缺少 api 目录,请检查插件文件是否完整') + + for root, _, api_files in os.walk(plugin_api_path): + for file in api_files: + if not (file.endswith('.py') and file != '__init__.py'): + continue + + file_config = data.get('api', {}).get(f'{file[:-3]}', {}) + prefix = file_config.get('prefix', '') + tags = file_config.get('tags', []) + + file_path = os.path.join(root, file) + path_to_module_str = os.path.relpath(file_path, PLUGIN_DIR).replace(os.sep, '.')[:-3] + module_path = f'backend.plugin.{path_to_module_str}' + try: module = import_module_cached(module_path) - except PluginInjectError as e: - raise PluginInjectError(f'导入独立 app 插件 {plugin} 模块 {module_path} 失败:{e}') from e - routers = data.get('app', {}).get('router', []) - if not routers or not isinstance(routers, list): - raise PluginInjectError(f'独立 app 插件 {plugin} 配置文件存在错误,请检查') - for router in routers: - plugin_router = getattr(module, router, None) - if not plugin_router or not isinstance(plugin_router, APIRouter): - raise PluginInjectError( - f'独立 app 插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整' + plugin_router = getattr(module, 'router', None) + if not plugin_router: + warnings.warn( + f'扩展级插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整', + FutureWarning, ) - target_module_path = 'backend.app.router' + continue + + relative_path = os.path.relpath(root, plugin_api_path) + target_module_path = f'backend.app.{app_include}.api.{relative_path.replace(os.sep, ".")}' target_module = import_module_cached(target_module_path) - target_router = getattr(target_module, 'router') + target_router = getattr(target_module, 'router', None) - # 将插件路由注入到目标 router 中 - target_router.include_router(plugin_router) + if not target_router or not isinstance(target_router, APIRouter): + raise PluginInjectError( + f'扩展级插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整' + ) + + target_router.include_router( + router=plugin_router, + prefix=prefix, + tags=[tags] if tags else [], + ) + except Exception as e: + raise PluginInjectError(f'注入扩展级插件 {plugin} 路由失败:{str(e)}') from e + + +def _inject_app_router(plugin: str, data: dict[str, Any]) -> None: + """ + 应用级插件路由注入 + + :param plugin: 插件名称 + :param data: 插件配置数据 + :return: + """ + module_path = f'backend.plugin.{plugin}.api.router' + try: + module = import_module_cached(module_path) + routers = data.get('app', {}).get('router', []) + if not routers or not isinstance(routers, list): + raise PluginInjectError(f'应用级插件 {plugin} 配置文件存在错误,请检查') + + target_module = import_module_cached('backend.app.router') + target_router = getattr(target_module, 'router') + + for router in routers: + plugin_router = getattr(module, router, None) + if not plugin_router or not isinstance(plugin_router, APIRouter): + raise PluginInjectError( + f'应用级插件 {plugin} 模块 {module_path} 中没有有效的 router,请检查插件文件是否完整' + ) + target_router.include_router(plugin_router) + except Exception as e: + raise PluginInjectError(f'注入应用级插件 {plugin} 路由失败:{str(e)}') from e + + +def plugin_router_inject() -> None: + """插件路由注入""" + for plugin in get_plugins(): + try: + data = _load_plugin_config(plugin) + # 基于插件 plugin.toml 配置文件,判断插件类型 + if data.get('api'): + _inject_extra_router(plugin, data) + else: + _inject_app_router(plugin, data) + except Exception as e: + raise PluginInjectError(f'插件 {plugin} 路由注入失败:{str(e)}') from e + + +def _install_plugin_requirements(plugin: str, requirements_file: str) -> None: + """ + 安装单个插件的依赖 + + :param plugin: 插件名称 + :param requirements_file: 依赖文件路径 + :return: + """ + try: + ensurepip_install = [sys.executable, '-m', 'ensurepip', '--upgrade'] + pip_install = [sys.executable, '-m', 'pip', 'install', '-r', requirements_file] + if settings.PLUGIN_PIP_CHINA: + pip_install.extend(['-i', settings.PLUGIN_PIP_INDEX_URL]) + subprocess.check_call(ensurepip_install) + subprocess.check_call(pip_install) + except subprocess.CalledProcessError as e: + raise PluginInjectError(f'插件 {plugin} 依赖安装失败:{e.stderr}') from e def install_requirements() -> None: """安装插件依赖""" - plugins = get_plugins() - for plugin in plugins: + for plugin in get_plugins(): requirements_file = os.path.join(PLUGIN_DIR, plugin, 'requirements.txt') - if not os.path.exists(requirements_file): - continue - else: - try: - ensurepip_install = [sys.executable, '-m', 'ensurepip', '--upgrade'] - pip_install = [sys.executable, '-m', 'pip', 'install', '-r', requirements_file] - if settings.PLUGIN_PIP_CHINA: - pip_install.extend(['-i', settings.PLUGIN_PIP_INDEX_URL]) - subprocess.check_call(ensurepip_install) - subprocess.check_call(pip_install) - except subprocess.CalledProcessError as e: - raise PluginInjectError(f'插件 {plugin} 依赖安装失败:{e.stderr}') from e + if os.path.exists(requirements_file): + _install_plugin_requirements(plugin, requirements_file) async def install_requirements_async() -> None: """ - 异步安装插件依赖(由于 Windows 平台限制,无法实现完美的全异步方案),详情: + 异步安装插件依赖 + + 由于 Windows 平台限制,无法实现完美的全异步方案,详情: https://stackoverflow.com/questions/44633458/why-am-i-getting-notimplementederror-with-async-and-await-on-windows """ await run_in_threadpool(install_requirements) diff --git a/backend/utils/build_tree.py b/backend/utils/build_tree.py index 4a948ad1..c227a37c 100644 --- a/backend/utils/build_tree.py +++ b/backend/utils/build_tree.py @@ -7,7 +7,12 @@ def get_tree_nodes(row: Sequence[RowData]) -> list[dict[str, Any]]: - """获取所有树形结构节点""" + """ + 获取所有树形结构节点 + + :param row: 原始数据行序列 + :return: + """ tree_nodes = select_list_serialize(row) tree_nodes.sort(key=lambda x: x['sort']) return tree_nodes @@ -17,10 +22,10 @@ def traversal_to_tree(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]: """ 通过遍历算法构造树形结构 - :param nodes: + :param nodes: 树节点列表 :return: """ - tree = [] + tree: list[dict[str, Any]] = [] node_dict = {node['id']: node for node in nodes} for node in nodes: @@ -45,16 +50,16 @@ def recursive_to_tree(nodes: list[dict[str, Any]], *, parent_id: int | None = No """ 通过递归算法构造树形结构(性能影响较大) - :param nodes: - :param parent_id: + :param nodes: 树节点列表 + :param parent_id: 父节点 ID,默认为 None 表示根节点 :return: """ - tree = [] + tree: list[dict[str, Any]] = [] for node in nodes: if node['parent_id'] == parent_id: - child_node = recursive_to_tree(nodes, parent_id=node['id']) - if child_node: - node['children'] = child_node + child_nodes = recursive_to_tree(nodes, parent_id=node['id']) + if child_nodes: + node['children'] = child_nodes tree.append(node) return tree @@ -65,9 +70,9 @@ def get_tree_data( """ 获取树形结构数据 - :param row: - :param build_type: - :param parent_id: + :param row: 原始数据行序列 + :param build_type: 构建树形结构的算法类型,默认为遍历算法 + :param parent_id: 父节点 ID,仅在递归算法中使用 :return: """ nodes = get_tree_nodes(row) diff --git a/backend/utils/demo_site.py b/backend/utils/demo_site.py index 4792b133..039577c0 100644 --- a/backend/utils/demo_site.py +++ b/backend/utils/demo_site.py @@ -6,9 +6,13 @@ from backend.core.conf import settings -async def demo_site(request: Request): - """演示站点""" +async def demo_site(request: Request) -> None: + """ + 演示站点 + :param request: FastAPI 请求对象 + :return: + """ method = request.method path = request.url.path if ( diff --git a/backend/utils/encrypt.py b/backend/utils/encrypt.py index 91f6db7c..855b026c 100644 --- a/backend/utils/encrypt.py +++ b/backend/utils/encrypt.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +import hashlib import os from typing import Any @@ -13,9 +14,14 @@ class AESCipher: - def __init__(self, key: bytes | str): + """AES 加密器""" + + def __init__(self, key: bytes | str) -> None: """ + 初始化 AES 加密器 + :param key: 密钥,16/24/32 bytes 或 16 进制字符串 + :return: """ self.key = key if isinstance(key, bytes) else bytes.fromhex(key) @@ -40,7 +46,7 @@ def decrypt(self, ciphertext: bytes | str) -> str: """ AES 解密 - :param ciphertext: 解密前的密文, bytes 或 16 进制字符串 + :param ciphertext: 解密前的密文,bytes 或 16 进制字符串 :return: """ ciphertext = ciphertext if isinstance(ciphertext, bytes) else bytes.fromhex(ciphertext) @@ -55,6 +61,8 @@ def decrypt(self, ciphertext: bytes | str) -> str: class Md5Cipher: + """MD5 加密器""" + @staticmethod def encrypt(plaintext: bytes | str) -> str: """ @@ -63,8 +71,6 @@ def encrypt(plaintext: bytes | str) -> str: :param plaintext: 加密前的明文 :return: """ - import hashlib - md5 = hashlib.md5() if not isinstance(plaintext, bytes): plaintext = str(plaintext).encode('utf-8') @@ -73,15 +79,20 @@ def encrypt(plaintext: bytes | str) -> str: class ItsDCipher: - def __init__(self, key: bytes | str): + """ItsDangerous 加密器""" + + def __init__(self, key: bytes | str) -> None: """ + 初始化 ItsDangerous 加密器 + :param key: 密钥,16/24/32 bytes 或 16 进制字符串 + :return: """ self.key = key if isinstance(key, bytes) else bytes.fromhex(key) def encrypt(self, plaintext: Any) -> str: """ - ItsDangerous 加密 (可能失败,如果 plaintext 无法序列化,则会加密为 MD5) + ItsDangerous 加密 :param plaintext: 加密前的明文 :return: @@ -96,7 +107,7 @@ def encrypt(self, plaintext: Any) -> str: def decrypt(self, ciphertext: str) -> Any: """ - ItsDangerous 解密 (可能失败,如果 ciphertext 无法反序列化,则解密失败, 返回原始密文) + ItsDangerous 解密 :param ciphertext: 解密前的密文 :return: diff --git a/backend/utils/file_ops.py b/backend/utils/file_ops.py index 0e88fc57..04cd5aa8 100644 --- a/backend/utils/file_ops.py +++ b/backend/utils/file_ops.py @@ -14,11 +14,11 @@ from backend.utils.timezone import timezone -def build_filename(file: UploadFile): +def build_filename(file: UploadFile) -> str: """ 构建文件名 - :param file: + :param file: FastAPI 上传文件对象 :return: """ timestamp = int(timezone.now().timestamp()) @@ -32,14 +32,15 @@ def file_verify(file: UploadFile, file_type: FileType) -> None: """ 文件验证 - :param file: - :param file_type: + :param file: FastAPI 上传文件对象 + :param file_type: 文件类型枚举 :return: """ filename = file.filename file_ext = filename.split('.')[-1].lower() if not file_ext: raise errors.ForbiddenError(msg='未知的文件类型') + if file_type == FileType.image: if file_ext not in settings.UPLOAD_IMAGE_EXT_INCLUDE: raise errors.ForbiddenError(msg='此图片格式暂不支持') @@ -52,11 +53,11 @@ def file_verify(file: UploadFile, file_type: FileType) -> None: raise errors.ForbiddenError(msg='视频超出最大限制,请重新选择') -async def upload_file(file: UploadFile): +async def upload_file(file: UploadFile) -> str: """ 上传文件 - :param file: + :param file: FastAPI 上传文件对象 :return: """ filename = build_filename(file) diff --git a/backend/utils/gen_template.py b/backend/utils/gen_template.py index 056a2a0b..5a3acbdd 100644 --- a/backend/utils/gen_template.py +++ b/backend/utils/gen_template.py @@ -13,6 +13,9 @@ class GenTemplate: def __init__(self): + """ + 初始化模板生成器 + """ self.env = Environment( loader=FileSystemLoader(JINJA2_TEMPLATE_DIR), autoescape=select_autoescape(enabled_extensions=['jinja']), @@ -25,18 +28,17 @@ def __init__(self): def get_template(self, jinja_file: str) -> Template: """ - 获取模版文件 + 获取模板文件 - :param jinja_file: + :param jinja_file: Jinja2 模板文件 :return: """ - return self.env.get_template(jinja_file) @staticmethod def get_template_paths() -> list[str]: """ - 获取模版文件路径 + 获取模板文件路径列表 :return: """ @@ -53,26 +55,25 @@ def get_code_gen_paths(business: GenBusiness) -> list[str]: """ 获取代码生成路径列表 - :param business: + :param business: 代码生成业务对象 :return: """ app_name = business.app_name module_name = business.table_name_en - target_files = [ + return [ f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/api/{business.api_version}/{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/crud/crud_{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/model/{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/schema/{module_name}.py', f'{generator_settings.TEMPLATE_BACKEND_DIR_NAME}/{app_name}/service/{module_name}_service.py', ] - return target_files def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str: """ 获取代码生成路径 - :param tpl_path: - :param business: + :param tpl_path: 模板文件路径 + :param business: 代码生成业务对象 :return: """ target_files = self.get_code_gen_paths(business) @@ -80,12 +81,12 @@ def get_code_gen_path(self, tpl_path: str, business: GenBusiness) -> str: return code_gen_path_mapping[tpl_path] @staticmethod - def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict: + def get_vars(business: GenBusiness, models: Sequence[GenModel]) -> dict[str, str | Sequence[GenModel]]: """ - 获取模版变量 + 获取模板变量 - :param business: - :param models: + :param business: 代码生成业务对象 + :param models: 代码生成模型对象列表 :return: """ return { diff --git a/backend/utils/health_check.py b/backend/utils/health_check.py index 2ca4d7c2..1a68a150 100644 --- a/backend/utils/health_check.py +++ b/backend/utils/health_check.py @@ -12,7 +12,7 @@ def ensure_unique_route_names(app: FastAPI) -> None: """ 检查路由名称是否唯一 - :param app: + :param app: FastAPI 应用实例 :return: """ temp_routes = set() @@ -23,13 +23,13 @@ def ensure_unique_route_names(app: FastAPI) -> None: temp_routes.add(route.name) -async def http_limit_callback(request: Request, response: Response, expire: int): +async def http_limit_callback(request: Request, response: Response, expire: int) -> None: """ 请求限制时的默认回调函数 - :param request: - :param response: - :param expire: 剩余毫秒 + :param request: FastAPI 请求对象 + :param response: FastAPI 响应对象 + :param expire: 剩余毫秒数 :return: """ expires = ceil(expire / 1000) diff --git a/backend/utils/import_parse.py b/backend/utils/import_parse.py index 9c50a037..55a3e26e 100644 --- a/backend/utils/import_parse.py +++ b/backend/utils/import_parse.py @@ -3,36 +3,36 @@ import importlib from functools import lru_cache -from typing import Any +from typing import Any, Type, TypeVar from backend.common.exception import errors from backend.common.log import log +T = TypeVar('T') + @lru_cache(maxsize=512) def import_module_cached(module_path: str) -> Any: """ 缓存导入模块 - :param module_path: + :param module_path: 模块路径 :return: """ return importlib.import_module(module_path) -def dynamic_import_data_model(module_path: str) -> Any: +def dynamic_import_data_model(module_path: str) -> Type[T]: """ 动态导入数据模型 - :param module_path: + :param module_path: 模块路径,格式为 'module_path.class_name' :return: """ - module_path, class_or_func = module_path.rsplit('.', 1) - try: + module_path, class_name = module_path.rsplit('.', 1) module = import_module_cached(module_path) - ins = getattr(module, class_or_func) + return getattr(module, class_name) except (ImportError, AttributeError) as e: - log.error(e) + log.error(f'动态导入数据模型失败:{e}') raise errors.ServerError(msg='数据模型列动态解析失败,请联系系统超级管理员') - return ins diff --git a/backend/utils/openapi.py b/backend/utils/openapi.py index fc53e82c..b5db5c13 100644 --- a/backend/utils/openapi.py +++ b/backend/utils/openapi.py @@ -6,9 +6,9 @@ def simplify_operation_ids(app: FastAPI) -> None: """ - 简化操作 ID,以便生成的客户端具有更简单的 api 函数名称 + 简化操作 ID,以便生成的客户端具有更简单的 API 函数名称 - :param app: + :param app: FastAPI 应用实例 :return: """ for route in app.routes: diff --git a/backend/utils/re_verify.py b/backend/utils/re_verify.py index b85f3f9f..b2d8eb22 100644 --- a/backend/utils/re_verify.py +++ b/backend/utils/re_verify.py @@ -3,41 +3,45 @@ import re -def search_string(pattern, text) -> bool: +def search_string(pattern: str, text: str) -> bool: """ 全字段正则匹配 - :param pattern: - :param text: + :param pattern: 正则表达式模式 + :param text: 待匹配的文本 :return: """ - result = re.search(pattern, text) - if result: - return True - else: + if not pattern or not text: return False + result = re.search(pattern, text) + return result is not None + -def match_string(pattern, text) -> bool: +def match_string(pattern: str, text: str) -> bool: """ 从字段开头正则匹配 - :param pattern: - :param text: + :param pattern: 正则表达式模式 + :param text: 待匹配的文本 :return: """ - result = re.match(pattern, text) - if result: - return True - else: + if not pattern or not text: return False + result = re.match(pattern, text) + return result is not None + def is_phone(text: str) -> bool: """ - 检查手机号码 + 检查手机号码格式 - :param text: + :param text: 待检查的手机号码 :return: """ - return match_string(r'^1[3-9]\d{9}$', text) + if not text: + return False + + phone_pattern = r'^1[3-9]\d{9}$' + return match_string(phone_pattern, text) diff --git a/backend/utils/redis_info.py b/backend/utils/redis_info.py index aab12eb7..bc43d3e4 100644 --- a/backend/utils/redis_info.py +++ b/backend/utils/redis_info.py @@ -6,27 +6,54 @@ class RedisInfo: @staticmethod - async def get_info(): + async def get_info() -> dict[str, str]: + """ + 获取 Redis 服务器信息 + + :return: + """ + # 获取原始信息 info = await redis_client.info() - fmt_info = {} + + # 格式化信息 + fmt_info: dict[str, str] = {} for key, value in info.items(): if isinstance(value, dict): - value = ','.join({f'{k}={v}' for k, v in value.items()}) + # 将字典格式化为字符串 + fmt_info[key] = ','.join(f'{k}={v}' for k, v in value.items()) else: - value = str(value) - fmt_info[key] = value + fmt_info[key] = str(value) + + # 添加数据库大小信息 db_size = await redis_client.dbsize() - fmt_info.update({'keys_num': db_size}) - fmt_uptime = server_info.fmt_seconds(fmt_info.get('uptime_in_seconds', 0)) - fmt_info.update({'uptime_in_seconds': fmt_uptime}) + fmt_info['keys_num'] = str(db_size) + + # 格式化运行时间 + uptime = int(fmt_info.get('uptime_in_seconds', '0')) + fmt_info['uptime_in_seconds'] = server_info.fmt_seconds(uptime) + return fmt_info @staticmethod - async def get_stats(): - stats_list = [] + async def get_stats() -> list[dict[str, str]]: + """ + 获取 Redis 命令统计信息 + + :return: + """ + # 获取命令统计信息 command_stats = await redis_client.info('commandstats') - for k, v in command_stats.items(): - stats_list.append({'name': k.split('_')[-1], 'value': str(v.get('calls', ''))}) + + # 格式化统计信息 + stats_list: list[dict[str, str]] = [] + for key, value in command_stats.items(): + if not isinstance(value, dict): + continue + + command_name = key.split('_')[-1] + call_count = str(value.get('calls', '0')) + stats_list.append({'name': command_name, 'value': call_count}) + return stats_list diff --git a/backend/utils/request_parse.py b/backend/utils/request_parse.py index 2ec142c1..5c61d739 100644 --- a/backend/utils/request_parse.py +++ b/backend/utils/request_parse.py @@ -15,28 +15,32 @@ def get_request_ip(request: Request) -> str: - """获取请求的 ip 地址""" + """ + 获取请求的 IP 地址 + + :param request: FastAPI 请求对象 + :return: + """ real = request.headers.get('X-Real-IP') if real: - ip = real - else: - forwarded = request.headers.get('X-Forwarded-For') - if forwarded: - ip = forwarded.split(',')[0] - else: - ip = request.client.host + return real + + forwarded = request.headers.get('X-Forwarded-For') + if forwarded: + return forwarded.split(',')[0] + # 忽略 pytest - if ip == 'testclient': - ip = '127.0.0.1' - return ip + if request.client.host == 'testclient': + return '127.0.0.1' + return request.client.host async def get_location_online(ip: str, user_agent: str) -> dict | None: """ - 在线获取 ip 地址属地,无法保证可用性,准确率较高 + 在线获取 IP 地址属地,无法保证可用性,准确率较高 - :param ip: - :param user_agent: + :param ip: IP 地址 + :param user_agent: 用户代理字符串 :return: """ async with httpx.AsyncClient(timeout=3) as client: @@ -47,16 +51,16 @@ async def get_location_online(ip: str, user_agent: str) -> dict | None: if response.status_code == 200: return response.json() except Exception as e: - log.error(f'在线获取 ip 地址属地失败,错误信息:{e}') + log.error(f'在线获取 IP 地址属地失败,错误信息:{e}') return None @sync_to_async def get_location_offline(ip: str) -> dict | None: """ - 离线获取 ip 地址属地,无法保证准确率,100%可用 + 离线获取 IP 地址属地,无法保证准确率,100% 可用 - :param ip: + :param ip: IP 地址 :return: """ try: @@ -71,23 +75,31 @@ def get_location_offline(ip: str) -> dict | None: 'city': data[3] if data[3] != '0' else None, } except Exception as e: - log.error(f'离线获取 ip 地址属地失败,错误信息:{e}') + log.error(f'离线获取 IP 地址属地失败,错误信息:{e}') return None async def parse_ip_info(request: Request) -> IpInfo: + """ + 解析请求的 IP 信息 + + :param request: FastAPI 请求对象 + :return: + """ country, region, city = None, None, None ip = get_request_ip(request) location = await redis_client.get(f'{settings.IP_LOCATION_REDIS_PREFIX}:{ip}') if location: country, region, city = location.split('|') return IpInfo(ip=ip, country=country, region=region, city=city) + if settings.IP_LOCATION_PARSE == 'online': location_info = await get_location_online(ip, request.headers.get('User-Agent')) elif settings.IP_LOCATION_PARSE == 'offline': location_info = await get_location_offline(ip) else: location_info = None + if location_info: country = location_info.get('country') region = location_info.get('regionName') @@ -101,6 +113,12 @@ async def parse_ip_info(request: Request) -> IpInfo: def parse_user_agent_info(request: Request) -> UserAgentInfo: + """ + 解析请求的用户代理信息 + + :param request: FastAPI 请求对象 + :return: + """ user_agent = request.headers.get('User-Agent') _user_agent = parse(user_agent) os = _user_agent.get_os() diff --git a/backend/utils/serializers.py b/backend/utils/serializers.py index 9d9cd6e2..ad2552e8 100644 --- a/backend/utils/serializers.py +++ b/backend/utils/serializers.py @@ -14,43 +14,38 @@ R = TypeVar('R', bound=RowData) -def select_columns_serialize(row: R) -> dict: +def select_columns_serialize(row: R) -> dict[str, Any]: """ - Serialize SQLAlchemy select table columns, does not contain relational columns + 序列化 SQLAlchemy 查询表的列,不包含关联列 - :param row: + :param row: SQLAlchemy 查询结果行 :return: """ result = {} for column in row.__table__.columns.keys(): - v = getattr(row, column) - if isinstance(v, Decimal): - v = decimal_encoder(v) - result[column] = v + value = getattr(row, column) + if isinstance(value, Decimal): + value = decimal_encoder(value) + result[column] = value return result def select_list_serialize(row: Sequence[R]) -> list[dict[str, Any]]: """ - Serialize SQLAlchemy select list + 序列化 SQLAlchemy 查询列表 - :param row: + :param row: SQLAlchemy 查询结果列表 :return: """ - result = [select_columns_serialize(_) for _ in row] - return result + return [select_columns_serialize(item) for item in row] -def select_as_dict(row: R, use_alias: bool = False) -> dict: +def select_as_dict(row: R, use_alias: bool = False) -> dict[str, Any]: """ - Converting SQLAlchemy select to dict, which can contain relational data, - depends on the properties of the select object itself - - If set use_alias is True, the column name will be returned as alias, - If alias doesn't exist in columns, we don't recommend setting it to True + 将 SQLAlchemy 查询结果转换为字典,可以包含关联数据 - :param row: - :param use_alias: + :param row: SQLAlchemy 查询结果行 + :param use_alias: 是否使用别名作为列名 :return: """ if not use_alias: @@ -70,7 +65,7 @@ def select_as_dict(row: R, use_alias: bool = False) -> dict: class MsgSpecJSONResponse(JSONResponse): """ - JSON response using the high-performance msgspec library to serialize data to JSON. + 使用高性能的 msgspec 库将数据序列化为 JSON 的响应类 """ def render(self, content: Any) -> bytes: diff --git a/backend/utils/server_info.py b/backend/utils/server_info.py index 2b2fd516..37a34404 100644 --- a/backend/utils/server_info.py +++ b/backend/utils/server_info.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- import os import platform import socket @@ -5,7 +7,6 @@ from datetime import datetime, timedelta from datetime import timezone as tz -from typing import List import psutil @@ -14,8 +15,13 @@ class ServerInfo: @staticmethod - def format_bytes(size) -> str: - """格式化字节""" + def format_bytes(size: int | float) -> str: + """ + 格式化字节大小 + + :param size: 字节大小 + :return: + """ factor = 1024 for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: if abs(size) < factor: @@ -25,65 +31,77 @@ def format_bytes(size) -> str: @staticmethod def fmt_seconds(seconds: int) -> str: + """ + 格式化秒数为可读的时间字符串 + + :param seconds: 秒数 + :return: + """ days, rem = divmod(int(seconds), 86400) hours, rem = divmod(rem, 3600) minutes, seconds = divmod(rem, 60) + parts = [] if days: - parts.append('{} 天'.format(days)) + parts.append(f'{days} 天') if hours: - parts.append('{} 小时'.format(hours)) + parts.append(f'{hours} 小时') if minutes: - parts.append('{} 分钟'.format(minutes)) + parts.append(f'{minutes} 分钟') if seconds: - parts.append('{} 秒'.format(seconds)) - if len(parts) == 0: - return '0 秒' - else: - return ' '.join(parts) + parts.append(f'{seconds} 秒') + + return ' '.join(parts) if parts else '0 秒' @staticmethod def fmt_timedelta(td: timedelta) -> str: - """格式化时间差""" + """ + 格式化时间差 + + :param td: 时间差对象 + :return: + """ total_seconds = round(td.total_seconds()) return ServerInfo.fmt_seconds(total_seconds) @staticmethod - def get_cpu_info() -> dict: - """获取 CPU 信息""" + def get_cpu_info() -> dict[str, float | int]: + """ + 获取 CPU 信息 + + :return: + """ cpu_info = {'usage': round(psutil.cpu_percent(percpu=False), 2)} # % # 检查是否是 Apple M系列芯片 if platform.system() == 'Darwin' and 'arm' in platform.machine().lower(): - cpu_info['max_freq'] = 0 - cpu_info['min_freq'] = 0 - cpu_info['current_freq'] = 0 + cpu_info.update({'max_freq': 0, 'min_freq': 0, 'current_freq': 0}) else: try: # CPU 频率信息,最大、最小和当前频率 cpu_freq = psutil.cpu_freq() - cpu_info['max_freq'] = round(cpu_freq.max, 2) # MHz - cpu_info['min_freq'] = round(cpu_freq.min, 2) # MHz - cpu_info['current_freq'] = round(cpu_freq.current, 2) # MHz - except FileNotFoundError: - # 处理无法获取频率的情况 - cpu_info['max_freq'] = 0 - cpu_info['min_freq'] = 0 - cpu_info['current_freq'] = 0 - except AttributeError: - # 处理属性不存在的情况(更安全的做法) - cpu_info['max_freq'] = 0 - cpu_info['min_freq'] = 0 - cpu_info['current_freq'] = 0 + cpu_info.update({ + 'max_freq': round(cpu_freq.max, 2), # MHz + 'min_freq': round(cpu_freq.min, 2), # MHz + 'current_freq': round(cpu_freq.current, 2), # MHz + }) + except (FileNotFoundError, AttributeError): + cpu_info.update({'max_freq': 0, 'min_freq': 0, 'current_freq': 0}) # CPU 逻辑核心数,物理核心数 - cpu_info['logical_num'] = psutil.cpu_count(logical=True) - cpu_info['physical_num'] = psutil.cpu_count(logical=False) + cpu_info.update({ + 'logical_num': psutil.cpu_count(logical=True), + 'physical_num': psutil.cpu_count(logical=False), + }) return cpu_info @staticmethod - def get_mem_info() -> dict: - """获取内存信息""" + def get_mem_info() -> dict[str, float]: + """ + 获取内存信息 + + :return: + """ mem = psutil.virtual_memory() return { 'total': round(mem.total / 1024 / 1024 / 1024, 2), # GB @@ -93,14 +111,19 @@ def get_mem_info() -> dict: } @staticmethod - def get_sys_info() -> dict: - """获取服务器信息""" + def get_sys_info() -> dict[str, str]: + """ + 获取服务器信息 + + :return: + """ try: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sk: sk.connect(('8.8.8.8', 80)) ip = sk.getsockname()[0] except socket.gaierror: ip = '127.0.0.1' + return { 'name': socket.gethostname(), 'ip': ip, @@ -109,8 +132,12 @@ def get_sys_info() -> dict: } @staticmethod - def get_disk_info() -> List[dict]: - """获取磁盘信息""" + def get_disk_info() -> list[dict[str, str]]: + """ + 获取磁盘信息 + + :return: + """ disk_info = [] for disk in psutil.disk_partitions(): usage = psutil.disk_usage(disk.mountpoint) @@ -126,11 +153,16 @@ def get_disk_info() -> List[dict]: return disk_info @staticmethod - def get_service_info(): - """获取服务信息""" + def get_service_info() -> dict[str, str | datetime]: + """ + 获取服务信息 + + :return: + """ process = psutil.Process(os.getpid()) mem_info = process.memory_info() start_time = timezone.f_datetime(datetime.utcfromtimestamp(process.create_time()).replace(tzinfo=tz.utc)) + return { 'name': 'Python3', 'version': platform.python_version(), @@ -140,7 +172,7 @@ def get_service_info(): 'mem_rss': ServerInfo.format_bytes(mem_info.rss), # 常驻内存, 即当前进程实际使用的物理内存 'mem_free': ServerInfo.format_bytes(mem_info.vms - mem_info.rss), # 空闲内存 'startup': start_time, - 'elapsed': f'{ServerInfo.fmt_timedelta(timezone.now() - start_time)}', + 'elapsed': ServerInfo.fmt_timedelta(timezone.now() - start_time), } diff --git a/backend/utils/timezone.py b/backend/utils/timezone.py index 6836d5bb..e91a01f2 100644 --- a/backend/utils/timezone.py +++ b/backend/utils/timezone.py @@ -9,12 +9,18 @@ class TimeZone: - def __init__(self, tz: str = settings.DATETIME_TIMEZONE): + def __init__(self, tz: str = settings.DATETIME_TIMEZONE) -> None: + """ + 初始化时区转换器 + + :param tz: 时区名称,默认为 settings.DATETIME_TIMEZONE + :return: + """ self.tz_info = zoneinfo.ZoneInfo(tz) def now(self) -> datetime: """ - 获取时区时间 + 获取当前时区时间 :return: """ @@ -22,19 +28,19 @@ def now(self) -> datetime: def f_datetime(self, dt: datetime) -> datetime: """ - datetime 时间转时区时间 + 将 datetime 对象转换为当前时区时间 - :param dt: + :param dt: 需要转换的 datetime 对象 :return: """ return dt.astimezone(self.tz_info) def f_str(self, date_str: str, format_str: str = settings.DATETIME_FORMAT) -> datetime: """ - 时间字符串转时区时间 + 将时间字符串转换为当前时区的 datetime 对象 - :param date_str: - :param format_str: + :param date_str: 时间字符串 + :param format_str: 时间格式字符串,默认为 settings.DATETIME_FORMAT :return: """ return datetime.strptime(date_str, format_str).replace(tzinfo=self.tz_info) @@ -42,10 +48,10 @@ def f_str(self, date_str: str, format_str: str = settings.DATETIME_FORMAT) -> da @staticmethod def t_str(dt: datetime, format_str: str = settings.DATETIME_FORMAT) -> str: """ - 时间转时间字符串 + 将 datetime 对象转换为指定格式的时间字符串 - :param dt: - :param format_str: + :param dt: datetime 对象 + :param format_str: 时间格式字符串,默认为 settings.DATETIME_FORMAT :return: """ return dt.strftime(format_str) @@ -53,9 +59,9 @@ def t_str(dt: datetime, format_str: str = settings.DATETIME_FORMAT) -> str: @staticmethod def f_utc(dt: datetime) -> datetime: """ - 时区时间转 UTC(GMT)时区 + 将 datetime 对象转换为 UTC (GMT) 时区时间 - :param dt: + :param dt: 需要转换的 datetime 对象 :return: """ return dt.astimezone(datetime_timezone.utc) diff --git a/backend/utils/trace_id.py b/backend/utils/trace_id.py index a2df0451..4f683b9b 100644 --- a/backend/utils/trace_id.py +++ b/backend/utils/trace_id.py @@ -6,4 +6,10 @@ def get_request_trace_id(request: Request) -> str: + """ + 从请求头中获取追踪 ID + + :param request: FastAPI 请求对象 + :return: + """ return request.headers.get(settings.TRACE_ID_REQUEST_HEADER_KEY) or settings.LOG_CID_DEFAULT_VALUE diff --git a/backend/utils/type_conversion.py b/backend/utils/type_conversion.py index 96500df9..1c6de58d 100644 --- a/backend/utils/type_conversion.py +++ b/backend/utils/type_conversion.py @@ -6,9 +6,9 @@ def sql_type_to_sqlalchemy(typing: str) -> str: """ - Converts a sql type to a SQLAlchemy type. + 将 SQL 类型转换为 SQLAlchemy 类型 - :param typing: + :param typing: SQL 类型字符串 :return: """ if settings.DATABASE_TYPE == 'mysql': @@ -22,17 +22,16 @@ def sql_type_to_sqlalchemy(typing: str) -> str: def sql_type_to_pydantic(typing: str) -> str: """ - Converts a sql type to a pydantic type. + 将 SQL 类型转换为 Pydantic 类型 - :param typing: + :param typing: SQL 类型字符串 :return: """ try: if settings.DATABASE_TYPE == 'mysql': return GenModelMySQLColumnType[typing].value - else: - if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名 - return 'str' - return GenModelPostgreSQLColumnType[typing].value + if typing == 'CHARACTER VARYING': # postgresql 中 DDL VARCHAR 的别名 + return 'str' + return GenModelPostgreSQLColumnType[typing].value except KeyError: return 'str'