Skip to content

Commit 085db1b

Browse files
Added related model support
1 parent 28b5dca commit 085db1b

File tree

1 file changed

+59
-4
lines changed

1 file changed

+59
-4
lines changed

invokeai/frontend/web/src/features/prompt/PromptTriggerSelect.tsx

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
2-
import { Combobox, FormControl } from '@invoke-ai/ui-library';
2+
import { Combobox, FormControl, Icon } from '@invoke-ai/ui-library';
33
import { skipToken } from '@reduxjs/toolkit/query';
44
import { useAppSelector } from 'app/store/storeHooks';
55
import type { GroupBase } from 'chakra-react-select';
@@ -10,12 +10,16 @@ import type { PromptTriggerSelectProps } from 'features/prompt/types';
1010
import { t } from 'i18next';
1111
import { memo, useCallback, useMemo } from 'react';
1212
import { useTranslation } from 'react-i18next';
13+
import { PiLinkSimple } from 'react-icons/pi';
14+
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
1315
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
1416
import { useEmbeddingModels, useLoRAModels } from 'services/api/hooks/modelsByType';
1517
import { isNonRefinerMainModelConfig } from 'services/api/types';
1618

1719
const noOptionsMessage = () => t('prompt.noMatchingTriggers');
1820

21+
type RelatedEmbedding = ComboboxOption & { starred?: boolean };
22+
1923
export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSelectProps) => {
2024
const { t } = useTranslation();
2125

@@ -27,6 +31,27 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
2731
const [loraModels, { isLoading: isLoadingLoRAs }] = useLoRAModels();
2832
const [tiModels, { isLoading: isLoadingTIs }] = useEmbeddingModels();
2933

34+
// Get related model keys for current selected models
35+
const selectedModelKeys = useMemo(() => {
36+
const keys: string[] = [];
37+
if (mainModel) {
38+
keys.push(mainModel.key);
39+
}
40+
for (const { model } of addedLoRAs) {
41+
keys.push(model.key);
42+
}
43+
return keys;
44+
}, [mainModel, addedLoRAs]);
45+
46+
const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery(selectedModelKeys, {
47+
selectFromResult: ({ data }) => {
48+
if (!data) {
49+
return { relatedModelKeys: [] };
50+
}
51+
return { relatedModelKeys: data };
52+
},
53+
});
54+
3055
const _onChange = useCallback<ComboboxOnChange>(
3156
(v) => {
3257
if (!v) {
@@ -62,9 +87,25 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
6287
}
6388

6489
if (tiModels) {
65-
const embeddingOptions = tiModels
90+
// Create embedding options with starred property for related models
91+
const embeddingOptions: RelatedEmbedding[] = tiModels
6692
.filter((ti) => ti.base === mainModelConfig?.base)
67-
.map((model) => ({ label: model.name, value: `<${model.name}>` }));
93+
.map((model) => ({
94+
label: model.name,
95+
value: `<${model.name}>`,
96+
starred: relatedModelKeys.includes(model.key),
97+
}));
98+
99+
// Sort so related embeddings come first
100+
embeddingOptions.sort((a, b) => {
101+
if (a.starred && !b.starred) {
102+
return -1;
103+
}
104+
if (!a.starred && b.starred) {
105+
return 1;
106+
}
107+
return 0;
108+
});
68109

69110
if (embeddingOptions.length > 0) {
70111
_options.push({
@@ -85,7 +126,20 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
85126
}
86127

87128
return _options;
88-
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs]);
129+
}, [tiModels, loraModels, mainModelConfig, t, addedLoRAs, relatedModelKeys]);
130+
131+
const formatOptionLabel = useCallback((option: ComboboxOption) => {
132+
const embeddingOption = option as RelatedEmbedding;
133+
if (embeddingOption.starred) {
134+
return (
135+
<div style={{ display: 'flex', alignItems: 'center', gap: '8px' }}>
136+
<Icon as={PiLinkSimple} color="invokeYellow.500" boxSize={3} />
137+
{option.label}
138+
</div>
139+
);
140+
}
141+
return option.label;
142+
}, []);
89143

90144
return (
91145
<FormControl>
@@ -104,6 +158,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
104158
onMenuClose={onClose}
105159
data-testid="add-prompt-trigger"
106160
sx={selectStyles}
161+
formatOptionLabel={formatOptionLabel}
107162
/>
108163
</FormControl>
109164
);

0 commit comments

Comments
 (0)