1
1
import type { ChakraProps , ComboboxOnChange , ComboboxOption } from '@invoke-ai/ui-library' ;
2
- import { Combobox , FormControl } from '@invoke-ai/ui-library' ;
2
+ import { Combobox , FormControl , Icon } from '@invoke-ai/ui-library' ;
3
3
import { skipToken } from '@reduxjs/toolkit/query' ;
4
4
import { useAppSelector } from 'app/store/storeHooks' ;
5
5
import type { GroupBase } from 'chakra-react-select' ;
@@ -10,12 +10,16 @@ import type { PromptTriggerSelectProps } from 'features/prompt/types';
10
10
import { t } from 'i18next' ;
11
11
import { memo , useCallback , useMemo } from 'react' ;
12
12
import { useTranslation } from 'react-i18next' ;
13
+ import { PiLinkSimple } from 'react-icons/pi' ;
14
+ import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships' ;
13
15
import { useGetModelConfigQuery } from 'services/api/endpoints/models' ;
14
16
import { useEmbeddingModels , useLoRAModels } from 'services/api/hooks/modelsByType' ;
15
17
import { isNonRefinerMainModelConfig } from 'services/api/types' ;
16
18
17
19
const noOptionsMessage = ( ) => t ( 'prompt.noMatchingTriggers' ) ;
18
20
21
+ type RelatedEmbedding = ComboboxOption & { starred ?: boolean } ;
22
+
19
23
export const PromptTriggerSelect = memo ( ( { onSelect, onClose } : PromptTriggerSelectProps ) => {
20
24
const { t } = useTranslation ( ) ;
21
25
@@ -27,6 +31,27 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
27
31
const [ loraModels , { isLoading : isLoadingLoRAs } ] = useLoRAModels ( ) ;
28
32
const [ tiModels , { isLoading : isLoadingTIs } ] = useEmbeddingModels ( ) ;
29
33
34
+ // Get related model keys for current selected models
35
+ const selectedModelKeys = useMemo ( ( ) => {
36
+ const keys : string [ ] = [ ] ;
37
+ if ( mainModel ) {
38
+ keys . push ( mainModel . key ) ;
39
+ }
40
+ for ( const { model } of addedLoRAs ) {
41
+ keys . push ( model . key ) ;
42
+ }
43
+ return keys ;
44
+ } , [ mainModel , addedLoRAs ] ) ;
45
+
46
+ const { relatedModelKeys } = useGetRelatedModelIdsBatchQuery ( selectedModelKeys , {
47
+ selectFromResult : ( { data } ) => {
48
+ if ( ! data ) {
49
+ return { relatedModelKeys : [ ] } ;
50
+ }
51
+ return { relatedModelKeys : data } ;
52
+ } ,
53
+ } ) ;
54
+
30
55
const _onChange = useCallback < ComboboxOnChange > (
31
56
( v ) => {
32
57
if ( ! v ) {
@@ -62,9 +87,25 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
62
87
}
63
88
64
89
if ( tiModels ) {
65
- const embeddingOptions = tiModels
90
+ // Create embedding options with starred property for related models
91
+ const embeddingOptions : RelatedEmbedding [ ] = tiModels
66
92
. filter ( ( ti ) => ti . base === mainModelConfig ?. base )
67
- . map ( ( model ) => ( { label : model . name , value : `<${ model . name } >` } ) ) ;
93
+ . map ( ( model ) => ( {
94
+ label : model . name ,
95
+ value : `<${ model . name } >` ,
96
+ starred : relatedModelKeys . includes ( model . key ) ,
97
+ } ) ) ;
98
+
99
+ // Sort so related embeddings come first
100
+ embeddingOptions . sort ( ( a , b ) => {
101
+ if ( a . starred && ! b . starred ) {
102
+ return - 1 ;
103
+ }
104
+ if ( ! a . starred && b . starred ) {
105
+ return 1 ;
106
+ }
107
+ return 0 ;
108
+ } ) ;
68
109
69
110
if ( embeddingOptions . length > 0 ) {
70
111
_options . push ( {
@@ -85,7 +126,20 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
85
126
}
86
127
87
128
return _options ;
88
- } , [ tiModels , loraModels , mainModelConfig , t , addedLoRAs ] ) ;
129
+ } , [ tiModels , loraModels , mainModelConfig , t , addedLoRAs , relatedModelKeys ] ) ;
130
+
131
+ const formatOptionLabel = useCallback ( ( option : ComboboxOption ) => {
132
+ const embeddingOption = option as RelatedEmbedding ;
133
+ if ( embeddingOption . starred ) {
134
+ return (
135
+ < div style = { { display : 'flex' , alignItems : 'center' , gap : '8px' } } >
136
+ < Icon as = { PiLinkSimple } color = "invokeYellow.500" boxSize = { 3 } />
137
+ { option . label }
138
+ </ div >
139
+ ) ;
140
+ }
141
+ return option . label ;
142
+ } , [ ] ) ;
89
143
90
144
return (
91
145
< FormControl >
@@ -104,6 +158,7 @@ export const PromptTriggerSelect = memo(({ onSelect, onClose }: PromptTriggerSel
104
158
onMenuClose = { onClose }
105
159
data-testid = "add-prompt-trigger"
106
160
sx = { selectStyles }
161
+ formatOptionLabel = { formatOptionLabel }
107
162
/>
108
163
</ FormControl >
109
164
) ;
0 commit comments