From 95edc68cb31beddfd1c3f930a2b0bc131ee4dcb2 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 6 Aug 2024 11:50:15 -0400 Subject: [PATCH] Fix gated test (#2993) * Fix gated test * Clean * Finally, adjust test --- src/accelerate/commands/estimate.py | 2 +- tests/test_cli.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/accelerate/commands/estimate.py b/src/accelerate/commands/estimate.py index 56da3c5ad9e..2cd731b2221 100644 --- a/src/accelerate/commands/estimate.py +++ b/src/accelerate/commands/estimate.py @@ -38,7 +38,7 @@ def verify_on_hub(repo: str, token: str = None): "Verifies that the model is on the hub and returns the model info." try: return model_info(repo, token=token) - except GatedRepoError: + except (OSError, GatedRepoError): return "gated" except RepositoryNotFoundError: return "repo" diff --git a/tests/test_cli.py b/tests/test_cli.py index 20e61e7a3b4..c4d35569ce8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -435,7 +435,10 @@ def test_no_metadata(self): estimate_command(args) def test_gated(self): - with self.assertRaises(GatedRepoError, msg="Repo for model `meta-llama/Llama-2-7b-hf` is gated"): + with self.assertRaises( + (GatedRepoError, EnvironmentError), + msg="Repo for model `meta-llama/Llama-2-7b-hf` is gated or environment error occurred", + ): args = self.parser.parse_args(["meta-llama/Llama-2-7b-hf"]) with patch_environment(hf_hub_disable_implicit_token="1"): estimate_command(args)