Skip to content

Commit c0c4240

Browse files
authored
Merge branch 'v2' into feature/damian/v2/factor_out_transformation_utils
2 parents a90a20a + 0a50d1d commit c0c4240

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from deepsparse.v2.text_generation import CompilePromptLogits
16+
17+
18+
def test_compile_logits(mock_logits, mock_inference_state):
19+
mock_inference_state.update_state({"prompt_logits": [mock_logits]})
20+
compile_prompt_logits = CompilePromptLogits()
21+
# Can operate as long as we're not in generation but in prompt_inference. This
22+
# can_operate() will check for the `in_generation` flag in the input.
23+
assert compile_prompt_logits.can_operate({})
24+
output, state = compile_prompt_logits.run(
25+
logits=mock_logits, inference_state=mock_inference_state
26+
)
27+
# The CompilePromptLogits is responsible for updating a list of prompt logits
28+
# calculated at each step during prompt inference. After one step of running this
29+
# operator, the total number of prompt_logits in the inference state should be
30+
# the current length of prompt logits + 1
31+
assert len(state.get("prompt_logits")) == len([mock_logits]) + 1

0 commit comments

Comments
 (0)