diff --git a/tests/models/test_dofa.py b/tests/models/test_dofa.py new file mode 100644 index 00000000000..5cceccdc52c --- /dev/null +++ b/tests/models/test_dofa.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from pathlib import Path +from typing import Any + +import pytest +import torch +import torchvision +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch +from torchvision.models._api import WeightsEnum + +from torchgeo.models import ( + DOFABase16_Weights, + dofa_base_patch16_224, + dofa_huge_patch16_224, + dofa_large_patch16_224, + dofa_small_patch16_224, +) + + +def load(url: str, *args: Any, **kwargs: Any) -> dict[str, Any]: + state_dict: dict[str, Any] = torch.load(url) + return state_dict + + +class TestDOFASmall16: + def test_dofa(self) -> None: + dofa_small_patch16_224() + + +class TestDOFABase16: + @pytest.fixture(params=[*DOFABase16_Weights]) + def weights(self, request: SubRequest) -> WeightsEnum: + return request.param + + @pytest.fixture + def mocked_weights( + self, tmp_path: Path, monkeypatch: MonkeyPatch, weights: WeightsEnum + ) -> WeightsEnum: + path = tmp_path / f"{weights}.pth" + model = dofa_base_patch16_224() + torch.save(model.state_dict(), path) + try: + monkeypatch.setattr(weights.value, "url", str(path)) + except AttributeError: + monkeypatch.setattr(weights, "url", str(path)) + monkeypatch.setattr(torchvision.models._api, "load_state_dict_from_url", load) + return weights + + def test_dofa(self) -> None: + dofa_base_patch16_224() + + def test_dofa_weights(self, mocked_weights: WeightsEnum) -> None: + dofa_base_patch16_224(weights=mocked_weights) + + def test_transforms(self, mocked_weights: WeightsEnum) -> None: + c = 4 + sample = { + "image": torch.arange(c * 224 * 224, dtype=torch.float).view(c, 224, 224) + } + mocked_weights.transforms(sample) + + @pytest.mark.slow + def test_dofa_download(self, weights: WeightsEnum) -> None: + dofa_base_patch16_224(weights=weights) + + +class TestDOFALarge16: + def test_dofa(self) -> None: + dofa_large_patch16_224() + + +class TestDOFAHuge16: + def test_dofa(self) -> None: + dofa_huge_patch16_224() diff --git a/torchgeo/models/__init__.py b/torchgeo/models/__init__.py index d40eac47c6f..f77ec0f0f2f 100644 --- a/torchgeo/models/__init__.py +++ b/torchgeo/models/__init__.py @@ -5,6 +5,14 @@ from .api import get_model, get_model_weights, get_weight, list_models from .changestar import ChangeMixin, ChangeStar, ChangeStarFarSeg +from .dofa import ( + DOFABase16_Weights, + OFAViT, + dofa_base_patch16_224, + dofa_huge_patch16_224, + dofa_large_patch16_224, + dofa_small_patch16_224, +) from .farseg import FarSeg from .fcn import FCN from .fcsiam import FCSiamConc, FCSiamDiff diff --git a/torchgeo/models/dofa.py b/torchgeo/models/dofa.py index bbcb7d8e8ff..78377ad53cb 100644 --- a/torchgeo/models/dofa.py +++ b/torchgeo/models/dofa.py @@ -673,7 +673,7 @@ def dofa_large_patch16_224(*args: Any, **kwargs: Any) -> OFAViT: return model -def dofa_huge_patch14_224(*args: Any, **kwargs: Any) -> OFAViT: +def dofa_huge_patch16_224(*args: Any, **kwargs: Any) -> OFAViT: """Dynamic One-For-All (DOFA) huge patch size 16 model. If you use this model in your research, please cite the following paper: