Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,78 @@
# virtual_stain_flow
For developing virtual staining models
# `virtual_stain_flow` - For developing virtual staining models

## Overview
`virtual_stain_flow` is a framework for the reproducible development and training of image-to-image translation models that enable virtual staining (the prediction of "virtual" stains) from label-free microscopy images.
The package provides comprehensive experiment tracking that spans the entire model development workflow, from dataset construction, augmetnation, model customization to training.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The package provides comprehensive experiment tracking that spans the entire model development workflow, from dataset construction, augmetnation, model customization to training.
The package provides comprehensive experiment tracking that spans the entire model development workflow, from dataset construction, augmentation, model customization to training.


---

## Supported Model Architectures

- **U-Net** – A classical encoder–decoder architecture with skip connections.
This lightweight design has been widely used for virtual staining tasks, as demonstrated by [Ounkomol et al., 2018](https://doi.org/10.1038/s41592-018-0111-2).

- **wGAN-GP** – A Wasserstein GAN with Gradient Penalty.
This generative adversarial setup combines a U-Net generator with a convolutional discriminator regularized via gradient penalty for stable training.
As shown by [Cross-Zamirski et al., 2022](https://doi.org/10.1038/s41598-022-12914-x), adversarial training enhances the realism of synthetic stains.

- **ConvNeXt-UNet** – A fully convolutional architecture inspired by recent computer vision advances.
Drawing from [Liu et al., 2022](https://doi.org/10.48550/arXiv.2201.03545) and [Liu et al., 2025](https://doi.org/10.1038/s42256-025-01046-2), this variant incorporates transformer-like architectural refinements to improve the fidelity of virtual staining details, at the cost of higher computational demand.

### Showcasing of model prediction

![Input/Target/Prediction](./assets/epoch_299.png)
**Prediction (DNA, Hoechst 33342)** generated by `virtual_stain_flow` using the ConvNeXt-UNet model, from brightfield microscopy images of the U2-OS cell line.

## Core Components

- **[datasets/](./src/virtual_stain_flow/datasets/)** - Data loading and preprocessing pipelines
- **[models/](./src/virtual_stain_flow/models/)** - Virtual staining models and building blocks
- **[trainers/](./src/virtual_stain_flow/trainers/)** - Training loops
- **[transforms/](./src/virtual_stain_flow/transforms/)** - Image normalization and augmnentation
- **[vsf_logging/](./src/virtual_stain_flow/vsf_logging/)** - Experiment tracking and logging

---

## Quick Start

Check out the **[examples/](./examples/)** directory for complete training scripts and tutorials demonstrating various use cases and configurations.

```python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like two separable blocks of code. Consider describing each with a bit more specific detail, separately

from virtual_stain_flow.models import UNet, Conv2DNormActBlock

model = UNet(
input_channels=1,
output_channels=3,
comp_block=Conv2DNormActBlock,
depth=4
)

from virtual_stain_flow.datasets import BaseImageDataset
from virtual_stain_flow.transforms import MaxScaleNormalize

dataset = BaseImageDataset(
file_index=file_index_df,
input_channel_keys="phase",
target_channel_keys=["dapi", "tubulin", "actin"],
transform=MaxScaleNormalize(normalization_factor='16bit')
)

from virtual_stain_flow.trainers import Trainer
from virtual_stain_flow.vsf_logging import MlflowLogger

logger = MlflowLogger(experiment_name="virtual_staining")
trainer = Trainer(model, dataset)
trainer.train(logger=logger)
```

## Installation

```bash
pip install virtual-stain-flow
```

---

## License

See the [LICENSE](LICENSE) file for full details.
Binary file added assets/epoch_299.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.