From 898a370dfc6286d75491abf436d4ad3bd759c3e6 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 15 Jan 2024 13:47:55 +0100 Subject: [PATCH 01/16] Add shard partitioner --- .../partitioner/shard_partitioner.py | 249 ++++++++++++++++++ 1 file changed, 249 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/shard_partitioner.py diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py new file mode 100644 index 000000000000..039bcd844575 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -0,0 +1,249 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Shard partitioner class.""" +# pylint: disable=R0912 +import math +from typing import Dict, List, Optional + +import numpy as np + +import datasets +from flwr_datasets.partitioner.partitioner import Partitioner + + +class ShardPartitioner(Partitioner): # pylint: disable=R0902 + """Partitioner based on shard of (typically) unique classes. + + The algorithm works as follows: the dataset is sorted by label e.g. [samples with + label 1, samples with labels 2 ...], then the shards are created, with each + shard of size - shard_size if provided or automatically calculated: + shards_size = len(dataset) / num_partitions * num_shards_per_node. + Each partition is created from `num_shards_per_node` that are chosen randomly. + + There are a few ways of partitioning data that result in certain properties + (depending on the parameters specification): + 1) same number of shards per nodes + the same shard size (specify: + a) num_shards_per_nodes, shard_size or b) num_shards_per_node) + In case of b the shard_size is calculated as floor(len(dataset) / + (num_shards_per_nodes * num_partitions)) + 2) possibly different number of shards per node (use nearly all data) + the same + shard size (specify: shard_size + keep_incomplete_shard=False) + 3) possibly different number of shards per node (use all data) + possibly different + shard size (specify: shard_size + keep_incomplete_shard=True) + + + Algorithm based on the description in Communication-Efficient Learning of Deep + Networks from Decentralized Data https://arxiv.org/abs/1602.05629. This + implementation expands on the initial idea by enabling more hyperparameters + specification therefore providing more control of the partitions. It enables the + division obtained in original paper. + + Parameters + ---------- + num_partitions : int + The total number of partitions that the data will be divided into. + partition_by : str + Column name of the labels (targets) based on which Dirichlet sampling works. + num_shards_per_node : Optional[int] + Number of shards to assign to a single partitioner. It's an alternative to + num_partitions. + shard_size : Optional[int] + Size of a single shards (a partition has one or more shards). If the size is not + given it will be automatically computed such that. + keep_incomplete_shard : bool + Weather to drop the last shard which might be incomplete (smaller than the + others). If it is dropped each shard is equal size. (It does not mean that each + client gets equal number of shards, which only happens if + num_partitions % num_shards = 0). This parameter has no effect if + num_shards_per_nodes and shard_size are specified. + shuffle: bool + Whether to randomize the order of samples. Shuffling applied after the + samples assignment to nodes. + seed: int + Seed used for dataset shuffling. It has no effect if `shuffle` is False. + """ + + def __init__( # pylint: disable=R0913 + self, + num_partitions: int, + partition_by: str, + num_shards_per_node: Optional[int], + keep_incomplete_shard: bool, + shard_size: Optional[int] = None, + shuffle: bool = True, + seed: Optional[int] = 42, + ) -> None: + super().__init__() + # Attributes based on the constructor + self._num_partitions = num_partitions + self._check_num_partitions_greater_than_zero() + self._partition_by = partition_by + self._num_shards_per_node = num_shards_per_node + self._total_num_shards: Optional[int] = None + self._shard_size = shard_size + self._keep_incomplete_shard = keep_incomplete_shard + self._shuffle = shuffle + self._seed = seed + + # Utility attributes + self._rng = np.random.default_rng(seed=self._seed) # NumPy random generator + self._node_id_to_indices: Dict[int, List[int]] = {} + self._node_id_to_indices_determined = False + + def load_partition(self, node_id: int) -> datasets.Dataset: + """Load a partition based on the partition index. + + Parameters + ---------- + node_id : int + the index that corresponds to the requested partition + + Returns + ------- + dataset_partition : Dataset + single partition of a dataset + """ + # The partitioning is done lazily - only when the first partition is + # requested. Only the first call creates the indices assignments for all the + # partition indices. + self._check_num_partitions_correctness_if_needed() + self._check_possibility_of_partitions_creation() + self._sort_dataset_if_needed() + self._determine_node_id_to_indices_if_needed() + return self.dataset.select(self._node_id_to_indices[node_id]) + + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 + """Assign sample indices to each node id. + + This method works on sorted datasets. "Shard" is a part of the dataset of the + consecutive samples (if self._keep_incomplete_shard is False, each shard is same + size). + """ + if self._node_id_to_indices_determined: + return + if self._num_shards_per_node is not None: + self._total_num_shards = int( + self._num_partitions * self._num_shards_per_node + ) + num_shards_per_node_array = ( + np.ones(self._num_partitions) * self._num_shards_per_node + ) + self._compute_shard_size_if_missing() + elif self._num_shards_per_node is None: + if self._shard_size is None: + raise ValueError( + "The shard_size needs to be specified if the " + "num_shards_per_node is None" + ) + if self._keep_incomplete_shard is False: + self._total_num_shards = int( + math.floor(len(self.dataset) / self._shard_size) + ) + elif self._keep_incomplete_shard is True: + self._total_num_shards = int( + math.ceil(len(self.dataset) / self._shard_size) + ) + else: + raise ValueError( + "The keep_incomplete_shards need to be specified " + "when _num_shards_per_node is None." + ) + num_shards_per_node = int(self._total_num_shards / self._num_partitions) + # Assign the shards per nodes (so far, the same as in ideal case) + num_shards_per_node_array = ( + np.ones(self._num_partitions) * num_shards_per_node + ) + num_shards_assigned = self._num_partitions * num_shards_per_node + num_shards_to_assign = self._total_num_shards - num_shards_assigned + # Assign the "missing" shards + for i in range(num_shards_to_assign): + num_shards_per_node_array[i] += 1 + + else: + raise ValueError( + "The specification of nm_shards_per_node and " + "keep_incomplete_shards is not correct." + ) + + indices_on_which_to_split_shards = np.cumsum( + num_shards_per_node_array, dtype=int + ) + shard_indices_array = np.random.permutation(self._total_num_shards) + # Randomly assign shards to node_id + nid_to_shard_indices = np.split( + shard_indices_array, indices_on_which_to_split_shards + )[:-1] + node_id_to_indices: Dict[int, List[int]] = { + cid: [] for cid in range(self._num_partitions) + } + # Compute node_id to sample indices based on the shard indices + for node_id in range(self._num_partitions): + for shard_idx in nid_to_shard_indices[node_id]: + start_id = int(shard_idx * self._shard_size) + end_id = min(int((shard_idx + 1) * self._shard_size), len(self.dataset)) + node_id_to_indices[node_id].extend(list(range(start_id, end_id))) + if self._shuffle: + for indices in node_id_to_indices.values(): + # In place shuffling + self._rng.shuffle(indices) + self._node_id_to_indices = node_id_to_indices + self._node_id_to_indices_determined = True + + def _check_num_partitions_correctness_if_needed(self) -> None: + """Test num_partitions when the dataset is given (in load_partition).""" + if not self._node_id_to_indices_determined: + if self._num_partitions > self.dataset.num_rows: + raise ValueError( + "The number of partitions needs to be smaller than the number of " + "samples in the dataset." + ) + + def _check_num_partitions_greater_than_zero(self) -> None: + """Test num_partition left sides correctness.""" + if not self._num_partitions > 0: + raise ValueError("The number of partitions needs to be greater than zero.") + + def _sort_dataset_if_needed(self) -> None: + """Sort dataset prior to determining the partitions. + + Operation only needed to be performed one time. It's required for the creation + of shards with the same labels. + """ + if self._node_id_to_indices_determined: + return + self._dataset = self.dataset.sort(self._partition_by) + + def _compute_shard_size_if_missing(self) -> None: + """Compute the parameters needed to perform sharding. + + This method should be called after the dataset is assigned. + """ + if self._shard_size is None: + # If shard size is not specified it needs to be computed + num_rows = self.dataset.num_rows + self._shard_size = int(num_rows / self._total_num_shards) + + def _check_possibility_of_partitions_creation(self) -> None: + if self._shard_size is not None and self._num_shards_per_node is not None: + implied_min_dataset_size = ( + self._shard_size * self._num_shards_per_node * self._num_partitions + ) + if implied_min_dataset_size > len(self.dataset): + raise ValueError( + f"Based on the given arguments the creation of the " + "partitions is impossible. The implied minimum dataset" + f"size is {implied_min_dataset_size} but the dataset" + f"size is {len(self.dataset)}" + ) From 1e68e0cea073779ef95e29ea32028eecad9c1c46 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 15 Jan 2024 13:48:03 +0100 Subject: [PATCH 02/16] Add shard partitioner tests --- .../partitioner/shard_partitioner_test.py | 371 ++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 datasets/flwr_datasets/partitioner/shard_partitioner_test.py diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py new file mode 100644 index 000000000000..1718b1409dc9 --- /dev/null +++ b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py @@ -0,0 +1,371 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test DirichletPartitioner.""" +# pylint: disable=W0212, R0913 +import unittest +from typing import Optional, Tuple + +from datasets import Dataset +from flwr_datasets.partitioner.shard_partitioner import ShardPartitioner + + +def _dummy_setup( + num_rows: int, + partition_by: str, + num_partitions: int, + num_shards_per_node: Optional[int], + shard_size: Optional[int], + keep_incomplete_shard: bool = False, +) -> Tuple[Dataset, ShardPartitioner]: + """Create a dummy dataset for testing..""" + data = { + partition_by: [i % 3 for i in range(num_rows)], + "features": list(range(num_rows)), + } + dataset = Dataset.from_dict(data) + partitioner = ShardPartitioner( + num_partitions=num_partitions, + num_shards_per_node=num_shards_per_node, + partition_by=partition_by, + shard_size=shard_size, + keep_incomplete_shard=keep_incomplete_shard, + ) + partitioner.dataset = dataset + return dataset, partitioner + + +class TestShardPartitionerSpec1(unittest.TestCase): + """Test first possible initialization of ShardPartitioner. + + Specify num_shards_per_node and shard_size arguments. + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [30, 30, 30]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec2(unittest.TestCase): + """Test second possible initialization of ShardPartitioner. + + Specify shard_size and keep_incomplete_shard=False. This setting creates partitions + that might have various sizes (each shard is same size). + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [30, 40, 40]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec3(unittest.TestCase): + """Test third possible initialization of ShardPartitioner. + + Specify shard_size and keep_incomplete_shard=True. This setting creates partitions + that might have various sizes (each shard is same size). + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [33, 40, 40]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = True + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerSpec4(unittest.TestCase): + """Test fourth possible initialization of ShardPartitioner. + + Specify num_shards_per_node but not shard_size arguments. + """ + + def test_correct_num_partitions(self) -> None: + """Test the correct number of partitions is created.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + _ = partitioner.load_partition(0) + num_partitions_created = len(partitioner._node_id_to_indices.keys()) + self.assertEqual(num_partitions_created, num_partitions) + + def test_correct_partition_sizes(self) -> None: + """Test if the partitions sizes are as theoretically calculated.""" + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + sizes = [len(partitioner.load_partition(i)) for i in range(num_partitions)] + sizes = sorted(sizes) + self.assertEqual(sizes, [36, 36, 36]) + + def test_unique_samples(self) -> None: + """Test if each partition has unique samples. + + (No duplicates along partitions). + """ + partition_by = "label" + num_rows = 113 + num_partitions = 3 + num_shards_per_node = 3 + shard_size = None + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + partitions = [ + partitioner.load_partition(i)["features"] for i in range(num_partitions) + ] + combined_list = [item for sublist in partitions for item in sublist] + combined_set = set(combined_list) + self.assertEqual(len(combined_list), len(combined_set)) + + +class TestShardPartitionerIncorrectSpec(unittest.TestCase): + """Test the incorrect specification cases. + + The lack of correctness can be caused by the num_partitions, shard_size and + num_shards_per_partition can create. + """ + + def test_incorrect_specification(self) -> None: + """Test if the given specification makes the partitioning possible.""" + partition_by = "label" + num_rows = 10 + num_partitions = 3 + num_shards_per_node = 2 + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(0) + + +if __name__ == "__main__": + unittest.main() From 529b3d79c7b5be3848a6654985b21124ad3b3ac7 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 12:53:10 +0100 Subject: [PATCH 03/16] Make ShardPartitioner visible from the package level --- datasets/flwr_datasets/partitioner/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py index 5e7c86718f67..bdfbfe937e04 100644 --- a/datasets/flwr_datasets/partitioner/__init__.py +++ b/datasets/flwr_datasets/partitioner/__init__.py @@ -20,6 +20,7 @@ from .linear_partitioner import LinearPartitioner from .natural_id_partitioner import NaturalIdPartitioner from .partitioner import Partitioner +from .shard_partitioner import ShardPartitioner from .size_partitioner import SizePartitioner from .square_partitioner import SquarePartitioner @@ -30,5 +31,6 @@ "SizePartitioner", "LinearPartitioner", "SquarePartitioner", + "ShardPartitioner", "ExponentialPartitioner", ] From d28a710af656d2524f5ddecf9029098848b697ff Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 13:48:01 +0100 Subject: [PATCH 04/16] Add examples to the docs --- .../partitioner/shard_partitioner.py | 46 +++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 039bcd844575..5e4973468195 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -73,15 +73,55 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 samples assignment to nodes. seed: int Seed used for dataset shuffling. It has no effect if `shuffle` is False. - """ + Examples + -------- + 1) If you need same number of shards per nodes + the same shard size (and you know + both of these values) + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", + >>> num_shards_per_node=2, shard_size=1_000) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition = fds.load_partition(0) + >>> print(partition[0]) # Print the first example + {'image': , + 'label': 3} + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(partition_sizes) + [2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000] + + 2) If you want to use nearly all the data and do not need to have the number of + shard per each node + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=9, partition_by="label", + >>> shard_size=1_000) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(9)] + >>> print(partition_sizes) + [7000, 7000, 7000, 7000, 7000, 7000, 6000, 6000, 6000] + + 3) If you want ot use all the data + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.partitioner import ShardPartitioner + >>> + >>> partitioner = ShardPartitioner(num_partitions=10, partition_by="label", + >>> shard_size=990, keep_incomplete_shard=True) + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner}) + >>> partition_sizes = [len(fds.load_partition(node_id)) for node_id in range(10)] + >>> print(sorted(partition_sizes)) + [5550, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 6930] + """ def __init__( # pylint: disable=R0913 self, num_partitions: int, partition_by: str, - num_shards_per_node: Optional[int], - keep_incomplete_shard: bool, + num_shards_per_node: Optional[int] = None, shard_size: Optional[int] = None, + keep_incomplete_shard: bool = False, shuffle: bool = True, seed: Optional[int] = 42, ) -> None: From bb90fd99fcf8741ecbd6d9dd4673befb0bf56ed2 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 16 Jan 2024 13:57:37 +0100 Subject: [PATCH 05/16] Fix formatting --- datasets/flwr_datasets/partitioner/shard_partitioner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 5e4973468195..bcdeeeec34b0 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -115,6 +115,7 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 >>> print(sorted(partition_sizes)) [5550, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 5940, 6930] """ + def __init__( # pylint: disable=R0913 self, num_partitions: int, From c38d80125cc5e85e95bebb912dfeafbc5540fc53 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:07:18 +0100 Subject: [PATCH 06/16] Apply text formatting suggestions from code review Co-authored-by: Javier --- .../partitioner/shard_partitioner.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index bcdeeeec34b0..78d835ac7512 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -28,27 +28,27 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 The algorithm works as follows: the dataset is sorted by label e.g. [samples with label 1, samples with labels 2 ...], then the shards are created, with each - shard of size - shard_size if provided or automatically calculated: - shards_size = len(dataset) / num_partitions * num_shards_per_node. + shard of size = `shard_size` if provided or automatically calculated: + shards_size = len(dataset) / `num_partitions` * `num_shards_per_node`. Each partition is created from `num_shards_per_node` that are chosen randomly. There are a few ways of partitioning data that result in certain properties (depending on the parameters specification): 1) same number of shards per nodes + the same shard size (specify: - a) num_shards_per_nodes, shard_size or b) num_shards_per_node) - In case of b the shard_size is calculated as floor(len(dataset) / - (num_shards_per_nodes * num_partitions)) + a) `num_shards_per_nodes`, `shard_size`; or b) `num_shards_per_node`) + In case of b the `shard_size` is calculated as floor(len(dataset) / + (`num_shards_per_nodes` * `num_partitions`)) 2) possibly different number of shards per node (use nearly all data) + the same - shard size (specify: shard_size + keep_incomplete_shard=False) + shard size (specify: `shard_size` + `keep_incomplete_shard=False`) 3) possibly different number of shards per node (use all data) + possibly different - shard size (specify: shard_size + keep_incomplete_shard=True) + shard size (specify: `shard_size` + `keep_incomplete_shard=True`) - Algorithm based on the description in Communication-Efficient Learning of Deep + Algorithm based on the description in Communication-Efficient Learning of Deep Networks from Decentralized Data https://arxiv.org/abs/1602.05629. This implementation expands on the initial idea by enabling more hyperparameters - specification therefore providing more control of the partitions. It enables the - division obtained in original paper. + specification therefore providing more control on how partitions are created. + It enables the division obtained in original paper. Parameters ---------- @@ -58,7 +58,7 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 Column name of the labels (targets) based on which Dirichlet sampling works. num_shards_per_node : Optional[int] Number of shards to assign to a single partitioner. It's an alternative to - num_partitions. + `num_partitions`. shard_size : Optional[int] Size of a single shards (a partition has one or more shards). If the size is not given it will be automatically computed such that. @@ -66,8 +66,8 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 Weather to drop the last shard which might be incomplete (smaller than the others). If it is dropped each shard is equal size. (It does not mean that each client gets equal number of shards, which only happens if - num_partitions % num_shards = 0). This parameter has no effect if - num_shards_per_nodes and shard_size are specified. + `num_partitions` % `num_shards` = 0). This parameter has no effect if + `num_shards_per_nodes` and `shard_size` are specified. shuffle: bool Whether to randomize the order of samples. Shuffling applied after the samples assignment to nodes. @@ -93,7 +93,7 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 [2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000] 2) If you want to use nearly all the data and do not need to have the number of - shard per each node + shard per each node to be the same >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.partitioner import ShardPartitioner >>> @@ -104,7 +104,7 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 >>> print(partition_sizes) [7000, 7000, 7000, 7000, 7000, 7000, 6000, 6000, 6000] - 3) If you want ot use all the data + 3) If you want to use all the data >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.partitioner import ShardPartitioner >>> @@ -168,7 +168,7 @@ def load_partition(self, node_id: int) -> datasets.Dataset: def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 """Assign sample indices to each node id. - This method works on sorted datasets. "Shard" is a part of the dataset of the + This method works on sorted datasets. A "shard" is a part of the dataset of consecutive samples (if self._keep_incomplete_shard is False, each shard is same size). """ From 5408c4f10e9c1df82701361ab1b2c9cc355e9d2f Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 30 Jan 2024 16:29:24 +0100 Subject: [PATCH 07/16] Use the random generator instead of np.random --- datasets/flwr_datasets/partitioner/shard_partitioner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 78d835ac7512..2c4862766116 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -221,7 +221,7 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 indices_on_which_to_split_shards = np.cumsum( num_shards_per_node_array, dtype=int ) - shard_indices_array = np.random.permutation(self._total_num_shards) + shard_indices_array = self._rng.permutation(self._total_num_shards) # Randomly assign shards to node_id nid_to_shard_indices = np.split( shard_indices_array, indices_on_which_to_split_shards From 8dd777193da980c09703e520e85d80b672550d04 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 31 Jan 2024 15:07:54 +0100 Subject: [PATCH 08/16] Add an explanation what a shard is --- .../flwr_datasets/partitioner/shard_partitioner.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 2c4862766116..358338742688 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -30,6 +30,16 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 label 1, samples with labels 2 ...], then the shards are created, with each shard of size = `shard_size` if provided or automatically calculated: shards_size = len(dataset) / `num_partitions` * `num_shards_per_node`. + + A shard is just a block (part) of a `dataset` that contains `shard_size` consecutive + samples. There might be shards that contain samples associated with more than + a single unique label. The first case is (remember we have a sorted dataset which + is always the prepocessing step) we are at the border between the samples of two + classes the shard contains samples of two different classes e.g. the "leftover" of + samples of class 1 and the majority of class 2. The another scenario when a shard + has samples with more than one unique label is when the shard size is bigger than + the number of samples of a certain class. + Each partition is created from `num_shards_per_node` that are chosen randomly. There are a few ways of partitioning data that result in certain properties From 000af21230ed1d5ce1b57ca2e037ec538ae5bdd6 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 1 Feb 2024 14:21:38 +0100 Subject: [PATCH 09/16] Fix the shards assignment for when shard_size is given The fix works when the shard_size is given along the num_shards_per_partition --- .../partitioner/shard_partitioner.py | 46 ++++++++++++++----- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 358338742688..00ba21f8bec0 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -57,7 +57,7 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 Algorithm based on the description in Communication-Efficient Learning of Deep Networks from Decentralized Data https://arxiv.org/abs/1602.05629. This implementation expands on the initial idea by enabling more hyperparameters - specification therefore providing more control on how partitions are created. + specification therefore providing more control on how partitions are created. It enables the division obtained in original paper. Parameters @@ -71,9 +71,9 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 `num_partitions`. shard_size : Optional[int] Size of a single shards (a partition has one or more shards). If the size is not - given it will be automatically computed such that. + given it will be automatically computed. keep_incomplete_shard : bool - Weather to drop the last shard which might be incomplete (smaller than the + Whether to drop the last shard which might be incomplete (smaller than the others). If it is dropped each shard is equal size. (It does not mean that each client gets equal number of shards, which only happens if `num_partitions` % `num_shards` = 0). This parameter has no effect if @@ -142,7 +142,7 @@ def __init__( # pylint: disable=R0913 self._check_num_partitions_greater_than_zero() self._partition_by = partition_by self._num_shards_per_node = num_shards_per_node - self._total_num_shards: Optional[int] = None + self._num_shards_used: Optional[int] = None self._shard_size = shard_size self._keep_incomplete_shard = keep_incomplete_shard self._shuffle = shuffle @@ -182,16 +182,33 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 consecutive samples (if self._keep_incomplete_shard is False, each shard is same size). """ + # No need to do anything if that node_id_to_indices are already determined if self._node_id_to_indices_determined: return + + # One of the specification allows to skip the `num_shards_per_node` param if self._num_shards_per_node is not None: - self._total_num_shards = int( + self._num_shards_used = int( self._num_partitions * self._num_shards_per_node ) num_shards_per_node_array = ( np.ones(self._num_partitions) * self._num_shards_per_node ) - self._compute_shard_size_if_missing() + if self._shard_size is None: + self._compute_shard_size_if_missing() + assert self._shard_size is not None + if self._keep_incomplete_shard: + num_usable_shards_in_dataset = int( + math.ceil(len(self.dataset) / self._shard_size) + ) + else: + num_usable_shards_in_dataset = int( + math.floor(len(self.dataset) / self._shard_size) + ) + else: + num_usable_shards_in_dataset = int( + math.floor(len(self.dataset) / self._shard_size) + ) elif self._num_shards_per_node is None: if self._shard_size is None: raise ValueError( @@ -199,25 +216,27 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 "num_shards_per_node is None" ) if self._keep_incomplete_shard is False: - self._total_num_shards = int( + self._num_shards_used = int( math.floor(len(self.dataset) / self._shard_size) ) + num_usable_shards_in_dataset = self._num_shards_used elif self._keep_incomplete_shard is True: - self._total_num_shards = int( + self._num_shards_used = int( math.ceil(len(self.dataset) / self._shard_size) ) + num_usable_shards_in_dataset = self._num_shards_used else: raise ValueError( "The keep_incomplete_shards need to be specified " "when _num_shards_per_node is None." ) - num_shards_per_node = int(self._total_num_shards / self._num_partitions) + num_shards_per_node = int(self._num_shards_used / self._num_partitions) # Assign the shards per nodes (so far, the same as in ideal case) num_shards_per_node_array = ( np.ones(self._num_partitions) * num_shards_per_node ) num_shards_assigned = self._num_partitions * num_shards_per_node - num_shards_to_assign = self._total_num_shards - num_shards_assigned + num_shards_to_assign = self._num_shards_used - num_shards_assigned # Assign the "missing" shards for i in range(num_shards_to_assign): num_shards_per_node_array[i] += 1 @@ -231,7 +250,10 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 indices_on_which_to_split_shards = np.cumsum( num_shards_per_node_array, dtype=int ) - shard_indices_array = self._rng.permutation(self._total_num_shards) + + shard_indices_array = self._rng.permutation(num_usable_shards_in_dataset)[ + : self._num_shards_used + ] # Randomly assign shards to node_id nid_to_shard_indices = np.split( shard_indices_array, indices_on_which_to_split_shards @@ -284,7 +306,7 @@ def _compute_shard_size_if_missing(self) -> None: if self._shard_size is None: # If shard size is not specified it needs to be computed num_rows = self.dataset.num_rows - self._shard_size = int(num_rows / self._total_num_shards) + self._shard_size = int(num_rows / self._num_shards_used) def _check_possibility_of_partitions_creation(self) -> None: if self._shard_size is not None and self._num_shards_per_node is not None: From 81d2048aa73b8dc6dc1259159f6d101e87d750fc Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 20 Feb 2024 12:06:40 +0100 Subject: [PATCH 10/16] Apply suggestions from code review Co-authored-by: Javier --- .../flwr_datasets/partitioner/shard_partitioner.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 00ba21f8bec0..91d07910917f 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================== """Shard partitioner class.""" + + # pylint: disable=R0912 import math from typing import Dict, List, Optional @@ -31,11 +33,11 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 shard of size = `shard_size` if provided or automatically calculated: shards_size = len(dataset) / `num_partitions` * `num_shards_per_node`. - A shard is just a block (part) of a `dataset` that contains `shard_size` consecutive - samples. There might be shards that contain samples associated with more than - a single unique label. The first case is (remember we have a sorted dataset which - is always the prepocessing step) we are at the border between the samples of two - classes the shard contains samples of two different classes e.g. the "leftover" of + A shard is just a block (chunk) of a `dataset` that contains `shard_size` consecutive + samples. There might be shards that contain samples associated with more than a + single unique label. The first case is (remember the preprocessing step sorts the + dataset by label) when a shard is constructed from samples at the boundaries of the + sorted dataset and therefore belonging to different classes e.g. the "leftover" of samples of class 1 and the majority of class 2. The another scenario when a shard has samples with more than one unique label is when the shard size is bigger than the number of samples of a certain class. From d2b1877e59ad295ba0d66b08bf2add43e7ae64f0 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 20 Feb 2024 12:42:50 +0100 Subject: [PATCH 11/16] Improve type checks for natural numbers --- .../partitioner/shard_partitioner.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 00ba21f8bec0..bc73f8a5420b 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -138,11 +138,13 @@ def __init__( # pylint: disable=R0913 ) -> None: super().__init__() # Attributes based on the constructor + _check_if_natual_number(num_partitions, "num_partitions") self._num_partitions = num_partitions - self._check_num_partitions_greater_than_zero() self._partition_by = partition_by + _check_if_natual_number(num_shards_per_node, "num_shards_per_node", True) self._num_shards_per_node = num_shards_per_node self._num_shards_used: Optional[int] = None + _check_if_natual_number(shard_size, "shard_size", True) self._shard_size = shard_size self._keep_incomplete_shard = keep_incomplete_shard self._shuffle = shuffle @@ -283,11 +285,6 @@ def _check_num_partitions_correctness_if_needed(self) -> None: "samples in the dataset." ) - def _check_num_partitions_greater_than_zero(self) -> None: - """Test num_partition left sides correctness.""" - if not self._num_partitions > 0: - raise ValueError("The number of partitions needs to be greater than zero.") - def _sort_dataset_if_needed(self) -> None: """Sort dataset prior to determining the partitions. @@ -320,3 +317,21 @@ def _check_possibility_of_partitions_creation(self) -> None: f"size is {implied_min_dataset_size} but the dataset" f"size is {len(self.dataset)}" ) + + +def _check_if_natual_number( + number: int, parameter_name: str, none_acceptable: bool = False +) -> None: + if none_acceptable and number is None: + return + if not isinstance(number, int): + raise TypeError( + f"The expected type of {parameter_name} is int but given: {number} of type " + f"{type(number)}. Please specify the correct type." + ) + if not number >= 1: + raise ValueError( + f"The expected value of {parameter_name} is >= 1 (greater or equal to 1) " + f"but given: {number} which does not meet this condition. Please " + f"provide a correct number." + ) From 63926df30124a285adc8f63f9cc81f7424053065 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 20 Feb 2024 12:42:59 +0100 Subject: [PATCH 12/16] Fix formatting --- datasets/flwr_datasets/partitioner/shard_partitioner_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py index 1718b1409dc9..8aede604e188 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test DirichletPartitioner.""" +"""Test ShardPartitioner.""" + + # pylint: disable=W0212, R0913 import unittest from typing import Optional, Tuple From ef08f25b74def820b5256b5d8948a4d13ec08d09 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 20 Feb 2024 12:46:42 +0100 Subject: [PATCH 13/16] Fix formatting --- .../flwr_datasets/partitioner/shard_partitioner.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index ca05e92ef77d..28e89853cc56 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -33,12 +33,12 @@ class ShardPartitioner(Partitioner): # pylint: disable=R0902 shard of size = `shard_size` if provided or automatically calculated: shards_size = len(dataset) / `num_partitions` * `num_shards_per_node`. - A shard is just a block (chunk) of a `dataset` that contains `shard_size` consecutive - samples. There might be shards that contain samples associated with more than a - single unique label. The first case is (remember the preprocessing step sorts the - dataset by label) when a shard is constructed from samples at the boundaries of the - sorted dataset and therefore belonging to different classes e.g. the "leftover" of - samples of class 1 and the majority of class 2. The another scenario when a shard + A shard is just a block (chunk) of a `dataset` that contains `shard_size` + consecutive samples. There might be shards that contain samples associated with more + than a single unique label. The first case is (remember the preprocessing step sorts + the dataset by label) when a shard is constructed from samples at the boundaries of + the sorted dataset and therefore belonging to different classes e.g. the "leftover" + of samples of class 1 and the majority of class 2. The another scenario when a shard has samples with more than one unique label is when the shard size is bigger than the number of samples of a certain class. From 1897cbc56806cf9f7a640057e9dc8d111338c1c7 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 20 Feb 2024 13:00:55 +0100 Subject: [PATCH 14/16] Fix formatting --- datasets/flwr_datasets/partitioner/shard_partitioner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 28e89853cc56..d6f6b6261833 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -322,7 +322,7 @@ def _check_possibility_of_partitions_creation(self) -> None: def _check_if_natual_number( - number: int, parameter_name: str, none_acceptable: bool = False + number: Optional[int], parameter_name: str, none_acceptable: bool = False ) -> None: if none_acceptable and number is None: return From 5789d79ac2a7fdf55b4c592d0a693afa29fc941d Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 26 Feb 2024 09:59:09 +0100 Subject: [PATCH 15/16] Add checks against empty partitions --- .../partitioner/shard_partitioner.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index d6f6b6261833..7c86570fe487 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -229,6 +229,14 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 math.ceil(len(self.dataset) / self._shard_size) ) num_usable_shards_in_dataset = self._num_shards_used + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "Based on the given arguments the creation of the partitions " + "is impossible. The implied number of partitions that can be " + "used is lower than the number of requested partitions " + "resulting in empty partitions. Please decrease the size of " + "shards: `shard_size`." + ) else: raise ValueError( "The keep_incomplete_shards need to be specified " @@ -251,6 +259,13 @@ def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0 "keep_incomplete_shards is not correct." ) + if num_usable_shards_in_dataset < self._num_partitions: + raise ValueError( + "The specified configuration results in empty partitions because the " + "number of usable shards is smaller that the number partitions. " + "Try decreasing the shard size or the number of partitions. " + ) + indices_on_which_to_split_shards = np.cumsum( num_shards_per_node_array, dtype=int ) From a2f9a1c5e0dbd0e681ea810e7eac65b0967eba5c Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Mon, 26 Feb 2024 09:59:19 +0100 Subject: [PATCH 16/16] Test for empty partitions --- .../partitioner/shard_partitioner_test.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py index 8aede604e188..47968699bba7 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner_test.py @@ -31,7 +31,7 @@ def _dummy_setup( shard_size: Optional[int], keep_incomplete_shard: bool = False, ) -> Tuple[Dataset, ShardPartitioner]: - """Create a dummy dataset for testing..""" + """Create a dummy dataset for testing.""" data = { partition_by: [i % 3 for i in range(num_rows)], "features": list(range(num_rows)), @@ -368,6 +368,25 @@ def test_incorrect_specification(self) -> None: with self.assertRaises(ValueError): _ = partitioner.load_partition(0) + def test_too_big_shard_size(self) -> None: + """Test if it is impossible to create an empty partition.""" + partition_by = "label" + num_rows = 20 + num_partitions = 3 + num_shards_per_node = None + shard_size = 10 + keep_incomplete_shard = False + _, partitioner = _dummy_setup( + num_rows, + partition_by, + num_partitions, + num_shards_per_node, + shard_size, + keep_incomplete_shard, + ) + with self.assertRaises(ValueError): + _ = partitioner.load_partition(2).num_rows + if __name__ == "__main__": unittest.main()