Skip to content

Commit d1683b4

Browse files
authored
Merge branch 'v2' into feature/damian/v2/factor_out_transformation_utils
2 parents 98f7a6d + bbd534d commit d1683b4

File tree

4 files changed

+185
-96
lines changed

4 files changed

+185
-96
lines changed

src/deepsparse/transformers/utils/helpers.py

+91-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import pathlib
1616
import uuid
17-
from typing import Any, Dict, List, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

1919
import numpy
2020
from transformers import AutoTokenizer, GenerationConfig
@@ -33,6 +33,7 @@
3333
"override_config",
3434
"process_generation_config",
3535
"validate_session_ids",
36+
"compute_engine_inputs",
3637
"set_generated_length",
3738
]
3839

@@ -82,6 +83,95 @@ def set_generated_length(
8283
)
8384

8485

86+
def compute_engine_inputs(onnx_input_names: str, **kwargs) -> List[numpy.ndarray]:
87+
"""
88+
Given the names of the onnx inputs, compute the inputs
89+
to the engine. The inputs will be calculating from the
90+
passed kwargs. The information about the required kwargs
91+
can be found in the docstring of the individual compute
92+
functions.
93+
94+
:param onnx_input_names: The names of the onnx inputs
95+
:param kwargs: The kwargs to compute the inputs from
96+
:return: The computed inputs to the engine
97+
"""
98+
engine_inputs = []
99+
for input_name in onnx_input_names:
100+
if input_name == "causal_mask":
101+
# delay the computation of the causal mask
102+
continue
103+
# fetch the compute function for the
104+
# given input_name
105+
compute_func = _get_compute_func(input_name)
106+
# compute the engine input from the kwargs
107+
# and append it to the engine_inputs
108+
engine_inputs.append(compute_func(**kwargs))
109+
110+
if "causal_mask" in onnx_input_names:
111+
# compute the causal mask and append it to the engine_inputs
112+
input_ids, attention_mask, *_ = engine_inputs
113+
engine_inputs.append(create_causal_mask(input_ids, attention_mask))
114+
115+
return engine_inputs
116+
117+
118+
def _get_compute_func(input_name: str) -> Callable[..., numpy.ndarray]:
119+
# given the input_name, return the appropriate compute function
120+
compute_func = {
121+
"input_ids": _compute_input_ids,
122+
"attention_mask": _compute_attention_mask,
123+
"positions": _compute_positions,
124+
}.get(input_name)
125+
if compute_func is None:
126+
raise ValueError(
127+
"Could not find compute function " f"for the input_name: {input_name}"
128+
)
129+
return compute_func
130+
131+
132+
def _compute_input_ids(token_batch: List[int], **kwargs) -> numpy.ndarray:
133+
# convert the token_batch to a numpy array
134+
return numpy.array([token_batch])
135+
136+
137+
def _compute_attention_mask(
138+
sequence_length: int,
139+
prompt_sequence_length: int,
140+
num_total_processed_tokens: int,
141+
**kwargs,
142+
) -> numpy.ndarray:
143+
# create a fully masked attention mask with the appropriate
144+
# shape (equal to the sequence_length)
145+
attention_mask = numpy.zeros((1, sequence_length), dtype=numpy.int64)
146+
# unmask the appropriate number of tokens, the sum of
147+
# - the number of tokens already processed and cached (num_total_processed_tokens)
148+
# - the number of tokens currently processed (prompt_sequence_length)
149+
# the sum cannot exceed the maximum length of the attention_mask
150+
num_attention_entries_to_unmask = min(
151+
num_total_processed_tokens + prompt_sequence_length, sequence_length
152+
)
153+
# unmask the bits from the right-hand side
154+
attention_mask[:, -num_attention_entries_to_unmask:] = 1
155+
return attention_mask
156+
157+
158+
def _compute_positions(
159+
num_total_processed_tokens: int, prompt_sequence_length: int, **kwargs
160+
):
161+
# create the positions array with the appropriate shape
162+
# positions count starts from the number of tokens already processed
163+
# and ends at the number of tokens already processed + the number of tokens
164+
# currently processed
165+
return (
166+
numpy.arange(
167+
num_total_processed_tokens,
168+
num_total_processed_tokens + prompt_sequence_length,
169+
)
170+
.reshape(1, -1)
171+
.astype(numpy.int64)
172+
)
173+
174+
85175
def validate_session_ids(
86176
session_ids: Optional[str], other_attributes: Dict[str, Any]
87177
) -> Optional[List[str]]:

src/deepsparse/v2/text_generation/autoregressive_preprocess_operator.py

+9-25
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
import logging
1616
from typing import Any
1717

