Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
请在此处填写插件使用说明和您的联系方式
## AI

如果插件需要付费,请提供付费相关说明

如有配套前端插件,请添加前端插件仓库链接说明

插件开发文档:[fba plugin dev](https://fastapi-practices.github.io/fastapi_best_architecture_docs/plugin/dev.html)
此插件提供了 AI 能力
Empty file added __init__.py
Empty file.
12 changes: 12 additions & 0 deletions api/router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from fastapi import APIRouter

from backend.core.conf import settings
from backend.plugin.ai.api.v1.chat import router as chat_router
from backend.plugin.ai.api.v1.model import router as model_router
from backend.plugin.ai.api.v1.provider import router as provider_router

v1 = APIRouter(prefix=settings.FASTAPI_API_V1_PATH)

v1.include_router(chat_router, prefix='/chat', tags=['AI 文本生成'])
v1.include_router(model_router, prefix='/models', tags=['AI 模型管理'])
v1.include_router(provider_router, prefix='/providers', tags=['AI 供应商管理'])
Empty file added api/v1/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastapi import APIRouter
from starlette.responses import StreamingResponse

from backend.database.db import CurrentSession
from backend.plugin.ai.schema.chat import AIChat
from backend.plugin.ai.service.chat_service import ai_chat_service

router = APIRouter()


@router.post('/completions', summary='文本生成(对话)')
async def completions(db: CurrentSession, chat: AIChat) -> StreamingResponse:
return StreamingResponse(ai_chat_service.stream_messages(db=db, chat=chat), media_type='text/plain')
85 changes: 85 additions & 0 deletions api/v1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Annotated

from fastapi import APIRouter, Depends, Path

from backend.common.pagination import DependsPagination, PageData
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
from backend.common.security.jwt import DependsJwtAuth
from backend.common.security.permission import RequestPermission
from backend.common.security.rbac import DependsRBAC
from backend.database.db import CurrentSession, CurrentSessionTransaction
from backend.plugin.ai.schema.model import (
CreateAIModelParam,
DeleteAIModelParam,
GetAIModelDetail,
UpdateAIModelParam,
)
from backend.plugin.ai.service.model_service import ai_model_service

router = APIRouter()


@router.get('/{pk}', summary='获取模型详情', dependencies=[DependsJwtAuth])
async def get_ai_model(
db: CurrentSession, pk: Annotated[int, Path(description='模型 ID')]
) -> ResponseSchemaModel[GetAIModelDetail]:
ai_model = await ai_model_service.get(db=db, pk=pk)
return response_base.success(data=ai_model)


@router.get(
'',
summary='分页获取所有模型',
dependencies=[
DependsJwtAuth,
DependsPagination,
],
)
async def get_ai_models_paginated(db: CurrentSession) -> ResponseSchemaModel[PageData[GetAIModelDetail]]:
page_data = await ai_model_service.get_list(db=db)
return response_base.success(data=page_data)


@router.post(
'',
summary='创建模型',
dependencies=[
Depends(RequestPermission('ai:model:add')),
DependsRBAC,
],
)
async def create_ai_model(db: CurrentSessionTransaction, obj: CreateAIModelParam) -> ResponseModel:
await ai_model_service.create(db=db, obj=obj)
return response_base.success()


@router.put(
'/{pk}',
summary='更新模型',
dependencies=[
Depends(RequestPermission('ai:model:edit')),
DependsRBAC,
],
)
async def update_ai_model(
db: CurrentSessionTransaction, pk: Annotated[int, Path(description='模型 ID')], obj: UpdateAIModelParam
) -> ResponseModel:
count = await ai_model_service.update(db=db, pk=pk, obj=obj)
if count > 0:
return response_base.success()
return response_base.fail()


@router.delete(
'',
summary='批量删除模型',
dependencies=[
Depends(RequestPermission('ai:model:del')),
DependsRBAC,
],
)
async def delete_ai_models(db: CurrentSessionTransaction, obj: DeleteAIModelParam) -> ResponseModel:
count = await ai_model_service.delete(db=db, obj=obj)
if count > 0:
return response_base.success()
return response_base.fail()
104 changes: 104 additions & 0 deletions api/v1/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Annotated

from fastapi import APIRouter, Depends, Path

from backend.common.pagination import DependsPagination, PageData
from backend.common.response.response_schema import ResponseModel, ResponseSchemaModel, response_base
from backend.common.security.jwt import DependsJwtAuth
from backend.common.security.permission import RequestPermission
from backend.common.security.rbac import DependsRBAC
from backend.database.db import CurrentSession, CurrentSessionTransaction
from backend.plugin.ai.schema.provider import (
CreateAIProviderParam,
DeleteAIProviderParam,
GetAIProviderDetail,
GetAIProviderModelDetail,
UpdateAIProviderParam,
)
from backend.plugin.ai.service.provider_service import ai_provider_service

router = APIRouter()


@router.get('/{pk}', summary='获取供应商详情', dependencies=[DependsJwtAuth])
async def get_ai_provider(
db: CurrentSession, pk: Annotated[int, Path(description='provider ID')]
) -> ResponseSchemaModel[GetAIProviderDetail]:
ai_provider = await ai_provider_service.get(db=db, pk=pk)
return response_base.success(data=ai_provider)


@router.get('/{pk}/models', summary='获取供应商模型列表', dependencies=[DependsJwtAuth])
async def get_ai_provider_models(
db: CurrentSession,
pk: Annotated[int, Path(description='provider ID')],
) -> ResponseSchemaModel[list[GetAIProviderModelDetail]]:
ai_provider_modes = await ai_provider_service.get_models(db=db, pk=pk)
return response_base.success(data=ai_provider_modes)


@router.get('/{pk}/models/sync', summary='同步供应商模型', dependencies=[DependsJwtAuth])
async def sync_ai_provider_models(
db: CurrentSessionTransaction,
pk: Annotated[int, Path(description='provider ID')],
) -> ResponseModel:
await ai_provider_service.sync_models(db=db, pk=pk)
return response_base.success()


@router.get(
'',
summary='分页获取所有供应商',
dependencies=[
DependsJwtAuth,
DependsPagination,
],
)
async def get_ai_providers_paginated(db: CurrentSession) -> ResponseSchemaModel[PageData[GetAIProviderDetail]]:
page_data = await ai_provider_service.get_list(db=db)
return response_base.success(data=page_data)


@router.post(
'',
summary='创建供应商',
dependencies=[
Depends(RequestPermission('ai:provider:add')),
DependsRBAC,
],
)
async def create_ai_provider(db: CurrentSessionTransaction, obj: CreateAIProviderParam) -> ResponseModel:
await ai_provider_service.create(db=db, obj=obj)
return response_base.success()


@router.put(
'/{pk}',
summary='更新供应商',
dependencies=[
Depends(RequestPermission('ai:provider:edit')),
DependsRBAC,
],
)
async def update_ai_provider(
db: CurrentSessionTransaction, pk: Annotated[int, Path(description='供应商 ID')], obj: UpdateAIProviderParam
) -> ResponseModel:
count = await ai_provider_service.update(db=db, pk=pk, obj=obj)
if count > 0:
return response_base.success()
return response_base.fail()


@router.delete(
'',
summary='批量删除供应商',
dependencies=[
Depends(RequestPermission('ai:provider:del')),
DependsRBAC,
],
)
async def delete_ai_providers(db: CurrentSessionTransaction, obj: DeleteAIProviderParam) -> ResponseModel:
count = await ai_provider_service.delete(db=db, obj=obj)
if count > 0:
return response_base.success()
return response_base.fail()
Empty file added crud/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions crud/crud_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from collections.abc import Sequence
from typing import Any

from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus

from backend.plugin.ai.model import AIModel
from backend.plugin.ai.schema.model import CreateAIModelParam, UpdateAIModelParam


class CRUDAIModel(CRUDPlus[AIModel]):
async def get(self, db: AsyncSession, pk: int) -> AIModel | None:
"""
获取模型

:param db: 数据库会话
:param pk: 模型 ID
:return:
"""
return await self.select_model(db, pk)

async def get_by_model_and_provider(self, db: AsyncSession, model_id: str, provider_id: int) -> AIModel | None:
"""
通过模型和供应商获取模型

:param db: 数据库会话
:param model_id: 模型
:param provider_id: 供应商
:return:
"""
return await self.select_model_by_column(db, model_id=model_id, provider_id=provider_id)

async def get_select(self) -> Select:
"""获取模型列表查询表达式"""
return await self.select_order('id', 'desc')

async def get_all(self, db: AsyncSession) -> Sequence[AIModel]:
"""
获取所有模型

:param db: 数据库会话
:return:
"""
return await self.select_models(db)

async def create(self, db: AsyncSession, obj: CreateAIModelParam) -> None:
"""
创建模型

:param db: 数据库会话
:param obj: 创建模型参数
:return:
"""
await self.create_model(db, obj)

async def bulk_create(self, db: AsyncSession, objs: list[dict[str, Any]]) -> None:
"""
批量创建模型

:param db:数据库会话
:param objs: 批量创建模型参数
:return:
"""
await self.bulk_create_models(db, objs)

async def update(self, db: AsyncSession, pk: int, obj: UpdateAIModelParam) -> int:
"""
更新模型

:param db: 数据库会话
:param pk: 模型 ID
:param obj: 更新 模型参数
:return:
"""
return await self.update_model(db, pk, obj)

async def delete(self, db: AsyncSession, pks: list[int]) -> int:
"""
批量删除模型

:param db: 数据库会话
:param pks: 模型 ID 列表
:return:
"""
return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks)

