Skip to content

Commit fc8330d

Browse files
Merge branch 'main' into luka/better-legacy-pipeline-warning
2 parents 9346df5 + e29211c commit fc8330d

15 files changed

+180
-82
lines changed

src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py

+6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def can_operate(self, inp: Any) -> bool:
5151
if inp.get("in_generation"):
5252
return True
5353

54+
if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
55+
raise RuntimeError(
56+
"Not enough kv_cache capacity to run generation. Please use a larger "
57+
"sequence_length or a shorter prompt"
58+
)
59+
5460
remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens
5561
can_process = (
5662
remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length

src/deepsparse/transformers/pipelines/text_generation/compile_generations.py

-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pydantic import BaseModel, Field
1818

1919
from deepsparse.operators import Operator
20-
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
2120
from deepsparse.utils import InferenceState
2221

2322

@@ -43,9 +42,6 @@ def run(self, inference_state: InferenceState, **kwargs):
4342
generated_logits = inference_state.current_state.get("generated_logits")
4443
finished_reason = inference_state.current_state.get("finished_reason")
4544

46-
if len(finished_reason) == 0:
47-
finished_reason.append(FinishReason.LENGTH)
48-
4945
generated_tokens = numpy.array([generated_tokens])
5046
generated_logits = numpy.concatenate(generated_logits, axis=1)
5147
return {

src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
from deepsparse.transformers.pipelines.text_generation.nl_engine_operator import (
2020
NLEngineOutputs,
2121
)
22-
from deepsparse.transformers.schemas.text_generation_schemas import (
23-
FinishReason,
24-
PromptLogitsNoKVCacheInference,
25-
)
22+
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
2623
from deepsparse.utils import InferenceState
2724

2825

@@ -36,14 +33,16 @@ def __init__(
3633
self.force_max_tokens = force_max_tokens
3734
self.tokenizer = tokenizer
3835

39-
def can_operate(self, inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs]):
36+
def can_operate(
37+
self, inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"] # noqa: F821
38+
):
4039
if inp.in_generation:
4140
return True
4241
return False
4342

4443
def run(
4544
self,
46-
inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs],
45+
inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"], # noqa: F821
4746
inference_state: InferenceState,
4847
**kwargs,
4948
):
@@ -52,21 +51,26 @@ def run(
5251
if isinstance(inp, NLEngineOutputs)
5352
else inp.prompt_logits
5453
)
55-
kv_cache = inp.kv_cache if isinstance(inp, NLEngineOutputs) else None
54+
kv_cache = inp.kv_cache
55+
56+
max_tokens = inference_state.current_state.get("max_tokens")
57+
length_finish_reason = inference_state.current_state.get("length_finish_reason")
58+
generated_tokens = inference_state.current_state.get("generated_tokens")
59+
num_generated_tokens = len(generated_tokens)
5660

5761
token_generator = inference_state.current_state.get("token_generator")
5862
token = token_generator.generate(logits=logits[0, -1, :])
5963
finish_reason = None
6064

61-
callback = inference_state.current_state.get("callback")
62-
stop = inference_state.current_state.get("stop")
63-
6465
if (
6566
kv_cache is not None
6667
and kv_cache.total_num_processed_tokens >= kv_cache.capacity
6768
):
6869
finish_reason = FinishReason.CAPACITY
6970

71+
callback = inference_state.current_state.get("callback")
72+
stop = inference_state.current_state.get("stop")
73+
7074
if token == self.tokenizer.eos_token_id and not self.force_max_tokens:
7175
finish_reason = FinishReason.STOP
7276

@@ -84,9 +88,11 @@ def run(
8488
)
8589
finish_reason = FinishReason.CALLBACK
8690

87-
max_tokens = inference_state.current_state.get("max_tokens")
88-
if len(inference_state.current_state.get("generated_tokens")) + 1 >= max_tokens:
89-
finish_reason = inference_state.current_state.get("length_finish_reason")
91+
# Note: this is +1 as the inference state variable keeping track of all the
92+
# generated tokens has not yet been updated with the most recently generated
93+
# token from this operator
94+
if num_generated_tokens + 1 == max_tokens:
95+
finish_reason = length_finish_reason
9096

9197
state_update = {
9298
"token_generator": token_generator,

src/deepsparse/transformers/pipelines/text_generation/multi_engine_prefill_operator.py

+6
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def can_operate(self, inp: Any):
4242
kv_cache = inp.get("kv_cache")
4343
tokens = inp.get("tokens")
4444

45+
if kv_cache.total_num_processed_tokens >= kv_cache.capacity:
46+
raise RuntimeError(
47+
"Not enough kv_cache capacity to run generation. Please use a larger "
48+
"sequence_length or a shorter prompt"
49+
)
50+
4551
if len(tokens) < self.prompt_sequence_length:
4652
return False
4753

src/deepsparse/transformers/pipelines/text_generation/pipeline.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def __init__(
239239
sequence_length=sequence_length,
240240
prompt_sequence_length=prompt_sequence_length,
241241
token_generator=token_generator,
242-
process_output_operator=process_output,
243242
)
244243

245244
# TODO: do we want to support lists for different engines?
@@ -286,7 +285,7 @@ def __init__(
286285
"compile_logits",
287286
"generate_new_token",
288287
],
289-
"prep_for_generation": "autoregressive_preprocess",
288+
"prep_for_generation": "generate_new_token",
290289
"generate_new_token": "compile_generated_tokens",
291290
}
292291

src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from deepsparse.routers import GraphRouter
2020
from deepsparse.schedulers import OperatorScheduler
2121
from deepsparse.transformers.pipelines.text_generation import (
22+
CompileGeneratedTokens,
2223
CompileGenerations,
2324
GenerateNewTokenOperator,
2425
JoinOutput,
@@ -73,6 +74,7 @@ def __init__(
7374
tokenizer=self.tokenizer, force_max_tokens=True
7475
)
7576
compile_generations = CompileGenerations()
77+
compile_generated_tokens = CompileGeneratedTokens()
7678
join_output = JoinOutput(tokenizer=self.tokenizer)
7779
process_outputs = ProcessOutputs(tokenizer=self.tokenizer)
7880

@@ -82,6 +84,7 @@ def __init__(
8284
"engine_operator": engine_operator,
8385
"prepare_generation": prepare_generation,
8486
"generate_new_token": generate_new_token,
87+
"compile_generated_tokens": compile_generated_tokens,
8588
"compile_generations": compile_generations,
8689
"join_output": join_output,
8790
"process_outputs": process_outputs,
@@ -92,7 +95,8 @@ def __init__(
9295
"SPLIT": "engine_operator",
9396
"engine_operator": "prepare_generation",
9497
"prepare_generation": "generate_new_token",
95-
"generate_new_token": "compile_generations",
98+
"generate_new_token": "compile_generated_tokens",
99+
"compile_generated_tokens": "compile_generations",
96100
"compile_generations": "JOIN",
97101
"JOIN": "join_output",
98102
"join_output": "process_outputs",

src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py

+19-39
Original file line numberDiff line numberDiff line change
@@ -15,35 +15,38 @@
1515
from typing import Any, Optional
1616

1717
import numpy
18+
from pydantic import BaseModel, Field
1819

1920
from deepsparse.operators import Operator
20-
from deepsparse.subgraph_execute import StreamingOutput
2121
from deepsparse.transformers.pipelines.text_generation import TokenGeneratorOperator
22-
from deepsparse.transformers.schemas.text_generation_schemas import (
23-
FinishReason,
24-
PromptLogitsNoKVCacheInference,
25-
)
22+
from deepsparse.transformers.schemas.text_generation_schemas import FinishReason
2623
from deepsparse.transformers.utils.helpers import set_generated_length
2724
from deepsparse.utils import InferenceState
2825

2926

30-
__all__ = ["PrepareGeneration"]
27+
__all__ = ["PrepareGeneration", "PrepareForGenerationOutput"]
28+
29+
30+
class PrepareForGenerationOutput(BaseModel):
31+
prompt_logits: Any = Field(
32+
description="A set of prompt logits generated during prefill"
33+
)
34+
kv_cache: Optional[Any] = Field(description="kv cache")
35+
in_generation: Optional[bool] = Field(description="in_generation flag")
3136

3237

3338
class PrepareGeneration(Operator):
39+
output_schema = PrepareForGenerationOutput
40+
3441
def __init__(
3542
self,
3643
token_generator: TokenGeneratorOperator,
3744
prompt_sequence_length: int,
3845
sequence_length: int,
39-
process_output_operator: Optional[Operator] = None,
4046
):
4147
self.sequence_length = sequence_length
4248
self.token_generator_creator = token_generator
4349
self.prompt_sequence_length = prompt_sequence_length
44-
# Needed for streaming as currently both setting up generation and generating
45-
# Will split this up soon
46-
self.process_output_operator = process_output_operator
4750

4851
def can_operate(self, inp: Any):
4952
kv_cache = inp.get("kv_cache")
@@ -79,7 +82,6 @@ def run(
7982
**inference_state.current_state,
8083
)
8184
token_generator = token_generator_creator_output.get("token_generator")
82-
token_generator.generate(prompt_logits[0, -1, :])
8385

8486
max_tokens, length_finish_reason = set_generated_length(
8587
max_length=generation_config.max_length,
@@ -93,43 +95,21 @@ def run(
9395
state_update = {
9496
"max_tokens": max_tokens,
9597
"length_finish_reason": length_finish_reason,
96-
"generated_tokens": [token_generator.tokens[-1]],
97-
"generated_logits": [prompt_logits]
98+
"generated_tokens": [],
99+
"generated_logits": [prompt_logits[:, 0:-1, :]]
98100
if include_prompt_logits
99-
else [numpy.expand_dims(prompt_logits[:, -1, :], 0)],
101+
else [],
100102
"finished_reason": [],
101103
"token_generator": token_generator,
102104
}
105+
103106
if kv_cache is None:
104-
output = PromptLogitsNoKVCacheInference(prompt_logits=prompt_logits)
107+
output = {"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0)}
105108
else:
106109
output = {
107-
"tokens": token_generator.tokens,
108110
"kv_cache": kv_cache,
109111
"in_generation": True,
112+
"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0),
110113
}
111-
# TODO: maybe break this operator up since it is both generating and setting
112-
# up values needed for generation? Holding off on this as this will change
113-
# routes slighty and want to confirm wont break anything for non-kv cache
114-
if inference_state.current_state.get("streaming") and max_tokens >= 1:
115-
finished_reason = [length_finish_reason] if max_tokens == 1 else [None]
116-
117-
if self.process_output_operator is None:
118-
raise ValueError(
119-
"An operator must be provided to process outputs"
120-
"while streaming."
121-
)
122-
data_to_yield = self.process_output_operator.run(
123-
generated_tokens=state_update.get("generated_tokens"),
124-
finished_reason=finished_reason,
125-
inference_state=inference_state,
126-
generated_logits=prompt_logits[0, -1, :],
127-
)
128-
output = StreamingOutput(
129-
data_to_yield=self.process_output_operator.output_schema(
130-
**data_to_yield
131-
),
132-
data_to_return=output,
133-
)
134114

135115
return output, state_update

src/deepsparse/transformers/schemas/text_generation_schemas.py

-8
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,3 @@ class TextGenerationOutput(BaseModel):
166166
class Config:
167167
arbitrary_types_allowed = True
168168
extra = "allow"
169-
170-
171-
class PromptLogitsNoKVCacheInference(BaseModel):
172-
prompt_logits: Any = Field(
173-
description="A set of prompt logits generated "
174-
"during the inference pass with a "
175-
"non-kv cache model"
176-
)

src/deepsparse/transformers/utils/helpers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ def set_generated_length(
104104
:param max_new_tokens: the max_new_tokens attribute, which may be provided
105105
as part of the input during inference
106106
"""
107-
if max_length:
107+
if max_length is not None:
108108
# if max_length provided, use that to cap total tokens generated
109+
if max_length == 0:
110+
raise ValueError("max_length must be greater than 0")
109111
max_tokens = max_length
110112
finish_reason = finish_reason_choices.LENGTH
111113
else:

tests/deepsparse/schedulers/test_continuous_batching_scheduler.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
import numpy
1818

19+
import pytest
1920
from deepsparse.operators import EngineOperator
2021
from deepsparse.schedulers import ContinuousBatchingScheduler
2122

2223

24+
@pytest.mark.skip("skip continuous batching tests")
2325
def test_continuous_batching_executor_thread():
2426
# simple test that ContinuousBatchingScheduler can be instantiated and return
2527
# a result from a request, for testing multi-batch execution, making enough

tests/deepsparse/schedulers/utils/test_continuous_batching_executor.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
import numpy
1818

19+
import pytest
1920
from deepsparse.operators import EngineOperator
2021
from deepsparse.schedulers.utils import (
2122
ContinuousBatchingExecutorThread,
2223
ContinuousBatchingQueues,
2324
)
2425

2526

27+
@pytest.mark.skip("skip continuous batching tests")
2628
def test_continuous_batching_executor_thread():
2729
# mobilenet model with batch_size=2
2830
engine_operator = EngineOperator("zoo:mobilenet_v2-1.0-imagenet-base")

0 commit comments

Comments
 (0)