From 84e4a07f20b0434e2b564e4c0f715ab5c3823951 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 9 Feb 2023 08:12:13 +0000 Subject: [PATCH] Revert 6380 for test_classification and test_video --- test/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index d4dab1bbc9d..5826cc77164 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -686,7 +686,7 @@ def test_classification_model(model_fn, dev): model.eval().to(device=dev) x = _get_image(input_shape=input_shape, real_image=real_image, device=dev) out = model(x) - _assert_expected(out.cpu(), model_name, prec=1e-3) + _assert_expected(out.cpu(), model_name, prec=0.1) assert out.shape[-1] == num_classes _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_fx_compatible(model, x, eager_out=out) @@ -917,7 +917,7 @@ def test_video_model(model_fn, dev): # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests x = torch.rand(input_shape).to(device=dev) out = model(x) - _assert_expected(out.cpu(), model_name, prec=1e-5) + _assert_expected(out.cpu(), model_name, prec=0.1) assert out.shape[-1] == num_classes _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out) _check_fx_compatible(model, x, eager_out=out)