Skip to content

Commit

Permalink
Update instance_segmentation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ariannasole23 authored Jan 28, 2025
1 parent 9f48f50 commit 63aefc8
Showing 1 changed file with 17 additions and 43 deletions.
60 changes: 17 additions & 43 deletions torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from .base import BaseTask

import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from ..datasets import RGBBandsMissingError, unbind_samples
Expand Down Expand Up @@ -113,7 +112,12 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor:
"""
images, targets = batch['image'], batch['target']
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
loss = sum(loss for loss in loss_dict.values())

print('\nTRAINING LOSS\n')
print(loss_dict, '\n\n')
print(loss)

self.log('train_loss', loss, batch_size=len(images))
return loss

Expand All @@ -130,29 +134,14 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
batch_size = images.shape[0]

outputs = self.model(images)
loss_dict = self.model(images, targets) # list of dictionaries
total_loss = sum(loss_item for loss_dict in loss_dict for loss_item in loss_dict.values() if loss_item.ndim == 0)

for target in targets:
target["masks"] = (target["masks"] > 0).to(torch.uint8)
target["boxes"] = target["boxes"].to(torch.float32)
target["labels"] = target["labels"].to(torch.int64)

# Compute the loss and predictions
loss_dict = self.model(images, targets) # list of dictionaries

print('\nDEBUG TRAINING LOSS\n')
print(f"Training loss: {loss_dict}")

# Post-process `loss_dict` to compute total loss
total_loss = 0.0
for loss in loss_dict:
if isinstance(loss, dict):
for key, value in loss.items():
# Ensure the loss component is a scalar tensor
if value.ndim == 0:
total_loss += value
else:
print(f"Skipping non-scalar loss: {key}, shape: {value.shape}")

# Post-process the outputs to ensure masks are in the correct format
for output in outputs:
if "masks" in output:
Expand All @@ -170,7 +159,7 @@ def validation_step(self, batch: Any, batch_idx: int) -> None:
value = value.to(torch.float32).mean()
scalar_metrics[key] = value

self.log_dict(scalar_metrics, batch_size=batch_size)
self.log_dict(scalar_metrics, batch_size=batch_size)

# check
if (
Expand Down Expand Up @@ -208,10 +197,8 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
batch_size = images.shape[0]

outputs = self.model(images)

print('\nDEBUG THE PREDICTIONS\n')
print(f"Predictions for batch {batch_idx}: {outputs}")
print(f"Ground truth for batch {batch_idx}: {targets}")
loss_dict = self.model(images, targets) # Compute all losses
total_loss = sum(loss_item for loss_dict in loss_dict for loss_item in loss_dict.values() if loss_item.ndim == 0)

for target in targets:
target["masks"] = target["masks"].to(torch.uint8)
Expand All @@ -221,21 +208,6 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
for output in outputs:
output["masks"] = (output["masks"] > 0.5).squeeze(1).to(torch.uint8)

loss_dict = self.model(images, targets) # Compute all losses

# Post-process `loss_dict` to compute total loss
total_loss = 0.0
for loss in loss_dict:
if isinstance(loss, dict):
for key, value in loss.items():
# Ensure the loss component is a scalar tensor
if value.ndim == 0:
total_loss += value
else:
print(f"Skipping non-scalar loss: {key}, shape: {value.shape}")


# Sum the losses
self.log('test_loss', total_loss, batch_size=batch_size)

metrics = self.val_metrics(outputs, targets)
Expand All @@ -249,10 +221,9 @@ def test_step(self, batch: Any, batch_idx: int) -> None:

self.log_dict(scalar_metrics, batch_size=batch_size)

print('\nDEBUG CAL METRICS\n')
print(f"Validation metrics: {metrics}")

return outputs
print('\nTESTING LOSS\n')
print(loss_dict, '\n\n')
print(total_loss)

def predict_step(self, batch: Any, batch_idx: int) -> Any:
"""Perform inference on a batch of images.
Expand All @@ -266,6 +237,9 @@ def predict_step(self, batch: Any, batch_idx: int) -> Any:
self.model.eval()
images = batch['image']
outputs = self.model(images)

for output in outputs:
output["masks"] = (output["masks"] > 0.5).to(torch.uint8)
return outputs


Expand Down

0 comments on commit 63aefc8

Please sign in to comment.