-
-
Notifications
You must be signed in to change notification settings - Fork 713
[GSoC 2025] Load model JSON files from URLs #5137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
f55a0a9
0b3d671
6ee08b0
5112f8b
ed0c714
3e3658b
fc437a1
00ff44f
4214681
71ae0c4
51bf407
c39d1c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,6 +1,16 @@ | ||||||||||||||||||||||||||||||
| import hashlib | ||||||||||||||||||||||||||||||
| import importlib.metadata | ||||||||||||||||||||||||||||||
| import textwrap | ||||||||||||||||||||||||||||||
| import urllib.request | ||||||||||||||||||||||||||||||
| from collections.abc import Callable, Mapping | ||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from platformdirs import user_cache_dir | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| from pybamm.expression_tree.operations.serialise import Serialise | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| APP_NAME = "pybamm" | ||||||||||||||||||||||||||||||
| APP_AUTHOR = "pybamm" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class EntryPoint(Mapping): | ||||||||||||||||||||||||||||||
|
|
@@ -109,7 +119,35 @@ def __getattribute__(self, name): | |||||||||||||||||||||||||||||
| models = EntryPoint(group="pybamm_models") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def Model(model: str, *args, **kwargs): | ||||||||||||||||||||||||||||||
| def _get_cache_dir() -> Path: | ||||||||||||||||||||||||||||||
| cache_dir = Path(user_cache_dir(APP_NAME, APP_AUTHOR)) / "models" | ||||||||||||||||||||||||||||||
| cache_dir.mkdir(parents=True, exist_ok=True) | ||||||||||||||||||||||||||||||
| return cache_dir | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def get_cache_path(url: str) -> Path: | ||||||||||||||||||||||||||||||
| cache_dir = _get_cache_dir() | ||||||||||||||||||||||||||||||
| file_hash = hashlib.md5(url.encode()).hexdigest() | ||||||||||||||||||||||||||||||
| return cache_dir / f"{file_hash}.json" | ||||||||||||||||||||||||||||||
|
Comment on lines
+128
to
+131
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why MD5 and not SHA-256? :) |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def clear_model_cache() -> None: | ||||||||||||||||||||||||||||||
| cache_dir = _get_cache_dir() | ||||||||||||||||||||||||||||||
| for file in cache_dir.glob("*.json"): | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| file.unlink() | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
|
Comment on lines
+134
to
+139
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note, we also store some of PyBaMM's data files in the cache dir using a PyBaMM/src/pybamm/pybamm_data.py Lines 126 to 139 in a1aa02c
I wonder if we could reuse some of the code here, because clearing the cache directory of JSON files could have unintended side effects. We do have JSON files there: https://github.com/pybamm-team/pybamm-data/releases/tag/v1.0.1 Could we perhaps rewrite the PR to use |
||||||||||||||||||||||||||||||
| # Optional: log error instead of failing silently | ||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||
| print(f"Could not delete {file}: {e}") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def Model( | ||||||||||||||||||||||||||||||
| model=None, | ||||||||||||||||||||||||||||||
| url=None, | ||||||||||||||||||||||||||||||
| force_download=False, | ||||||||||||||||||||||||||||||
| *args, | ||||||||||||||||||||||||||||||
| **kwargs, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
|
Comment on lines
+144
to
+150
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can have better typing here. |
||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| Returns the loaded model object | ||||||||||||||||||||||||||||||
| Note: This feature is in its experimental phase. | ||||||||||||||||||||||||||||||
|
|
@@ -137,6 +175,26 @@ def Model(model: str, *args, **kwargs): | |||||||||||||||||||||||||||||
| >>> pybamm.Model('SPM') # doctest: +SKIP | ||||||||||||||||||||||||||||||
| <pybamm.models.full_battery_models.lithium_ion.spm.SPM object> | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| model_class = models._get_class(model) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return model_class(*args, **kwargs) | ||||||||||||||||||||||||||||||
| if (model is None and url is None) or (model and url): | ||||||||||||||||||||||||||||||
| raise ValueError("You must provide exactly one of `model` or `url`.") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if url is not None: | ||||||||||||||||||||||||||||||
| cache_path = get_cache_path(url) | ||||||||||||||||||||||||||||||
| if not cache_path.exists() or force_download: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| print(f"Downloading model from {url}...") | ||||||||||||||||||||||||||||||
| urllib.request.urlretrieve(url, cache_path) | ||||||||||||||||||||||||||||||
| print(f"Model cached at: {cache_path}") | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| raise RuntimeError(f"Failed to download model from URL: {e}") from e | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| print(f"Using cached model at: {cache_path}") | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| return Serialise.load_custom_model(str(cache_path)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| if model is not None: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| model_class = models._get_class(model) | ||||||||||||||||||||||||||||||
| return model_class(*args, **kwargs) | ||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||
| raise ValueError(f"Could not load model '{model}': {e}") from e | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can keep just
APP_NAMEand dropAPP_AUTHOR, and we can setPath(user_cache_dir(APP_NAME)) / "models"as a constant here instead.