diff --git a/.github/workflows/build-main.yml b/.github/workflows/build-main.yml index 7731c2a..e85cac7 100644 --- a/.github/workflows/build-main.yml +++ b/.github/workflows/build-main.yml @@ -15,7 +15,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: [3.9, '3.10', 3.11] + python-version: ['3.10', '3.11'] os: [ ubuntu-latest, windows-latest, @@ -36,17 +36,17 @@ jobs: run: | pytest --cov-report xml --cov=mmv_im2im mmv_im2im/tests/ - name: Upload codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v4 lint: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9.15 + python-version: 3.11 - name: Install Dependencies run: | python -m pip install --upgrade pip @@ -68,7 +68,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9 + python-version: 3.11 - name: Install Dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/test-and-lint.yml b/.github/workflows/test-and-lint.yml index 86478e5..0bb5e5b 100644 --- a/.github/workflows/test-and-lint.yml +++ b/.github/workflows/test-and-lint.yml @@ -7,7 +7,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - python-version: [3.9, '3.10', 3.11] + python-version: ['3.10', '3.11'] os: [ ubuntu-latest, windows-latest, @@ -28,17 +28,17 @@ jobs: run: | pytest mmv_im2im/tests/ - name: Upload codecov - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v4 lint: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.9.15 + python-version: 3.11 - name: Install Dependencies run: | python -m pip install --upgrade pip diff --git a/README.md b/README.md index 14a4787..c4dc8ef 100644 --- a/README.md +++ b/README.md @@ -19,14 +19,19 @@ The overall package is designed with a generic image-to-image transformation fra ## Installation -Before starting, we recommend to [create a new conda environment](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands) or [a virtual environment](https://docs.python.org/3/library/venv.html) with Python 3.9+. +Before starting, we recommend to [create a new conda environment](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands) or [a virtual environment](https://docs.python.org/3/library/venv.html) with Python 3.10+. + +```bash +conda create -y -n im2im -c conda-forge python=3.11 +conda activate im2im +``` Please note that the proper setup of hardware is beyond the scope of this pacakge. This package was tested with GPU/CPU on Linux/Windows and CPU on MacOS. [Special note for MacOS users: Directly pip install in MacOS may need [additional setup of xcode](https://developer.apple.com/forums/thread/673827).] ### Install MONAI To reproduce our results, we need to install MONAI's code version of a specific commit. To do this: -``` +```bash git clone https://github.com/Project-MONAI/MONAI.git cd ./MONAI git checkout 37b58fcec48f3ec1f84d7cabe9c7ad08a93882c0 @@ -49,7 +54,7 @@ For MacOS users, additional ' ' marks are need when using installation tags in z ### Install MMV_Im2Im for customization or extension: -``` +```bash git clone https://github.com/MMV-Lab/mmv_im2im.git cd mmv_im2im pip install -e .[all] @@ -71,10 +76,10 @@ You can try out on a simple example following [the quick start guide](tutorials/ Basically, you can specify your training configuration in a yaml file and run training with `run_im2im --config /path/to/train_config.yaml`. Then, you can specify the inference configuration in another yaml file and run inference with `run_im2im --config /path/to/inference_config.yaml`. You can also run the inference as a function with the provided API. This will be useful if you want to run the inference within another python script or workflow. Here is an example: -``` +```python from pathlib import Path -from aicsimageio import AICSImage -from aicsimageio.writers import OmeTiffWriter +from bioio import BioImage +from bioio.writers import OmeTiffWriter from mmv_im2im.configs.config_base import ProgramConfig, parse_adaptor, configuration_validation from mmv_im2im import ProjectTester @@ -89,9 +94,9 @@ executor.setup_data_processing() # get the data, run inference, and save the result fn = Path("./data/img_00_IM.tiff") -img = AICSImage(fn).get_image_data("YX", Z=0, C=0, T=0) +img = BioImage(fn).get_image_data("YX", Z=0, C=0, T=0) # or using delayed loading if the data is large -# img = AICSImage(fn).get_image_dask_data("YX", Z=0, C=0, T=0) +# img = BioImage(fn).get_image_dask_data("YX", Z=0, C=0, T=0) seg = executor.process_one_image(img) OmeTiffWriter.save(seg, "output.tiff", dim_orders="YX") ``` diff --git a/docs/conf.py b/docs/conf.py old mode 100755 new mode 100644 diff --git a/mmv_im2im/configs/config_base.py b/mmv_im2im/configs/config_base.py index 75e5095..3d2a531 100644 --- a/mmv_im2im/configs/config_base.py +++ b/mmv_im2im/configs/config_base.py @@ -405,9 +405,9 @@ def configuration_validation(cfg): cfg.data.dataloader.train.dataset_params["cache_dir"] != cfg.data.dataloader.val.dataset_params["cache_dir"] ): - cfg.data.dataloader.val.dataset_params[ - "cache_dir" - ] = cfg.data.dataloader.train.dataset_params["cache_dir"] + cfg.data.dataloader.val.dataset_params["cache_dir"] = ( + cfg.data.dataloader.train.dataset_params["cache_dir"] + ) warnings.warn( UserWarning( "The cache dir of PersistentDataset for validation was" diff --git a/mmv_im2im/data_modules/data_loader_basic.py b/mmv_im2im/data_modules/data_loader_basic.py index dc9fb6d..71aca23 100644 --- a/mmv_im2im/data_modules/data_loader_basic.py +++ b/mmv_im2im/data_modules/data_loader_basic.py @@ -106,7 +106,7 @@ def train_dataloader(self): train_dataset = train_dataset_func( data=train_data, transform=self.transform, - **train_loader_info.dataset_params + **train_loader_info.dataset_params, ) else: train_dataset = train_dataset_func( @@ -117,7 +117,7 @@ def train_dataloader(self): train_dataset, shuffle=True, collate_fn=list_data_collate, - **train_loader_info.dataloader_params + **train_loader_info.dataloader_params, ) return train_dataloader @@ -130,7 +130,7 @@ def val_dataloader(self): val_dataset = val_dataset_func( data=self.val_data, transform=self.preproc, - **val_loader_info.dataset_params + **val_loader_info.dataset_params, ) else: val_dataset = val_dataset_func(data=self.val_data, transform=self.preproc) @@ -138,6 +138,6 @@ def val_dataloader(self): val_dataset, shuffle=False, collate_fn=list_data_collate, - **val_loader_info.dataloader_params + **val_loader_info.dataloader_params, ) return val_dataloader diff --git a/mmv_im2im/models/nets/BranchedERFNet_2d.py b/mmv_im2im/models/nets/BranchedERFNet_2d.py index 819f24b..bbacefd 100644 --- a/mmv_im2im/models/nets/BranchedERFNet_2d.py +++ b/mmv_im2im/models/nets/BranchedERFNet_2d.py @@ -1,9 +1,10 @@ """ Author: Davy Neven -Licensed under the CC BY-NC 4.0 license +Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) https://github.com/davyneven/SpatialEmbeddings """ + import torch import torch.nn as nn import mmv_im2im.models.nets.erfnet as erfnet diff --git a/mmv_im2im/models/nets/ProbUnet.py b/mmv_im2im/models/nets/ProbUnet.py new file mode 100644 index 0000000..ed64e67 --- /dev/null +++ b/mmv_im2im/models/nets/ProbUnet.py @@ -0,0 +1,257 @@ +# Save this as ProbUnet.py (or mmv_im2im/models/ProbUnet.py if that's its actual path) +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def get_valid_num_groups(channels): + """Returns a valid number of groups for GroupNorm.""" + for g in [8, 4, 2, 1]: + if channels % g == 0: + return g + return 1 + + +class ConvBlock(nn.Module): + """Standard 2D Convolutional Block.""" + + def __init__(self, in_channels, out_channels): + super().__init__() + gn_groups1 = get_valid_num_groups(out_channels) + gn_groups2 = get_valid_num_groups(out_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.gn1 = nn.GroupNorm(gn_groups1, out_channels) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.gn2 = nn.GroupNorm(gn_groups2, out_channels) + self.relu2 = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.relu1(self.gn1(self.conv1(x))) + x = self.relu2(self.gn2(self.conv2(x))) + return x + + +class Down(nn.Module): + """Downsampling block (MaxPool + ConvBlock).""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.pool = nn.MaxPool2d(2) + self.conv_block = ConvBlock(in_channels, out_channels) + + def forward(self, x): + x = self.pool(x) + x = self.conv_block(x) + return x + + +class Up(nn.Module): + """Upsampling block (ConvTranspose + Concat + ConvBlock). + + Args: + in_channels_x1_before_upsample (int): Number of channels of the feature map (x1) + before being upsampled by ConvTranspose2d. + in_channels_x2_skip_connection (int): Number of channels of the skip connection (x2). + out_channels (int): Number of output channels for the final ConvBlock in this Up stage. + """ + + def __init__( + self, + in_channels_x1_before_upsample, + in_channels_x2_skip_connection, + out_channels, + ): + super().__init__() + + self.up = nn.ConvTranspose2d( + in_channels_x1_before_upsample, + in_channels_x1_before_upsample // 2, + kernel_size=2, + stride=2, + ) + + channels_for_conv_block = ( + in_channels_x1_before_upsample // 2 + ) + in_channels_x2_skip_connection + self.conv_block = ConvBlock(channels_for_conv_block, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # Adjust dimensions if there's a mismatch due to padding or odd sizes + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + x = torch.cat([x2, x1], dim=1) + return self.conv_block(x) + + +class PriorNet(nn.Module): + """Network to predict prior distribution (mu, logvar).""" + + def __init__(self, in_channels, latent_dim): + super().__init__() + self.conv = nn.Conv2d(in_channels, 2 * latent_dim, kernel_size=1) + + def forward(self, x): + mu_logvar = self.conv(x) + mu = mu_logvar[:, : self.conv.out_channels // 2, :, :] + logvar = mu_logvar[:, self.conv.out_channels // 2 :, :, :] + return mu, logvar + + +class PosteriorNet(nn.Module): + """Network to predict posterior distribution (mu, logvar).""" + + def __init__(self, in_channels, latent_dim): + super().__init__() + self.conv = nn.Conv2d(in_channels, 2 * latent_dim, kernel_size=1) + + def forward(self, x): + mu_logvar = self.conv(x) + mu = mu_logvar[:, : self.conv.out_channels // 2, :, :] + logvar = mu_logvar[:, self.conv.out_channels // 2 :, :, :] + return mu, logvar + + +class ProbabilisticUNet(nn.Module): + """Probabilistic UNet model.""" + + def __init__( + self, in_channels, n_classes, latent_dim=6, **kwargs + ): # Added **kwargs to capture extra params + super().__init__() + self.in_channels = in_channels + self.n_classes = n_classes + self.latent_dim = latent_dim + # self.beta is no longer needed here as it's handled by the loss function + + # Encoder path (U-Net) + self.inc = ConvBlock(in_channels, 32) + self.down1 = Down(32, 64) + self.down2 = Down(64, 128) + self.down3 = Down(128, 256) + self.down4 = Down(256, 512) # Bottleneck features + + # Prior and Posterior Networks + self.prior_net = PriorNet(512, latent_dim) + # PosteriorNet input channels: 512 (features) + n_classes (one-hot y) + self.posterior_net = PosteriorNet(512 + n_classes, latent_dim) + + # Decoder Path (U-Net upsampling path) + # Input channels for Up blocks adjusted to include latent_dim + self.up1 = Up( + in_channels_x1_before_upsample=512 + latent_dim, + in_channels_x2_skip_connection=256, + out_channels=256, + ) + + self.up2 = Up( + in_channels_x1_before_upsample=256, + in_channels_x2_skip_connection=128, + out_channels=128, + ) + + self.up3 = Up( + in_channels_x1_before_upsample=128, + in_channels_x2_skip_connection=64, + out_channels=64, + ) + + self.up4 = Up( + in_channels_x1_before_upsample=64, + in_channels_x2_skip_connection=32, + out_channels=32, + ) + + self.outc = nn.Conv2d(32, n_classes, kernel_size=1) + + def forward(self, x, y=None): + """ + Forward pass of the Probabilistic UNet. + + Args: + x (torch.Tensor): Input image tensor (B, C, H, W). + y (torch.Tensor, optional): Ground truth segmentation mask (B, 1, H, W or B, H, W) + used for training to calculate posterior. + Defaults to None (for inference). + + Returns: + tuple: A tuple containing: + - logits (torch.Tensor): Output logits of the UNet (B, n_classes, H, W). + - prior_mu (torch.Tensor): Mean of the prior distribution. + - prior_logvar (torch.Tensor): Log-variance of the prior distribution. + - post_mu (torch.Tensor or None): Mean of the posterior distribution (None if y is None). + - post_logvar (torch.Tensor or None): Log-variance of the posterior distribution (None if y is None). + """ + # Encoder (U-Net) + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + features = self.down4(x4) # Bottleneck + + # Prior distribution + prior_mu, prior_logvar = self.prior_net(features) + + # Posterior calculation and latent variable sampling + post_mu, post_logvar = None, None + if y is not None: + # Ensure y is one-hot encoded and downsampled to match features spatial dimensions. + # y typically comes as [B, 1, H, W] with integer class labels. + # Convert to [B, n_classes, H, W] for one-hot, then permute for channel dim. + y_one_hot = ( + F.one_hot(y.long().squeeze(1), num_classes=self.n_classes) + .permute(0, 3, 1, 2) + .float() + ) + + # Downsample y_one_hot to match features' spatial dimensions + y_downsampled = F.interpolate( + y_one_hot, size=features.shape[2:], mode="nearest" + ) + + # Concatenate features and downsampled one-hot y for posterior network + post_mu, post_logvar = self.posterior_net( + torch.cat([features, y_downsampled], dim=1) + ) + + # Sample 'z' from the posterior distribution + std_post = torch.exp(0.5 * post_logvar) + eps = torch.randn_like(std_post) + z = post_mu + eps * std_post + else: + # If 'y' is not provided (inference), sample 'z' from the prior. + std_prior = torch.exp(0.5 * prior_logvar) + eps = torch.randn_like(std_prior) + z = prior_mu + eps * std_prior + + # Expand 'z' to spatial dimensions for concatenation + if z.dim() == 2: # [B, latent_dim] + z_expanded = ( + z.unsqueeze(-1) + .unsqueeze(-1) + .repeat(1, 1, features.size(2), features.size(3)) + ) + elif z.dim() == 4: # [B, latent_dim, H, W] + if z.size(2) != features.size(2) or z.size(3) != features.size(3): + z_expanded = F.interpolate( + z, size=(features.size(2), features.size(3)), mode="nearest" + ) + else: + z_expanded = z + else: + raise ValueError(f"Unexpected latent vector z dimension: {z.dim()}") + + # Concatenate bottleneck features with latent vector + concat_bottleneck = torch.cat([features, z_expanded], dim=1) + + # Decoder (U-Net upsampling path) + x_up = self.up1(concat_bottleneck, x4) + x_up = self.up2(x_up, x3) + x_up = self.up3(x_up, x2) + x_up = self.up4(x_up, x1) + output = self.outc(x_up) + + # Return all necessary components for ELBO calculation + return output, prior_mu, prior_logvar, post_mu, post_logvar diff --git a/mmv_im2im/models/nets/gans.py b/mmv_im2im/models/nets/gans.py index 2b39896..5e70066 100644 --- a/mmv_im2im/models/nets/gans.py +++ b/mmv_im2im/models/nets/gans.py @@ -129,7 +129,7 @@ def __init__(self, model_info): down_block( in_channels=prev_channel, out_channels=this_channel, - **model_info["down_block"]["params"] + **model_info["down_block"]["params"], ) ) @@ -142,7 +142,7 @@ def __init__(self, model_info): res_block( in_channels=prev_channel, out_channels=this_channel, - **model_info["res_block"]["params"] + **model_info["res_block"]["params"], ) ) @@ -155,7 +155,7 @@ def __init__(self, model_info): up_block( in_channels=prev_channel, out_channels=this_channel, - **model_info["up_block"]["params"] + **model_info["up_block"]["params"], ) ) diff --git a/mmv_im2im/models/pl_FCN.py b/mmv_im2im/models/pl_FCN.py index db37947..e98e59a 100644 --- a/mmv_im2im/models/pl_FCN.py +++ b/mmv_im2im/models/pl_FCN.py @@ -4,7 +4,7 @@ from random import randint import lightning as pl import torch -from aicsimageio.writers import OmeTiffWriter +from bioio.writers import OmeTiffWriter from mmv_im2im.utils.misc import ( parse_config, @@ -53,7 +53,16 @@ def configure_optimizers(self): lr_scheduler = scheduler_func( optimizer, **self.model_info.scheduler["params"] ) - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler} + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } def prepare_batch(self, batch): return @@ -120,40 +129,17 @@ def training_step(self, batch, batch_idx): tar_out = np.squeeze(tar[0,].detach().cpu().numpy()).astype(float) prd_out = np.squeeze(yhat_act[0,].detach().cpu().numpy()).astype(float) - if len(src_out.shape) == 2: - src_order = "YX" - elif len(src_out.shape) == 3: - src_order = "ZYX" - elif len(src_out.shape) == 4: - src_order = "CZYX" - else: - raise ValueError("unexpected source dims") - - if len(tar_out.shape) == 2: - tar_order = "YX" - elif len(tar_out.shape) == 3: - tar_order = "ZYX" - elif len(tar_out.shape) == 4: - tar_order = "CZYX" - else: - raise ValueError("unexpected target dims") - - if len(prd_out.shape) == 2: - prd_order = "YX" - elif len(prd_out.shape) == 3: - prd_order = "ZYX" - elif len(prd_out.shape) == 4: - prd_order = "CZYX" - else: - raise ValueError(f"unexpected pred dims {prd_out.shape}") + def get_dim_order(arr): + dims = len(arr.shape) + return {2: "YX", 3: "ZYX", 4: "CZYX"}.get(dims, "YX") rand_tag = randint(1, 1000) out_fn = save_path / f"epoch_{self.current_epoch}_src_{rand_tag}.tiff" - OmeTiffWriter.save(src_out, out_fn, dim_order=src_order) + OmeTiffWriter.save(src_out, out_fn, dim_order=get_dim_order(src_out)) out_fn = save_path / f"epoch_{self.current_epoch}_tar_{rand_tag}.tiff" - OmeTiffWriter.save(tar_out, out_fn, dim_order=tar_order) + OmeTiffWriter.save(tar_out, out_fn, dim_order=get_dim_order(tar_out)) out_fn = save_path / f"epoch_{self.current_epoch}_prd_{rand_tag}_.tiff" - OmeTiffWriter.save(prd_out, out_fn, dim_order=prd_order) + OmeTiffWriter.save(prd_out, out_fn, dim_order=get_dim_order(prd_out)) return loss @@ -170,3 +156,4 @@ def validation_step(self, batch, batch_idx): ) return loss + \ No newline at end of file diff --git a/mmv_im2im/models/pl_ProbUnet.py b/mmv_im2im/models/pl_ProbUnet.py new file mode 100644 index 0000000..f25aa89 --- /dev/null +++ b/mmv_im2im/models/pl_ProbUnet.py @@ -0,0 +1,163 @@ +import numpy as np +from typing import Dict +from pathlib import Path +from random import randint +import lightning as pl +import torch +from bioio.writers import OmeTiffWriter + +from mmv_im2im.utils.misc import ( + parse_config, + parse_config_func, + parse_config_func_without_params, +) +from mmv_im2im.utils.model_utils import init_weights + + +class Model(pl.LightningModule): + def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = False): + super().__init__() + self.net = parse_config(model_info_xx.net) + init_weights(self.net, init_type="kaiming") + + self.model_info = model_info_xx + self.verbose = verbose + self.weighted_loss = False + if train: + self.criterion = parse_config(model_info_xx.criterion) + self.optimizer_func = parse_config_func(model_info_xx.optimizer) + + # Store these as attributes for access in run_step/training_step/validation_step + self.last_prior_mu = None + self.last_prior_logvar = None + self.last_post_mu = None + self.last_post_logvar = None + + def forward(self, x, y=None): + # The underlying ProbabilisticUNet returns multiple values. + # Capture them here and store them as instance attributes. + logits, prior_mu, prior_logvar, post_mu, post_logvar = self.net(x, y) + + # Store for use in run_step (which calculates loss) + self.last_prior_mu = prior_mu + self.last_prior_logvar = prior_logvar + self.last_post_mu = post_mu + self.last_post_logvar = post_logvar + + # For the 'Model' (LightningModule) forward, only return the logits + # This makes the API consistent with other models in your framework. + return logits + + def configure_optimizers(self): + optimizer = self.optimizer_func(self.parameters()) + if self.model_info.scheduler is None: + return optimizer + else: + scheduler_func = parse_config_func_without_params(self.model_info.scheduler) + lr_scheduler = scheduler_func( + optimizer, **self.model_info.scheduler["params"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } + + def run_step(self, batch, validation_stage): + x = batch["IM"] + y = batch["GT"] + + if x.size(-1) == 1: + x = torch.squeeze(x, dim=-1) + y = torch.squeeze(y, dim=-1) + + # Call forward pass of the LightningModule. + # This will internally call self.net(x,y) and store the extra outputs. + logits = self(x, y) # This is now just 'logits' + + # Calculate loss using the stored attributes + # Ensure post_mu and post_logvar are not None if y was provided + # The ELBOLoss expects these to be tensors, not None. + if self.last_post_mu is None or self.last_post_logvar is None: + raise ValueError( + "Posterior distributions (mu, logvar) were not computed. Ensure 'y' is provided during training." + ) + + loss = self.criterion( + logits, + y, + self.last_prior_mu, + self.last_prior_logvar, + self.last_post_mu, + self.last_post_logvar, + ) + + return loss, logits + + def training_step(self, batch, batch_idx): + loss, y_hat = self.run_step(batch, validation_stage=False) + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + if self.verbose and batch_idx == 0: + self.log_images(batch, y_hat, "train") + + return loss + + def validation_step(self, batch, batch_idx): + loss, y_hat = self.run_step(batch, validation_stage=True) + self.log( + "val_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + if self.verbose and batch_idx == 0: + self.log_images(batch, y_hat, "val") + + return loss + + def log_images(self, batch, y_hat, stage): + src = batch["IM"] + tar = batch["GT"] + + save_path = Path(self.trainer.log_dir) + save_path.mkdir(parents=True, exist_ok=True) + + act = torch.nn.Softmax(dim=1) + yhat_act = act(y_hat) + + src_out = np.squeeze(src[0].detach().cpu().numpy()).astype(float) + tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float) + prd_out = np.squeeze(yhat_act[0].detach().cpu().numpy()).astype(float) + + def get_dim_order(arr): + dims = len(arr.shape) + return {2: "YX", 3: "ZYX", 4: "CZYX"}.get(dims, "YX") + + rand_tag = randint(1, 1000) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_src_{rand_tag}.tiff" + OmeTiffWriter.save(src_out, out_fn, dim_order=get_dim_order(src_out)) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_tar_{rand_tag}.tiff" + OmeTiffWriter.save(tar_out, out_fn, dim_order=get_dim_order(tar_out)) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_prd_{rand_tag}.tiff" + OmeTiffWriter.save(prd_out, out_fn, dim_order=get_dim_order(prd_out)) diff --git a/mmv_im2im/models/pl_cyclegan.py b/mmv_im2im/models/pl_cyclegan.py index 6498adf..fc0e15f 100644 --- a/mmv_im2im/models/pl_cyclegan.py +++ b/mmv_im2im/models/pl_cyclegan.py @@ -1,6 +1,7 @@ """ This module provides lighting module for cycleGAN """ + import torch import lightning as pl from pathlib import Path diff --git a/mmv_im2im/models/pl_embedseg.py b/mmv_im2im/models/pl_embedseg.py index c50a80b..cc4bb10 100644 --- a/mmv_im2im/models/pl_embedseg.py +++ b/mmv_im2im/models/pl_embedseg.py @@ -1,7 +1,7 @@ import numpy as np from typing import Dict from pathlib import Path -from aicsimageio.writers import OmeTiffWriter +from bioio.writers import OmeTiffWriter import lightning as pl from mmv_im2im.postprocessing.embedseg_cluster import generate_instance_clusters from mmv_im2im.utils.embedseg_utils import prepare_embedseg_tensor diff --git a/mmv_im2im/models/pl_pix2pix.py b/mmv_im2im/models/pl_pix2pix.py index 2126517..ae863e2 100644 --- a/mmv_im2im/models/pl_pix2pix.py +++ b/mmv_im2im/models/pl_pix2pix.py @@ -58,7 +58,7 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals except RuntimeError: pre_train = torch.load( Path(dis_init), map_location=torch.device("cpu") - ) + ) self.discriminator.load_state_dict(pre_train["state_dict"]) else: init_weights(self.discriminator, init_type=dis_init) diff --git a/mmv_im2im/postprocessing/embedseg_cluster.py b/mmv_im2im/postprocessing/embedseg_cluster.py index 7751850..ba117cd 100644 --- a/mmv_im2im/postprocessing/embedseg_cluster.py +++ b/mmv_im2im/postprocessing/embedseg_cluster.py @@ -115,9 +115,9 @@ def cluster( ): instance_map_masked[proposal.squeeze()] = count instance_mask = torch.zeros(height, width).short() - instance_mask[ - mask.squeeze().cpu() - ] = proposal.short().cpu() # TODO + instance_mask[mask.squeeze().cpu()] = ( + proposal.short().cpu() + ) # TODO center_image = torch.zeros(height, width).short() center[0] = int( diff --git a/mmv_im2im/proj_tester.py b/mmv_im2im/proj_tester.py index 4ed383a..af0a8c7 100644 --- a/mmv_im2im/proj_tester.py +++ b/mmv_im2im/proj_tester.py @@ -9,8 +9,8 @@ import tempfile import shutil import numpy as np -from aicsimageio import AICSImage -from aicsimageio.writers import OmeTiffWriter +from bioio import BioImage +from bioio.writers import OmeTiffWriter import torch from mmv_im2im.utils.misc import generate_test_dataset_dict, parse_config from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla @@ -61,8 +61,7 @@ def setup_model(self): ): self.cpu = True pre_train = torch.load( - self.model_cfg.checkpoint, - map_location=torch.device('cpu') + self.model_cfg.checkpoint, map_location=torch.device("cpu") ) else: pre_train = torch.load(self.model_cfg.checkpoint) @@ -92,11 +91,13 @@ def setup_data_processing(self): def process_one_image( self, img: Union[DaskArray, NumpyArray], out_fn: Union[str, Path] = None ): + if isinstance(img, DaskArray): # Perform the prediction x = img.compute() elif isinstance(img, NumpyArray): x = img + else: raise ValueError("invalid image") @@ -108,6 +109,7 @@ def process_one_image( x = torch.tensor(x.astype(np.float32)) # run pre-processing on tensor if needed + if self.pre_process is not None: x = self.pre_process(x) @@ -115,6 +117,7 @@ def process_one_image( # the input here is assumed to be a tensor with torch.no_grad(): # add batch dimension and move to GPU + if self.cpu: x = torch.unsqueeze(x, dim=0) else: @@ -132,6 +135,7 @@ def process_one_image( device=torch.device("cpu"), **self.model_cfg.model_extra["sliding_window_params"], ) + # currently, we keep sliding window stiching step on CPU, but assume # the output is on GPU (see note below). So, we manually move the data # back to GPU @@ -238,6 +242,7 @@ def run_inference(self): # loop through all images and apply the model for i, ds in enumerate(dataset_list): + # Read the image print(f"Reading the image {i}/{dataset_length}") @@ -259,12 +264,12 @@ def run_inference(self): print(f"making a temp folder at {tmppath}") # get the number of time points - reader = AICSImage(ds) + reader = BioImage(ds) timelapse_data = reader.dims.T tmpfile_list = [] for t_idx in range(timelapse_data): - img = AICSImage(ds).reader.get_image_dask_data( + img = BioImage(ds).get_image_data( T=[t_idx], **self.data_cfg.inference_input.reader_params ) print(f"Predicting the image timepoint {t_idx}") @@ -314,10 +319,10 @@ def run_inference(self): # clean up temporary dir shutil.rmtree(tmppath) else: - img = AICSImage(ds).reader.get_image_dask_data( + img = BioImage(ds).get_image_data( **self.data_cfg.inference_input.reader_params ) - + # prepare output filename if "." in suffix: if ( diff --git a/mmv_im2im/proj_trainer.py b/mmv_im2im/proj_trainer.py index 315e669..9f7682a 100644 --- a/mmv_im2im/proj_trainer.py +++ b/mmv_im2im/proj_trainer.py @@ -22,7 +22,7 @@ class ProjectTrainer(object): """ - entry for training models + Entry point for training models. Parameters ---------- @@ -30,23 +30,15 @@ class ProjectTrainer(object): """ def __init__(self, cfg): - # seed everything before start pl.seed_everything(123, workers=True) - - # extract the three major chuck of the config self.model_cfg = cfg.model self.train_cfg = cfg.trainer self.data_cfg = cfg.data - - # define variables self.model = None self.data = None def run_training(self): - # set up data self.data = get_data_module(self.data_cfg) - - # set up model model_category = self.model_cfg.framework model_module = import_module(f"mmv_im2im.models.pl_{model_category}") my_model_func = getattr(model_module, "Model") @@ -59,18 +51,37 @@ def run_training(self): ) elif "pre-train" in self.model_cfg.model_extra: pre_train = torch.load(self.model_cfg.model_extra["pre-train"]) - # TODO: hacky solution to remove a wrongly registered key - pre_train["state_dict"].pop("criterion.xym", None) - self.model.load_state_dict(pre_train["state_dict"]) - # set up training + if "extend" in self.model_cfg.model_extra: + if ( + self.model_cfg.model_extra["extend"] is not None + and self.model_cfg.model_extra["extend"] is True + ): + pre_train["state_dict"].pop("criterion.xym", None) + model_state = self.model.state_dict() + pretrained_dict = pre_train["state_dict"] + filtered_dict = {} + + for k, v in pretrained_dict.items(): + if k in model_state and v.shape == model_state[k].shape: + filtered_dict[k] = v + else: + print( + f"Skipped loading layer: {k} due to shape mismatch." + ) + + model_state.update(filtered_dict) + self.model.load_state_dict(model_state) + else: + pre_train["state_dict"].pop("criterion.xym", None) + self.model.load_state_dict(pre_train["state_dict"]) + if self.train_cfg.callbacks is None: trainer = pl.Trainer(**self.train_cfg.params) else: callback_list = parse_ops_list(self.train_cfg.callbacks) trainer = pl.Trainer(callbacks=callback_list, **self.train_cfg.params) - # save the configuration in the log directory save_path = Path(trainer.log_dir) if trainer.local_rank == 0: save_path.mkdir(parents=True, exist_ok=True) @@ -84,6 +95,6 @@ def run_training(self): self.data_cfg, open(save_path / Path("data_config.yaml"), "w") ) - # start training print("start training ... ") trainer.fit(model=self.model, datamodule=self.data) + \ No newline at end of file diff --git a/mmv_im2im/utils/elbo_loss.py b/mmv_im2im/utils/elbo_loss.py new file mode 100644 index 0000000..2a3cbc0 --- /dev/null +++ b/mmv_im2im/utils/elbo_loss.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class KLDivergence(nn.Module): + """Calculates KL Divergence between two diagonal Gaussians.""" + + def __init__(self): + super().__init__() + + def forward(self, mu_q, logvar_q, mu_p, logvar_p): + """ + Calculates the KL Divergence between two diagonal Gaussian distributions. + + Args: + mu_q (torch.Tensor): Mean of the approximate posterior distribution. + logvar_q (torch.Tensor): Log-variance of the approximate posterior distribution. + mu_p (torch.Tensor): Mean of the prior distribution. + logvar_p (torch.Tensor): Log-variance of the prior distribution. + + Returns: + torch.Tensor: The mean KL divergence over the batch. + """ + kl_batch_sum = 0.5 * torch.sum( + logvar_p + - logvar_q + + (torch.exp(logvar_q) + (mu_q - mu_p) ** 2) / torch.exp(logvar_p) + - 1, + dim=[1, 2, 3], # Sum over latent channels, H, W + ) + return torch.mean(kl_batch_sum) # Average over batch + + +class ELBOLoss(nn.Module): + """ + Calculates the Evidence Lower Bound (ELBO) loss for Probabilistic UNet. + + Args: + beta (float): Weighting factor for the KL divergence term. + n_classes (int): Number of classes in the segmentation task. + """ + + def __init__(self, beta: float = 1.0, n_classes: int = 2): + super().__init__() + self.beta = beta + self.n_classes = n_classes + self.kl_divergence_calculator = KLDivergence() + + def forward(self, logits, y_true, prior_mu, prior_logvar, post_mu, post_logvar): + """ + Computes the ELBO loss. + + Args: + logits (torch.Tensor): Output logits from the Probabilistic UNet (B, C, H, W). + y_true (torch.Tensor): Ground truth segmentation mask (B, 1, H, W or B, H, W). + prior_mu (torch.Tensor): Mean of the prior distribution. + prior_logvar (torch.Tensor): Log-variance of the prior distribution. + post_mu (torch.Tensor): Mean of the approximate posterior distribution. + post_logvar (torch.Tensor): Log-variance of the approximate posterior distribution. + + Returns: + torch.Tensor: The calculated ELBO loss. + """ + # Ensure y_true has correct dimensions (e.g., [B, H, W]) for cross_entropy + if y_true.ndim == 4 and y_true.shape[1] == 1: + y_true = y_true.squeeze(1) # Squeeze channel dim to [B, H, W] + + # Negative Cross-Entropy (Log-Likelihood) + # Using reduction='mean' to get a scalar loss per batch + log_likelihood = -F.cross_entropy(logits, y_true.long(), reduction="mean") + + # KL-Divergence + kl_div = self.kl_divergence_calculator( + post_mu, post_logvar, prior_mu, prior_logvar + ) + + # ELBO = Log-Likelihood - beta * KL_Divergence + # We minimize the negative ELBO to maximize the ELBO + elbo_loss = -(log_likelihood - self.beta * kl_div) + + return elbo_loss diff --git a/mmv_im2im/utils/embedseg_utils.py b/mmv_im2im/utils/embedseg_utils.py index cbfa858..a343e77 100644 --- a/mmv_im2im/utils/embedseg_utils.py +++ b/mmv_im2im/utils/embedseg_utils.py @@ -3,8 +3,8 @@ from numba import jit from scipy.ndimage.measurements import find_objects from scipy.ndimage.morphology import binary_fill_holes -from aicsimageio.writers import OmeTiffWriter -from aicsimageio import AICSImage +from bioio.writers import OmeTiffWriter +from bioio import BioImage from tqdm import tqdm from pathlib import Path import warnings @@ -163,7 +163,7 @@ def prepare_embedseg_cache( spatial_dim = 2 for ds in dataset_list: fn = ds["source_fn"] - reader = AICSImage(fn) + reader = BioImage(fn) this_minXY = min(reader.dims.X, reader.dims.Y) min_xy = min((this_minXY, min_xy)) if reader.dims.Z > 1 and spatial_dim == 2: @@ -204,11 +204,11 @@ def prepare_embedseg_cache( # loop through the dataset for ds in tqdm(dataset_list): # get instance segmentation labels - instance_reader = AICSImage(ds["target_fn"]) + instance_reader = BioImage(ds["target_fn"]) instance = instance_reader.get_image_data(**reader_params) # get raw image - image_reader = AICSImage(ds["source_fn"]) + image_reader = BioImage(ds["source_fn"]) image = image_reader.get_image_data(**raw_reader_params) # check if costmap exists @@ -217,7 +217,7 @@ def prepare_embedseg_cache( costmap_flag = False if cm_fn.is_file(): costmap_flag = True - cm_reader = AICSImage(cm_fn) + cm_reader = BioImage(cm_fn) costmap = cm_reader.get_image_data(**reader_params) # parse filename diff --git a/mmv_im2im/utils/lovasz_losses.py b/mmv_im2im/utils/lovasz_losses.py index b8fc349..2e3d831 100644 --- a/mmv_im2im/utils/lovasz_losses.py +++ b/mmv_im2im/utils/lovasz_losses.py @@ -78,7 +78,7 @@ def iou(preds, labels, C, EMPTY=1.0, ignore=None, per_image=False): def lovasz_hinge(logits, labels, per_image=True, ignore=None): """ Binary Lovasz hinge loss - logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) # noqa W605 + logits: [B, H, W] Variable, logits at each pixel (between -infty and +infty) # noqa W605 labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) per_image: compute the loss per image instead of per batch ignore: void class id @@ -98,7 +98,7 @@ def lovasz_hinge(logits, labels, per_image=True, ignore=None): def lovasz_hinge_flat(logits, labels): """ Binary Lovasz hinge loss - logits: [P] Variable, logits at each prediction (between -\infty and +\infty) # noqa W605 + logits: [P] Variable, logits at each prediction (between -infty and +infty) # noqa W605 labels: [P] Tensor, binary ground truth labels (0 or 1) ignore: label to ignore """ @@ -143,7 +143,7 @@ def forward(self, input, target): def binary_xloss(logits, labels, ignore=None): """ Binary Cross entropy loss - logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) # noqa W605 + logits: [B, H, W] Variable, logits at each pixel (between -infty and +infty) # noqa W605 labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) ignore: void class id """ @@ -169,7 +169,7 @@ def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=N loss = mean( lovasz_softmax_flat( *flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), - only_present=only_present + only_present=only_present, ) for prob, lab in zip(probas, labels) ) diff --git a/mmv_im2im/utils/misc.py b/mmv_im2im/utils/misc.py index b90f1ba..8a0f2a3 100644 --- a/mmv_im2im/utils/misc.py +++ b/mmv_im2im/utils/misc.py @@ -4,7 +4,7 @@ import importlib import numpy as np import inspect -from aicsimageio import AICSImage +from bioio import BioImage from typing import Sequence, Tuple from monai.data import ImageReader from monai.utils import ensure_tuple, require_pkg @@ -12,7 +12,7 @@ from monai.data.image_reader import _stack_images -@require_pkg(pkg_name="aicsimageio") +@require_pkg(pkg_name="bioio") class monai_bio_reader(ImageReader): def __init__(self, **kwargs): super().__init__() @@ -22,7 +22,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike]): filenames: Sequence[PathLike] = ensure_tuple(data) img_ = [] for name in filenames: - img_.append(AICSImage(f"{name}")) + img_.append(BioImage(f"{name}")) return img_ if len(filenames) > 1 else img_[0] diff --git a/paper_configs/prepare_data/denoising.ipynb b/paper_configs/prepare_data/denoising.ipynb index f912dc1..a3e09b6 100644 --- a/paper_configs/prepare_data/denoising.ipynb +++ b/paper_configs/prepare_data/denoising.ipynb @@ -21,8 +21,8 @@ "outputs": [], "source": [ "import pooch\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "import matplotlib.pyplot as plt\n", "import tarfile\n", "from pathlib import Path\n", diff --git a/paper_configs/prepare_data/instance_seg_2d.ipynb b/paper_configs/prepare_data/instance_seg_2d.ipynb index 4ecab67..2d81e97 100644 --- a/paper_configs/prepare_data/instance_seg_2d.ipynb +++ b/paper_configs/prepare_data/instance_seg_2d.ipynb @@ -15,13 +15,14 @@ "source": [ "import pooch\n", "from skimage.io import imread\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "import zipfile\n", "from pathlib import Path\n", "from random import random\n", "import numpy as np\n", "\n", + "\n", "data_path = Path(\"../../data/instance2D\")\n", "data_path.mkdir(exist_ok=True, parents=True)\n", "\n", @@ -77,21 +78,17 @@ "filenames = sorted(download_path.glob(\"*_w2_*.tif\"))\n", "gt_path = download_path / Path(\"BBBC010_v1_foreground_eachworm\")\n", "for fn in filenames:\n", - " # extract the file key\n", " fn_key = fn.name[33:36]\n", "\n", - " # load raw image\n", - " reader = AICSImage(fn)\n", + " reader = BioImage(fn)\n", " raw = reader.get_image_data(\"YX\", Z=0, C=0, T=0)\n", "\n", - " # load ground truth\n", " gt = np.zeros(raw.shape, dtype=np.uint8)\n", " gt_filenames = sorted(Path(gt_path).glob(f\"{fn_key}_*.png\"))\n", " for gt_idx, gt_fn in enumerate(gt_filenames):\n", " gt_item = imread(gt_fn)\n", " gt[gt_item > 0] = gt_idx + 1\n", "\n", - " # since the dataset is very small, we only reserve 5% for testing\n", " if random() < 0.05:\n", " out_path = test_path\n", " else:\n", @@ -114,13 +111,6 @@ "from shutil import rmtree\n", "rmtree(download_path)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/paper_configs/prepare_data/instance_seg_3d.ipynb b/paper_configs/prepare_data/instance_seg_3d.ipynb index 85643ad..4bf5616 100644 --- a/paper_configs/prepare_data/instance_seg_3d.ipynb +++ b/paper_configs/prepare_data/instance_seg_3d.ipynb @@ -31,8 +31,8 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "from random import random\n", "import numpy as np" ] @@ -109,7 +109,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the bf and DNA dye channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " bf_img = reader.get_image_data(\n", " \"ZYX\", C=row.ChannelNumberBrightfield, S=0, T=0\n", " )\n", @@ -137,7 +137,7 @@ " all_cells = cell_df[\"this_cell_index\"].tolist()\n", "\n", " # extract the DNA segmentation\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " dna_seg = reader.get_image_data(\n", " \"ZYX\", C=0, S=0, T=0\n", " ).astype(np.uint8)\n", diff --git a/paper_configs/prepare_data/labelfree_2d.ipynb b/paper_configs/prepare_data/labelfree_2d.ipynb index a19555d..cd23f2c 100644 --- a/paper_configs/prepare_data/labelfree_2d.ipynb +++ b/paper_configs/prepare_data/labelfree_2d.ipynb @@ -33,8 +33,8 @@ "outputs": [], "source": [ "import pooch\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "import matplotlib.pyplot as plt\n", "import zipfile\n", "from pathlib import Path\n", @@ -74,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "reader = AICSImage(source_part1)\n", + "reader = BioImage(source_part1)\n", "print(reader.dims)" ] }, @@ -154,11 +154,11 @@ " fn_base = fn.stem.replace(\" \", \"\")\n", "\n", " # get bright field image\n", - " bf_reader = AICSImage(fn)\n", + " bf_reader = BioImage(fn)\n", " im = bf_reader.get_image_data(\"YX\", Z=0, T=0, C=0)\n", "\n", " # get H2b fluorescent image\n", - " h2b_reader = AICSImage(fn_fluo)\n", + " h2b_reader = BioImage(fn_fluo)\n", " gt = h2b_reader.get_image_data(\"YX\", Z=0, C=1, T=0)\n", "\n", " if random() < 0.15:\n", diff --git a/paper_configs/prepare_data/labelfree_3d.ipynb b/paper_configs/prepare_data/labelfree_3d.ipynb index 17ebf37..7416324 100644 --- a/paper_configs/prepare_data/labelfree_3d.ipynb +++ b/paper_configs/prepare_data/labelfree_3d.ipynb @@ -29,8 +29,8 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "from random import random" ] }, @@ -105,7 +105,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the bf and structures channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " bf_img = reader.get_image_data(\n", " \"ZYX\", C=row.ChannelNumberBrightfield, S=0, T=0\n", " )\n", diff --git a/paper_configs/prepare_data/modaity_transfer.ipynb b/paper_configs/prepare_data/modaity_transfer.ipynb index 0f72cf9..59a6a53 100644 --- a/paper_configs/prepare_data/modaity_transfer.ipynb +++ b/paper_configs/prepare_data/modaity_transfer.ipynb @@ -21,8 +21,8 @@ "outputs": [], "source": [ "import pooch\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "import matplotlib.pyplot as plt\n", "import zipfile\n", "from pathlib import Path\n", @@ -31,11 +31,11 @@ "from shutil import move\n", "\n", "\n", - "data_path = Path(\"../../data/modalityTransfer\")\n", + "data_path = Path(\"../../data/modality_transfer\")\n", "data_path.mkdir(exist_ok=True, parents=True)\n", "\n", "p = data_path / Path(\"download\")\n", - "p.mkdir(exist_ok=True)\n", + "p.mkdir(exist_ok=True, parents=True)\n", "p = data_path / Path(\"train\")\n", "p.mkdir(exist_ok=True)\n", "p = data_path / Path(\"test\")\n", diff --git a/paper_configs/prepare_data/semantic_seg_2d.ipynb b/paper_configs/prepare_data/semantic_seg_2d.ipynb index c22ee36..35e8774 100644 --- a/paper_configs/prepare_data/semantic_seg_2d.ipynb +++ b/paper_configs/prepare_data/semantic_seg_2d.ipynb @@ -16,7 +16,7 @@ "source": [ "import pooch\n", "from skimage.io import imread\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio.writers import OmeTiffWriter\n", "import zipfile\n", "from pathlib import Path\n", "from random import random\n", diff --git a/paper_configs/prepare_data/semantic_seg_3d.ipynb b/paper_configs/prepare_data/semantic_seg_3d.ipynb index 3c485b6..bf7838b 100644 --- a/paper_configs/prepare_data/semantic_seg_3d.ipynb +++ b/paper_configs/prepare_data/semantic_seg_3d.ipynb @@ -15,8 +15,8 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "from random import random\n", "import numpy as np\n", "\n", @@ -103,12 +103,12 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the fbl channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " img = reader.get_image_data(\n", " \"ZYX\", C=row.ChannelNumberStruct, S=0, T=0\n", " )\n", "\n", - " mean_intensity = img.mean() \n", + " mean_intensity = img.mean()\n", " if mean_intensity < 450 or mean_intensity > 500:\n", " continue\n", "\n", @@ -128,7 +128,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the Cell segmentation\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " cell_seg = reader.get_image_data(\n", " \"ZYX\", C=1, S=0, T=0\n", " ).astype(np.uint8)\n", diff --git a/paper_configs/prepare_data/sythetic_gen_2d.ipynb b/paper_configs/prepare_data/sythetic_gen_2d.ipynb index cb1f1b3..4ceac2d 100644 --- a/paper_configs/prepare_data/sythetic_gen_2d.ipynb +++ b/paper_configs/prepare_data/sythetic_gen_2d.ipynb @@ -31,11 +31,11 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", "from aicsimageio.writers import OmeTiffWriter\n", "from random import random, sample\n", "from shutil import move\n", - "import numpy as np" + "import numpy as np\n", + "from bioio import BioImage" ] }, { @@ -104,7 +104,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " img = reader.get_image_data(\"ZYX\", C=row.ChannelNumberStruct, S=0, T=0)\n", "\n", " subdir_name = row.struct_seg_path.split(\"/\")[0]\n", @@ -114,7 +114,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure segmentation\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " seg = reader.get_image_data(\"ZYX\", C=0, S=0, T=0).astype(np.uint8)\n", " seg[seg > 0] = 1\n", "\n", diff --git a/paper_configs/prepare_data/sythetic_gen_3d.ipynb b/paper_configs/prepare_data/sythetic_gen_3d.ipynb index 22d2ba8..236b509 100644 --- a/paper_configs/prepare_data/sythetic_gen_3d.ipynb +++ b/paper_configs/prepare_data/sythetic_gen_3d.ipynb @@ -31,11 +31,11 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", "from aicsimageio.writers import OmeTiffWriter\n", "from random import random, sample\n", "from shutil import move\n", - "import numpy as np" + "import numpy as np\n", + "from bioio import BioImage" ] }, { @@ -105,7 +105,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " img = reader.get_image_data(\"ZYX\", C=row.ChannelNumberStruct, S=0, T=0)\n", "\n", " # fetch segmentation (use nuclear segmentation for H2B,\n", @@ -121,7 +121,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure segmentation\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " seg = reader.get_image_data(\"ZYX\", C=0, S=0, T=0).astype(np.uint8)\n", " seg[seg > 0] = 1\n", "\n", diff --git a/paper_configs/prepare_data/unsupervised_seg_2d.ipynb b/paper_configs/prepare_data/unsupervised_seg_2d.ipynb index 0b8d5ae..f66ffb1 100644 --- a/paper_configs/prepare_data/unsupervised_seg_2d.ipynb +++ b/paper_configs/prepare_data/unsupervised_seg_2d.ipynb @@ -31,11 +31,11 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", "from aicsimageio.writers import OmeTiffWriter\n", "from random import random, sample\n", "from shutil import move\n", - "import numpy as np" + "import numpy as np\n", + "from bioio import BioImage" ] }, { @@ -106,7 +106,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " img = reader.get_image_data(\"ZYX\", C=row.ChannelNumberStruct, S=0, T=0)\n", " img_proj = np.amax(img, axis=0)\n", "\n", @@ -118,7 +118,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure segmentation\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " seg = reader.get_image_data(\"ZYX\", C=0, S=0, T=0).astype(np.uint8)\n", " seg[seg > 0] = 1\n", " seg_proj = np.amax(seg, axis=0)\n", diff --git a/paper_configs/prepare_data/unsupervised_seg_3d.ipynb b/paper_configs/prepare_data/unsupervised_seg_3d.ipynb index 87cb4c6..0fb4c93 100644 --- a/paper_configs/prepare_data/unsupervised_seg_3d.ipynb +++ b/paper_configs/prepare_data/unsupervised_seg_3d.ipynb @@ -31,11 +31,11 @@ "import pandas as pd\n", "import quilt3\n", "from pathlib import Path\n", - "from aicsimageio import AICSImage\n", "from aicsimageio.writers import OmeTiffWriter\n", "from random import random, sample\n", "from shutil import move\n", - "import numpy as np" + "import numpy as np\n", + "from bioio import BioImage" ] }, { @@ -106,7 +106,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure channel\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " img = reader.get_image_data(\"ZYX\", C=row.ChannelNumberStruct, S=0, T=0)\n", "\n", " # fetch segmentation (use nuclear segmentation for H2B,\n", @@ -122,7 +122,7 @@ " pkg[subdir_name][file_name].fetch(local_fn)\n", "\n", " # extract the structure segmentation\n", - " reader = AICSImage(local_fn)\n", + " reader = BioImage(local_fn)\n", " seg = reader.get_image_data(\"ZYX\", C=0, S=0, T=0).astype(np.uint8)\n", " seg[seg > 0] = 1\n", "\n", diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..11728cc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,146 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mmv_im2im" +version = "0.5.2" +authors = [ + { name="Jianxu Chen", email="jianxuchen.ai@gmail.com" }, +] +description = "A python package for deep learning based image to image transformation" +readme = "README.md" +license = { text="MIT license" } +requires-python = ">=3.10" +keywords = ["deep learning", "microscopy image analysis", "biomedical image analysis"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = [ + "lightning>=2.0.1", + "torch==2.0.1", + "monai>=1.2.0", + "bioio==1.6.1", + "pandas", + "scikit-image", + "protobuf<4.21.0", + "pyrallis", + "scikit-learn", + "tensorboard", + "numba", + "numpy<2", + "pydantic==2.11.7", + "fastapi", + "uvicorn", + "botocore==1.38.38", + "bioio-ome-tiff==1.1.0", + "bioio-ome-zarr==1.2.0", + "pydantic-zarr", + "bioio-tifffile==1.1.0", + "bioio-lif==1.1.0", + "ngff-zarr", + "tifffile", + "ome-types", + "imageio", + "zarr", +] + +[project.optional-dependencies] +advanced = [ + "tensorboard", +] +paper = [ + "quilt3", + "pooch", + "matplotlib", + "notebook", +] + + +image-io = [] + +dev = [ + "pytest-runner>=5.2", + "black>=19.10b0", + "codecov>=2.1.4", + "flake8>=3.8.3", + "flake8-debugger>=3.2.1", + "pytest>=5.4.3", + "pytest-cov>=2.9.0", + "pytest-raises>=0.11", + "numpy<2", + "bump2version>=1.0.1", + "coverage>=5.1", + "ipython>=7.15.0", + "m2r2>=0.2.7", + "Sphinx>=3.4.3", + "sphinx_rtd_theme>=0.5.1", + "tox>=3.15.2", + "twine>=3.1.1", + "wheel>=0.34.2", +] + +test = [ + "black>=19.10b0", + "codecov>=2.1.4", + "flake8>=3.8.3", + "flake8-debugger>=3.2.1", + "pytest>=5.4.3", + "pytest-cov>=2.9.0", + "pytest-raises>=0.11", + "numpy<2", +] + +all = [ + "mmv_im2im[advanced]", + "mmv_im2im[paper]", + "mmv_im2im[image-io]", + "mmv_im2im[dev]", + "mmv_im2im[test]", +] + +dynamic = ["optional-dependencies"] + +[project.urls] +"Homepage" = "https://github.com/MMV-Lab/mmv_im2im" + +[tool.setuptools] +include-package-data = true +zip-safe = false + +[project.scripts] +run_im2im = "mmv_im2im.bin.run_im2im:main" + +[tool.pytest.ini_options] +addopts = "--cov=mmv_im2im --no-cov-on-fail --cov-report=term-missing --cov-report=xml --cov-branch --durations=10" +testpaths = [ + "mmv_im2im/tests", +] +python_files = "test_*.py" + +[tool.flake8] +exclude = [ + "docs/", + ".git/", + "__pycache__/", + "build/", + "dist/", + ".venv/", + ".tox/", + "*.egg-info/", +] +ignore = [ + "E203", + "E402", + "W291", + "W503", + "W293", + "W292", + "E501", +] +max-line-length = 88 \ No newline at end of file diff --git a/script/pull_labelfree_sample_data.py b/script/pull_labelfree_sample_data.py index db5c446..cb4ea5f 100644 --- a/script/pull_labelfree_sample_data.py +++ b/script/pull_labelfree_sample_data.py @@ -4,8 +4,8 @@ import quilt3 from pathlib import Path import random -from aicsimageio import AICSImage -from aicsimageio.writers import OmeTiffWriter +from bioio import BioImage +from bioio.writers import OmeTiffWriter import os import sys import logging @@ -158,7 +158,7 @@ def execute(self, args): pkg[subdir_name][file_name].fetch(local_fn) # extract the bf and structures channel - reader = AICSImage(local_fn) + reader = BioImage(local_fn) bf_img = reader.get_image_data( "ZYX", C=row.ChannelNumberBrightfield, S=0, T=0 ) @@ -197,4 +197,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index ea198fb..0642911 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,21 +11,11 @@ replace = version="{new_version}" search = {current_version} replace = {new_version} -[bdist_wheel] -universal = 1 - -[aliases] -test = pytest - -[tool:pytest] -collect_ignore = ['setup.py'] +[bumpversion:file:pyproject.toml] +search = version = "{current_version}" +replace = version = "{new_version}" [flake8] -exclude = - docs/ -ignore = - E203 - E402 - W291 - W503 -max-line-length = 88 +exclude = docs/, .git/, __pycache__/, build/, dist/, .venv/, .tox/, *.egg-info/ +ignore = E203, E402, W291, W503, W293, W292, E501 +max-line-length = 88 \ No newline at end of file diff --git a/setup.py b/setup.py index 2e922db..beb61a9 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -"""The setup script.""" -from setuptools import find_packages, setup +from setuptools import setup, find_packages -with open("README.md") as readme_file: - readme = readme_file.read() setup_requirements = [ "pytest-runner>=5.2", @@ -20,6 +17,7 @@ "pytest>=5.4.3", "pytest-cov>=2.9.0", "pytest-raises>=0.11", + "numpy<2", ] dev_requirements = [ @@ -48,75 +46,10 @@ "tensorboard", ] -requirements = [ - "lightning==2.0.0", - "torch==2.0.1", - "monai>=1.2.0", - "aicsimageio==4.10.0", - "pandas", - "scikit-image", - "protobuf<4.21.0", - "pyrallis", - "scikit-learn", - "tensorboard", - "numba", -] - -extra_requirements = { - "setup": setup_requirements, - "test": test_requirements, - "dev": dev_requirements, - "paper": [ - *requirements, - *data_requirements, - ], - "advanced": [ - *requirements, - *logger_requirements, - ], - "all": [ - *requirements, - *logger_requirements, - *data_requirements, - *dev_requirements, - ] -} setup( - author="Jianxu Chen", - author_email="jianxuchen.ai@gmail.com", - classifiers=[ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "Natural Language :: English", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - description="A python package for deep learing based image to image transformation", # noqa E501 - entry_points={ - "console_scripts": [ - "run_im2im=mmv_im2im.bin.run_im2im:main" - ], - }, - install_requires=requirements, - license="MIT license", - long_description=readme, - long_description_content_type="text/markdown", - include_package_data=True, - keywords="deep learning, microscopy image analysis, biomedical image analysis", - name="mmv_im2im", - packages=find_packages(exclude=["tests", "*.tests", "*.tests.*"]), - python_requires=">=3.9", - setup_requires=setup_requirements, - test_suite="mmv_im2im/tests", - tests_require=test_requirements, - extras_require=extra_requirements, - url="https://github.com/MMV-Lab/mmv_im2im", - # Do not edit this string manually, always use bumpversion - # Details in CONTRIBUTING.rst - version="0.5.2", - zip_safe=False, -) + packages=["mmv_im2im"], + package_dir={"mmv_im2im": "mmv_im2im"}, + +) \ No newline at end of file diff --git a/tox.ini b/tox.ini index d2bf26b..72ba0b7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,18 +1,20 @@ [tox] skipsdist = True -envlist = py37, py38, py39, lint +envlist = py310, py311, lint [testenv:lint] +basepython = python3.11 deps = .[test] + black commands = flake8 mmv_im2im --count --verbose --show-source --statistics black --check mmv_im2im -[testenv] +[testenv:py311] setenv = PYTHONPATH = {toxinidir} deps = .[test] commands = - pytest --basetemp={envtmpdir} --cov-report html --cov=mmv_im2im mmv_im2im/tests/ + pytest --basetemp={envtmpdir} --cov-report html --cov=mmv_im2im mmv_im2im/tests/ \ No newline at end of file diff --git a/tutorials/colab/labelfree_2d.ipynb b/tutorials/colab/labelfree_2d.ipynb index 5d5ac4b..9687bd4 100644 --- a/tutorials/colab/labelfree_2d.ipynb +++ b/tutorials/colab/labelfree_2d.ipynb @@ -88,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": { "cellView": "form", "id": "ZuJQDKpkfkcK" @@ -97,8 +97,8 @@ "source": [ "# @title 3.Download the dataset from zenodo\n", "import pooch\n", - "from aicsimageio import AICSImage\n", - "from aicsimageio.writers import OmeTiffWriter\n", + "from bioio import BioImage\n", + "from bioio.writers import OmeTiffWriter\n", "import matplotlib.pyplot as plt\n", "import zipfile\n", "from pathlib import Path\n", @@ -124,7 +124,7 @@ " path=data_path / Path(\"download\"),\n", ")\n", "\n", - "reader = AICSImage(source_part1)\n", + "reader = BioImage(source_part1)\n", "\n", "# input (bright field) channel: 2\n", "# ground truth (mCherry-H2B) channel: 4\n", @@ -160,11 +160,11 @@ " fn_base = fn.stem.replace(\" \", \"\")\n", "\n", " # get bright field image\n", - " bf_reader = AICSImage(fn)\n", + " bf_reader = BioImage(fn)\n", " im = bf_reader.get_image_data(\"YX\", Z=0, T=0, C=0)\n", "\n", " # get H2b fluorescent image\n", - " h2b_reader = AICSImage(fn_fluo)\n", + " h2b_reader = BioImage(fn_fluo)\n", " gt = h2b_reader.get_image_data(\"YX\", Z=0, C=1, T=0)\n", "\n", " if random() < 0.15:\n", diff --git a/tutorials/how_to_understand_boilerplates.md b/tutorials/how_to_understand_boilerplates.md index b2b4a37..76b475c 100644 --- a/tutorials/how_to_understand_boilerplates.md +++ b/tutorials/how_to_understand_boilerplates.md @@ -1,7 +1,7 @@ # Notes on the package design -1. The four main packages we build upon: [pytorch-lightning](https://www.pytorchlightning.ai/), [MONAI](https://monai.io/), [pyrallis](https://eladrich.github.io/pyrallis/), and [aicsimageio](https://github.com/AllenCellModeling/aicsimageio). +1. The four main packages we build upon: [pytorch-lightning](https://www.pytorchlightning.ai/), [MONAI](https://monai.io/), [pyrallis](https://eladrich.github.io/pyrallis/), and [bioio](https://github.com/bioio-devs/bioio). The whole package uses [pytorch-lightning](https://www.pytorchlightning.ai/) as the core of its backend, in the sense that the package is implemented following the boilerplate components in pytorch-lightning, such as `LightningModule`, `DataModule` and `Trainer`. All small building blocks, like network architecture, optimizer, etc., can be swapped easily without changing the boilerplate. @@ -9,7 +9,7 @@ We adopt the [PersistentDataset](https://docs.monai.io/en/stable/data.html#persi [Pyrallis](https://eladrich.github.io/pyrallis/) provides a handy configuration system. Combining pyrallis and the boilerplate concepts in pytorch-lightning, it is very easy to configure your method at any level of details (as high level as only providing the path to the training data, all the way to as low level as changing which type of normalization to use in the model). -Finally, [aicsimageio](https://github.com/AllenCellModeling/aicsimageio) is adopted for efficient data I/O, which not only supports all major bio-formats and OME-TIFF, but also makes it painless to handle hugh data by delayed loading. +Finally, [bioio](https://github.com/bioio-devs/bioio) is adopted for efficient data I/O, which not only supports all major bio-formats and OME-TIFF, but also makes it painless to handle hugh data by delayed loading. 2. The codebase is modularized and organized at three levels: