Skip to content

Commit dbf736c

Browse files
committed
nit: format the file and rebase
Signed-off-by: vbaddi <[email protected]>
1 parent 7e26c7a commit dbf736c

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tests/transformers/models/test_causal_lm_models.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,19 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
180180
)
181181
pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model)
182182

183+
<<<<<<< HEAD
183184
if model_name not in ModelConfig.SWIFTKV_MODELS:
184185
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
185186
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
186187
)
187188
assert (
188189
pytorch_hf_tokens == pytorch_kv_tokens
189190
).all(), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
191+
=======
192+
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
193+
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
194+
)
195+
>>>>>>> 684644e (nit: format the file)
190196

191197
onnx_model_path = qeff_model.export()
192198
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
@@ -212,13 +218,13 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
212218
:, :gen_len
213219
] # Because we always run for single input and single batch size
214220
if prefill_only:
215-
assert (
216-
ort_tokens[0][0] == cloud_ai_100_tokens[0][0]
217-
).all(), "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
221+
assert (ort_tokens[0][0] == cloud_ai_100_tokens[0][0]).all(), (
222+
"prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
223+
)
218224
else:
219-
assert (
220-
ort_tokens == cloud_ai_100_tokens
221-
).all(), "Tokens don't match for ONNXRT output and Cloud AI 100 output."
225+
assert (ort_tokens == cloud_ai_100_tokens).all(), (
226+
"Tokens don't match for ONNXRT output and Cloud AI 100 output."
227+
)
222228
assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json"))
223229
if prefill_only is not None:
224230
return
@@ -307,9 +313,9 @@ def test_causal_lm_export_with_deprecated_api(model_name):
307313
new_api_ort_tokens = api_runner.run_kv_model_on_ort(new_api_onnx_model_path)
308314
old_api_ort_tokens = api_runner.run_kv_model_on_ort(old_api_onnx_model_path)
309315

310-
assert (
311-
new_api_ort_tokens == old_api_ort_tokens
312-
).all(), "New API output does not match old API output for ONNX export function"
316+
assert (new_api_ort_tokens == old_api_ort_tokens).all(), (
317+
"New API output does not match old API output for ONNX export function"
318+
)
313319

314320

315321
@pytest.mark.on_qaic

0 commit comments

Comments
 (0)