18-
import numpy
19-
20-
from deepsparse.transformers.utils.helpers import create_causal_mask
18+
from deepsparse.transformers.utils.helpers import compute_engine_inputs
2119
from deepsparse.v2.operators import Operator
2220
from deepsparse.v2.utils import PipelineState
2321

@@ -66,30 +64,16 @@ def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwarg
6664

6765
num_total_processed_tokens = kv_cache.total_num_processed_tokens
6866
new_token = tokens[num_total_processed_tokens]
69-
engine_input_names = pipeline_state.current_state.get(
70-
"onnx_input_names_no_cache"
71-
)
72-
73-
# padding is added to left, so attention mask is 1s from the
74-
# right up to the number of total tokens (prompt + generated)
75-
attention_mask = numpy.zeros((1, self.sequence_length), dtype=numpy.int64)
76-
num_attention_entries_to_unmask = min(
77-
num_total_processed_tokens + 1, self.sequence_length
78-
) # cap by seq len
79-
attention_mask[:, -num_attention_entries_to_unmask:] = 1
80-
positions = numpy.array([[num_total_processed_tokens]], dtype=numpy.int64)
81-
input_ids = numpy.array([[new_token]])
82-
causal_mask = create_causal_mask(input_ids, attention_mask)
8367

84-
engine_inputs_map = dict(
85-
input_ids=input_ids,
86-
attention_mask=attention_mask,
87-
causal_mask=causal_mask,
88-
positions=positions,
68+
engine_inputs = compute_engine_inputs(
69+
onnx_input_names=pipeline_state.current_state.get(
70+
"onnx_input_names_no_cache"
71+
),
72+
token_batch=[new_token],
73+
prompt_sequence_length=1,
74+
sequence_length=self.sequence_length,
75+
num_total_processed_tokens=num_total_processed_tokens,
8976
)
90-
91-
engine_inputs = [engine_inputs_map[name] for name in engine_input_names]
92-
9377
return {
9478
"engine_inputs": engine_inputs,
9579
"kv_cache": kv_cache,

src/deepsparse/v2/text_generation/multi_engine_prefill_operator.py

+11-70
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from enum import Enum
1716
from typing import Any
1817

19-
import numpy
20-
21-
from deepsparse.transformers.utils.helpers import create_causal_mask
18+
from deepsparse.transformers.utils.helpers import compute_engine_inputs
2219
from deepsparse.v2.operators import Operator
2320
from deepsparse.v2.utils import PipelineState
2421

@@ -28,34 +25,14 @@
2825
__all__ = ["MultiEnginePrefill"]
2926

3027

31-
class OnnxInputNames(Enum):
32-
INPUT_IDS = "input_ids"
33-
ATTN_MASK = "attention_mask"
34-
CAUSAL_MASK = "causal_mask"
35-
POSITIONS = "positions"
36-
37-
38-
# NOTE: A possible clean-up could involve combining this Operator and the
39-
# autoregressive_preprocess_operator
40-
41-
4228
class MultiEnginePrefill(Operator):
4329
def __init__(self, prompt_sequence_length, sequence_length):
4430
"""
4531
Prepare the tokens for the multi-token engine. This requires creating the
46-
attention mask, positions, and causal mask. The output contains these three
47-
arrays to be passed into the multi-token engine.
32+
appropriate engine_inputsto be passed into the multi-token engine.
4833
"""
4934
self.prompt_sequence_length = prompt_sequence_length
5035
self.sequence_length = sequence_length
51-
self.cases = {
52-
OnnxInputNames.ATTN_MASK.value: self._case_attn_mask,
53-
OnnxInputNames.POSITIONS.value: self._case_positions,
54-
}
55-
_LOGGER.warn(
56-
"This operator requires the PipelineState to be set-up with the "
57-
"onnx_input_names_no_cache attribute set from the NLEngineOperator."
58-
)
5936

6037
def can_operate(self, inp: Any):
6138
"""
@@ -75,59 +52,23 @@ def can_operate(self, inp: Any):
7552
return True
7653
return False
7754

78-
def _case_attn_mask(self, num_total_processed_tokens: int):
79-
# create an empty attention mask
80-
engine_input = numpy.zeros((1, self.sequence_length), dtype=numpy.int64)
81-
# calculate the number of entries in attention mask that should be set to 1
82-
num_attention_entries_to_unmask = min(
83-
num_total_processed_tokens + self.prompt_sequence_length,
84-
self.sequence_length,
85-
)
86-
engine_input[:, -num_attention_entries_to_unmask:] = 1
87-
return engine_input
88-
89-
def _case_positions(self, num_total_processed_tokens: int):
90-
return (
91-
numpy.arange(
92-
num_total_processed_tokens,
93-
num_total_processed_tokens + self.prompt_sequence_length,
94-
)
95-
.reshape(1, -1)
96-
.astype(numpy.int64)
97-
)
98-
9955
def run(self, tokens: Any, kv_cache: Any, pipeline_state: PipelineState, **kwargs):
10056
kv_cache.set_capacity(self.sequence_length - self.prompt_sequence_length)
10157

102-
onnx_input_names_no_cache = pipeline_state.current_state.get(
103-
"onnx_input_names_no_cache"
104-
)
105-
10658
num_total_processed_tokens = kv_cache.total_num_processed_tokens
10759
start = num_total_processed_tokens
10860
end = start + self.prompt_sequence_length
10961
token_batch = tokens[start:end]
11062

111-
engine_inputs = []
112-
for name in onnx_input_names_no_cache:
113-
if name == OnnxInputNames.INPUT_IDS.value:
114-
engine_input = numpy.array([token_batch])
115-
elif (
116-
name == OnnxInputNames.ATTN_MASK.value
117-
or name == OnnxInputNames.POSITIONS.value
118-
):
119-
engine_input = self.cases[name](num_total_processed_tokens)
120-
elif name == OnnxInputNames.CAUSAL_MASK.value:
121-
continue
122-
123-
engine_inputs.append(engine_input)
124-
125-
if OnnxInputNames.CAUSAL_MASK.value in onnx_input_names_no_cache:
126-
causal_mask = create_causal_mask(
127-
input_ids=engine_inputs[0],
128-
attention_mask=engine_inputs[1],
129-
)
130-
engine_inputs.append(causal_mask)
63+
engine_inputs = compute_engine_inputs(
64+
onnx_input_names=pipeline_state.current_state.get(
65+
"onnx_input_names_no_cache"
66+
),
67+
token_batch=token_batch,
68+
prompt_sequence_length=self.prompt_sequence_length,
69+
sequence_length=self.sequence_length,
70+
num_total_processed_tokens=num_total_processed_tokens,
71+
)
13172

13273
return {
13374
"engine_inputs": engine_inputs,

tests/deepsparse/transformers/utils/test_helpers.py

+74
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,86 @@
1616

1717
import pytest
1818
from deepsparse.transformers.utils.helpers import (
19+
compute_engine_inputs,
1920
create_causal_mask,
2021
initialize_kv_cache_state,
2122
validate_session_ids,
2223
)
2324

2425

26+
@pytest.mark.parametrize(
27+
"onnx_input_names, "
28+
"token_batch, "
29+
"prompt_sequence_length, "
30+
"sequence_length, "
31+
"num_total_processed_tokens, "
32+
"expected_engine_inputs",
33+
[
34+
(
35+
["input_ids", "attention_mask", "positions"],
36+
[1, 2, 3],
37+
3,
38+
6,
39+
2,
40+
[
41+
numpy.array([[1, 2, 3]]),
42+
numpy.array([[0, 1, 1, 1, 1, 1]]),
43+
numpy.array([[2, 3, 4]]),
44+
],
45+
),
46+
(
47+
["input_ids", "attention_mask", "positions", "causal_mask"],
48+
[1, 2, 3],
49+
3,
50+
6,
51+
2,
52+
[
53+
numpy.array([[1, 2, 3]]),
54+
numpy.array([[0, 1, 1, 1, 1, 1]]),
55+
numpy.array([[2, 3, 4]]),
56+
create_causal_mask(
57+
input_ids=numpy.array([[1, 2, 3]]),
58+
attention_mask=numpy.array([[0, 1, 1, 1, 1, 1]]),
59+
),
60+
],
61+
),
62+
(
63+
["input_ids", "attention_mask", "positions", "causal_mask"],
64+
[15],
65+
1,
66+
5,
67+
3,
68+
[
69+
numpy.array([[15]]),
70+
numpy.array([[0, 1, 1, 1, 1]]),
71+
numpy.array([[3]]),
72+
create_causal_mask(
73+
input_ids=numpy.array([[15]]),
74+
attention_mask=numpy.array([[0, 1, 1, 1, 1]]),
75+
),
76+
],
77+
),
78+
],
79+
)
80+
def test_compute_engine_inputs(
81+
onnx_input_names,
82+
token_batch,
83+
prompt_sequence_length,
84+
sequence_length,
85+
num_total_processed_tokens,
86+
expected_engine_inputs,
87+
):
88+
engine_inputs = compute_engine_inputs(
89+
onnx_input_names=onnx_input_names,
90+
token_batch=token_batch,
91+
prompt_sequence_length=prompt_sequence_length,
92+
sequence_length=sequence_length,
93+
num_total_processed_tokens=num_total_processed_tokens,
94+
)
95+
for x, y in zip(engine_inputs, expected_engine_inputs):
96+
assert numpy.array_equal(x, y)
97+
98+
2599
@pytest.mark.parametrize(
26100
"input_ids, attention_mask, expected_causal_mask",
27101
[

0 commit comments

Comments
 (0)