diff --git a/torchgeo/datasets/mdas.py b/torchgeo/datasets/mdas.py index ff1bf672046..ea0bb30e292 100644 --- a/torchgeo/datasets/mdas.py +++ b/torchgeo/datasets/mdas.py @@ -11,7 +11,6 @@ import numpy as np import rasterio as rio import torch -from matplotlib.colormaps import get_cmap from matplotlib.colors import ListedColormap from matplotlib.figure import Figure from torch import Tensor @@ -356,7 +355,7 @@ def plot( axs[idx].imshow(img) case 'osm_landuse_mask': img = data.numpy().squeeze(0) - cmap = ListedColormap([get_cmap('tab20')(i) for i in range(20)]) + cmap = ListedColormap([plt.get_cmap('tab20')(i) for i in range(20)]) im = axs[idx].imshow(img, cmap=cmap) cbar = plt.colorbar(im, ax=axs[idx], ticks=range(19)) cbar.ax.set_yticklabels(