async def delete_by_provider(self, db: AsyncSession, provider_id: int) -> int:
"""
通过供应商 ID 删除模型

:param db: 数据库会话
:param provider_id: 供应商 ID
:return:
"""
return await self.delete_model_by_column(db, allow_multiple=True, provider_id=provider_id)


ai_model_dao: CRUDAIModel = CRUDAIModel(AIModel)
67 changes: 67 additions & 0 deletions crud/crud_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from collections.abc import Sequence

from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_crud_plus import CRUDPlus

from backend.plugin.ai.model import AIProvider
from backend.plugin.ai.schema.provider import CreateAIProviderParam, UpdateAIProviderParam


class CRUDAIProvider(CRUDPlus[AIProvider]):
async def get(self, db: AsyncSession, pk: int) -> AIProvider | None:
"""
获取供应商

:param db: 数据库会话
:param pk: 供应商 ID
:return:
"""
return await self.select_model(db, pk)

async def get_select(self) -> Select:
"""获取供应商列表查询表达式"""
return await self.select_order('id', 'desc')

async def get_all(self, db: AsyncSession) -> Sequence[AIProvider]:
"""
获取所有供应商

:param db: 数据库会话
:return:
"""
return await self.select_models(db)

async def create(self, db: AsyncSession, obj: CreateAIProviderParam) -> None:
"""
创建供应商

:param db: 数据库会话
:param obj: 创建供应商参数
:return:
"""
await self.create_model(db, obj)

async def update(self, db: AsyncSession, pk: int, obj: UpdateAIProviderParam) -> int:
"""
更新供应商

:param db: 数据库会话
:param pk: 供应商 ID
:param obj: 更新 供应商参数
:return:
"""
return await self.update_model(db, pk, obj)

async def delete(self, db: AsyncSession, pks: list[int]) -> int:
"""
批量删除供应商

:param db: 数据库会话
:param pks: 供应商 ID 列表
:return:
"""
return await self.delete_model_by_column(db, allow_multiple=True, id__in=pks)


ai_provider_dao: CRUDAIProvider = CRUDAIProvider(AIProvider)
Loading