diff --git a/README.md b/README.md index 95403f5..969483e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,91 @@ -# virtual_stain_flow -For developing virtual staining models +# `virtual_stain_flow` - For developing virtual staining models + +## Overview +`virtual_stain_flow` is a framework for the reproducible development and training of image-to-image translation models that enable virtual staining (the prediction of "virtual" stains) from label-free microscopy images. +The package provides comprehensive experiment tracking that spans the entire model development workflow, from dataset construction, augmentation, model customization to training. + +--- + +## Supported Model Architectures + +- **U-Net** – A classical encoder–decoder architecture with skip connections. + This lightweight design has been widely used for virtual staining tasks, as demonstrated by [Ounkomol et al., 2018](https://doi.org/10.1038/s41592-018-0111-2). + +- **wGAN-GP** – A Wasserstein GAN with Gradient Penalty. + This generative adversarial setup combines a U-Net generator with a convolutional discriminator regularized via gradient penalty for stable training. + As shown by [Cross-Zamirski et al., 2022](https://doi.org/10.1038/s41598-022-12914-x), adversarial training enhances the realism of synthetic stains. + +- **ConvNeXt-UNet** – A fully convolutional architecture inspired by recent computer vision advances. + Drawing from [Liu et al., 2022](https://doi.org/10.48550/arXiv.2201.03545) and [Liu et al., 2025](https://doi.org/10.1038/s42256-025-01046-2), this variant incorporates transformer-like architectural refinements to improve the fidelity of virtual staining details, at the cost of higher computational demand. + +### Showcasing of model prediction + +![Input/Target/Prediction](./assets/epoch_299.png) +**Prediction (DNA, Hoechst 33342)** generated by `virtual_stain_flow` using the ConvNeXt-UNet model, from brightfield microscopy images of the U2-OS cell line. + +## Core Components + +- **[datasets/](./src/virtual_stain_flow/datasets/)** - Data loading and preprocessing pipelines +- **[models/](./src/virtual_stain_flow/models/)** - Virtual staining models and building blocks +- **[trainers/](./src/virtual_stain_flow/trainers/)** - Training loops +- **[transforms/](./src/virtual_stain_flow/transforms/)** - Image normalization and augmnentation +- **[vsf_logging/](./src/virtual_stain_flow/vsf_logging/)** - Experiment tracking and logging + +--- + +## Quick Start + +Check out the **[examples/](./examples/)** directory for complete training scripts and tutorials demonstrating various use cases and configurations. + +Defining a UNet model generating staining of 3 target channels from 1 input (phase here). +- Here the `compute_block` is set as the coventional Conv2D > Normalize > ReLU. +Alternative compute blocks are avaiable in this package, including the `Conv2DConvNeXtBlock`. +```python +from virtual_stain_flow.models import UNet, Conv2DNormActBlock + +model = UNet( + input_channels=1, + output_channels=3, + comp_block=Conv2DNormActBlock, + _num_units=2, + depth=4 +) +``` + + +Building a dataset tailored to the UNet model specification. +- The input channel is configured as `['phase']` which matches the `input_channels=1` of the model specification. +- Likewise, the target channel is configured as `["dapi", "tubulin", "actin"]` to match the `output_channels=3` model setting. +- Note these channel keys must exist in the supplied `file_index_df` as a column of image filepaths. +- Specify the appropriate post-processing to match image bit depth and/or model output activation. +By default all models activate output with sigmoid and a maxscale normalization normalizing by max pixel value is used. +```python +from virtual_stain_flow.datasets import BaseImageDataset +from virtual_stain_flow.transforms import MaxScaleNormalize + +dataset = BaseImageDataset( + file_index=file_index_df, + input_channel_keys="phase", + target_channel_keys=["dapi", "tubulin", "actin"], + transform=MaxScaleNormalize(normalization_factor='16bit') +) + +from virtual_stain_flow.trainers import Trainer +from virtual_stain_flow.vsf_logging import MlflowLogger + +logger = MlflowLogger(experiment_name="virtual_staining") +trainer = Trainer(model, dataset) +trainer.train(logger=logger) +``` + +## Installation + +```bash +pip install virtual-stain-flow +``` + +--- + +## License + +See the [LICENSE](LICENSE) file for full details. diff --git a/assets/epoch_299.png b/assets/epoch_299.png new file mode 100644 index 0000000..7f33f9e Binary files /dev/null and b/assets/epoch_299.png differ