diff --git a/tests/cache/test_neuronx_cache.py b/tests/cache/test_neuronx_cache.py index 684e3b6fe..db1a50182 100644 --- a/tests/cache/test_neuronx_cache.py +++ b/tests/cache/test_neuronx_cache.py @@ -18,13 +18,13 @@ import socket import subprocess from tempfile import TemporaryDirectory +from time import time import PIL import pytest import torch from huggingface_hub import HfApi from transformers import AutoTokenizer -from transformers.testing_utils import ENDPOINT_STAGING from optimum.neuron import ( NeuronModelForCausalLM, @@ -37,27 +37,21 @@ @pytest.fixture -def cache_repos(staging): +def cache_repos(): # Setup: create temporary Hub repository and local cache directory - token = staging["token"] - user = staging["user"] - api = HfApi(endpoint=ENDPOINT_STAGING, token=token) + api = HfApi() hostname = socket.gethostname() - cache_repo_id = f"{user}/{hostname}-optimum-neuron-cache" - if api.repo_exists(cache_repo_id): - api.delete_repo(cache_repo_id) + cache_repo_id = f"{hostname}-{time()}-optimum-neuron-cache" cache_repo_id = api.create_repo(cache_repo_id, private=True).repo_id cache_dir = TemporaryDirectory() cache_path = cache_dir.name # Modify environment to force neuronx cache to use temporary caches previous_env = {} - env_vars = ["NEURON_COMPILE_CACHE_URL", "CUSTOM_CACHE_REPO", "HF_ENDPOINT", "HF_TOKEN"] + env_vars = ["NEURON_COMPILE_CACHE_URL", "CUSTOM_CACHE_REPO"] for var in env_vars: previous_env[var] = os.environ.get(var) os.environ["NEURON_COMPILE_CACHE_URL"] = cache_path os.environ["CUSTOM_CACHE_REPO"] = cache_repo_id - os.environ["HF_ENDPOINT"] = ENDPOINT_STAGING - os.environ["HF_TOKEN"] = token yield (cache_path, cache_repo_id) # Teardown api.delete_repo(cache_repo_id) @@ -173,8 +167,7 @@ def check_traced_cache_entry(cache_path): def assert_local_and_hub_cache_sync(cache_path, cache_repo_id): - # Since created models are public on the staging endpoint we don't need a token - api = HfApi(endpoint=ENDPOINT_STAGING) + api = HfApi() remote_files = api.list_repo_files(cache_repo_id) local_files = get_local_cached_files(cache_path) for file in local_files: