@@ -180,13 +180,19 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
180
180
)
181
181
pytorch_kv_tokens = api_runner .run_kv_model_on_pytorch (qeff_model .model )
182
182
183
+ < << << << HEAD
183
184
if model_name not in ModelConfig .SWIFTKV_MODELS :
184
185
assert (pytorch_hf_tokens == pytorch_kv_tokens ).all (), (
185
186
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
186
187
)
187
188
assert (
188
189
pytorch_hf_tokens == pytorch_kv_tokens
189
190
).all (), "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
191
+ == == == =
192
+ assert (pytorch_hf_tokens == pytorch_kv_tokens ).all (), (
193
+ "Tokens don't match for HF PyTorch model output and KV PyTorch model output"
194
+ )
195
+ >> >> >> > 684644 e (nit : format the file )
190
196
191
197
onnx_model_path = qeff_model .export ()
192
198
ort_tokens = api_runner .run_kv_model_on_ort (onnx_model_path , is_tlm = is_tlm )
@@ -212,13 +218,13 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
212
218
:, :gen_len
213
219
] # Because we always run for single input and single batch size
214
220
if prefill_only :
215
- assert (
216
- ort_tokens [ 0 ][ 0 ] == cloud_ai_100_tokens [ 0 ][ 0 ]
217
- ). all (), "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
221
+ assert (ort_tokens [ 0 ][ 0 ] == cloud_ai_100_tokens [ 0 ][ 0 ]). all (), (
222
+ "prefill run output tokens don't match for ONNXRT output and Cloud AI 100 output."
223
+ )
218
224
else :
219
- assert (
220
- ort_tokens == cloud_ai_100_tokens
221
- ). all (), "Tokens don't match for ONNXRT output and Cloud AI 100 output."
225
+ assert (ort_tokens == cloud_ai_100_tokens ). all (), (
226
+ "Tokens don't match for ONNXRT output and Cloud AI 100 output."
227
+ )
222
228
assert os .path .isfile (os .path .join (os .path .dirname (qpc_path ), "qconfig.json" ))
223
229
if prefill_only is not None :
224
230
return
@@ -307,9 +313,9 @@ def test_causal_lm_export_with_deprecated_api(model_name):
307
313
new_api_ort_tokens = api_runner .run_kv_model_on_ort (new_api_onnx_model_path )
308
314
old_api_ort_tokens = api_runner .run_kv_model_on_ort (old_api_onnx_model_path )
309
315
310
- assert (
311
- new_api_ort_tokens == old_api_ort_tokens
312
- ). all (), "New API output does not match old API output for ONNX export function"
316
+ assert (new_api_ort_tokens == old_api_ort_tokens ). all (), (
317
+ "New API output does not match old API output for ONNX export function"
318
+ )
313
319
314
320
315
321
@pytest .mark .on_qaic
0 commit comments