diff --git a/README.md b/README.md index 350fac9418..be82acc678 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ bash scripts/download_data.sh The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white. -You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. +You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. For example, loading RGB-coded black and white masks will not work with the data-loading code as-is. --- diff --git a/utils/data_loading.py b/utils/data_loading.py index 24893b02db..1e18817681 100644 --- a/utils/data_loading.py +++ b/utils/data_loading.py @@ -33,7 +33,11 @@ def preprocess(pil_img, scale, is_mask): pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) img_ndarray = np.asarray(pil_img) - if not is_mask: + if is_mask: + if img_ndarray.ndim > 2: + # customize this function if you want it to support RGB masks + raise RuntimeError("Only black-and-white images are supported as masks.") + else: if img_ndarray.ndim == 2: img_ndarray = img_ndarray[np.newaxis, ...] else: