Skip to content

Commit

Permalink
Update and rename test_instancesegmentation.py to test_trainer_instan…
Browse files Browse the repository at this point in the history
…cesegmentation.py
  • Loading branch information
ariannasole23 authored Jan 28, 2025
1 parent d9158a0 commit 9f48f50
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 123 deletions.
123 changes: 0 additions & 123 deletions test_instancesegmentation.py

This file was deleted.

123 changes: 123 additions & 0 deletions test_trainer_instancesegmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch

Check failure on line 1 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D100)

test_trainer_instancesegmentation.py:1:1: D100 Missing docstring in public module
import lightning.pytorch as pl
from torch.utils.data import DataLoader, Subset
from torchgeo.datasets import VHR10
from torchvision.transforms.functional import to_pil_image
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchgeo.trainers import InstanceSegmentationTask

Check failure on line 9 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

test_trainer_instancesegmentation.py:1:1: I001 Import block is un-sorted or un-formatted

def collate_fn(batch):

Check failure on line 11 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

test_trainer_instancesegmentation.py:11:5: ANN201 Missing return type annotation for public function `collate_fn`

Check failure on line 11 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:11:16: ANN001 Missing type annotation for function argument `batch`
"""Custom collate function for DataLoader."""
max_height = max(sample['image'].shape[1] for sample in batch)
max_width = max(sample['image'].shape[2] for sample in batch)

images = torch.stack([
F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))
for sample in batch
])

targets = [
{
"labels": sample["labels"].to(torch.int64),
"boxes": sample["boxes"].to(torch.float32),
"masks": F.pad(
sample["masks"],
(0, max_width - sample["masks"].shape[2], 0, max_height - sample["masks"].shape[1]),
).to(torch.uint8),
}
for sample in batch
]

return {"image": images, "target": targets}

def visualize_predictions(image, predictions, targets):

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

test_trainer_instancesegmentation.py:35:5: ANN201 Missing return type annotation for public function `visualize_predictions`

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:35:27: ANN001 Missing type annotation for function argument `image`

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:35:34: ANN001 Missing type annotation for function argument `predictions`

Check failure on line 35 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:35:47: ANN001 Missing type annotation for function argument `targets`
"""Visualize predictions and ground truth."""
image = to_pil_image(image)

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image)

# Predictions
for box, label in zip(predictions['boxes'], predictions['labels']):
x1, y1, x2, y2 = box
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
ax.add_patch(rect)
ax.text(x1, y1, f"Pred: {label.item()}", color='red', fontsize=12)

# Ground truth
for box, label in zip(targets['boxes'], targets['labels']):
x1, y1, x2, y2 = box
rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')
ax.add_patch(rect)
ax.text(x1, y1, f"GT: {label.item()}", color='blue', fontsize=12)

plt.show()

def plot_losses(train_losses, val_losses):

Check failure on line 58 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

test_trainer_instancesegmentation.py:58:5: ANN201 Missing return type annotation for public function `plot_losses`

Check failure on line 58 in test_trainer_instancesegmentation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN001)

test_trainer_instancesegmentation.py:58:17: ANN001 Missing type annotation for function argument `train_losses`
"""Plot training and validation losses over epochs."""
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss', marker='s')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()

# Initialize VHR-10 dataset
train_dataset = VHR10(root="data", split="positive", transforms=None, download=True)
val_dataset = VHR10(root="data", split="positive", transforms=None)

# Subset for quick experimentation (adjust N as needed)
N = 100
train_subset = Subset(train_dataset, list(range(N)))
val_subset = Subset(val_dataset, list(range(N)))


if __name__ == '__main__':
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)

train_loader = DataLoader(train_subset, batch_size=8, shuffle=True, num_workers=1, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=8, shuffle=False, num_workers=1, collate_fn=collate_fn)

# Trainer setup
trainer = pl.Trainer(
max_epochs=5,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1
)

task = InstanceSegmentationTask(
model="mask_rcnn",
backbone="resnet50",
weights="imagenet", # Pretrained on ImageNet
num_classes=11, # VHR-10 has 10 classes + 1 background
lr=1e-3,
freeze_backbone=False
)

print('\nSTART TRAINING\n')
# trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
train_losses, val_losses = [], []
for epoch in range(5):
trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
train_loss = task.trainer.callback_metrics.get("train_loss")
val_loss = task.trainer.callback_metrics.get("val_loss")
if train_loss is not None:
train_losses.append(train_loss.item())
if val_loss is not None:
val_losses.append(val_loss.item())

plot_losses(train_losses, val_losses)

#trainer.test(task, dataloaders=val_loader)

# Inference and Visualization
sample = train_dataset[1]
image = sample['image'].unsqueeze(0)
predictions = task.predict_step({"image": image}, batch_idx=0)
visualize_predictions(image[0], predictions[0], sample)

0 comments on commit 9f48f50

Please sign in to comment.