Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmv_im2im/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
__email__ = "[email protected]"
# Do not edit this string manually, always use bumpversion
# Details in CONTRIBUTING.md
__version__ = "0.8.0"
__version__ = "0.7.1"


def get_module_version():
Expand Down
25 changes: 25 additions & 0 deletions mmv_im2im/bin/run_im2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
parse_adaptor,
configuration_validation,
)
from mmv_im2im.proj_trainer_multishape import VariableSizeProjectTrainer

###############################################################################

Expand All @@ -33,6 +34,7 @@

###############################################################################
TRAIN_MODE = "train"
TRAIN_MULTISHAPE_MODE = "train-multishape"
INFER_MODE = "inference"
MAP_MODE = "uncertainty_map"

Expand Down Expand Up @@ -70,12 +72,35 @@ def main():
else:
exe = ProjectTrainer(cfg)
exe.run_training()

elif cfg.mode.lower() == TRAIN_MULTISHAPE_MODE:
if (
cfg.data.dataloader.train.dataloader_type["func_name"]
== "PersistentDataset"
):
# Mirror the same PersistentDataset cache handling used in
# standard training mode.
cache_root = Path(cfg.data.dataloader.train.dataset_params["cache_dir"])
cache_root.mkdir(exist_ok=True)
with tempfile.TemporaryDirectory(dir=cache_root) as tmp_exp:
train_cache = Path(tmp_exp) / "train"
val_cache = Path(tmp_exp) / "val"
cfg.data.dataloader.train.dataset_params["cache_dir"] = train_cache
cfg.data.dataloader.val.dataset_params["cache_dir"] = val_cache
exe = VariableSizeProjectTrainer(cfg)
exe.run_training()
else:
exe = VariableSizeProjectTrainer(cfg)
exe.run_training()

elif cfg.mode.lower() == INFER_MODE:
exe = ProjectTester(cfg)
exe.run_inference()

elif cfg.mode.lower() == MAP_MODE:
exe = MapExtractor(cfg)
exe.run_inference()

else:
log.error(f"Mode {cfg.mode} is not supported yet")
sys.exit(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mode: train

data:
category: "pair"
data_path: "/mnt/eternus/users/Jair/Trainings/Retrain2/DataMax"
data_path: "path/to/your/data"
dataloader:
train:
dataloader_type:
Expand Down
3 changes: 3 additions & 0 deletions mmv_im2im/map_extractor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
import logging
from typing import Union
from dask.array.core import Array as DaskArray
Expand Down
149 changes: 103 additions & 46 deletions mmv_im2im/models/pl_FCN.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict

import lightning as pl
import torch
import torch.nn as nn
from mmv_im2im.utils.gdl_regularized import (
RegularizedGeneralizedDiceFocalLoss as regularized,
)
Expand All @@ -19,20 +21,32 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals
if isinstance(model_info_xx.net["params"], dict):
self.task = model_info_xx.net["params"].pop("task", "segmentation")

if self.task != "regression" and self.task != "segmentation":
if self.task not in ("regression", "segmentation"):
raise ValueError(
f"Task should be regression/segmentation : {self.task} was given"
f"Task should be 'regression' or 'segmentation'; got '{self.task}'"
)

self.net = parse_config(model_info_xx.net)

init_weights(self.net, init_type="kaiming")

self.model_info = model_info_xx
self.verbose = verbose
self.weighted_loss = False
self.seg_flag = False

# ── Regression: adaptive global average pooling ────────────────
# Works for ANY spatial size (2-D or 3-D) and is more robust than
# the manual .view(...).mean(-1) idiom.
# AdaptiveAvgPool reduces every spatial dimension to size 1, after
# which squeeze() gives a flat [B, C] tensor.
self._gap_2d = nn.AdaptiveAvgPool2d(1)
self._gap_3d = nn.AdaptiveAvgPool3d(1)

# _logged_shapes → verbose print fires exactly once (any task)
# _shape_checked → regression shape validation fires exactly once
self._logged_shapes = False
self._shape_checked = False

if train:
if "use_costmap" in model_info_xx.criterion[
"params"
Expand All @@ -48,70 +62,114 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals
):
self.seg_flag = True

# ------------------------------------------------------------------
def configure_optimizers(self):
optimizer = self.optimizer_func(self.parameters())
if self.model_info.scheduler is None:
return optimizer
else:
scheduler_func = parse_config_func_without_params(self.model_info.scheduler)
lr_scheduler = scheduler_func(
optimizer, **self.model_info.scheduler["params"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
"strict": True,
},
}

def prepare_batch(self, batch):
return

scheduler_func = parse_config_func_without_params(self.model_info.scheduler)
lr_scheduler = scheduler_func(optimizer, **self.model_info.scheduler["params"])
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
"strict": True,
},
}

