Skip to content

Commit 105b1d5

Browse files
committed
simplify after PR review round
1 parent 4397c80 commit 105b1d5

File tree

3 files changed

+14
-11
lines changed

3 files changed

+14
-11
lines changed

src/deepsparse/v2/text_generation/generate_new_token.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def can_operate(self, inp: NLEngineOutputs):
3636
return True
3737
return False
3838

39-
def run(self, *args, inference_state: InferenceState, **kwargs):
40-
logits = args[0].engine_outputs if args else kwargs.get("logits")
41-
kv_cache = args[0].kv_cache if args else kwargs.get("kv_cache")
39+
def run(self, inp: NLEngineOutputs, inference_state: InferenceState, **kwargs):
40+
logits = inp.engine_outputs
41+
kv_cache = inp.kv_cache
4242

4343
token_generator = inference_state.current_state.get("token_generator")
4444
token = token_generator.generate(logits=logits[0, -1, :])

src/deepsparse/v2/text_generation/prep_for_generation.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from deepsparse.transformers.utils.helpers import set_generated_length
2121
from deepsparse.v2.operators import Operator
2222
from deepsparse.v2.text_generation import TokenGeneratorOperator
23+
from deepsparse.v2.text_generation.nl_engine_operator import NLEngineOutputs
2324
from deepsparse.v2.utils import InferenceState
2425

2526

@@ -41,10 +42,11 @@ def can_operate(self, inp: Any):
4142
kv_cache = inp.get("kv_cache")
4243
tokens = inp.get("tokens")
4344

44-
# If the number of prompt tokens is greater than what we've processed,
45-
# don't start generation. Should be equal when started as all prompt logits
46-
# should be accounted for and we should have updated the kv_cache for the single
47-
# token engine.
45+
# If the number of prompt tokens is greater
46+
# than what we've processed, don't start generation.
47+
# Should be equal when started as all prompt logits
48+
# should be accounted for, and we should have updated
49+
# the kv_cache for the single token engine.
4850
if len(tokens) == kv_cache.total_num_processed_tokens:
4951
return True
5052
return False
@@ -90,10 +92,13 @@ def run(
9092
"finished_reason": [],
9193
"token_generator": token_generator,
9294
}
95+
9396
output = {
94-
"logits": prompt_logits,
9597
"tokens": token_generator.tokens,
9698
"kv_cache": kv_cache,
9799
"in_generation": True,
98100
}
101+
if kv_cache is None:
102+
output = NLEngineOutputs(**output, engine_outputs=prompt_logits)
103+
99104
return output, state_update

tests/deepsparse/v2/unit/text_generation/test_token_generation.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ def test_generate_new_token(
9393
in_generation=True,
9494
)
9595
outputs, state = generate_new_token.run(
96-
logits=inp.engine_outputs,
97-
kv_cache=inp.kv_cache,
98-
inference_state=mock_inference_state,
96+
inp=inp, inference_state=mock_inference_state
9997
)
10098
# The new_token generated/returned by ths operator should match the last token in
10199
# token_generator

0 commit comments

Comments
 (0)