Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 69 additions & 28 deletions ui/ModelModal/src/ModelModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ export const ModelModal: React.FC<ModelModalProps> = ({
messageComponent,
is_close_model_remark = false,
addingModelTutorialURL = 'https://github.com/chaitin/ModelKit/blob/main/docs/AddinModelTutorial.md',
beforeSubmit,
onOk,
}: ModelModalProps) => {
const theme = useTheme();

Expand Down Expand Up @@ -192,7 +194,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
setModelUserList(
(res.models || [])
.filter((item): item is { model: string } => !!item.model)
.sort((a, b) => a.model!.localeCompare(b.model!))
.sort((a, b) => a.model!.localeCompare(b.model!)),
);
if (
data &&
Expand Down Expand Up @@ -334,6 +336,36 @@ export const ModelModal: React.FC<ModelModalProps> = ({
});
};

// 处理提交前的验证逻辑
const handleOk = async (value: AddModelForm) => {
// 如果有 beforeSubmit,先执行验证
if (beforeSubmit) {
try {
// 调用 beforeSubmit,支持同步和异步返回值
const result = beforeSubmit(value);

// 如果返回的是 Promise,等待其 resolve
const shouldSubmit = result instanceof Promise ? await result : result;

// 如果 beforeSubmit 返回 false,不继续执行
if (!shouldSubmit) {
return;
}
} catch (error) {
// 如果 beforeSubmit 抛出异常或 Promise reject,不继续执行
console.error('beforeSubmit error:', error);
return;
}
}

// beforeSubmit 验证通过后,执行 onOk 或 onSubmit
if (onOk) {
onOk(value);
} else {
onSubmit(value);
}
};

const resetCurData = (value: Model) => {
// @ts-ignore
if (value.provider && value.provider !== 'Other') {
Expand All @@ -348,7 +380,10 @@ export const ModelModal: React.FC<ModelModalProps> = ({
api_header_value: value.api_header?.split('=')[1] || '',
model_type,
show_name: value.show_name || '',
resource_name: value.provider === 'AzureOpenAI' ? extractResourceNameFromUrl(value.base_url || '') : '',
resource_name:
value.provider === 'AzureOpenAI'
? extractResourceNameFromUrl(value.base_url || '')
: '',
context_window_size: 64000,
max_output_tokens: 8192,
enable_r1_params: false,
Expand All @@ -367,11 +402,17 @@ export const ModelModal: React.FC<ModelModalProps> = ({
api_header_key: value.api_header?.split('=')[0] || '',
api_header_value: value.api_header?.split('=')[1] || '',
show_name: value.show_name || '',
resource_name: value.provider === 'AzureOpenAI' ? extractResourceNameFromUrl(value.base_url || '') : '',
resource_name:
value.provider === 'AzureOpenAI'
? extractResourceNameFromUrl(value.base_url || '')
: '',
context_window_size: value.param?.context_window || 64000,
max_output_tokens: value.param?.max_tokens || 8192,
enable_r1_params: value.param?.r1_enabled || false,
support_image: model_type === 'analysis-vl' ? true : (value.param?.support_images || false),
support_image:
model_type === 'analysis-vl'
? true
: value.param?.support_images || false,
support_compute: value.param?.support_computer_use || false,
support_prompt_caching: value.param?.support_prompt_cache || false,
});
Expand Down Expand Up @@ -419,7 +460,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
onCancel={handleReset}
cancelText={getLocaleMessage('cancel', language)}
okText={getLocaleMessage('save', language)}
onOk={handleSubmit(onSubmit)}
onOk={handleSubmit(handleOk)}
okButtonProps={{
loading,
disabled: !success && providerBrand !== 'Other',
Expand Down Expand Up @@ -517,7 +558,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
...(providerBrand === it.label && {
bgcolor: addOpacityToColor(
theme.palette.primary.main,
0.1
0.1,
),
color: 'primary.main',
}),
Expand Down Expand Up @@ -552,7 +593,8 @@ export const ModelModal: React.FC<ModelModalProps> = ({
context_window_size: 64000,
max_output_tokens: 8192,
enable_r1_params: false,
support_image: model_type === 'analysis-vl' ? true : false,
support_image:
model_type === 'analysis-vl' ? true : false,
support_compute: false,
support_prompt_caching: false,
});
Expand Down Expand Up @@ -607,12 +649,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
ml: 1,
textAlign: 'right',
}}
onClick={() =>
window.open(
addingModelTutorialURL,
'_blank'
)
}
onClick={() => window.open(addingModelTutorialURL, '_blank')}
>
添加模型教程
</Box>
Expand Down Expand Up @@ -677,7 +714,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
(() => {
const processedUrl = getProcessedUrl(
baseUrl,
providerBrand
providerBrand,
);
if (baseUrl.endsWith('#')) {
return processedUrl;
Expand Down Expand Up @@ -727,7 +764,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
onClick={() =>
window.open(
providers[providerBrand].modelDocumentUrl,
'_blank'
'_blank',
)
}
>
Expand Down Expand Up @@ -839,7 +876,10 @@ export const ModelModal: React.FC<ModelModalProps> = ({
// 动态更新base_url
const resourceName = e.target.value;
if (resourceName) {
setValue('base_url', `https://${resourceName}.openai.azure.com`);
setValue(
'base_url',
`https://${resourceName}.openai.azure.com`,
);
} else {
setValue('base_url', '');
}
Expand Down Expand Up @@ -966,15 +1006,15 @@ export const ModelModal: React.FC<ModelModalProps> = ({
filteredModelList.length > 0
? filteredModelList
: modelUserList.map((item) => ({
model: item.model,
provider: providerBrand,
}));
model: item.model,
provider: providerBrand,
}));

const query = modelSearchQuery.trim().toLowerCase();
const modelsToShow = query
? modelsBase.filter((m) =>
m.model.toLowerCase().includes(query)
)
m.model.toLowerCase().includes(query),
)
: modelsBase;

// 按组分类模型
Expand All @@ -987,7 +1027,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
acc[group].push(model);
return acc;
},
{} as Record<string, typeof modelsToShow>
{} as Record<string, typeof modelsToShow>,
);

// 渲染分组后的模型
Expand Down Expand Up @@ -1073,7 +1113,9 @@ export const ModelModal: React.FC<ModelModalProps> = ({
<Stack direction='column' spacing={1}>
<TextField
value={modelSearchQuery}
onChange={(e) => setModelSearchQuery(e.target.value)}
onChange={(e) =>
setModelSearchQuery(e.target.value)
}
fullWidth
size='small'
placeholder='搜索模型名称'
Expand Down Expand Up @@ -1210,7 +1252,9 @@ export const ModelModal: React.FC<ModelModalProps> = ({
name='support_image'
render={({ field }) => {
const isAnalysisVl = model_type === 'analysis-vl';
const isChecked = isAnalysisVl ? true : field.value;
const isChecked = isAnalysisVl
? true
: field.value;

return (
<FormControlLabel
Expand Down Expand Up @@ -1239,8 +1283,7 @@ export const ModelModal: React.FC<ModelModalProps> = ({
>
{isAnalysisVl
? '(图像分析模型默认启用图片功能)'
: '(支持图片输入的模型可以启用此选项)'
}
: '(支持图片输入的模型可以启用此选项)'}
</Box>
</Box>
}
Expand Down Expand Up @@ -1362,8 +1405,6 @@ export const ModelModal: React.FC<ModelModalProps> = ({
)}
/>
</Box>


</Stack>
</AccordionDetails>
</Accordion>
Expand Down
18 changes: 14 additions & 4 deletions ui/ModelModal/src/types/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,18 @@ export interface UpdateModelReq {

// 模型服务接口
export interface ModelService {
createModel: (data: CreateModelReq) => Promise<{ model: Model; error?: string }>;
listModel: (data: ListModelReq) => Promise<{ models: ModelListItem[]; error?: string }>;
checkModel: (data: CheckModelReq) => Promise<{ model: Model; error?: string }>;
updateModel: (data: UpdateModelReq) => Promise<{ model: Model; error?: string }>;
createModel: (
data: CreateModelReq,
) => Promise<{ model: Model; error?: string }>;
listModel: (
data: ListModelReq,
) => Promise<{ models: ModelListItem[]; error?: string }>;
checkModel: (
data: CheckModelReq,
) => Promise<{ model: Model; error?: string }>;
updateModel: (
data: UpdateModelReq,
) => Promise<{ model: Model; error?: string }>;
}

export interface ModelListItem {
Expand Down Expand Up @@ -177,4 +185,6 @@ export interface ModelModalProps {
messageComponent?: MessageComponent;
is_close_model_remark?: boolean;
addingModelTutorialURL?: string;
beforeSubmit?: (data: AddModelForm) => boolean | Promise<boolean>;
onOk?: (data: AddModelForm) => void;
}
Loading