diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index be91f41d433..8ce91fafc68 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -294,6 +294,10 @@ DL4GAM ^^^^^^ .. autoclass:: DL4GAMAlps +DOTA +^^^^ +.. autoclass:: DOTA + ETCI2021 Flood Detection ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/datasets/non_geo_datasets.csv b/docs/api/datasets/non_geo_datasets.csv index e7b63a774fa..5d93b2af150 100644 --- a/docs/api/datasets/non_geo_datasets.csv +++ b/docs/api/datasets/non_geo_datasets.csv @@ -16,6 +16,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands `DIOR`_,OD,Aerial,"CC-BY-NC-4.0","23,463",20,"800x800",0.5,RGB `Digital Typhoon`_,"C, R",Himawari,"CC-BY-4.0","189,364",8,512,5000,Infrared `DL4GAM`_,S,"Sentinel-2","CC-BY-4.0","2,251 or 11,440","2","256x256","10","MSI" +`DOTA`_,OD,"Google Earth, Gaofen-2, Jilin-1, CycloMedia B.V.","non-commercial","5,229",15,"1000--5000",RGB `ETCI2021 Flood Detection`_,S,Sentinel-1,-,"66,810",2,256x256,5--20,SAR `EuroSAT`_,C,Sentinel-2,"MIT","27,000",10,64x64,10,MSI `EverWatch`_,OD,Aerial,"CC0-1.0","5,325",7,"1,500x1500p",0.01,RGB diff --git a/tests/data/dota/data.py b/tests/data/dota/data.py new file mode 100644 index 00000000000..b2b1e46f7ea --- /dev/null +++ b/tests/data/dota/data.py @@ -0,0 +1,159 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil +import tarfile +from pathlib import Path + +import numpy as np +import pandas as pd +from PIL import Image + + +def create_dummy_image(path: Path, size: tuple[int, int] = (64, 64)) -> None: + """Create small dummy image.""" + img = np.random.randint(0, 255, (*size, 3), dtype=np.uint8) + Image.fromarray(img).save(path) + + +def create_annotation_file( + path: Path, is_hbb: bool = False, no_boxes: bool = False +) -> None: + """Create dummy annotation file with scaled coordinates.""" + if is_hbb: + # Horizontal boxes scaled for 64x64 + boxes = [ + '10.0 10.0 20.0 10.0 20.0 20.0 10.0 20.0 plane 0\n', + '30.0 30.0 40.0 30.0 40.0 40.0 30.0 40.0 ship 0\n', + ] + else: + # Oriented boxes scaled for 64x64 + boxes = [ + '10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0\n', + '30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0\n', + ] + + if no_boxes: + boxes = [] + + with open(path, 'w') as f: + f.write('imagesource:dummy\n') + f.write('gsd:1.0\n') + f.writelines(boxes) + + +def create_test_data(root: Path) -> None: + """Create DOTA test dataset.""" + splits = ['train', 'val'] + versions = ['1.0', '1.5', '2.0'] + + # Create directory structure + for split in splits: + num_samples = 3 if split == 'train' else 2 + + if os.path.exists(root / split): + shutil.rmtree(root / split) + for version in versions: + # Create images and annotations + for i in range(num_samples): + img_name = f'P{i:04d}.png' + ann_name = f'P{i:04d}.txt' + + # Create directories + (root / split / 'images').mkdir(parents=True, exist_ok=True) + (root / split / 'annotations' / f'version{version}').mkdir( + parents=True, exist_ok=True + ) + + # Create files + if i == 0: + no_boxes = True + else: + no_boxes = False + create_dummy_image(root / split / 'images' / img_name) + create_annotation_file( + root / split / 'annotations' / f'version{version}' / ann_name, + False, + no_boxes, + ) + + # Create tar archives + for type_ in ['images', 'annotations']: + src_dir = root / split / type_ + if src_dir.exists(): + tar_name = f'dotav{version}_{type_}_{split}.tar.gz' + with tarfile.open(root / tar_name, 'w:gz') as tar: + tar.add(src_dir, arcname=f'{split}/{type_}') + + # print md5sums + def md5(fname: str) -> str: + hash_md5 = hashlib.md5() + with open(fname, 'rb') as f: + for chunk in iter(lambda: f.read(4096), b''): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + print('file_info = {') + for split in splits: + print(f" '{split}': {{") + + for type_ in ['images', 'annotations']: + print(f" '{type_}': {{") + + for version in versions: + tar_name = f'dotav{version}_{type_}_{split}.tar.gz' + checksum = md5(tar_name) + + # version 1.0 and 1.5 have the same images + if version == '1.5' and type_ == 'images': + version_filename = '1.0' + else: + version_filename = version + + print(f" '{version}': {{") + print( + f" 'filename': 'dotav{version_filename}_{type_}_{split}.tar.gz'," + ) + print(f" 'md5': '{checksum}',") + print(' },') + + print(' },') + + print(' },') + print('}') + + +def create_sample_df(root: Path) -> pd.DataFrame: + """Create sample DataFrame for test data.""" + rows = [] + splits = ['train', 'val'] + versions = ['1.0', '1.5', '2.0'] + + for split in splits: + num_samples = 3 if split == 'train' else 2 + for version in versions: + for i in range(num_samples): + img_name = f'P{i:04d}.png' + ann_name = f'P{i:04d}.txt' + + row = { + 'image_path': str(Path(split) / 'images' / img_name), + 'annotation_path': str( + Path(split) / 'annotations' / f'version{version}' / ann_name + ), + 'split': split, + 'version': version, + } + rows.append(row) + + df = pd.DataFrame(rows) + df.to_csv(root / 'samples.csv') + return df + + +if __name__ == '__main__': + root = Path('.') + create_test_data(root) + df = create_sample_df(root) diff --git a/tests/data/dota/dotav1.0_annotations_train.tar.gz b/tests/data/dota/dotav1.0_annotations_train.tar.gz new file mode 100644 index 00000000000..0d23ac13202 Binary files /dev/null and b/tests/data/dota/dotav1.0_annotations_train.tar.gz differ diff --git a/tests/data/dota/dotav1.0_annotations_val.tar.gz b/tests/data/dota/dotav1.0_annotations_val.tar.gz new file mode 100644 index 00000000000..d432b12872c Binary files /dev/null and b/tests/data/dota/dotav1.0_annotations_val.tar.gz differ diff --git a/tests/data/dota/dotav1.0_images_train.tar.gz b/tests/data/dota/dotav1.0_images_train.tar.gz new file mode 100644 index 00000000000..82c9ccf1c04 Binary files /dev/null and b/tests/data/dota/dotav1.0_images_train.tar.gz differ diff --git a/tests/data/dota/dotav1.0_images_val.tar.gz b/tests/data/dota/dotav1.0_images_val.tar.gz new file mode 100644 index 00000000000..81c52c64200 Binary files /dev/null and b/tests/data/dota/dotav1.0_images_val.tar.gz differ diff --git a/tests/data/dota/dotav1.5_annotations_train.tar.gz b/tests/data/dota/dotav1.5_annotations_train.tar.gz new file mode 100644 index 00000000000..a63f3d683fe Binary files /dev/null and b/tests/data/dota/dotav1.5_annotations_train.tar.gz differ diff --git a/tests/data/dota/dotav1.5_annotations_val.tar.gz b/tests/data/dota/dotav1.5_annotations_val.tar.gz new file mode 100644 index 00000000000..0acce6ea66e Binary files /dev/null and b/tests/data/dota/dotav1.5_annotations_val.tar.gz differ diff --git a/tests/data/dota/dotav1.5_images_train.tar.gz b/tests/data/dota/dotav1.5_images_train.tar.gz new file mode 100644 index 00000000000..21765e3c65b Binary files /dev/null and b/tests/data/dota/dotav1.5_images_train.tar.gz differ diff --git a/tests/data/dota/dotav1.5_images_val.tar.gz b/tests/data/dota/dotav1.5_images_val.tar.gz new file mode 100644 index 00000000000..ac7d73f0614 Binary files /dev/null and b/tests/data/dota/dotav1.5_images_val.tar.gz differ diff --git a/tests/data/dota/dotav2.0_annotations_train.tar.gz b/tests/data/dota/dotav2.0_annotations_train.tar.gz new file mode 100644 index 00000000000..945a7ced291 Binary files /dev/null and b/tests/data/dota/dotav2.0_annotations_train.tar.gz differ diff --git a/tests/data/dota/dotav2.0_annotations_val.tar.gz b/tests/data/dota/dotav2.0_annotations_val.tar.gz new file mode 100644 index 00000000000..d3f95593b0d Binary files /dev/null and b/tests/data/dota/dotav2.0_annotations_val.tar.gz differ diff --git a/tests/data/dota/dotav2.0_images_train.tar.gz b/tests/data/dota/dotav2.0_images_train.tar.gz new file mode 100644 index 00000000000..d3280139fd1 Binary files /dev/null and b/tests/data/dota/dotav2.0_images_train.tar.gz differ diff --git a/tests/data/dota/dotav2.0_images_val.tar.gz b/tests/data/dota/dotav2.0_images_val.tar.gz new file mode 100644 index 00000000000..0e0f875749a Binary files /dev/null and b/tests/data/dota/dotav2.0_images_val.tar.gz differ diff --git a/tests/data/dota/samples.csv b/tests/data/dota/samples.csv new file mode 100644 index 00000000000..5733c21796d --- /dev/null +++ b/tests/data/dota/samples.csv @@ -0,0 +1,16 @@ +,image_path,annotation_path,split,version +0,train/images/P0000.png,train/annotations/version1.0/P0000.txt,train,1.0 +1,train/images/P0001.png,train/annotations/version1.0/P0001.txt,train,1.0 +2,train/images/P0002.png,train/annotations/version1.0/P0002.txt,train,1.0 +3,train/images/P0000.png,train/annotations/version1.5/P0000.txt,train,1.5 +4,train/images/P0001.png,train/annotations/version1.5/P0001.txt,train,1.5 +5,train/images/P0002.png,train/annotations/version1.5/P0002.txt,train,1.5 +6,train/images/P0000.png,train/annotations/version2.0/P0000.txt,train,2.0 +7,train/images/P0001.png,train/annotations/version2.0/P0001.txt,train,2.0 +8,train/images/P0002.png,train/annotations/version2.0/P0002.txt,train,2.0 +9,val/images/P0000.png,val/annotations/version1.0/P0000.txt,val,1.0 +10,val/images/P0001.png,val/annotations/version1.0/P0001.txt,val,1.0 +11,val/images/P0000.png,val/annotations/version1.5/P0000.txt,val,1.5 +12,val/images/P0001.png,val/annotations/version1.5/P0001.txt,val,1.5 +13,val/images/P0000.png,val/annotations/version2.0/P0000.txt,val,2.0 +14,val/images/P0001.png,val/annotations/version2.0/P0001.txt,val,2.0 diff --git a/tests/data/dota/samples.parquet b/tests/data/dota/samples.parquet new file mode 100644 index 00000000000..b97ce0de9d6 Binary files /dev/null and b/tests/data/dota/samples.parquet differ diff --git a/tests/data/dota/train/annotations/version1.0/P0000.txt b/tests/data/dota/train/annotations/version1.0/P0000.txt new file mode 100644 index 00000000000..0e9123426ea --- /dev/null +++ b/tests/data/dota/train/annotations/version1.0/P0000.txt @@ -0,0 +1,2 @@ +imagesource:dummy +gsd:1.0 diff --git a/tests/data/dota/train/annotations/version1.0/P0001.txt b/tests/data/dota/train/annotations/version1.0/P0001.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/train/annotations/version1.0/P0001.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/train/annotations/version1.0/P0002.txt b/tests/data/dota/train/annotations/version1.0/P0002.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/train/annotations/version1.0/P0002.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/train/annotations/version1.5/P0000.txt b/tests/data/dota/train/annotations/version1.5/P0000.txt new file mode 100644 index 00000000000..0e9123426ea --- /dev/null +++ b/tests/data/dota/train/annotations/version1.5/P0000.txt @@ -0,0 +1,2 @@ +imagesource:dummy +gsd:1.0 diff --git a/tests/data/dota/train/annotations/version1.5/P0001.txt b/tests/data/dota/train/annotations/version1.5/P0001.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/train/annotations/version1.5/P0001.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/train/annotations/version1.5/P0002.txt b/tests/data/dota/train/annotations/version1.5/P0002.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/train/annotations/version1.5/P0002.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/train/annotations/version2.0/P0000.txt b/tests/data/dota/train/annotations/version2.0/P0000.txt new file mode 100644 index 00000000000..0e9123426ea --- /dev/null +++ b/tests/data/dota/train/annotations/version2.0/P0000.txt @@ -0,0 +1,2 @@ +imagesource:dummy +gsd:1.0 diff --git a/tests/data/dota/train/annotations/version2.0/P0001.txt b/tests/data/dota/train/annotations/version2.0/P0001.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/train/annotations/version2.0/P0001.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/train/annotations/version2.0/P0002.txt b/tests/data/dota/train/annotations/version2.0/P0002.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/train/annotations/version2.0/P0002.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/train/images/P0000.png b/tests/data/dota/train/images/P0000.png new file mode 100644 index 00000000000..a613e42d70a Binary files /dev/null and b/tests/data/dota/train/images/P0000.png differ diff --git a/tests/data/dota/train/images/P0001.png b/tests/data/dota/train/images/P0001.png new file mode 100644 index 00000000000..db50a37c4c1 Binary files /dev/null and b/tests/data/dota/train/images/P0001.png differ diff --git a/tests/data/dota/train/images/P0002.png b/tests/data/dota/train/images/P0002.png new file mode 100644 index 00000000000..b51d3a2dc1b Binary files /dev/null and b/tests/data/dota/train/images/P0002.png differ diff --git a/tests/data/dota/val/annotations/version1.0/P0000.txt b/tests/data/dota/val/annotations/version1.0/P0000.txt new file mode 100644 index 00000000000..0e9123426ea --- /dev/null +++ b/tests/data/dota/val/annotations/version1.0/P0000.txt @@ -0,0 +1,2 @@ +imagesource:dummy +gsd:1.0 diff --git a/tests/data/dota/val/annotations/version1.0/P0001.txt b/tests/data/dota/val/annotations/version1.0/P0001.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/val/annotations/version1.0/P0001.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/val/annotations/version1.5/P0000.txt b/tests/data/dota/val/annotations/version1.5/P0000.txt new file mode 100644 index 00000000000..0e9123426ea --- /dev/null +++ b/tests/data/dota/val/annotations/version1.5/P0000.txt @@ -0,0 +1,2 @@ +imagesource:dummy +gsd:1.0 diff --git a/tests/data/dota/val/annotations/version1.5/P0001.txt b/tests/data/dota/val/annotations/version1.5/P0001.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/val/annotations/version1.5/P0001.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/val/annotations/version2.0/P0000.txt b/tests/data/dota/val/annotations/version2.0/P0000.txt new file mode 100644 index 00000000000..0e9123426ea --- /dev/null +++ b/tests/data/dota/val/annotations/version2.0/P0000.txt @@ -0,0 +1,2 @@ +imagesource:dummy +gsd:1.0 diff --git a/tests/data/dota/val/annotations/version2.0/P0001.txt b/tests/data/dota/val/annotations/version2.0/P0001.txt new file mode 100644 index 00000000000..20b1dc50fc6 --- /dev/null +++ b/tests/data/dota/val/annotations/version2.0/P0001.txt @@ -0,0 +1,4 @@ +imagesource:dummy +gsd:1.0 +10.0 10.0 20.0 12.0 18.0 20.0 8.0 18.0 plane 0 +30.0 30.0 42.0 32.0 40.0 40.0 28.0 38.0 ship 0 diff --git a/tests/data/dota/val/images/P0000.png b/tests/data/dota/val/images/P0000.png new file mode 100644 index 00000000000..72023b4a34c Binary files /dev/null and b/tests/data/dota/val/images/P0000.png differ diff --git a/tests/data/dota/val/images/P0001.png b/tests/data/dota/val/images/P0001.png new file mode 100644 index 00000000000..2258575293f Binary files /dev/null and b/tests/data/dota/val/images/P0001.png differ diff --git a/tests/datasets/test_dota.py b/tests/datasets/test_dota.py new file mode 100644 index 00000000000..23a34188d34 --- /dev/null +++ b/tests/datasets/test_dota.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from itertools import product +from pathlib import Path + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from pytest import MonkeyPatch + +from torchgeo.datasets import DOTA, DatasetNotFoundError + + +class TestDOTA: + @pytest.fixture( + params=product( + ['train', 'val'], ['1.0', '1.5', '2.0'], ['horizontal', 'oriented'] + ) + ) + def dataset( + self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + ) -> DOTA: + url = os.path.join('tests', 'data', 'dota', '{}') + monkeypatch.setattr(DOTA, 'url', url) + + file_info = { + 'train': { + 'images': { + '1.0': { + 'filename': 'dotav1.0_images_train.tar.gz', + 'md5': '126d42cc8b2c093e7914528ac01ea8fc', + }, + '1.5': { + 'filename': 'dotav1.0_images_train.tar.gz', + 'md5': 'fd187ea8acc3d429f0ba9e5ef96def75', + }, + '2.0': { + 'filename': 'dotav2.0_images_train.tar.gz', + 'md5': '613d192b70dc53fe7e10f95eed0e1a9d', + }, + }, + 'annotations': { + '1.0': { + 'filename': 'dotav1.0_annotations_train.tar.gz', + 'md5': '1fbdb35e2d55cab2632a8c20ed54a6de', + }, + '1.5': { + 'filename': 'dotav1.5_annotations_train.tar.gz', + 'md5': '7a7ed5a309acb45dd1885f088fa24783', + }, + '2.0': { + 'filename': 'dotav2.0_annotations_train.tar.gz', + 'md5': 'f8cd1bf53362bd372ddc2fba97cff2b6', + }, + }, + }, + 'val': { + 'images': { + '1.0': { + 'filename': 'dotav1.0_images_val.tar.gz', + 'md5': 'f73dbdc8aa4e580dda4ef6cb54cfbd68', + }, + '1.5': { + 'filename': 'dotav1.0_images_val.tar.gz', + 'md5': 'b1c618180e0ca3e4426ecf53b82c8d74', + }, + '2.0': { + 'filename': 'dotav2.0_images_val.tar.gz', + 'md5': '0950df7a4c700934572f3a9a85133520', + }, + }, + 'annotations': { + '1.0': { + 'filename': 'dotav1.0_annotations_val.tar.gz', + 'md5': '700fd2e7cba8dd543ca5bcbe411c9db4', + }, + '1.5': { + 'filename': 'dotav1.5_annotations_val.tar.gz', + 'md5': 'f0a32911fa3614a8de67f5fd8d04dd9e', + }, + '2.0': { + 'filename': 'dotav2.0_annotations_val.tar.gz', + 'md5': '4823cdc2c35d5f74254ffab0d99ea876', + }, + }, + }, + } + monkeypatch.setattr(DOTA, 'file_info', file_info) + + root = tmp_path + split, version, bbox_orientation = request.param + + transforms = nn.Identity() + + return DOTA( + root, + split, + version=version, + bbox_orientation=bbox_orientation, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_getitem(self, dataset: DOTA) -> None: + for i in range(len(dataset)): + x = dataset[i] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['labels'], torch.Tensor) + if dataset.bbox_orientation == 'oriented': + bbox_key = 'bbox' + else: + bbox_key = 'bbox_xyxy' + assert isinstance(x[bbox_key], torch.Tensor) + + if dataset.bbox_orientation == 'oriented': + assert x[bbox_key].shape[1] == 8 + else: + assert x[bbox_key].shape[1] == 4 + + assert x['labels'].shape[0] == x[bbox_key].shape[0] + + def test_len(self, dataset: DOTA) -> None: + if dataset.split == 'train': + assert len(dataset) == 3 + else: + assert len(dataset) == 2 + + def test_already_downloaded(self, dataset: DOTA) -> None: + DOTA(root=dataset.root, download=True) + + def test_not_yet_extracted(self, tmp_path: Path) -> None: + files = [ + 'dotav1.0_images_train.tar.gz', + 'dotav1.0_annotations_train.tar.gz', + 'dotav1.5_annotations_train.tar.gz', + 'dotav1.5_annotations_val.tar.gz', + 'dotav1.0_images_val.tar.gz', + 'dotav1.0_annotations_val.tar.gz', + 'dotav2.0_images_train.tar.gz', + 'dotav2.0_annotations_train.tar.gz', + 'dotav2.0_images_val.tar.gz', + 'dotav2.0_annotations_val.tar.gz', + 'samples.csv', + ] + for path in files: + shutil.copyfile( + os.path.join('tests', 'data', 'dota', path), + os.path.join(str(tmp_path), path), + ) + + DOTA(root=tmp_path) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, 'dotav1.0_images_train.tar.gz'), 'w') as f: + f.write('bad') + with pytest.raises(RuntimeError, match='Archive'): + DOTA(root=tmp_path, checksum=True) + + def test_not_downloaded(self, tmp_path: Path) -> None: + with pytest.raises(DatasetNotFoundError, match='Dataset not found'): + DOTA(tmp_path) + + def test_plot(self, dataset: DOTA) -> None: + x = dataset[1] + dataset.plot(x, suptitle='Test') + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index ebfd2e1d5c5..5a92ce04a36 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -39,6 +39,7 @@ from .digital_typhoon import DigitalTyphoon from .dior import DIOR from .dl4gam import DL4GAMAlps +from .dota import DOTA from .eddmaps import EDDMapS from .enmap import EnMAP from .enviroatlas import EnviroAtlas @@ -162,6 +163,7 @@ 'COWC', 'DFC2022', 'DIOR', + 'DOTA', 'ETCI2021', 'EUDEM', 'FAIR1M', diff --git a/torchgeo/datasets/dota.py b/torchgeo/datasets/dota.py new file mode 100644 index 00000000000..0be663a88dd --- /dev/null +++ b/torchgeo/datasets/dota.py @@ -0,0 +1,506 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""DOTA dataset.""" + +import os +from collections.abc import Callable +from typing import Any, ClassVar, Literal + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from matplotlib import patches +from matplotlib.figure import Figure +from PIL import Image +from torch import Tensor + +from .errors import DatasetNotFoundError +from .geo import NonGeoDataset +from .utils import ( + Path, + check_integrity, + download_url, + extract_archive, + percentile_normalization, +) + + +class DOTA(NonGeoDataset): + """DOTA dataset. + + `DOTA `__ is a large-scale object + detection dataset for aerial imagery containing RGB and gray-scale imagery + from Google Earth, GF-2 and JL-1 satellites as well as additional aerial imagery + from CycloMedia. There are three versions of the dataset: v1.0, v1.5, and v2.0, where, + v1.0 and v1.5 have the same images but different annotations, + and v2.0 extends both the images and annotations with more samples + + Dataset features: + + * 1869 samples in v1.0 and v1.5 and 2423 samples in v2.0 + * multi-class object detection (15 classes in v1.0 and v1.5 and 18 classes in v2.0) + * horizontal and oriented bounding boxes + + Dataset format: + + * images are three channel PNGs with various pixel sizes + * annotations are text files with one line per bounding box + + Classes: + + 0. plane + 1. ship + 2. storage-tank + 3. baseball-diamond + 4. tennis-court + 5. basketball-court + 6. ground-track-field + 7. harbor + 8. bridge + 9. large-vehicle + 10. small-vehicle + 11. helicopter + 12. roundabout + 13. soccer-ball-field + 14. swimming-pool + 15. container-crane (v1.5+) + 16. airport (v2.0+) + 17. helipad (v2.0+) + + If you use this work in your research, please cite the following papers: + + * https://arxiv.org/abs/2102.12219 + * https://arxiv.org/abs/1711.10398 + + .. versionadded:: 0.7 + """ + + url = 'https://huggingface.co/datasets/torchgeo/dota/resolve/672e63236622f7da6ee37fca44c50ac368b77cab/{}' + + file_info: ClassVar[dict[str, dict[str, dict[str, dict[str, str]]]]] = { + 'train': { + 'images': { + '1.0': { + 'filename': 'dotav1.0_images_train.tar.gz', + 'md5': '363b472dc3c71e7fa2f4a60223b437ea', + }, + '1.5': { + 'filename': 'dotav1.0_images_train.tar.gz', + 'md5': '363b472dc3c71e7fa2f4a60223b437ea', + }, + '2.0': { + 'filename': 'dotav2.0_images_train.tar.gz', + 'md5': '91ae5212d170330ab9f65ccb6c675763', + }, + }, + 'annotations': { + '1.0': { + 'filename': 'dotav1.0_annotations_train.tar.gz', + 'md5': 'f6788257bcc4d29018344a4128e3734a', + }, + '1.5': { + 'filename': 'dotav1.5_annotations_train.tar.gz', + 'md5': '0da97e5623a87d7bec22e75f6978dbce', + }, + '2.0': { + 'filename': 'dotav2.0_annotations_train.tar.gz', + 'md5': '04d3d626df2203053b7f06581b3b0667', + }, + }, + }, + 'val': { + 'images': { + '1.0': { + 'filename': 'dotav1.0_images_val.tar.gz', + 'md5': '42293219ba61d61c417ae558bbe1f2ba', + }, + '1.5': { + 'filename': 'dotav1.0_images_val.tar.gz', + 'md5': '42293219ba61d61c417ae558bbe1f2ba', + }, + '2.0': { + 'filename': 'dotav2.0_images_val.tar.gz', + 'md5': '737f65edf54b5aa627b3d48b0e253095', + }, + }, + 'annotations': { + '1.0': { + 'filename': 'dotav1.0_annotations_val.tar.gz', + 'md5': '28155c05b1dc3a0f5cb6b9bdfef85a13', + }, + '1.5': { + 'filename': 'dotav1.5_annotations_val.tar.gz', + 'md5': '85bf945788784cf9b4f1c714453178fc', + }, + '2.0': { + 'filename': 'dotav2.0_annotations_val.tar.gz', + 'md5': 'ec53c1dbcfc125d7532bd6a065c647ac', + }, + }, + }, + } + + sample_df_path = 'samples.csv' + + classes = ( + 'plane', + 'ship', + 'storage-tank', + 'baseball-diamond', + 'tennis-court', + 'basketball-court', + 'ground-track-field', + 'harbor', + 'bridge', + 'large-vehicle', + 'small-vehicle', + 'helicopter', + 'roundabout', + 'soccer-ball-field', + 'swimming-pool', + 'container-crane', + 'airport', + 'helipad', + ) + + valid_splits = ('train', 'val') + valid_versions = ('1.0', '1.5', '2.0') + + valid_orientations = ('horizontal', 'oriented') + + def __init__( + self, + root: Path = 'data', + split: Literal['train', 'val'] = 'train', + version: Literal['1.0', '1.5', '2.0'] = '2.0', + bbox_orientation: Literal['horizontal', 'oriented'] = 'oriented', + transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new DOTA dataset instance. + + Args: + root: root directory where dataset can be found + split: split of the dataset to use, one of ['train', 'val'] + version: version of the dataset to use, one of ['1.0', '1.5', '2.0'] + bbox_orientation: bounding box orientation, one of ['horizontal', 'oriented'], where horizontal + returnx xyxy format and oriented returns x1y1x2y2x3y3x4y4 format + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: if *split*, *version*, or *bbox_orientation* argument are not valid + DatasetNotFoundError: if dataset is not found or corrupted, and *download* is False + """ + assert split in self.valid_splits, ( + f"Split '{split}' not supported, use one of {self.valid_splits}" + ) + assert version in self.valid_versions, ( + f"Version '{version}' not supported, use one of {self.valid_versions}" + ) + + assert bbox_orientation in self.valid_orientations, ( + f'Bounding box orientation must be one of {self.valid_orientations}' + ) + + self.root = root + self.split = split + self.version = version + self.transforms = transforms + self.download = download + self.checksum = checksum + self.bbox_orientation = bbox_orientation + + self._verify() + + self.sample_df = pd.read_csv(os.path.join(self.root, 'samples.csv')) + self.sample_df['version'] = self.sample_df['version'].astype(str) + self.sample_df = self.sample_df[self.sample_df['split'] == self.split] + self.sample_df = self.sample_df[ + self.sample_df['version'] == self.version + ].reset_index(drop=True) + + def __len__(self) -> int: + """Return the number of samples in the dataset. + + Returns: + length of the dataset + """ + return len(self.sample_df) + + def __getitem__(self, index: int) -> dict[str, Any]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + sample_row = self.sample_df.iloc[index] + + sample = {'image': self._load_image(sample_row['image_path'])} + + boxes, labels = self._load_annotations(sample_row['annotation_path']) + + if self.bbox_orientation == 'horizontal': + sample['bbox_xyxy'] = boxes + else: + sample['bbox'] = boxes + sample['labels'] = labels + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_image(self, path: str) -> Tensor: + """Load image. + + Args: + path: path to image file + + Returns: + image: image tensor + """ + image = Image.open(os.path.join(self.root, path)).convert('RGB') + return torch.from_numpy(np.array(image).transpose(2, 0, 1)).float() + + def _load_annotations(self, path: str) -> tuple[Tensor, Tensor]: + """Load DOTA annotations from text file. + + Format: + x1 y1 x2 y2 x3 y3 x4 y4 class difficult + + Some files have 2 header lines that need to be skipped: + imagesource:GoogleEarth + gsd:0.146343590398 + + Args: + path: path to annotation file + + Returns: + tuple of: + boxes: tensor of shape (N, 8) with coordinates for oriented + and (N, 4) for horizontal + labels: tensor of shape (N,) with class indices + """ + with open(os.path.join(self.root, path)) as f: + lines = f.readlines() + + # Skip header if present + start_idx = 0 + if lines and lines[0].startswith('imagesource'): + start_idx = 2 + boxes = [] + labels = [] + + for line in lines[start_idx:]: + parts = line.strip().split(' ') + + # Always read 8 coordinates + coords = [float(p) for p in parts[:8]] + label = parts[8] + + labels.append(self.classes.index(label)) + + if self.bbox_orientation == 'horizontal': + # Convert to [xmin, ymin, xmax, ymax] format + x_coords = coords[::2] # even indices (0,2,4,6) + y_coords = coords[1::2] # odd indices (1,3,5,7) + xmin, xmax = min(x_coords), max(x_coords) + ymin, ymax = min(y_coords), max(y_coords) + boxes.append([xmin, ymin, xmax, ymax]) + else: + boxes.append(coords) + + if not boxes: + return ( + torch.zeros((0, 4 if self.bbox_orientation == 'horizontal' else 8)), + torch.zeros(0, dtype=torch.long), + ) + else: + return torch.tensor(boxes), torch.tensor(labels) + + def _verify(self) -> None: + """Verify dataset integrity and download/extract if needed.""" + # check if directories and sample file are present + required_dirs = [ + os.path.join(self.root, self.split, 'images'), + os.path.join( + self.root, self.split, 'annotations', f'version{self.version}' + ), + os.path.join(self.root, self.sample_df_path), + ] + if all(os.path.exists(d) for d in required_dirs): + return + + # Check for compressed files, v1.0 and v1.5 have the same images but different annotations + files_needed = [ + ( + self.file_info[self.split]['images'][self.version]['filename'], + self.file_info[self.split]['images'][self.version]['md5'], + ), + ( + self.file_info[self.split]['annotations'][self.version]['filename'], + self.file_info[self.split]['annotations'][self.version]['md5'], + ), + ] + # For v2.0, also need v1.0 image files, but only v2 annotations + if self.version == '2.0': + files_needed.append( + ( + self.file_info[self.split]['images']['1.0']['filename'], + self.file_info[self.split]['images']['1.0']['md5'], + ) + ) + + # Check if archives exist and verify checksums if requested + exists = [] + for filename, md5 in files_needed: + filepath = os.path.join(self.root, filename) + if os.path.exists(filepath): + if self.checksum: + if not check_integrity(filepath, md5): + raise RuntimeError(f'Archive {filename} corrupted') + exists.append(True) + self._extract([(filename, md5)]) + else: + exists.append(False) + + if all(exists): + return + + if not self.download: + raise DatasetNotFoundError(self) + + # also download the metadata file + self._download(files_needed) + self._extract(files_needed) + + def _download(self, files_needed: list[tuple[str, str]]) -> None: + """Download the dataset. + + Args: + files_needed: list of files to download for the particular version + """ + for filename, md5 in files_needed: + if not os.path.exists(os.path.join(self.root, filename)): + download_url( + url=self.url.format(filename), + root=self.root, + filename=filename, + md5=None if not self.checksum else md5, + ) + + if not os.path.exists(os.path.join(self.root, self.sample_df_path)): + download_url( + url=self.url.format(self.sample_df_path), + root=self.root, + filename=self.sample_df_path, + ) + + def _extract(self, files_needed: list[tuple[str, str]]) -> None: + """Extract the dataset. + + Args: + files_needed: list of files to extract for the particular version + """ + for filename, _ in files_needed: + filepath = os.path.join(self.root, filename) + extract_archive(filepath, self.root) + + def plot( + self, + sample: dict[str, Tensor], + show_titles: bool = True, + suptitle: str | None = None, + box_alpha: float = 0.7, + ) -> Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by __getitem__ + show_titles: flag indicating whether to show titles + suptitle: optional string to use as a suptitle + box_alpha: alpha value for boxes + + Returns: + a matplotlib Figure with the rendered sample + """ + image = percentile_normalization(sample['image'].permute(1, 2, 0).numpy()) + if self.bbox_orientation == 'horizontal': + boxes = sample['bbox_xyxy'].cpu().numpy() + else: + boxes = sample['bbox'].cpu().numpy() + labels = sample['labels'].cpu().numpy() + + fig, ax = plt.subplots(figsize=(10, 10)) + ax.imshow(image) + ax.axis('off') + + # Create color map for classes + cm = plt.get_cmap('gist_rainbow') + + for box, label_idx in zip(boxes, labels): + color = cm(label_idx / len(self.classes)) + label = self.classes[label_idx] + + if self.bbox_orientation == 'horizontal': + # Horizontal box: [xmin, ymin, xmax, ymax] + x1, y1, x2, y2 = box + rect = patches.Rectangle( + (x1, y1), + x2 - x1, + y2 - y1, + linewidth=2, + alpha=box_alpha, + linestyle='solid', + edgecolor=color, + facecolor='none', + ) + ax.add_patch(rect) + # Add label above box + ax.text( + x1, + y1 - 5, + label, + color='white', + fontsize=8, + bbox=dict(facecolor=color, alpha=box_alpha), + ) + else: + # Oriented box: [x1,y1,x2,y2,x3,y3,x4,y4] + vertices = box.reshape(4, 2) + polygon = patches.Polygon( + vertices, + linewidth=2, + alpha=box_alpha, + linestyle='solid', + edgecolor=color, + facecolor='none', + ) + ax.add_patch(polygon) + # Add label at centroid + centroid_x = vertices[:, 0].mean() + centroid_y = vertices[:, 1].mean() + ax.text( + centroid_x, + centroid_y, + label, + color='white', + fontsize=8, + bbox=dict(facecolor=color, alpha=box_alpha), + ha='center', + va='center', + ) + + if suptitle is not None: + plt.suptitle(suptitle) + + return fig