Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,35 @@ jobs:
run: pipx run pre-commit run --all-files

test-core:
name: Core Tests - ${{ matrix.platform }} py${{ matrix.python-version }} ilp=${{ matrix.ilp }}
name: Core Tests - ${{ matrix.platform }} py${{ matrix.python-version }} ilp=${{ matrix.ilp }} sam2=${{ matrix.sam2 }}
runs-on: ${{ matrix.platform }}
defaults:
run:
shell: bash -l {0}
strategy:
fail-fast: false
fail-fast: true
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
platform: [ubuntu-latest]
ilp: ["false"]
sam2: ["false"]
include:
- platform: windows-latest
python-version: "3.10"
ilp: "false"
sam2: "false"
- platform: macos-latest
python-version: "3.10"
ilp: "false"
sam2: "false"
- platform: ubuntu-latest
python-version: "3.10"
ilp: "true"
sam2: "false"
- platform: macos-latest
python-version: "3.12"
ilp: "false"
sam2: "true"

steps:
- uses: actions/checkout@v4
Expand All @@ -64,6 +72,10 @@ jobs:
- name: Install package with core dependencies
run: python -m pip install -e .[test]

- name: Install additional dependencies for SAM2 features tests
if: matrix.sam2 == 'true'
run: python -m pip install -e .[test,etultra]

- name: Run core tests
run: pytest tests -v -s --color=yes -m core

Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,6 @@ cython_debug/
#.idea/

trackastra/_version.py
# Used by pre-trained feature extractors
embeddings/
tests/embeddings/
58 changes: 34 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,34 @@


