diff --git a/src/pybamm/dispatch/entry_points.py b/src/pybamm/dispatch/entry_points.py index 715d1cc876..ceefbf50a8 100644 --- a/src/pybamm/dispatch/entry_points.py +++ b/src/pybamm/dispatch/entry_points.py @@ -1,6 +1,16 @@ +import hashlib import importlib.metadata import textwrap +import urllib.request from collections.abc import Callable, Mapping +from pathlib import Path + +from platformdirs import user_cache_dir + +from pybamm.expression_tree.operations.serialise import Serialise + +APP_NAME = "pybamm" +APP_AUTHOR = "pybamm" class EntryPoint(Mapping): @@ -109,7 +119,35 @@ def __getattribute__(self, name): models = EntryPoint(group="pybamm_models") -def Model(model: str, *args, **kwargs): +def _get_cache_dir() -> Path: + cache_dir = Path(user_cache_dir(APP_NAME, APP_AUTHOR)) / "models" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def get_cache_path(url: str) -> Path: + cache_dir = _get_cache_dir() + file_hash = hashlib.md5(url.encode()).hexdigest() + return cache_dir / f"{file_hash}.json" + + +def clear_model_cache() -> None: + cache_dir = _get_cache_dir() + for file in cache_dir.glob("*.json"): + try: + file.unlink() + except Exception as e: + # Optional: log error instead of failing silently + print(f"Could not delete {file}: {e}") + + +def Model( + model=None, + url=None, + force_download=False, + *args, + **kwargs, +): """ Returns the loaded model object Note: This feature is in its experimental phase. @@ -137,6 +175,26 @@ def Model(model: str, *args, **kwargs): >>> pybamm.Model('SPM') # doctest: +SKIP """ - model_class = models._get_class(model) - - return model_class(*args, **kwargs) + if (model is None and url is None) or (model and url): + raise ValueError("You must provide exactly one of `model` or `url`.") + + if url is not None: + cache_path = get_cache_path(url) + if not cache_path.exists() or force_download: + try: + print(f"Downloading model from {url}...") + urllib.request.urlretrieve(url, cache_path) + print(f"Model cached at: {cache_path}") + except Exception as e: + raise RuntimeError(f"Failed to download model from URL: {e}") from e + else: + print(f"Using cached model at: {cache_path}") + + return Serialise.load_custom_model(str(cache_path)) + + if model is not None: + try: + model_class = models._get_class(model) + return model_class(*args, **kwargs) + except Exception as e: + raise ValueError(f"Could not load model '{model}': {e}") from e diff --git a/tests/unit/test_dispatch.py b/tests/unit/test_dispatch.py index 7128cd74c3..0cebf78164 100644 --- a/tests/unit/test_dispatch.py +++ b/tests/unit/test_dispatch.py @@ -1,10 +1,23 @@ # # Test dispatching mechanism in entry points # +import pytest + import pybamm +from pybamm.dispatch.entry_points import ( + _get_cache_dir, + clear_model_cache, + get_cache_path, +) + +MODEL_URL = "https://raw.githubusercontent.com/pybamm-team/pybamm-reservoir-example/refs/heads/main/dfn.py" class TestDispatch: + def setup_method(self): + """Clear cache before each test""" + clear_model_cache() + def test_model_loads_through_entry_points(self): """Test that a model loaded through Model() function is actually functional""" # Load model with build=False to avoid full initialization for faster testing @@ -31,5 +44,112 @@ def test_model_function(self): # Test Model function with options parameter options = {"thermal": "isothermal"} - model = pybamm.Model("SPM", options) + model = pybamm.Model("SPM", options=options) assert model.__class__.__name__ == "SPM" + + def test_model_value_error(self): + """Test that Model raises ValueError when given invalid arguments""" + + # Neither model nor url provided + with pytest.raises( + ValueError, match="You must provide exactly one of `model` or `url`." + ): + pybamm.Model() + + # Both model and url provided + with pytest.raises( + ValueError, match="You must provide exactly one of `model` or `url`." + ): + pybamm.Model(model="SPM", url="http://example.com/dfn.py") + + def test_model_download_runtime_error(self): + """Test that Model raises RuntimeError when download fails""" + + bad_url = "h://example.invalid/model.json" + + with pytest.raises(RuntimeError, match="Failed to download model from URL:"): + pybamm.Model(url=bad_url, force_download=True) + + def test_invalid_model_name_raises_value_error(self): + """Test that Model raises ValueError for an invalid model name""" + + bad_model = "NonExistentModel123" + + with pytest.raises(ValueError, match=f"Could not load model '{bad_model}':"): + pybamm.Model(model=bad_model) + + def test_model_download_and_cache(self): + """Force exception in clear_model_cache without interfering with real cache files""" + cache_dir = _get_cache_dir() + bad_path = cache_dir / "force_exception_test_dir" + + try: + bad_path.mkdir(exist_ok=True) # not .json + try: + bad_path.unlink() + except Exception as e: + print(f"Expected error: {e}") + finally: + if bad_path.exists(): + bad_path.rmdir() + + def test_force_download_overwrites_cache(self): + """Force an exception when trying to unlink a directory""" + cache_dir = _get_cache_dir() + bad_path = cache_dir / "force_exception_test_dir" + + try: + bad_path.mkdir(exist_ok=True) + clear_model_cache() + try: + bad_path.unlink() + except Exception as e: + print(f"Expected error: {e}") + finally: + if bad_path.exists(): + bad_path.rmdir() + + def test_clear_model_cache_exception_branch(self, capsys): + """Test that clear_model_cache gracefully handles deletion errors""" + cache_dir = pybamm.dispatch.entry_points._get_cache_dir() + cache_dir.mkdir(parents=True, exist_ok=True) + + bad_path = cache_dir / "bad.json" + bad_path.mkdir(exist_ok=True) + + try: + clear_model_cache() + + captured = capsys.readouterr() + assert "Could not delete" in captured.out + assert "bad.json" in captured.out + + assert bad_path.exists() + finally: + bad_path.rmdir() + + def test_model_download_and_cache_integration(self, capsys): + """Integration test using a real model URL""" + + cache_path = get_cache_path(MODEL_URL) + + # Clean up: if cache path exists as directory, remove it + if cache_path.exists(): + if cache_path.is_dir(): + import shutil + + shutil.rmtree(cache_path) + else: + clear_model_cache() + + # First call -> should download and print "Model cached at" + model = pybamm.Model(url=MODEL_URL, force_download=True) + captured = capsys.readouterr() + assert "Model cached at:" in captured.out + assert hasattr(model, "name") + + # Second call -> should use cached file + model2 = pybamm.Model(url=MODEL_URL) + captured = capsys.readouterr() + assert "Using cached model at:" in captured.out + assert hasattr(model2, "name")