Skip to content

Conversation

isaaccorley
Copy link
Collaborator

Adds the DFC2025 Track 1 Baseline U-Net weights trained by @cliffbb in the baseline repo. The model is a land cover land-use semantic segmentation model trained on the OpenEarthMap-SAR dataset which contains high-resolution 1-channel uint8 SAR images from Umbra Space
particularly the Geo-Ellipsoid Corrected (GEC) geotiffs.

The model has an IoU of 0.41 on the private test set leaderboard
image

Before @calebrob6 asks, I've checked that it runs and it's decent but seems to get confused by the noisiness in the SAR imagery. Below is a minimum working example:

import random
import glob

import torch
import torchvision.transforms.v2.functional as F
import rasterio
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb, ListedColormap
from torchgeo.models import Unet_Weights, unet

h = w = 1024
images = sorted(glob.glob("/data/train/sar_images/*.tif"))
labels = sorted(glob.glob("/data/train/labels/*.tif"))

device = torch.device("cuda")
dtype = torch.float32
weights = Unet_Weights.UMBRA_GEC_OPENEARTHMAP_SAR
classes = weights.meta["classes"]
colormap = ListedColormap([to_rgb(hex) for hex in weights.meta["colormap"]])

transforms = weights.transforms
model = unet(weights)
model.eval()
transforms.eval()
model = model.to(device).to(dtype)
transforms = transforms.to(device).to(dtype)

idx = random.randint(0, len(images) - 1)
path = images[idx]
label = labels[idx]
with rasterio.open(path) as src:
    image = torch.from_numpy(src.read())
    image = F.center_crop(image, (h, w))
    x = image.clone().unsqueeze(dim=0).to(device).to(dtype)
    image = image.repeat(3, 1, 1)

with rasterio.open(label) as src:
    mask = torch.from_numpy(src.read())
    mask = F.center_crop(mask, (h, w))
    mask = mask.squeeze(dim=0).numpy()

x = transforms(dict(image=x))["image"]
preds = model(x).cpu().squeeze().argmax(dim=0)

fig, axs = plt.subplots(1, 3, figsize=(12, 5))
axs[0].imshow(image.permute(1, 2, 0).numpy())
axs[0].axis('off')
axs[1].imshow(mask, vmin=0, vmax=len(classes), cmap=colormap, interpolation='none')
axs[1].axis('off')
axs[2].imshow(preds, vmin=0, vmax=len(classes), cmap=colormap, interpolation='none')
axs[0].set_title('Image')
axs[1].set_title('Mask')
axs[2].set_title('Predictions')
image image image image

@isaaccorley isaaccorley added this to the 0.8.0 milestone Aug 1, 2025
@isaaccorley isaaccorley self-assigned this Aug 1, 2025
@github-actions github-actions bot added documentation Improvements or additions to documentation models Models and pretrained weights labels Aug 1, 2025
@isaaccorley isaaccorley force-pushed the models/dfc2025-oem-sar-baseline branch from f3273bb to 5adf3fa Compare August 1, 2025 18:50
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds DFC2025 Track 1 Baseline U-Net weights for land cover segmentation using SAR imagery from the OpenEarthMap-SAR dataset. The model achieves 0.41 IoU on the private test set and is trained on high-resolution Umbra Space SAR data.

  • Adds new UMBRA_GEC_OPENEARTHMAP_SAR weight configuration to the U-Net model
  • Updates documentation to rename "Sentinel-1" section to "Synthetic Aperture Radar (SAR)" for broader SAR model coverage
  • Includes model metadata with 9 land cover classes, colormap definitions, and encoder specifications

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
torchgeo/models/unet.py Adds new SAR weight configuration with metadata and model kwargs support
docs/api/weights/sar.csv Adds entry for the new U-Net SAR weights in documentation table
docs/api/models.rst Updates section title from "Sentinel-1" to "Synthetic Aperture Radar (SAR)"

@isaaccorley isaaccorley requested a review from calebrob6 August 1, 2025 18:51
@github-actions github-actions bot added the testing Continuous integration testing label Aug 1, 2025
@calebrob6
Copy link
Collaborator

I see my job here is not needed

Comment on lines +178 to +179
Synthetic Aperture Radar (SAR)
------------------------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree with this change. There are many different kinds of SAR (C-band vs. L-band) just like there are many different kinds of optical imagery (Sentinel-2, Landsat, etc.). Just because two models have the same number of bands doesn't mean they will be compatible with any SAR imagery.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's okay with me but you're not proposing a solution here. Do you prefer that I just make a separate group for "Umbra Space" in this case? Because this isn't Sentinel-1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, separate group/table for Umbra

@@ -2,6 +2,9 @@ Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_GRD_DECUR, 2,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0"
ResNet50_Weights.SENTINEL1_GRD_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
ResNet50_Weights.SENTINEL1_GRD_SOFTCON, 2,`link <https://github.com/zhu-xlab/softcon>`__,`link <https://arxiv.org/abs/2405.20462>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move sorting to a separate PR to make this easier to review.

'num_classes': 9,
'model': 'U-Net',
'encoder': 'efficientnet-b4',
'classes': (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot more metadata than other models. Is this coming from STAC MLM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of our weights are foundation models with no class specific outputs. I think it makes sense to add this here for the user to have information about the output order and categories.

),
'publication': 'https://arxiv.org/abs/2501.10891',
'repo': 'https://github.com/cliffbb/DFC2025-OEM-SAR-Baseline',
'bands': ['B1'],
Copy link
Member

@adamjstewart adamjstewart Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B1 doesn't sound right, is it VV or VH? These are very different polarizations, an incompatible choice could result in the poor performance you are seeing.

Copy link
Collaborator Author

@isaaccorley isaaccorley Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a different level of processing that is used from Umbra where it converted to a grayscale uint8 image. This doesn't use the raw band data unfortunately.

@@ -2,6 +2,9 @@ Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_GRD_DECUR, 2,`link <https://github.com/zhu-xlab/DeCUR>`__,`link <https://arxiv.org/abs/2309.05300>`__,"Apache-2.0"
ResNet50_Weights.SENTINEL1_GRD_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
ResNet50_Weights.SENTINEL1_GRD_SOFTCON, 2,`link <https://github.com/zhu-xlab/softcon>`__,`link <https://arxiv.org/abs/2405.20462>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_MI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,ODC-BY
Unet_Weights.UMBRA_GEC_OPENEARTHMAP_SAR, 1,`link <https://github.com/cliffbb/DFC2025-OEM-SAR-Baseline/>`__,`link <https://arxiv.org/abs/2501.10891>`__,"CC-BY-4.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I find this license documented?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dataset is CC-BY-4.0 licensed so the model is also the same license found here https://registry.opendata.aws/umbra-open-data/

Copy link
Collaborator Author

@isaaccorley isaaccorley Aug 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cliffbb can you clarify here what license you are releasing the model weights under?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation models Models and pretrained weights testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants