-
Notifications
You must be signed in to change notification settings - Fork 451
Trainers: add Instance Segmentation Task #2513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+422
−35
Merged
Changes from all commits
Commits
Show all changes
58 commits
Select commit
Hold shift + click to select a range
d00c087
Add files via upload
ariannasole23 52daa1c
Add files via upload
ariannasole23 68756a7
Update instancesegmentation.py
ariannasole23 e249883
Merge branch 'microsoft:main' into main
ariannasole23 7676ac3
Update and rename instancesegmentation.py to instance_segmentation.py
ariannasole23 0fa7b07
Update test_instancesegmentation.py
ariannasole23 b4334f0
Update instance_segmentation.py
ariannasole23 a160baa
Update __init__.py
ariannasole23 fa8697b
Update instance_segmentation.py
ariannasole23 f6ceed1
Update instance_segmentation.py
ariannasole23 619760b
Add files via upload
ariannasole23 d9158a0
Update test_instancesegmentation.py
ariannasole23 9f48f50
Update and rename test_instancesegmentation.py to test_trainer_instan…
ariannasole23 63aefc8
Update instance_segmentation.py
ariannasole23 70074e7
Add files via upload
ariannasole23 b3de001
Creato con Colab
ariannasole23 d70f1e3
Creato con Colab
ariannasole23 1e68d2d
Creato con Colab
ariannasole23 98c836a
Merge branch 'microsoft:main' into main
ariannasole23 9664834
Update instance_segmentation.py
ariannasole23 f802574
Delete test_trainer.ipynb
ariannasole23 3c86306
Delete test_trainer_instancesegmentation.py
ariannasole23 7ec3930
Update and rename test_instancesegmentation.py to test_instance_segme…
ariannasole23 927f7fc
Update instance_segmentation.py
ariannasole23 4f1cecf
Update test_instance_segmentation.py
ariannasole23 21e0af2
Update instance_segmentation.py
ariannasole23 3956d23
Update instance_segmentation.py
ariannasole23 0e458a5
Update instance_segmentation.py run ruff
ariannasole23 870845b
Merge remote-tracking branch 'upstream/main'
adamjstewart fafb001
Ruff
adamjstewart ad7197d
dos2unix
adamjstewart 954e898
Add support for MSI, weights
adamjstewart 3c6ee68
Update tests
adamjstewart 7c4e30c
timm and torchvision are not compatible
adamjstewart 7c34d4a
Finalize trainer code, simpler
adamjstewart 649a877
Update VHR10 tests
adamjstewart 4f201fd
Uniformity
adamjstewart 006cfa9
Fix most tests
adamjstewart b3a4e44
100% coverage
adamjstewart 1d80adc
Fix datasets tests
adamjstewart d8e8fe6
Fix weight tests
adamjstewart f774875
Fix MSI support
adamjstewart c823fd0
Fix parameter replacement
adamjstewart 94e8001
Fix minimum tests
adamjstewart 5e01c96
Fix minimum tests
adamjstewart 2460b26
Add all unpacked data
adamjstewart d63cf85
Fix tests
adamjstewart f85a72e
Undo FTW changes
adamjstewart 683c162
Undo FTW changes
adamjstewart 8a9c0e9
Undo FTW changes
adamjstewart b072a38
Remove dead code
adamjstewart c4b5d17
Remove dead code, match detection style
adamjstewart 801c0ba
Try newer torchmetrics
adamjstewart 4640d6c
Try newer torchmetrics
adamjstewart 1d2a595
Try newer torchmetrics
adamjstewart 8f165ab
More metrics
adamjstewart 7b6182d
Fix mypy
adamjstewart 335f072
Fix and test weights=True, num_classes!=91
adamjstewart File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
model: | ||
class_path: InstanceSegmentationTask | ||
init_args: | ||
model: 'mask-rcnn' | ||
backbone: 'resnet50' | ||
num_classes: 11 | ||
data: | ||
class_path: VHR10DataModule | ||
init_args: | ||
batch_size: 1 | ||
num_workers: 0 | ||
patch_size: 4 | ||
dict_kwargs: | ||
root: 'tests/data/vhr10' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from typing import Any | ||
|
||
import pytest | ||
from lightning.pytorch import Trainer | ||
from pytest import MonkeyPatch | ||
|
||
from torchgeo.datamodules import MisconfigurationException, VHR10DataModule | ||
from torchgeo.datasets import VHR10, RGBBandsMissingError | ||
from torchgeo.main import main | ||
from torchgeo.trainers import InstanceSegmentationTask | ||
|
||
# mAP metric requires pycocotools to be installed | ||
pytest.importorskip('pycocotools') | ||
|
||
|
||
class PredictInstanceSegmentationDataModule(VHR10DataModule): | ||
def setup(self, stage: str) -> None: | ||
self.predict_dataset = VHR10(**self.kwargs) | ||
|
||
|
||
def plot(*args: Any, **kwargs: Any) -> None: | ||
return None | ||
|
||
|
||
def plot_missing_bands(*args: Any, **kwargs: Any) -> None: | ||
raise RGBBandsMissingError() | ||
|
||
|
||
class TestInstanceSegmentationTask: | ||
@pytest.mark.parametrize('name', ['vhr10_ins_seg']) | ||
def test_trainer( | ||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool | ||
) -> None: | ||
config = os.path.join('tests', 'conf', name + '.yaml') | ||
|
||
args = [ | ||
'--config', | ||
config, | ||
'--trainer.accelerator', | ||
'cpu', | ||
'--trainer.fast_dev_run', | ||
str(fast_dev_run), | ||
'--trainer.max_epochs', | ||
'1', | ||
'--trainer.log_every_n_steps', | ||
'1', | ||
] | ||
|
||
main(['fit', *args]) | ||
try: | ||
main(['test', *args]) | ||
except MisconfigurationException: | ||
pass | ||
try: | ||
main(['predict', *args]) | ||
except MisconfigurationException: | ||
pass | ||
|
||
def test_invalid_model(self) -> None: | ||
match = 'Invalid model type' | ||
with pytest.raises(ValueError, match=match): | ||
InstanceSegmentationTask(model='invalid_model') | ||
|
||
def test_invalid_backbone(self) -> None: | ||
match = 'Invalid backbone type' | ||
with pytest.raises(ValueError, match=match): | ||
InstanceSegmentationTask(backbone='invalid_backbone') | ||
|
||
def test_weights(self) -> None: | ||
InstanceSegmentationTask(weights=True, num_classes=3) | ||
InstanceSegmentationTask(weights=True, num_classes=91) | ||
|
||
def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: | ||
monkeypatch.setattr(VHR10DataModule, 'plot', plot) | ||
datamodule = VHR10DataModule( | ||
root='tests/data/vhr10', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask(in_channels=3, num_classes=11) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.validate(model=model, datamodule=datamodule) | ||
|
||
def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: | ||
monkeypatch.setattr(VHR10DataModule, 'plot', plot_missing_bands) | ||
datamodule = VHR10DataModule( | ||
root='tests/data/vhr10', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask(in_channels=3, num_classes=11) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.validate(model=model, datamodule=datamodule) | ||
|
||
def test_predict(self, fast_dev_run: bool) -> None: | ||
datamodule = PredictInstanceSegmentationDataModule( | ||
root='tests/data/vhr10', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask(num_classes=11) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.predict(model=model, datamodule=datamodule) | ||
|
||
def test_freeze_backbone(self) -> None: | ||
task = InstanceSegmentationTask(freeze_backbone=True) | ||
for param in task.model.backbone.parameters(): | ||
assert param.requires_grad is False | ||
|
||
for head in ['rpn', 'roi_heads']: | ||
for param in getattr(task.model, head).parameters(): | ||
assert param.requires_grad is True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yet another test that downloads data on the fly (#1088), but let's fix that another day when we figure out how to support custom weights.