Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
lccol authored Jan 22, 2025
2 parents 63bbbd4 + ada5122 commit 99829ad
Show file tree
Hide file tree
Showing 33 changed files with 47 additions and 107 deletions.
2 changes: 0 additions & 2 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ updates:
# https://github.com/pytorch/pytorch_sphinx_theme/issues/175
- dependency-name: 'sphinx'
versions: '>=6'
# segmentation-models-pytorch pins timm, must update in unison
- dependency-name: 'timm'
- package-ecosystem: 'npm'
directory: '/'
schedule:
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ jobs:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/required.txt', 'requirements/datasets.txt', 'requirements/tests.txt') }}
if: ${{ runner.os != 'macOS' }}
- name: Setup headless display for pyvista
uses: pyvista/setup-headless-display-action@v3
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
Expand Down Expand Up @@ -68,8 +66,6 @@ jobs:
with:
path: ${{ env.pythonLocation }}
key: ${{ runner.os }}-${{ runner.arch }}-Python-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('requirements/min-reqs.old') }}
- name: Setup headless display for pyvista
uses: pyvista/setup-headless-display-action@v3
- name: Install pip dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
Expand Down
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ repos:
- numpy>=1.22
- pillow>=10.4.0
- pytest>=6.1.2
- pyvista>=0.34.2
- scikit-image>=0.22.0
- torch>=2.3
- torchmetrics>=0.10
Expand Down
2 changes: 1 addition & 1 deletion docs/api/datasets/non_geo_datasets.csv
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
`VHR-10`_,I,"Google Earth, Vaihingen","MIT",800,10,"358--1,728",0.08--2,RGB
`Western USA Live Fuel Moisture`_,R,"Landsat8, Sentinel-1","CC-BY-NC-ND-4.0",2615,-,-,-,-
`xView2`_,CD,Maxar,"CC-BY-NC-SA-4.0","3,732",4,"1,024x1,024",0.8,RGB
`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI
`ZueriCrop`_,"I, T",Sentinel-2,CC-BY-NC-4.0,116K,48,24x24,10,MSI
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
'numpy': ('https://numpy.org/doc/stable/', None),
'python': ('https://docs.python.org/3', None),
'lightning': ('https://lightning.ai/docs/pytorch/stable/', None),
'pyvista': ('https://docs.pyvista.org/version/stable/', None),
'rasterio': ('https://rasterio.readthedocs.io/en/stable/', None),
'rtree': ('https://rtree.readthedocs.io/en/stable/', None),
'segmentation_models_pytorch': ('https://smp.readthedocs.io/en/stable/', None),
Expand Down
2 changes: 1 addition & 1 deletion experiments/ssl4eo/flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
for model in models:
print(f'Model: {model}')

m = timm.create_model(model, num_classes=num_classes, in_chans=in_channels)
m = timm.create_model(model, num_classes=num_classes, in_chans=in_channels) # type: ignore[attr-defined]

# Calculate memory requirements of model
mem_params = sum([p.nelement() * p.element_size() for p in m.parameters()])
Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ datasets = [
"pandas[parquet]>=2",
# pycocotools 2.0.7+ required for wheels
"pycocotools>=2.0.7",
# pyvista 0.34.2+ required to avoid ImportError in CI
"pyvista>=0.34.2",
# scikit-image 0.19+ required for Python 3.10 wheels
"scikit-image>=0.19",
# scipy 1.7.2+ required for Python 3.10 wheels
Expand Down
3 changes: 1 addition & 2 deletions requirements/datasets.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# datasets
h5py==3.12.1
laspy==2.5.4
opencv-python==4.10.0.84
opencv-python==4.11.0.86
pandas[parquet]==2.2.3
pycocotools==2.0.8
pyvista==0.44.2
scikit-image==0.25.0
scipy==1.15.1
2 changes: 0 additions & 2 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ laspy==2.0.0
opencv-python==4.5.4.58
pycocotools==2.0.7
pyarrow==15.0.0 # Remove when we upgrade min version of pandas to `pandas[parquet]>=2`
pyvista==0.34.2
scikit-image==0.19.0
scipy==1.7.2
vtk==9.3.1 # PyVista is not yet compatible with VTK 9.4+

# tests
pytest==7.3.0
Expand Down
8 changes: 4 additions & 4 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ setuptools==75.8.0
einops==0.8.0
fiona==1.10.1
kornia==0.8.0
lightly==1.5.16
lightly==1.5.17
lightning[pytorch-extra]==2.5.0.post0
matplotlib==3.10.0
numpy==2.2.1
numpy==2.2.2
pandas==2.2.3
pillow==11.1.0
pyproj==3.7.0
rasterio==1.4.3
rtree==1.3.0
segmentation-models-pytorch==0.3.4
segmentation-models-pytorch==0.4.0
shapely==2.0.6
timm==0.9.7
timm==1.0.14
torch==2.5.1
torchmetrics==1.6.1
torchvision==0.20.1
2 changes: 1 addition & 1 deletion requirements/style.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# style
mypy==1.14.1
ruff==0.9.1
ruff==0.9.2
8 changes: 0 additions & 8 deletions tests/datasets/test_idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,3 @@ def test_plot(self, dataset: IDTReeS) -> None:
x['prediction_label'] = x['label']
dataset.plot(x, show_titles=False)
plt.close()

def test_plot_las(self, dataset: IDTReeS) -> None:
pyvista = pytest.importorskip('pyvista', minversion='0.34.2')
pyvista.OFF_SCREEN = True

# Test point cloud without colors
point_cloud = dataset.plot_las(index=0)
pyvista.plot(point_cloud, scalars=point_cloud.points, cpos='yz', cmap='viridis')
12 changes: 3 additions & 9 deletions tests/datasets/test_zuericrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,11 @@
class TestZueriCrop:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> ZueriCrop:
data_dir = os.path.join('tests', 'data', 'zuericrop')
urls = [
os.path.join(data_dir, 'ZueriCrop.hdf5'),
os.path.join(data_dir, 'labels.csv'),
]
md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b']
monkeypatch.setattr(ZueriCrop, 'urls', urls)
monkeypatch.setattr(ZueriCrop, 'md5s', md5s)
url = os.path.join('tests', 'data', 'zuericrop') + os.sep
monkeypatch.setattr(ZueriCrop, 'url', url)
root = tmp_path
transforms = nn.Identity()
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)
return ZueriCrop(root=root, transforms=transforms, download=True)

def test_getitem(self, dataset: ZueriCrop) -> None:
x = dataset[0]
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet18', in_chans=weights.meta['in_chans'])
model = timm.create_model('resnet18', in_chans=weights.meta['in_chans']) # type: ignore[attr-defined]
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
Expand Down Expand Up @@ -78,7 +78,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans'])
model = timm.create_model('resnet50', in_chans=weights.meta['in_chans']) # type: ignore[attr-defined]
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
Expand Down Expand Up @@ -122,7 +122,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model('resnet152', in_chans=weights.meta['in_chans'])
model = timm.create_model('resnet152', in_chans=weights.meta['in_chans']) # type: ignore[attr-defined]
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_moco.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
4 changes: 2 additions & 2 deletions tests/trainers/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down Expand Up @@ -273,7 +273,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def mocked_weights(
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
model = timm.create_model( # type: ignore[attr-defined]
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_extract_backbone_unsupported_model(tmp_path: Path) -> None:


def test_get_input_layer_name_and_module() -> None:
key, module = _get_input_layer_name_and_module(timm.create_model('resnet18'))
key, module = _get_input_layer_name_and_module(timm.create_model('resnet18')) # type: ignore[attr-defined]
assert key == 'conv1'
assert isinstance(module, nn.Conv2d)
assert module.in_channels == 3
Expand Down
36 changes: 1 addition & 35 deletions torchgeo/datasets/idtrees.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ class IDTReeS(NonGeoDataset):
* https://doi.org/10.1101/2021.08.06.453503
This dataset requires the following additional libraries to be installed:
This dataset requires the following additional library to be installed:
* `laspy <https://pypi.org/project/laspy/>`_ to read lidar point clouds
* `pyvista <https://pypi.org/project/pyvista/>`_ to plot lidar point clouds
.. versionadded:: 0.2
"""
Expand Down Expand Up @@ -552,36 +551,3 @@ def normalize(x: Tensor) -> Tensor:
plt.suptitle(suptitle)

return fig

def plot_las(self, index: int) -> 'pyvista.Plotter': # type: ignore[name-defined] # noqa: F821
"""Plot a sample point cloud at the index.
Args:
index: index to plot
Returns:
pyvista.PolyData object. Run pyvista.plot(point_cloud, ...) to display
Raises:
DependencyNotFoundError: If laspy or pyvista are not installed.
.. versionchanged:: 0.4
Ported from Open3D to PyVista, *colormap* parameter removed.
"""
laspy = lazy_import('laspy')
pyvista = lazy_import('pyvista')
path = self.images[index]
path = path.replace('RGB', 'LAS').replace('.tif', '.las')
las = laspy.read(path)
points: np.typing.NDArray[np.int_] = np.stack(
[las.x, las.y, las.z], axis=0
).transpose((1, 0))
point_cloud = pyvista.PolyData(points)

# Some point cloud files have no color->points mapping
if hasattr(las, 'red'):
colors = np.stack([las.red, las.green, las.blue], axis=0)
colors = colors.transpose((1, 0)) / np.iinfo(np.uint16).max
point_cloud['colors'] = colors

return point_cloud
2 changes: 1 addition & 1 deletion torchgeo/datasets/satlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class _Task(TypedDict, total=False):
class SatlasPretrain(NonGeoDataset):
"""SatlasPretrain dataset.
`SatlasPretrain <https://satlas-pretrain.allen.ai/>`_ is a large-scale pre-training
`SatlasPretrain <https://satlas-pretrain.allen.ai/>`__ is a large-scale pre-training
dataset for tasks that involve understanding satellite images. Regularly-updated
satellite data is publicly available for much of the Earth through sources such as
Sentinel-2 and NAIP, and can inform numerous applications from tackling illegal
Expand Down
9 changes: 3 additions & 6 deletions torchgeo/datasets/zuericrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ class ZueriCrop(NonGeoDataset):
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
"""

urls = (
'https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download',
'https://raw.githubusercontent.com/0zgur0/multi-stage-convSTAR-network/fa92b5b3cb77f5171c5c3be740cd6e6395cc29b6/labels.csv',
)
url = 'https://hf.co/datasets/torchgeo/zuericrop/resolve/8ac0f416fbaab032d8670cc55f984b9f079e86b2/'
md5s = ('1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b')
filenames = ('ZueriCrop.hdf5', 'labels.csv')

Expand Down Expand Up @@ -221,11 +218,11 @@ def _verify(self) -> None:

def _download(self) -> None:
"""Download the dataset."""
for url, filename, md5 in zip(self.urls, self.filenames, self.md5s):
for filename, md5 in zip(self.filenames, self.md5s):
filepath = os.path.join(self.root, filename)
if not os.path.exists(filepath):
download_url(
url,
self.url + filename,
self.root,
filename=filename,
md5=md5 if self.checksum else None,
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def resnet18(
if weights:
kwargs['in_chans'] = weights.meta['in_chans']

model: ResNet = timm.create_model('resnet18', *args, **kwargs)
model: ResNet = timm.create_model('resnet18', *args, **kwargs) # type: ignore[attr-defined]

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand Down Expand Up @@ -803,7 +803,7 @@ def resnet50(
if weights:
kwargs['in_chans'] = weights.meta['in_chans']

model: ResNet = timm.create_model('resnet50', *args, **kwargs)
model: ResNet = timm.create_model('resnet50', *args, **kwargs) # type: ignore[attr-defined]

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand Down Expand Up @@ -837,7 +837,7 @@ def resnet152(
if weights:
kwargs['in_chans'] = weights.meta['in_chans']

model: ResNet = timm.create_model('resnet152', *args, **kwargs)
model: ResNet = timm.create_model('resnet152', *args, **kwargs) # type: ignore[attr-defined]

if weights:
missing_keys, unexpected_keys = model.load_state_dict(
Expand Down
Loading

0 comments on commit 99829ad

Please sign in to comment.