Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ install_requires =
tqdm
requests
psutil
platformdirs
# zarr>=3 # Will need a python 3.11+ version
python_requires = >=3.10
include_package_data = True
Expand Down
2 changes: 1 addition & 1 deletion trackastra/model/model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def from_pretrained(
Args:
name: Name of pretrained model (e.g. "general_2d").
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
download_dir: Directory to download model. Default handled by platformdirs.

Returns:
Trackastra model instance.
Expand Down
4 changes: 2 additions & 2 deletions trackastra/model/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import shutil
import tempfile
import zipfile
from importlib.resources import files
from pathlib import Path

import requests
from platformdirs import user_data_dir
from tqdm import tqdm

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,7 +60,7 @@ def download(url: str, fname: Path):
def download_pretrained(name: str, download_dir: Path | None = None):
# TODO make safe, introduce versioning
if download_dir is None:
download_dir = files("trackastra").joinpath(".models")
download_dir = Path(user_data_dir("trackastra")) / "models"
else:
download_dir = Path(download_dir)

Expand Down
Loading