diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index 48d993037708..0b9a6685427b 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -16,6 +16,7 @@ from flwr_datasets import partitioner, resplitter +from flwr_datasets import utils as utils from flwr_datasets.common.version import package_version as _package_version from flwr_datasets.federated_dataset import FederatedDataset @@ -23,6 +24,7 @@ "FederatedDataset", "partitioner", "resplitter", + "utils", ] __version__ = _package_version diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index c40f8cc34857..588d1ab40aec 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, cast import datasets from datasets import Dataset, DatasetDict @@ -25,9 +25,12 @@ _check_if_dataset_tested, _instantiate_partitioners, _instantiate_resplitter_if_needed, + divide_dataset, ) +# flake8: noqa: E501 +# pylint: disable=line-too-long class FederatedDataset: """Representation of a dataset for federated learning/evaluation/analytics. @@ -51,6 +54,19 @@ class FederatedDataset: (representing the number of IID partitions that this split should be partitioned into). One or multiple `Partitioner` objects can be specified in that manner, but at most, one per split. + partition_division : Optional[Union[List[float], Tuple[float, ...], + Dict[str, float], Dict[str, Optional[Union[List[float], Tuple[float, ...], + Dict[str, float]]]]]] + Fractions specifing the division of the partition assiciated with certain split + (and partitioner) that enable returning already divided partition from the + `load_partition` method. You can think of this as on-edge division of the data + into multiple divisions (e.g. into train and validation). You can also name the + divisions by using the Dict or create specify it as a List/Tuple. If you + specified a single partitioner you can provide the simplified form e.g. + [0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2} but when multiple + partitioners are specified you need to indicate the result of which partitioner + are further divided e.g. {"train": [0.8, 0.2]} would result in dividing only the + partitions that are created from the "train" split. shuffle : bool Whether to randomize the order of samples. Applied prior to resplitting, speratelly to each of the present splits in the dataset. It uses the `seed` @@ -64,14 +80,18 @@ class FederatedDataset: Use MNIST dataset for Federated Learning with 100 clients (edge devices): >>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) - - Load partition for client with ID 10. - + >>> # Load partition for client with ID 10. >>> partition = mnist_fds.load_partition(10, "train") - - Use test split for centralized evaluation. - + >>> # Use test split for centralized evaluation. >>> centralized = mnist_fds.load_full("test") + + Automatically divde the data returned from `load_partition` + >>> mnist_fds = FederatedDataset( + >>> dataset="mnist", + >>> partitioners={"train": 100}, + >>> partition_division=[0.8, 0.2], + >>> ) + >>> partition_train, partition_test = mnist_fds.load_partition(10, "train") """ # pylint: disable=too-many-instance-attributes @@ -82,6 +102,17 @@ def __init__( subset: Optional[str] = None, resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ] = None, shuffle: bool = True, seed: Optional[int] = 42, ) -> None: @@ -94,6 +125,9 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) + self._partition_division = self._initialize_partition_division( + partition_division + ) self._shuffle = shuffle self._seed = seed # _dataset is prepared lazily on the first call to `load_partition` @@ -102,7 +136,11 @@ def __init__( # Indicate if the dataset is prepared for `load_partition` or `load_full` self._dataset_prepared: bool = False - def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: + def load_partition( + self, + node_id: int, + split: Optional[str] = None, + ) -> Union[Dataset, List[Dataset], DatasetDict]: """Load the partition specified by the idx in the selected split. The dataset is downloaded only when the first call to `load_partition` or @@ -122,8 +160,13 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: Returns ------- - partition : Dataset - Single partition from the dataset split. + partition : Union[Dataset, List[Dataset], DatasetDict] + Undivided or divided partition from the dataset split. + If `partition_division` is not specified then `Dataset` is returned. + If `partition_division` is specified as `List` or `Tuple` then + `List[Dataset]` is returned. + If `partition_division` is specified as `Dict` then `DatasetDict` is + returned. """ if not self._dataset_prepared: self._prepare_dataset() @@ -136,7 +179,16 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: self._check_if_split_possible_to_federate(split) partitioner: Partitioner = self._partitioners[split] self._assign_dataset_to_partitioner(split) - return partitioner.load_partition(node_id) + partition = partitioner.load_partition(node_id) + if self._partition_division is None: + return partition + partition_division = self._partition_division.get(split) + if partition_division is None: + return partition + divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset( + partition, partition_division + ) + return divided_partition def load_full(self, split: str) -> Dataset: """Load the full split of the dataset. @@ -230,3 +282,62 @@ def _check_if_no_split_keyword_possible(self) -> None: "Please set the `split` argument. You can only omit the split keyword " "if there is exactly one partitioner specified." ) + + def _initialize_partition_division( + self, + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ], + ) -> Optional[ + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ] + ]: + """Create the partition division in the full format. + + Reduced format (possible if only one partitioner exist): + + Union[List[float], Tuple[float, ...], Dict[str, float] + + Full format: Dict[str, Reduced format] + Full format represents the split to division mapping. + """ + # Check for simple dict, list, or tuple types directly + if isinstance(partition_division, (list, tuple)) or ( + isinstance(partition_division, dict) + and all(isinstance(value, float) for value in partition_division.values()) + ): + if len(self._partitioners) > 1: + raise ValueError( + f"The specified partition_division {partition_division} does not " + f"provide mapping to split but more than one partitioners is " + f"specified. Please adjust the partition_division specification to " + f"have the split names as the keys." + ) + return cast( + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + {list(self._partitioners.keys())[0]: partition_division}, + ) + if isinstance(partition_division, dict): + return cast( + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + partition_division, + ) + if partition_division is None: + return None + raise TypeError("Unsupported type for partition_division") diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index e02b6ed5add8..e01f56342954 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -17,7 +17,7 @@ import unittest -from typing import Dict, Union +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import Mock, patch import pytest @@ -67,6 +67,48 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: len(dataset_partition0), len(dataset["train"]) // train_num_partitions ) + @parameterized.expand( # type: ignore + [ + ((0.2, 0.8), 2, False), + ({"train": 0.2, "test": 0.8}, 2, False), + ({"train": {"train": 0.2, "test": 0.8}}, 2, True), + # Not full dataset + ([0.2, 0.1], 2, False), + ({"train": 0.2, "test": 0.1}, 2, False), + (None, None, False), + ], + ) + def test_divide_partition_integration_size( + self, + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ], + expected_length: Optional[int], + add_test_partitioner: bool, + ): + """Test is the `partition_division` create correct data.""" + partitioners: Dict[str, Union[Partitioner, int]] = {"train": 10} + if add_test_partitioner: + partitioners[self.test_split] = 10 + dataset_fds = FederatedDataset( + dataset=self.dataset_name, + partitioners=partitioners, + partition_division=partition_division, + ) + partition = dataset_fds.load_partition(0, "train") + if partition_division is None: + self.assertEqual(expected_length, None) + else: + self.assertEqual(len(partition), expected_length) + def test_load_full(self) -> None: """Test if the load_full works with the correct split name.""" dataset_fds = FederatedDataset( diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index e3d0fdfffa63..38382508035c 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,8 +16,9 @@ import warnings -from typing import Dict, Optional, Tuple, Union, cast +from typing import Dict, List, Optional, Tuple, Union, cast +from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import IidPartitioner, Partitioner from flwr_datasets.resplitter import Resplitter from flwr_datasets.resplitter.merge_resplitter import MergeResplitter @@ -85,3 +86,156 @@ def _check_if_dataset_tested(dataset: str) -> None: f"The currently tested dataset are {tested_datasets}. Given: {dataset}.", stacklevel=1, ) + + +def divide_dataset( + dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> Union[List[Dataset], DatasetDict]: + """Divide the dataset according to the `division`. + + The division support varying number of splits, which you can name. The splits are + created from the beginning of the dataset. + + Parameters + ---------- + dataset : Dataset + Dataset to be divided. + division: Union[List[float], Tuple[float, ...], Dict[str, float]] + Configuration specifying how the dataset is divided. Each fraction has to be + >0 and <=1. They have to sum up to at most 1 (smaller sum is possible). + + Returns + ------- + divided_dataset : Union[List[Dataset], DatasetDict] + If `division` is `List` or `Tuple` then `List[Dataset]` is returned else if + `division` is `Dict` then `DatasetDict` is returned. + + Examples + -------- + Use `divide_dataset` with division specified as a list. + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import divide_dataset + >>> + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> partition = fds.load_partition(0) + >>> division = [0.8, 0.2] + >>> train, test = divide_dataset(dataset=partition, division=division) + + Use `divide_dataset` with division specified as a dict. + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import divide_dataset + >>> + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> partition = fds.load_partition(0) + >>> division = {"train": 0.8, "test": 0.2} + >>> train_test = divide_dataset(dataset=partition, division=division) + >>> train, test = train_test["train"], train_test["test"] + """ + dataset_length = len(dataset) + ranges = _create_division_indices_ranges(dataset_length, division) + if isinstance(division, (list, tuple)): + split_partition: List[Dataset] = [] + for single_range in ranges: + split_partition.append(dataset.select(single_range)) + return split_partition + if isinstance(division, dict): + split_partition_dict: Dict[str, Dataset] = {} + for split_name, single_range in zip(division.keys(), ranges): + split_partition_dict[split_name] = dataset.select(single_range) + return DatasetDict(split_partition_dict) + raise TypeError( + f"The type of the `division` should be dict, " + f"tuple or list but is {type(division)} instead." + ) + + +def _create_division_indices_ranges( + dataset_length: int, + division: Union[List[float], Tuple[float, ...], Dict[str, float]], +) -> List[range]: + ranges = [] + if isinstance(division, (list, tuple)): + start_idx = 0 + end_idx = 0 + for fraction in division: + end_idx += int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx += end_idx + elif isinstance(division, dict): + ranges = [] + start_idx = 0 + end_idx = 0 + for fraction in division.values(): + end_idx += int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx += end_idx + else: + TypeError( + f"The type of the `division` should be dict, " + f"tuple or list but is {type(division)} instead. " + ) + return ranges + + +def _check_division_config_types_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: + if isinstance(division, (list, tuple)): + if not all(isinstance(x, float) for x in division): + raise TypeError( + "List or tuple values of `division` must contain only floats, " + "other types are not allowed." + ) + elif isinstance(division, dict): + if not all(isinstance(x, float) for x in division.values()): + raise TypeError( + "Dict values of `division` must be only floats, " + "other types are not allowed." + ) + else: + raise TypeError("`division` must be a list, tuple, or dict.") + + +def _check_division_config_values_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: + if isinstance(division, (list, tuple)): + if not all(0 < x <= 1 for x in division): + raise ValueError( + "All fractions for the division must be greater than 0 and smaller or " + "equal to 1." + ) + fraction_sum_from_list_tuple = sum(division) + if fraction_sum_from_list_tuple > 1: + raise ValueError("Sum of fractions for division must not exceed 1.") + if fraction_sum_from_list_tuple < 1: + warnings.warn( + f"Sum of fractions for division is {sum(division)}, which is below 1. " + f"Make sure that's the desired behavior. Some data will not be used " + f"in the current specification.", + stacklevel=1, + ) + elif isinstance(division, dict): + values = list(division.values()) + if not all(0 < x <= 1 for x in values): + raise ValueError( + "All fractions must be greater than 0 and smaller or equal to 1." + ) + if sum(values) > 1: + raise ValueError("Sum of fractions must not exceed 1.") + if sum(values) < 1: + warnings.warn( + f"Sum of fractions in `division` is {values}, which is below 1. " + f"Make sure that's the desired behavior. Some data will not be used " + f"in the current specification.", + stacklevel=1, + ) + else: + raise TypeError("`division` must be a list, tuple, or dict.") + + +def _check_division_config_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: + _check_division_config_types_correctness(division) + _check_division_config_values_correctness(division) diff --git a/datasets/flwr_datasets/utils_test.py b/datasets/flwr_datasets/utils_test.py new file mode 100644 index 000000000000..26f24519eb76 --- /dev/null +++ b/datasets/flwr_datasets/utils_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils tests.""" +import unittest +from typing import Dict, List, Tuple, Union + +from parameterized import parameterized_class + +from datasets import Dataset, DatasetDict +from flwr_datasets.utils import divide_dataset + + +@parameterized_class( + ( + "divide", + "sizes", + ), + [ + ((0.2, 0.8), [8, 32]), + ([0.2, 0.8], [8, 32]), + ({"train": 0.2, "test": 0.8}, [8, 32]), + # Not full dataset + ([0.2, 0.1], [8, 4]), + ((0.2, 0.1), [8, 4]), + ({"train": 0.2, "test": 0.1}, [8, 4]), + ], +) +class UtilsTests(unittest.TestCase): + """Utils tests.""" + + divide: Union[List[float], Tuple[float, ...], Dict[str, float]] + sizes: Tuple[int] + + def setUp(self) -> None: + """Set up a dataset.""" + self.dataset = Dataset.from_dict({"data": range(40)}) + + def test_correct_sizes(self) -> None: + """Test correct size of the division.""" + divided_dataset = divide_dataset(self.dataset, self.divide) + if isinstance(divided_dataset, (list, tuple)): + lengths = [len(split) for split in divided_dataset] + else: + lengths = [len(split) for split in divided_dataset.values()] + + self.assertEqual(lengths, self.sizes) + + def test_correct_return_types(self) -> None: + """Test correct types of the divided dataset based on the config.""" + divided_dataset = divide_dataset(self.dataset, self.divide) + if isinstance(self.divide, (list, tuple)): + self.assertIsInstance(divided_dataset, list) + else: + self.assertIsInstance(divided_dataset, DatasetDict) + + +if __name__ == "__main__": + unittest.main()