Skip to content
66 changes: 62 additions & 4 deletions src/pybamm/dispatch/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
import hashlib
import importlib.metadata
import textwrap
import urllib.request
from collections.abc import Callable, Mapping
from pathlib import Path

from platformdirs import user_cache_dir

from pybamm.expression_tree.operations.serialise import Serialise

APP_NAME = "pybamm"
APP_AUTHOR = "pybamm"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep just APP_NAME and drop APP_AUTHOR, and we can set Path(user_cache_dir(APP_NAME)) / "models" as a constant here instead.



class EntryPoint(Mapping):
Expand Down Expand Up @@ -109,7 +119,35 @@ def __getattribute__(self, name):
models = EntryPoint(group="pybamm_models")


def Model(model: str, *args, **kwargs):
def _get_cache_dir() -> Path:
cache_dir = Path(user_cache_dir(APP_NAME, APP_AUTHOR)) / "models"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir


def get_cache_path(url: str) -> Path:
cache_dir = _get_cache_dir()
file_hash = hashlib.md5(url.encode()).hexdigest()
return cache_dir / f"{file_hash}.json"
Comment on lines +128 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why MD5 and not SHA-256? :)



def clear_model_cache() -> None:
cache_dir = _get_cache_dir()
for file in cache_dir.glob("*.json"):
try:
file.unlink()
except Exception as e:
Comment on lines +134 to +139
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, we also store some of PyBaMM's data files in the cache dir using a pooch registry:

def get_data(self, filename: str):
"""
Fetches the data file from upstream and stores it in the local cache directory under pybamm directory.
Parameters
----------
filename : str
Name of the data file to be fetched from the registry.
Returns
-------
pathlib.PurePath
"""
self.registry.fetch(filename)
return pathlib.Path(f"{self.path}/{self.version}/{filename}")

I wonder if we could reuse some of the code here, because clearing the cache directory of JSON files could have unintended side effects. We do have JSON files there: https://github.com/pybamm-team/pybamm-data/releases/tag/v1.0.1

Could we perhaps rewrite the PR to use pooch instead? That will also provide safer defaults than using urllib.request.urlretrieve() directly, and we could rely on it for the checksum/caching bits to see if we need to download the model JSON again or not.

# Optional: log error instead of failing silently
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Optional: log error instead of failing silently

print(f"Could not delete {file}: {e}")


def Model(
model=None,
url=None,
force_download=False,
*args,
**kwargs,
):
Comment on lines +144 to +150
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have better typing here.

"""
Returns the loaded model object
Note: This feature is in its experimental phase.
Expand Down Expand Up @@ -137,6 +175,26 @@ def Model(model: str, *args, **kwargs):
>>> pybamm.Model('SPM') # doctest: +SKIP
<pybamm.models.full_battery_models.lithium_ion.spm.SPM object>
"""
model_class = models._get_class(model)

return model_class(*args, **kwargs)
if (model is None and url is None) or (model and url):
raise ValueError("You must provide exactly one of `model` or `url`.")

if url is not None:
cache_path = get_cache_path(url)
if not cache_path.exists() or force_download:
try:
print(f"Downloading model from {url}...")
urllib.request.urlretrieve(url, cache_path)
print(f"Model cached at: {cache_path}")
except Exception as e:
raise RuntimeError(f"Failed to download model from URL: {e}") from e
else:
print(f"Using cached model at: {cache_path}")

return Serialise.load_custom_model(str(cache_path))

if model is not None:
try:
model_class = models._get_class(model)
return model_class(*args, **kwargs)
except Exception as e:
raise ValueError(f"Could not load model '{model}': {e}") from e
122 changes: 121 additions & 1 deletion tests/unit/test_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
#
# Test dispatching mechanism in entry points
#
import pytest

import pybamm
from pybamm.dispatch.entry_points import (
_get_cache_dir,
clear_model_cache,
get_cache_path,
)

MODEL_URL = "https://raw.githubusercontent.com/pybamm-team/pybamm-reservoir-example/refs/heads/main/dfn.py"


class TestDispatch:
def setup_method(self):
"""Clear cache before each test"""
clear_model_cache()

def test_model_loads_through_entry_points(self):
"""Test that a model loaded through Model() function is actually functional"""
# Load model with build=False to avoid full initialization for faster testing
Expand All @@ -31,5 +44,112 @@ def test_model_function(self):

# Test Model function with options parameter
options = {"thermal": "isothermal"}
model = pybamm.Model("SPM", options)
model = pybamm.Model("SPM", options=options)
assert model.__class__.__name__ == "SPM"

def test_model_value_error(self):
"""Test that Model raises ValueError when given invalid arguments"""

# Neither model nor url provided
with pytest.raises(
ValueError, match="You must provide exactly one of `model` or `url`."
):
pybamm.Model()

# Both model and url provided
with pytest.raises(
ValueError, match="You must provide exactly one of `model` or `url`."
):
pybamm.Model(model="SPM", url="http://example.com/dfn.py")

def test_model_download_runtime_error(self):
"""Test that Model raises RuntimeError when download fails"""

bad_url = "h://example.invalid/model.json"

with pytest.raises(RuntimeError, match="Failed to download model from URL:"):
pybamm.Model(url=bad_url, force_download=True)

def test_invalid_model_name_raises_value_error(self):
"""Test that Model raises ValueError for an invalid model name"""

bad_model = "NonExistentModel123"

with pytest.raises(ValueError, match=f"Could not load model '{bad_model}':"):
pybamm.Model(model=bad_model)

def test_model_download_and_cache(self):
"""Force exception in clear_model_cache without interfering with real cache files"""
cache_dir = _get_cache_dir()
bad_path = cache_dir / "force_exception_test_dir"

try:
bad_path.mkdir(exist_ok=True) # not .json
try:
bad_path.unlink()
except Exception as e:
print(f"Expected error: {e}")
finally:
if bad_path.exists():
bad_path.rmdir()

def test_force_download_overwrites_cache(self):
"""Force an exception when trying to unlink a directory"""
cache_dir = _get_cache_dir()
bad_path = cache_dir / "force_exception_test_dir"

try:
bad_path.mkdir(exist_ok=True)
clear_model_cache()
try:
bad_path.unlink()
except Exception as e:
print(f"Expected error: {e}")
finally:
if bad_path.exists():
bad_path.rmdir()

def test_clear_model_cache_exception_branch(self, capsys):
"""Test that clear_model_cache gracefully handles deletion errors"""
cache_dir = pybamm.dispatch.entry_points._get_cache_dir()
cache_dir.mkdir(parents=True, exist_ok=True)

bad_path = cache_dir / "bad.json"
bad_path.mkdir(exist_ok=True)

try:
clear_model_cache()

captured = capsys.readouterr()
assert "Could not delete" in captured.out
assert "bad.json" in captured.out

assert bad_path.exists()
finally:
bad_path.rmdir()

def test_model_download_and_cache_integration(self, capsys):
"""Integration test using a real model URL"""

cache_path = get_cache_path(MODEL_URL)

# Clean up: if cache path exists as directory, remove it
if cache_path.exists():
if cache_path.is_dir():
import shutil

shutil.rmtree(cache_path)
else:
clear_model_cache()

# First call -> should download and print "Model cached at"
model = pybamm.Model(url=MODEL_URL, force_download=True)
captured = capsys.readouterr()
assert "Model cached at:" in captured.out
assert hasattr(model, "name")

# Second call -> should use cached file
model2 = pybamm.Model(url=MODEL_URL)
captured = capsys.readouterr()
assert "Using cached model at:" in captured.out
assert hasattr(model2, "name")
Loading