Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit a90a20a

Browse files
committed
Merge remote-tracking branch 'origin/features/v2/unit_testing' into feature/damian/v2/factor_out_transformation_utils
2 parents 6f1b175 + 379481e commit a90a20a

File tree

14 files changed

+570
-65
lines changed

14 files changed

+570
-65
lines changed

src/deepsparse/v2/routers/router.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,5 @@ def next(
158158

159159
@staticmethod
160160
def validate(ops) -> bool:
161+
# TODO: still needs to be implemented for the GraphRouter
161162
pass

src/deepsparse/v2/schedulers/scheduler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from concurrent.futures import Future, ThreadPoolExecutor
17+
from typing import Callable
1718

1819
from deepsparse.v2.operators import Operator
1920

@@ -64,3 +65,13 @@ def can_process(
6465
Base OperatorScheduler always returns True
6566
"""
6667
return True
68+
69+
def map(self, *args, func: Callable):
70+
"""
71+
:param func: generic callable run for each arg
72+
:return: list of futures for each submit
73+
"""
74+
futures = []
75+
for _, values in enumerate(zip(*args)):
76+
futures.append(self.submit(*values, operator=func))
77+
return futures

src/deepsparse/v2/schedulers/scheduler_group.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
from concurrent.futures import Future
17-
from typing import Callable, List
17+
from typing import List
1818

1919
from deepsparse.v2.operators import Operator
2020
from deepsparse.v2.schedulers.scheduler import OperatorScheduler
@@ -55,13 +55,3 @@ def submit(
5555
operator=operator,
5656
**kwargs,
5757
)
58-
59-
def map(self, *args, func: Callable):
60-
"""
61-
:param func: generic callable run for each arg
62-
:return: list of futures for each submit
63-
"""
64-
futures = []
65-
for _, values in enumerate(zip(*args)):
66-
futures.append(self.submit(*values, operator=func))
67-
return futures

src/deepsparse/v2/text_generation/kv_cache_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from deepsparse.v2.operators import Operator
2525

2626

27-
__all__ = ["KVCacheCreator"]
27+
__all__ = ["KVCacheCreator", "KVCacheCreatorInput"]
2828

2929

3030
class KVCacheCreatorOutput(BaseModel):

src/deepsparse/v2/text_generation/nl_engine_operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
)
3030

3131

32-
__all__ = ["NLEngineOperator"]
32+
__all__ = ["NLEngineOperator", "NlEngineInput"]
3333

3434

3535
class NlEngineInput(BaseModel):

src/deepsparse/v2/text_generation/prep_for_generation.py

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
from typing import Any
1516

1617
import numpy
1718

1819
from deepsparse.transformers.pipelines.text_generation import FinishReason
20+
from deepsparse.transformers.utils.helpers import set_generated_length
1921
from deepsparse.v2.operators import Operator
2022
from deepsparse.v2.text_generation import TokenGeneratorOperator
2123
from deepsparse.v2.utils import InferenceState
@@ -31,9 +33,9 @@ def __init__(
3133
prompt_sequence_length: int,
3234
sequence_length: int,
3335
):
34-
self.prompt_sequence_length = prompt_sequence_length
3536
self.sequence_length = sequence_length
3637
self.token_generator_creator = token_generator
38+
self.prompt_sequence_length = prompt_sequence_length
3739

3840
def can_operate(self, inp: Any):
3941
kv_cache = inp.get("kv_cache")
@@ -47,49 +49,6 @@ def can_operate(self, inp: Any):
4749
return True
4850
return False
4951

50-
@staticmethod
51-
def set_generated_length(
52-
max_length: int,
53-
prompt_tokens_length: int,
54-
sequence_length: int,
55-
prompt_sequence_length: int,
56-
max_new_tokens: int,
57-
finish_reason_choices: "FinishReason", # noqa
58-
):
59-
"""
60-
Determine the length of the generated tokens. The hard cap on the total number
61-
of tokens is based on the sequence length. If max_length is provided and is less
62-
than the sequence length, it will be used to cap the total number of tokens
63-
generated. If it is not provided, the max_new_tokens attribute will be used and
64-
also capped by the sequence length.
65-
66-
:param max_length: max_length attribute, provided as input during inference
67-
:param prompt_tokens_length: the number of prompt tokens used as part of the
68-
generated output
69-
:param sequence_length: the sequence length used for the pipeline
70-
:param prompt_sequence_length: the prompt sequence length used for the pipeline
71-
:param max_new_tokens: the max_new_tokens attribute, which may be provided
72-
as part of the input during inference
73-
"""
74-
if max_length:
75-
# if max_length provided, use that to cap total tokens generated
76-
max_tokens = max_length
77-
finish_reason = finish_reason_choices.LENGTH
78-
else:
79-
# if not provided, max tokens is based on max_new_tokens + prompt tokens
80-
max_tokens = (
81-
min(max_new_tokens, sequence_length - prompt_sequence_length)
82-
+ prompt_tokens_length
83-
)
84-
finish_reason = finish_reason_choices.MAX_NEW_TOKENS
85-
86-
# hard model/pipeline cap
87-
return (
88-
(sequence_length, finish_reason_choices.CAPACITY)
89-
if sequence_length < max_tokens
90-
else (max_tokens, finish_reason)
91-
)
92-
9352
def run(
9453
self, tokens: Any, kv_cache: Any, inference_state: InferenceState, **kwargs
9554
):
@@ -107,13 +66,13 @@ def run(
10766
logits_shape=prompt_logits[0, -1, :].shape,
10867
deterministic=not generation_config.do_sample,
10968
sampling_temperature=generation_config.temperature,
110-
tokens=tokens,
69+
tokens=copy.copy(tokens),
11170
**inference_state.current_state,
11271
)
11372
token_generator = token_generator_creator_output.get("token_generator")
11473
token_generator.generate(prompt_logits[0, -1, :])
11574

116-
max_tokens, length_finish_reason = PrepareGeneration.set_generated_length(
75+
max_tokens, length_finish_reason = set_generated_length(
11776
max_length=generation_config.max_length,
11877
prompt_tokens_length=1,
11978
max_new_tokens=generation_config.max_new_tokens,
@@ -131,7 +90,6 @@ def run(
13190
"finished_reason": [],
13291
"token_generator": token_generator,
13392
}
134-
13593
output = {
13694
"tokens": token_generator.tokens,
13795
"kv_cache": kv_cache,

src/deepsparse/v2/text_generation/process_inputs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from deepsparse.v2.operators import Operator
2727

2828

29+
__all__ = ["ProcessInputsTextGeneration", "GenerationDefaults"]
30+
31+
2932
class GenerationDefaults:
3033
num_return_sequences = 1
3134
max_length = 100
@@ -38,9 +41,6 @@ class GenerationDefaults:
3841
temperature = 1.0
3942

4043

41-
__all__ = ["ProcessInputsTextGeneration"]
42-
43-
4444
class ProcessInputsTextGeneration(Operator):
4545
"""
4646
Input processing operator. Responsible for tokenizing the input, handling the
@@ -54,10 +54,10 @@ class ProcessInputsTextGeneration(Operator):
5454
def __init__(
5555
self,
5656
tokenizer: transformers.PreTrainedTokenizerBase,
57+
sequence_length: int,
5758
generation_config: Union[
5859
str, pathlib.Path, Dict, transformers.GenerationConfig
59-
],
60-
sequence_length: int,
60+
] = None,
6161
):
6262
self.generation_config = generation_config
6363
self.tokenizer = tokenizer
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
17+
import numpy
18+
from transformers import AutoTokenizer
19+
20+
import pytest
21+
from deepsparse.transformers.helpers import get_deployment_path
22+
from deepsparse.transformers.pipelines.text_generation import TextGenerationInput
23+
from deepsparse.transformers.utils import DecoderKVCache
24+
from deepsparse.transformers.utils.helpers import initialize_kv_cache_state
25+
from deepsparse.v2 import InferenceState, PipelineState
26+
from deepsparse.v2.text_generation import (
27+
GenerationDefaults,
28+
NLEngineOperator,
29+
TokenGeneratorOperator,
30+
)
31+
32+
33+
@pytest.fixture(scope="module")
34+
def text_generation_attributes():
35+
sequence_length = 5
36+
prompt_sequence_length = 1
37+
return sequence_length, prompt_sequence_length
38+
39+
40+
@pytest.fixture(scope="module")
41+
def model_attributes(text_generation_attributes):
42+
model_path = "hf:mgoin/TinyStories-1M-deepsparse"
43+
sequence_length, _ = text_generation_attributes
44+
deployment_path, model_path = get_deployment_path(model_path)
45+
46+
tokenizer = AutoTokenizer.from_pretrained(
47+
deployment_path,
48+
trust_remote_code=False,
49+
model_max_length=sequence_length,
50+
)
51+
52+
tokenizer.padding_side = "left"
53+
if not tokenizer.pad_token:
54+
tokenizer.pad_token = tokenizer.eos_token
55+
56+
return tokenizer, model_path
57+
58+
59+
@pytest.fixture(scope="module")
60+
def single_token_engine_no_internal_cache(text_generation_attributes, model_attributes):
61+
seq_length, _ = text_generation_attributes
62+
_, model_path = model_attributes
63+
64+
nl_engine_operator = NLEngineOperator(
65+
sequence_length=seq_length, input_ids_length=1, model_path=model_path
66+
)
67+
return nl_engine_operator
68+
69+
70+
@pytest.fixture(scope="module")
71+
def pipeline_state(single_token_engine_no_internal_cache):
72+
pipeline_state = PipelineState()
73+
pipeline_state_vals = {}
74+
pipeline_state_vals[
75+
"onnx_input_names_no_cache"
76+
] = single_token_engine_no_internal_cache.onnx_input_names_no_cache
77+
pipeline_state_vals[
78+
"cache_shape"
79+
] = single_token_engine_no_internal_cache.cache_shape
80+
pipeline_state_vals[
81+
"output_names"
82+
] = single_token_engine_no_internal_cache.output_names
83+
pipeline_state_vals[
84+
"kv_cache_data_type"
85+
] = single_token_engine_no_internal_cache.kv_cache_data_type
86+
pipeline_state.create_state(pipeline_state_vals)
87+
return pipeline_state
88+
89+
90+
@pytest.fixture(scope="module")
91+
def large_prompt():
92+
prompt = "Hello, how are you doing today?"
93+
generation_config = {"top_p": 0, "top_k": 0, "max_length": 10}
94+
return TextGenerationInput(prompt=prompt, generation_config=generation_config)
95+
96+
97+
@pytest.fixture(scope="module")
98+
def small_prompt():
99+
prompt = "Hello"
100+
return TextGenerationInput(prompt=prompt)
101+
102+
103+
@pytest.fixture(scope="module")
104+
def mock_kv_cache():
105+
kv_cache = DecoderKVCache()
106+
kv_cache.setup(
107+
state={"dummy_cache_name": numpy.array([[[[0], [0], [1], [2], [3]]]])},
108+
)
109+
return kv_cache
110+
111+
112+
@pytest.fixture(scope="module")
113+
def mock_kv_cache_three_tokens_processed():
114+
kv_cache = DecoderKVCache()
115+
kv_cache.setup(
116+
state={"dummy_cache_name": numpy.array([[[[0], [0], [1], [2], [3]]]])},
117+
num_processed_tokens=3,
118+
)
119+
return kv_cache
120+
121+
122+
@pytest.fixture(scope="module")
123+
def mock_kv_cache_single_token_engine(pipeline_state, text_generation_attributes):
124+
seq_len, prompt_seq_len = text_generation_attributes
125+
kv_cache = DecoderKVCache()
126+
kv_cache_state = initialize_kv_cache_state(
127+
cache_shape=pipeline_state.current_state.get("cache_shape"),
128+
kv_cache_data_type=pipeline_state.current_state.get("kv_cache_data_type"),
129+
output_names=pipeline_state.current_state.get("output_names"),
130+
length=seq_len - prompt_seq_len,
131+
empty=False,
132+
)
133+
kv_cache.setup(state=kv_cache_state)
134+
return kv_cache
135+
136+
137+
@pytest.fixture(scope="module")
138+
def mock_tokens():
139+
return [15496]
140+
141+
142+
@pytest.fixture(scope="module")
143+
def mock_tokens_multiple():
144+
return [15496, 15496, 15496]
145+
146+
147+
@pytest.fixture(scope="module")
148+
def mock_inference_state():
149+
generation_config = GenerationDefaults()
150+
inference_state = InferenceState()
151+
inference_state.create_state({})
152+
inference_state.update_state({"generation_config": generation_config})
153+
return inference_state
154+
155+
156+
@pytest.fixture(scope="module")
157+
def mock_token_generator(model_attributes, mock_tokens_multiple):
158+
tokenizer, _ = model_attributes
159+
token_generator_creator = TokenGeneratorOperator()
160+
prompt_logits = numpy.random.rand(1, len(mock_tokens_multiple), len(tokenizer))
161+
token_generator_creator_output = token_generator_creator.run(
162+
logits_shape=prompt_logits[0, -1, :].shape,
163+
deterministic=True,
164+
sampling_temperature=1.0,
165+
tokens=copy.copy(mock_tokens_multiple),
166+
)
167+
return token_generator_creator_output.get("token_generator")
168+
169+
170+
@pytest.fixture(scope="module")
171+
def mock_logits(model_attributes):
172+
tokenizer, _ = model_attributes
173+
return numpy.random.rand(1, 1, len(tokenizer))

0 commit comments

Comments
 (0)