Skip to content

Commit

Permalink
CV4A Kenya Crop Type: radiant mlhub -> source cooperative (#2090)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart authored Jul 10, 2024
1 parent 83cad60 commit 32aa349
Show file tree
Hide file tree
Showing 21 changed files with 134 additions and 254 deletions.
5 changes: 5 additions & 0 deletions tests/data/cv4a_kenya_crop_type/FieldIds.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
train,test
1,2
3,4
5
6
51 changes: 51 additions & 0 deletions tests/data/cv4a_kenya_crop_type/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
from PIL import Image

DTYPE = np.float32
SIZE = 2

np.random.seed(0)

all_bands = (
'B01',
'B02',
'B03',
'B04',
'B05',
'B06',
'B07',
'B08',
'B8A',
'B09',
'B11',
'B12',
'CLD',
)

for tile in range(1):
directory = os.path.join('data', str(tile))
os.makedirs(directory, exist_ok=True)

arr = np.random.randint(np.iinfo(np.int32).max, size=(SIZE, SIZE), dtype=np.int32)
img = Image.fromarray(arr)
img.save(os.path.join(directory, f'{tile}_field_id.tif'))

arr = np.random.randint(np.iinfo(np.uint8).max, size=(SIZE, SIZE), dtype=np.uint8)
img = Image.fromarray(arr)
img.save(os.path.join(directory, f'{tile}_label.tif'))

for date in ['20190606']:
directory = os.path.join(directory, date)
os.makedirs(directory, exist_ok=True)

for band in all_bands:
arr = np.random.rand(SIZE, SIZE).astype(DTYPE) * np.finfo(DTYPE).max
img = Image.fromarray(arr)
img.save(os.path.join(directory, f'{tile}_{band}_{date}.tif'))
Binary file not shown.
Binary file added tests/data/cv4a_kenya_crop_type/data/0/0_label.tif
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
85 changes: 18 additions & 67 deletions tests/datasets/test_cv4a_kenya_crop_type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import glob
import os
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
Expand All @@ -18,44 +16,23 @@
DatasetNotFoundError,
RGBBandsMissingError,
)


class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
)
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)


def fetch(dataset_id: str, **kwargs: str) -> Collection:
return Collection()
from torchgeo.datasets.utils import Executable


class TestCV4AKenyaCropType:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CV4AKenyaCropType:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
source_md5 = '7f4dcb3f33743dddd73f453176308bfb'
labels_md5 = '95fc59f1d94a85ec00931d4d1280bec9'
monkeypatch.setitem(CV4AKenyaCropType.image_meta, 'md5', source_md5)
monkeypatch.setitem(CV4AKenyaCropType.target_meta, 'md5', labels_md5)
monkeypatch.setattr(
CV4AKenyaCropType, 'tile_names', ['ref_african_crops_kenya_02_tile_00']
)
def dataset(
self, azcopy: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
) -> CV4AKenyaCropType:
url = os.path.join('tests', 'data', 'cv4a_kenya_crop_type')
monkeypatch.setattr(CV4AKenyaCropType, 'url', url)
monkeypatch.setattr(CV4AKenyaCropType, 'tiles', list(map(str, range(1))))
monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606'])
monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2)
monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2)
root = str(tmp_path)
transforms = nn.Identity()
return CV4AKenyaCropType(
root,
transforms=transforms,
download=True,
api_key='',
checksum=True,
verbose=True,
)
return CV4AKenyaCropType(root, transforms=transforms, download=True)

def test_getitem(self, dataset: CV4AKenyaCropType) -> None:
x = dataset[0]
Expand All @@ -66,60 +43,34 @@ def test_getitem(self, dataset: CV4AKenyaCropType) -> None:
assert isinstance(x['y'], torch.Tensor)

def test_len(self, dataset: CV4AKenyaCropType) -> None:
assert len(dataset) == 345
assert len(dataset) == 1

def test_add(self, dataset: CV4AKenyaCropType) -> None:
ds = dataset + dataset
assert isinstance(ds, ConcatDataset)
assert len(ds) == 690

def test_get_splits(self, dataset: CV4AKenyaCropType) -> None:
train_field_ids, test_field_ids = dataset.get_splits()
assert isinstance(train_field_ids, list)
assert isinstance(test_field_ids, list)
assert len(train_field_ids) == 18
assert len(test_field_ids) == 9
assert 336 in train_field_ids
assert 336 not in test_field_ids
assert 4793 in test_field_ids
assert 4793 not in train_field_ids
assert len(ds) == 2

def test_already_downloaded(self, dataset: CV4AKenyaCropType) -> None:
CV4AKenyaCropType(root=dataset.root, download=True, api_key='')
CV4AKenyaCropType(root=dataset.root, download=True)

def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CV4AKenyaCropType(str(tmp_path))

def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None:
with pytest.raises(AssertionError):
dataset._load_label_tile('foo')

with pytest.raises(AssertionError):
dataset._load_all_image_tiles('foo', ('B01', 'B02'))

with pytest.raises(AssertionError):
dataset._load_single_image_tile('foo', '20190606', ('B01', 'B02'))

def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):
CV4AKenyaCropType(bands=['B01', 'B02']) # type: ignore[arg-type]

with pytest.raises(ValueError, match='is an invalid band name.'):
CV4AKenyaCropType(bands=('foo', 'bar'))

def test_plot(self, dataset: CV4AKenyaCropType) -> None:
dataset.plot(dataset[0], time_step=0, suptitle='Test')
plt.close()

sample = dataset[0]
dataset.plot(sample, time_step=0, suptitle='Test')
plt.close()
sample['prediction'] = sample['mask'].clone()
dataset.plot(sample, time_step=0, suptitle='Pred')
plt.close()

def test_plot_rgb(self, dataset: CV4AKenyaCropType) -> None:
dataset = CV4AKenyaCropType(root=dataset.root, bands=tuple(['B01']))
with pytest.raises(
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):
dataset.plot(dataset[0], time_step=0, suptitle='Single Band')
match = 'Dataset does not contain some of the RGB bands'
with pytest.raises(RGBBandsMissingError, match=match):
dataset.plot(dataset[0])
Loading

0 comments on commit 32aa349

Please sign in to comment.