[![PyPI](https://img.shields.io/pypi/v/trackastra)](https://pypi.org/project/trackastra/)
![Python](https://img.shields.io/pypi/pyversions/trackastra)
![Build](https://img.shields.io/github/actions/workflow/status/weigertlab/trackastra/python-package-conda.yml?branch=main)
[![Python](https://img.shields.io/pypi/pyversions/trackastra)](https://pypi.org/project/trackastra/)
[![Build](https://img.shields.io/github/actions/workflow/status/weigertlab/trackastra/python-package-conda.yml?branch=main)](https://github.com/weigertlab/trackastra/actions/workflows/python-package-conda.yml)
[![Downloads](https://static.pepy.tech/badge/trackastra)](https://pepy.tech/project/trackastra)
[![License](https://img.shields.io/github/license/weigertlab/trackastra)](https://github.com/weigertlab/trackastra/blob/main/LICENSE)


</div>


# *Trackastra* - Tracking by Association with Transformers


*Trackastra* is a cell tracking approach that links already segmented cells in a microscopy timelapse by predicting associations with a transformer model that was trained on a diverse set of microscopy videos.
*Trackastra* is a cell tracking approach that links already segmented cells in a microscopy timelapse by predicting associations with a transformer model. It comes with [pretrained models](trackastra/model/pretrained.json) that perform well out of the box for many types of live-cell imaging data.

![Overview](overview.png)

## Reference

Paper: [Trackastra: Transformer-based cell tracking for live-cell microscopy](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf)
Give it a try with our easy-to-use [napari plugin](https://github.com/weigertlab/napari-trackastra/)!

```
@inproceedings{gallusser2024trackastra,
title={Trackastra: Transformer-based cell tracking for live-cell microscopy},
author={Gallusser, Benjamin and Weigert, Martin},
booktitle={European conference on computer vision},
pages={467--484},
year={2024},
organization={Springer}
}
```
<p align="center">
<img src="https://github.com/weigertlab/napari-trackastra/assets/8866751/097eb82d-0fef-423e-9275-3fb528c20f7d" alt="demo" width="80%">
</p>

## Examples
## Example tracking results
Nuclei tracking | Bacteria tracking
:-: | :-:
<video src='https://github.com/weigertlab/trackastra/assets/8866751/807a8545-2f65-4697-a175-89b90dfdc435' width=180></video>| <video src='https://github.com/weigertlab/trackastra/assets/8866751/e7426d34-4407-4acb-ad79-fae3bc7ee6f9' width=180/></video>

## Installation
This repository contains the Python implementation of Trackastra.
If you use Trackastra in your research, please cite [our paper](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf) ([BibTeX](#reference)).

Please first set up a Python environment (with Python version 3.10 or higher), preferably via [conda](https://conda.io/projects/conda/en/latest/user-guide/install/index.html) or [mamba](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html#mamba-install).

Expand All @@ -61,6 +52,14 @@ conda install -c conda-forge -c gurobi -c funkelab ilpy
pip install "trackastra[ilp]"
```

### 🆕😎 With pretrained features

For our [new model variant](https://github.com/C-Achard/Trackastra-et-Ultra) that uses SAM2 features improves tracking performance on certain data, for example for tracking bacteria:
```bash
pip install "trackastra[etultra]"
```
and select the `general_2d_w_SAM2_features` pre-trained model for predictions, preferably on a machine with a GPU (slow on CPU!).

### Installation with training support
```bash
pip install "trackastra[train]"
Expand Down Expand Up @@ -94,11 +93,7 @@ pip install -e "./trackastra[all]"

## Usage: Tracking with a pretrained model

The input to Trackastra is a sequence of images and their corresponding cell (instance) segmentations.

![demo](https://github.com/weigertlab/napari-trackastra/assets/8866751/097eb82d-0fef-423e-9275-3fb528c20f7d)

> The available pretrained models are described in detail [here](trackastra/model/pretrained.json).
The input to Trackastra is a sequence of images and their corresponding cell (instance) segmentations. The available pretrained models are described in detail [here](trackastra/model/pretrained.json).

Tracking with Trackastra can be done via:

Expand Down Expand Up @@ -241,3 +236,18 @@ python train.py --config example_config.yaml
```

Generally, training data needs to be provided in the [Cell Tracking Challenge (CTC) format](http://public.celltrackingchallenge.net/documents/Naming%20and%20file%20content%20conventions.pdf), i.e. annotations are located in a folder containing one or several subfolders named `TRA`, with masks and tracklet information.

## Reference

Paper: [Trackastra: Transformer-based cell tracking for live-cell microscopy](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/09819.pdf)

```bibtex
@inproceedings{gallusser2024trackastra,
title={Trackastra: Transformer-based cell tracking for live-cell microscopy},
author={Gallusser, Benjamin and Weigert, Martin},
booktitle={European conference on computer vision},
pages={467--484},
year={2024},
organization={Springer}
}
```
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ all =
trackastra[train,ilp,dev]
test =
pytest
etultra =
trackastra_pretrained_feats @ git+https://github.com/bentaculum/Trackastra-et-Ultra.git#egg=trackastra_pretrained_feats

[options.entry_points]
console_scripts =
Expand Down
62 changes: 60 additions & 2 deletions tests/test_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@

import pytest
import torch
from trackastra.data import example_data_hela
from trackastra.data import example_data_bacteria, example_data_hela
from trackastra.model import Trackastra

try:
import trackastra_pretrained_feats # noqa: F401

SAM2_TEST = True
except ModuleNotFoundError:
SAM2_TEST = False


# Mark all tests in this module as core/inference tests
pytestmark = pytest.mark.core

Expand Down Expand Up @@ -69,5 +77,55 @@ def test_integration(name, device, batch_size):
assert (len(track_graph.edges), len(track_graph.nodes)) == length_edges_nodes[name]


def limit_mps_memory(target=5 * 2**30):
"""Limit MPS memory usage to target bytes."""
if torch.backends.mps.is_available():
# Get Metal's recommended max working set (bytes)j
max_bytes = torch.mps.recommended_max_memory()
print(f"Recommended max MPS memory: {max_bytes / 2**30:.2f} GB.")
max_bytes = int(
max_bytes / 0.8
) # 80% is recommended, this gets the full amount

fraction = target / max_bytes
if fraction > 1.0:
raise ValueError(
f"Target memory limit ({target / 2**30:.2f} GB) exceeds maximum available memory ({max_bytes / 2**30:.2f} GB)."
)
else:
torch.mps.set_per_process_memory_fraction(fraction)
print(
f"Set MPS memory limit to {target / 2**30:.2f} GB ({fraction * 100:.1f}% of max)."
)


@pytest.mark.skipif(
not SAM2_TEST, reason="Package for using SAM2 features not installed"
)
@pytest.mark.parametrize("device", ["mps", "cuda"])
@pytest.mark.parametrize("batch_size", [1])
def test_integration_SAM2(device, batch_size):
"""
Test that the number of edges and nodes in the track graph is consistent with the pretrained model.
"""

if device == "cuda" and not torch.cuda.is_available():
pytest.skip("cuda not available")
elif device == "mps":
if not torch.backends.mps.is_available():
pytest.skip("mps not available")
limit_mps_memory()

model = Trackastra.from_pretrained(
name="general_2d_w_SAM2_features",
device=device,
)
imgs, masks = example_data_bacteria()
track_graph, _ = model.track(imgs, masks, batch_size=batch_size)

assert len(track_graph.edges) == 126
assert len(track_graph.nodes) == 128


if __name__ == "__main__":
test_integration("ctc", "mps", 3)
test_integration_SAM2("mps", 1)
13 changes: 12 additions & 1 deletion trackastra/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,9 +1462,16 @@ def collate_sequence_padding(batch):
normal_keys = {
"coords": 0,
"features": 0,
"pretrained_feats": 0,
"labels": 0, # Not needed, remove for speed.
"timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that.
}
set_keys = {
k: v
for k, v in normal_keys.items()
if k in batch[0] and batch[0][k] is not None
}
none_keys = [k for k in normal_keys.keys() if k in batch[0] and batch[0][k] is None]
n_pads = tuple(n_max_len - s for s in lens)
batch_new = dict(
(
Expand All @@ -1473,8 +1480,12 @@ def collate_sequence_padding(batch):
[pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0
),
)
for k, v in normal_keys.items()
for k, v in set_keys.items()
)
for k in (
none_keys
): # keys that are None or not present in the input dicts are set to None
batch_new[k] = None
if "assoc_matrix" in batch[0]:
batch_new["assoc_matrix"] = torch.stack(
[
Expand Down
61 changes: 56 additions & 5 deletions trackastra/data/wrfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from functools import reduce
from typing import Literal
from typing import TYPE_CHECKING, Literal, Optional

import joblib
import numpy as np
Expand All @@ -19,6 +19,15 @@

from trackastra.data.utils import load_tiff_timeseries

try:
PRETRAINED_FEATS_INSTALLED = True
if TYPE_CHECKING:
from trackastra_pretrained_feats import FeatureExtractor
except ImportError:
PRETRAINED_FEATS_INSTALLED = False
if TYPE_CHECKING:
FeatureExtractor = None # type: ignore

logger = logging.getLogger(__name__)

_PROPERTIES = {
Expand Down Expand Up @@ -133,7 +142,23 @@ def __repr__(self):

@property
def features_stacked(self):
return np.concatenate([v for k, v in self.features.items()], axis=-1)
# Do not include pretrained_feats here
# They are handled separately and should not be added to shallow features
if not self.features or (
len(self.features) == 1 and "pretrained_feats" in self.features
):
return None
feats = np.concatenate(
[v for k, v in self.features.items() if k != "pretrained_feats"], axis=-1
)
return feats

@property
def pretrained_feats(self):
# for compatibility with WRPretrainedFeatures
if "pretrained_feats" in self.features:
return self.features["pretrained_feats"]
return None

def __len__(self):
return len(self.labels)
Expand Down Expand Up @@ -504,14 +529,17 @@ def __call__(self, feats: WRFeatures):
def get_features(
detections: np.ndarray,
imgs: np.ndarray | None = None,
features: Literal["none", "wrfeat"] = "wrfeat",
features: Literal[
"none", "wrfeat", "pretrained_feats", "pretrained_feats_aug"
] = "wrfeat",
ndim: int = 2,
n_workers=0,
progbar_class=tqdm,
feature_extractor: Optional["FeatureExtractor"] | None = None,
) -> list[WRFeatures]:
detections = _check_dimensions(detections, ndim)
imgs = _check_dimensions(imgs, ndim)
logger.info(f"Extracting features from {len(detections)} detections")
logger.info(f"Extracting features from {len(detections)} frames.")
if n_workers > 0:
logger.info(f"Using {n_workers} processes for feature extraction")
features = joblib.Parallel(n_jobs=n_workers, backend="loky")(
Expand All @@ -527,6 +555,20 @@ def get_features(
desc="Extracting features",
)
)
elif features == "pretrained_feats" or features == "pretrained_feats_aug":
feature_extractor.precompute_image_embeddings(imgs)
from trackastra_pretrained_feats import WRPretrainedFeatures

features = [
WRPretrainedFeatures.from_mask_img(
img=img[np.newaxis],
mask=mask[np.newaxis],
feature_extractor=feature_extractor,
t_start=t,
additional_properties=feature_extractor.additional_features,
)
for t, (mask, img) in enumerate(zip(detections, imgs))
]
else:
logger.info("Using single process for feature extraction")
features = tuple(
Expand Down Expand Up @@ -572,7 +614,12 @@ def build_windows(
desc="Building windows",
):
feat = WRFeatures.concat(features[t1:t2])

try:
pt_feats = (
feat.pretrained_feats if feat.pretrained_feats is not None else None
)
except AttributeError:
pt_feats = None
labels = feat.labels
timepoints = feat.timepoints
coords = feat.coords
Expand All @@ -589,6 +636,10 @@ def build_windows(
if as_torch
else feat.features_stacked,
)
# Add pre-trained features
if pt_feats is not None:
w["pretrained_feats"] = torch.from_numpy(pt_feats) if as_torch else pt_feats

windows.append(w)

logger.debug(f"Built {len(windows)} track windows.\n")
Expand Down
Loading
Loading