# ------------------------------------------------------------------
def forward(self, x):
return self.net(x)

def run_step(self, batch, validation_stage):
x = batch["IM"]
y = batch["GT"]
if "CM" in batch.keys():
# ------------------------------------------------------------------
def _global_average_pool(self, y_hat: torch.Tensor) -> torch.Tensor:
"""
Reduce [B, C, *spatial] → [B, C] via global average pooling.

Handles both 2-D [B, C, H, W] and 3-D [B, C, D, H, W] tensors
of arbitrary spatial size.
"""
ndim = y_hat.dim()
if ndim == 4: # 2-D input
return self._gap_2d(y_hat).squeeze(-1).squeeze(-1)
elif ndim == 5: # 3-D input
return self._gap_3d(y_hat).squeeze(-1).squeeze(-1).squeeze(-1)
else:
# Fallback: manual mean over all spatial dims
return y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1)

# ------------------------------------------------------------------
def run_step(self, batch, validation_stage: bool):
x = batch["IM"] # [B, 1, Z, Y, X] – spatial dims vary per batch
y = batch["GT"] # [B, 3 + n_coeffs]

if "CM" in batch:
assert (
self.weighted_loss
), "Costmap is detected, but no use_costmap param in criterion"
), "Costmap detected in batch but use_costmap=False in criterion config."
cm = batch["CM"]

# only for badly formated data file
if x.size()[-1] == 1:
x = torch.squeeze(x, dim=-1)
y = torch.squeeze(y, dim=-1)
# ── Remove accidental trailing singleton (badly formatted data) ──
if x.dim() > 4 and x.size(-1) == 1:
x = x.squeeze(-1)
if y.dim() > 2 and y.size(-1) == 1:
y = y.squeeze(-1)

y_hat = self(x)
# ── Forward pass ──────────────────────────────────────────────
y_hat = self(x) # [B, C, Z', Y', X'] (spatial may still vary)

# Verbose: imprime shapes una sola vez al inicio, independiente del task
if self.verbose and not self._logged_shapes:
print(
f"[pl_FCN] x.shape={tuple(x.shape)} "
f"y_hat.shape (pre-GAP)={tuple(y_hat.shape)}"
)
self._logged_shapes = True

# ── Task-specific output processing ────────────────────────────
if self.task == "regression":
# Global Average Pooling: Dim reduction 2D / 3D
# [B, C, H, W] -> [B, C]
y_hat = y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1)
# GT: [B, C]
# Global Average Pool: [B, C, *spatial] → [B, C]
y_hat = self._global_average_pool(y_hat)

# GT: ensure [B, N]
y = y.view(y.size(0), -1).float()
else:

# One-time shape sanity check
if not self._shape_checked:
C_pred = y_hat.shape[1]
N_gt = y.shape[1]
if C_pred != N_gt:
raise ValueError(
f"[pl_FCN] Shape mismatch: network predicts {C_pred} "
f"values (out_channels={C_pred}) but GT vector has "
f"{N_gt} elements. "
f"Set out_channels: {N_gt} in your YAML."
)
if self.verbose:
print(
f"[pl_FCN] Regression shapes OK "
f"y_hat={tuple(y_hat.shape)} y={tuple(y.shape)}"
)
self._shape_checked = True

else: # segmentation
if isinstance(self.criterion, torch.nn.CrossEntropyLoss):
# in case of CrossEntropy related error
y = torch.squeeze(y, dim=1) # remove C dimension
y = y.squeeze(dim=1)

# ── Loss ─────────────────────────────────────────────────────
if isinstance(self.criterion, regularized):
current_epoch = self.current_epoch
loss = self.criterion(y_hat, y, epoch=current_epoch)
loss = self.criterion(y_hat, y, epoch=self.current_epoch)
elif self.weighted_loss:
loss = self.criterion(y_hat, y, cm)
else:
if self.weighted_loss:
loss = self.criterion(y_hat, y, cm)
else:
loss = self.criterion(y_hat, y)
loss = self.criterion(y_hat, y)

return loss

# ------------------------------------------------------------------
def on_train_epoch_end(self):
torch.cuda.synchronize()

Expand Down Expand Up @@ -142,5 +200,4 @@ def validation_step(self, batch, batch_idx):
logger=True,
sync_dist=True,
)

return loss
Loading
Loading