diff --git a/setup.cfg b/setup.cfg index be7dc3f..970d96f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -38,6 +38,7 @@ install_requires = tqdm requests psutil + platformdirs # zarr>=3 # Will need a python 3.11+ version python_requires = >=3.10 include_package_data = True diff --git a/trackastra/model/model_api.py b/trackastra/model/model_api.py index 243bfaf..71abc51 100644 --- a/trackastra/model/model_api.py +++ b/trackastra/model/model_api.py @@ -172,7 +172,7 @@ def from_pretrained( Args: name: Name of pretrained model (e.g. "general_2d"). device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). - download_dir: Directory to download model to (defaults to ~/.cache/trackastra). + download_dir: Directory to download model. Default handled by platformdirs. Returns: Trackastra model instance. diff --git a/trackastra/model/pretrained.py b/trackastra/model/pretrained.py index 767da6a..eb8f7aa 100644 --- a/trackastra/model/pretrained.py +++ b/trackastra/model/pretrained.py @@ -2,10 +2,10 @@ import shutil import tempfile import zipfile -from importlib.resources import files from pathlib import Path import requests +from platformdirs import user_data_dir from tqdm import tqdm logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def download(url: str, fname: Path): def download_pretrained(name: str, download_dir: Path | None = None): # TODO make safe, introduce versioning if download_dir is None: - download_dir = files("trackastra").joinpath(".models") + download_dir = Path(user_data_dir("trackastra")) / "models" else: download_dir = Path(download_dir)