20
20
from deepsparse .transformers .utils .helpers import set_generated_length
21
21
from deepsparse .v2 .operators import Operator
22
22
from deepsparse .v2 .text_generation import TokenGeneratorOperator
23
+ from deepsparse .v2 .text_generation .nl_engine_operator import NLEngineOutputs
23
24
from deepsparse .v2 .utils import InferenceState
24
25
25
26
@@ -41,10 +42,11 @@ def can_operate(self, inp: Any):
41
42
kv_cache = inp .get ("kv_cache" )
42
43
tokens = inp .get ("tokens" )
43
44
44
- # If the number of prompt tokens is greater than what we've processed,
45
- # don't start generation. Should be equal when started as all prompt logits
46
- # should be accounted for and we should have updated the kv_cache for the single
47
- # token engine.
45
+ # If the number of prompt tokens is greater
46
+ # than what we've processed, don't start generation.
47
+ # Should be equal when started as all prompt logits
48
+ # should be accounted for, and we should have updated
49
+ # the kv_cache for the single token engine.
48
50
if len (tokens ) == kv_cache .total_num_processed_tokens :
49
51
return True
50
52
return False
@@ -90,10 +92,13 @@ def run(
90
92
"finished_reason" : [],
91
93
"token_generator" : token_generator ,
92
94
}
95
+
93
96
output = {
94
- "logits" : prompt_logits ,
95
97
"tokens" : token_generator .tokens ,
96
98
"kv_cache" : kv_cache ,
97
99
"in_generation" : True ,
98
100
}
101
+ if kv_cache is None :
102
+ output = NLEngineOutputs (** output , engine_outputs = prompt_logits )
103
+
99
104
return output , state_update
0 commit comments