diff --git a/datasets/flwr_datasets/partitioner/image_semantic_partitioner.py b/datasets/flwr_datasets/partitioner/image_semantic_partitioner.py index 5632ce41b070..26ea0a7e750c 100644 --- a/datasets/flwr_datasets/partitioner/image_semantic_partitioner.py +++ b/datasets/flwr_datasets/partitioner/image_semantic_partitioner.py @@ -391,9 +391,7 @@ def _determine_partition_id_to_indices_if_needed(self) -> None: def _preprocess_dataset_images(self, indices: List[int]) -> NDArrayFloat: """Preprocess the images in the dataset.""" - images = np.array( - self.dataset[indices][self._image_column_name], dtype=np.float32 - ) + images = np.array(self.dataset[indices][self._image_column_name], dtype=float) if len(images.shape) == 3: # [B, H, W] images = np.reshape( images, (images.shape[0], 1, images.shape[1], images.shape[2])