15
15
from typing import Any , Optional
16
16
17
17
import numpy
18
+ from pydantic import BaseModel , Field
18
19
19
20
from deepsparse .operators import Operator
20
- from deepsparse .subgraph_execute import StreamingOutput
21
21
from deepsparse .transformers .pipelines .text_generation import TokenGeneratorOperator
22
- from deepsparse .transformers .schemas .text_generation_schemas import (
23
- FinishReason ,
24
- PromptLogitsNoKVCacheInference ,
25
- )
22
+ from deepsparse .transformers .schemas .text_generation_schemas import FinishReason
26
23
from deepsparse .transformers .utils .helpers import set_generated_length
27
24
from deepsparse .utils import InferenceState
28
25
29
26
30
- __all__ = ["PrepareGeneration" ]
27
+ __all__ = ["PrepareGeneration" , "PrepareForGenerationOutput" ]
28
+
29
+
30
+ class PrepareForGenerationOutput (BaseModel ):
31
+ prompt_logits : Any = Field (
32
+ description = "A set of prompt logits generated during prefill"
33
+ )
34
+ kv_cache : Optional [Any ] = Field (description = "kv cache" )
35
+ in_generation : Optional [bool ] = Field (description = "in_generation flag" )
31
36
32
37
33
38
class PrepareGeneration (Operator ):
39
+ output_schema = PrepareForGenerationOutput
40
+
34
41
def __init__ (
35
42
self ,
36
43
token_generator : TokenGeneratorOperator ,
37
44
prompt_sequence_length : int ,
38
45
sequence_length : int ,
39
- process_output_operator : Optional [Operator ] = None ,
40
46
):
41
47
self .sequence_length = sequence_length
42
48
self .token_generator_creator = token_generator
43
49
self .prompt_sequence_length = prompt_sequence_length
44
- # Needed for streaming as currently both setting up generation and generating
45
- # Will split this up soon
46
- self .process_output_operator = process_output_operator
47
50
48
51
def can_operate (self , inp : Any ):
49
52
kv_cache = inp .get ("kv_cache" )
@@ -79,7 +82,6 @@ def run(
79
82
** inference_state .current_state ,
80
83
)
81
84
token_generator = token_generator_creator_output .get ("token_generator" )
82
- token_generator .generate (prompt_logits [0 , - 1 , :])
83
85
84
86
max_tokens , length_finish_reason = set_generated_length (
85
87
max_length = generation_config .max_length ,
@@ -93,43 +95,21 @@ def run(
93
95
state_update = {
94
96
"max_tokens" : max_tokens ,
95
97
"length_finish_reason" : length_finish_reason ,
96
- "generated_tokens" : [token_generator . tokens [ - 1 ] ],
97
- "generated_logits" : [prompt_logits ]
98
+ "generated_tokens" : [],
99
+ "generated_logits" : [prompt_logits [:, 0 : - 1 , :] ]
98
100
if include_prompt_logits
99
- else [numpy . expand_dims ( prompt_logits [:, - 1 , :], 0 ) ],
101
+ else [],
100
102
"finished_reason" : [],
101
103
"token_generator" : token_generator ,
102
104
}
105
+
103
106
if kv_cache is None :
104
- output = PromptLogitsNoKVCacheInference (prompt_logits = prompt_logits )
107
+ output = { "prompt_logits" : numpy . expand_dims (prompt_logits [:, - 1 , :], 0 )}
105
108
else :
106
109
output = {
107
- "tokens" : token_generator .tokens ,
108
110
"kv_cache" : kv_cache ,
109
111
"in_generation" : True ,
112
+ "prompt_logits" : numpy .expand_dims (prompt_logits [:, - 1 , :], 0 ),
110
113
}
111
- # TODO: maybe break this operator up since it is both generating and setting
112
- # up values needed for generation? Holding off on this as this will change
113
- # routes slighty and want to confirm wont break anything for non-kv cache
114
- if inference_state .current_state .get ("streaming" ) and max_tokens >= 1 :
115
- finished_reason = [length_finish_reason ] if max_tokens == 1 else [None ]
116
-
117
- if self .process_output_operator is None :
118
- raise ValueError (
119
- "An operator must be provided to process outputs"
120
- "while streaming."
121
- )
122
- data_to_yield = self .process_output_operator .run (
123
- generated_tokens = state_update .get ("generated_tokens" ),
124
- finished_reason = finished_reason ,
125
- inference_state = inference_state ,
126
- generated_logits = prompt_logits [0 , - 1 , :],
127
- )
128
- output = StreamingOutput (
129
- data_to_yield = self .process_output_operator .output_schema (
130
- ** data_to_yield
131
- ),
132
- data_to_return = output ,
133
- )
134
114
135
115
return output , state_update
0 commit comments