From 507065b2ae253a199f986cf3970c04b73533b6e0 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 14 Feb 2024 16:22:33 +0100 Subject: [PATCH 01/18] Add partition_division to FederatedDataset --- datasets/flwr_datasets/federated_dataset.py | 40 ++++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index c40f8cc34857..63b307f24600 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, List import datasets from datasets import Dataset, DatasetDict @@ -25,6 +25,7 @@ _check_if_dataset_tested, _instantiate_partitioners, _instantiate_resplitter_if_needed, + divide_dataset, ) @@ -51,6 +52,10 @@ class FederatedDataset: (representing the number of IID partitions that this split should be partitioned into). One or multiple `Partitioner` objects can be specified in that manner, but at most, one per split. + partition_division : Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]] + Divide each partition into a splits (e.g. into a train and evaluation) and + control the size of the splits - fractions of the data. You can also name the + splits for verification. shuffle : bool Whether to randomize the order of samples. Applied prior to resplitting, speratelly to each of the present splits in the dataset. It uses the `seed` @@ -82,6 +87,7 @@ def __init__( subset: Optional[str] = None, resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], + partition_division: Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]] = None, shuffle: bool = True, seed: Optional[int] = 42, ) -> None: @@ -94,6 +100,7 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) + self._partition_division = partition_division self._shuffle = shuffle self._seed = seed # _dataset is prepared lazily on the first call to `load_partition` @@ -102,9 +109,14 @@ def __init__( # Indicate if the dataset is prepared for `load_partition` or `load_full` self._dataset_prepared: bool = False - def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: + def load_partition(self, node_id: int, split: Optional[str] = None, inner_split: Optional[Union[int, str]]=None) -> Union[Dataset, List[Dataset], DatasetDict]: """Load the partition specified by the idx in the selected split. + If the partition is divided as specified in `partition_split` then you can + return only the specific split by specifying `inner_split` otherwise all split + are return. If the `partition_split` was specified as `Dict` give the string + name, otherwise give the index. + The dataset is downloaded only when the first call to `load_partition` or `load_full` is made. @@ -119,11 +131,21 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: not need to provide this argument, but if `partitioners={"train": 10, "test": 100}`, you need to set it to differentiate which partitioner should be used. + inner_split: Optional[Union[int, str]] + In case of the `partition_split` specification you can identify the split + after the division that you want to return, otherwise all are returned. If + `partition_split` is list or tuple specify int, if it is dict specify the + str. Returns ------- - partition : Dataset - Single partition from the dataset split. + partition : Union[Dataset, List[Dataset], DatasetDict] + Single partition from the dataset split which can be further divided. + If `partition_split` specified and `inner_split` not given then + `List[Dataset]` is returned in case of `List` or `Tuple` specification of + the `partition_split` and `DatasetDict` in case of `Dict` specification. + If `partition_split` specified and the `inner_split` given, then Dataset is + returned. """ if not self._dataset_prepared: self._prepare_dataset() @@ -136,7 +158,15 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: self._check_if_split_possible_to_federate(split) partitioner: Partitioner = self._partitioners[split] self._assign_dataset_to_partitioner(split) - return partitioner.load_partition(node_id) + partition = partitioner.load_partition(node_id) + if self._partition_division is None: + return partition + else: + divided_partition = divide_dataset(partition, self._partition_division) + if inner_split is None: + return divided_partition + else: + return divided_partition[inner_split] def load_full(self, split: str) -> Dataset: """Load the full split of the dataset. From 670783911166eb94908ffb4c439210f6b9f0a125 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 14 Feb 2024 16:23:02 +0100 Subject: [PATCH 02/18] wip: partition division utils --- datasets/flwr_datasets/utils.py | 86 ++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index e3d0fdfffa63..7431a9a5d0b2 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,7 +16,9 @@ import warnings -from typing import Dict, Optional, Tuple, Union, cast +from typing import Dict, Optional, Tuple, Union, cast, List + +from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import IidPartitioner, Partitioner from flwr_datasets.resplitter import Resplitter @@ -85,3 +87,85 @@ def _check_if_dataset_tested(dataset: str) -> None: f"The currently tested dataset are {tested_datasets}. Given: {dataset}.", stacklevel=1, ) + + +def divide_dataset(dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> Union[Dataset, List[Dataset], DatasetDict]: + + dataset_length = len(dataset) + ranges = _create_division_indices_ranges(dataset_length, division) + if isinstance(division, (list, tuple)): + split_partition = [] + for r in ranges: + split_partition.append(dataset.select(r)) + return split_partition + elif isinstance(division, dict): + split_partition = {} + ranges = _create_division_indices_ranges(dataset_length, division) + for split_name, r in zip(division.keys(), ranges): + split_partition[split_name] = dataset.select(r) + return DatasetDict(split_partition) + else: + TypeError( + f"The type of the `division` should be dict, tuple or list but is {type(division)} instead.") + + +def _create_division_indices_ranges(dataset_length: int, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> List[range]: + ranges = [] + if isinstance(division, (list, tuple)): + start_idx = 0 + end_idx = 0 + for fraction in division: + end_idx = int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx = end_idx + elif isinstance(division, dict): + ranges = [] + start_idx = 0 + end_idx = 0 + for fraction in division.values(): + end_idx = int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx = end_idx + else: + TypeError("The type of the `partition_split` should be dict, tuple or list but is {type(self.partition_split)} instead. ") + return ranges + + +def _check_division_config_types_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: + if isinstance(division, (list, tuple)): + if not all(isinstance(x, float) for x in division): + raise TypeError( + "List or tuple values of `partition_split` must contain only floats, other types are not allowed.") + elif isinstance(division, dict): + if not all(isinstance(x, float) for x in division.values()): + raise TypeError( + "Dict values of `partition_split` must be only floats, other types are not allowed.") + else: + raise TypeError("`partition_split` must be a list, tuple, or dict.") + +def _check_division_config_values_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: + if isinstance(division, (list, tuple)): + if not all(0 < x <= 1 for x in division): + raise ValueError( + "All fractions for the division must be greater than 0 and smaller or equal to 1.") + fraction_sum_from_list_tuple = sum(division) + if fraction_sum_from_list_tuple > 1: + raise ValueError("Sum of fractions for division must not exceed 1.") + if fraction_sum_from_list_tuple < 1: + warnings.warn(f"Sum of fractions for division is {sum(division)}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.") + elif isinstance(division, dict): + values = list(division.values()) + if not all(0 < x <= 1 for x in values): + raise ValueError( + "All fractions must be greater than 0 and smaller or equal to 1.") + if sum(values) > 1: + raise ValueError("Sum of fractions must not exceed 1.") + if sum(division) < 1: + warnings.warn( + f"Sum of fractions in `partition_split` is {values}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.") + else: + raise TypeError("`partition_split` must be a list, tuple, or dict.") + +def _check_division_config_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: + _check_division_config_types_correctness(division) + _check_division_config_values_correctness(division) From 9f4b3106ae7cd6c99f3d5012744654d5b6bf094c Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 15 Feb 2024 14:00:56 +0100 Subject: [PATCH 03/18] Update load_partition --- datasets/flwr_datasets/federated_dataset.py | 48 ++++++++++---- datasets/flwr_datasets/utils_test.py | 70 +++++++++++++++++++++ 2 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 datasets/flwr_datasets/utils_test.py diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 63b307f24600..11281f95b3ec 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -13,9 +13,8 @@ # limitations under the License. # ============================================================================== """FederatedDataset.""" - - -from typing import Dict, Optional, Tuple, Union, List +import warnings +from typing import Dict, List, Optional, Tuple, Union import datasets from datasets import Dataset, DatasetDict @@ -29,6 +28,8 @@ ) +# flake8: noqa: E501 +# pylint: disable=line-too-long class FederatedDataset: """Representation of a dataset for federated learning/evaluation/analytics. @@ -52,8 +53,9 @@ class FederatedDataset: (representing the number of IID partitions that this split should be partitioned into). One or multiple `Partitioner` objects can be specified in that manner, but at most, one per split. - partition_division : Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]] - Divide each partition into a splits (e.g. into a train and evaluation) and + partition_division : Optional[Union[List[float], Tuple[float, ...], Dict[str, + float]]] + Divide each partition into splits (e.g. into train and evaluation) and control the size of the splits - fractions of the data. You can also name the splits for verification. shuffle : bool @@ -87,7 +89,9 @@ def __init__( subset: Optional[str] = None, resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], - partition_division: Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]] = None, + partition_division: Optional[ + Union[List[float], Tuple[float, ...], Dict[str, float]] + ] = None, shuffle: bool = True, seed: Optional[int] = 42, ) -> None: @@ -109,7 +113,12 @@ def __init__( # Indicate if the dataset is prepared for `load_partition` or `load_full` self._dataset_prepared: bool = False - def load_partition(self, node_id: int, split: Optional[str] = None, inner_split: Optional[Union[int, str]]=None) -> Union[Dataset, List[Dataset], DatasetDict]: + def load_partition( + self, + node_id: int, + split: Optional[str] = None, + inner_split: Optional[Union[int, str]] = None, + ) -> Union[Dataset, List[Dataset], DatasetDict]: """Load the partition specified by the idx in the selected split. If the partition is divided as specified in `partition_split` then you can @@ -160,13 +169,28 @@ def load_partition(self, node_id: int, split: Optional[str] = None, inner_split: self._assign_dataset_to_partitioner(split) partition = partitioner.load_partition(node_id) if self._partition_division is None: + if inner_split is not None: + warnings.warn( + "`inner_split` was specified but it does not have any " + "effect when the `partition_division` is None." + ) return partition - else: - divided_partition = divide_dataset(partition, self._partition_division) - if inner_split is None: - return divided_partition - else: + divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset( + partition, self._partition_division + ) + if inner_split is None: + return divided_partition + if isinstance(divided_partition, list): + if isinstance(inner_split, int): return divided_partition[inner_split] + raise ValueError( + "The type of `inner_split` = {type(inner_split)}, does not " + "work with the specified type `divide_partition` that " + "results in list." + ) + if isinstance(divided_partition, DatasetDict): + return divided_partition[inner_split] + raise ValueError("The types of divided_partition should be list or dict only.") def load_full(self, split: str) -> Dataset: """Load the full split of the dataset. diff --git a/datasets/flwr_datasets/utils_test.py b/datasets/flwr_datasets/utils_test.py new file mode 100644 index 000000000000..debb98976e37 --- /dev/null +++ b/datasets/flwr_datasets/utils_test.py @@ -0,0 +1,70 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils tests.""" +import unittest +from typing import Dict, List, Tuple, Union + +from parameterized import parameterized_class + +from datasets import Dataset, DatasetDict +from flwr_datasets.utils import divide_dataset + + +@parameterized_class( + ( + "divide", + "sizes", + ), + [ + ((0.2, 0.8), [8, 32]), + ([0.2, 0.8], [8, 32]), + ({"train": 0.2, "test": 0.8}, [8, 32]), + # Not full dataset + ([0.2, 0.1], [8, 4]), + ((0.2, 0.1), [8, 4]), + ({"train": 0.2, "test": 0.1}, [8, 4]), + ], +) +class UtilsTests(unittest.TestCase): + """Utils tests.""" + + divide: Union[List[float], Tuple[float, ...], Dict[str, float]] + sizes: Tuple[int] + + def setUp(self) -> None: + """Set up a dataset.""" + self.dataset = Dataset.from_dict({"data": range(40)}) + + def test_correct_sizes(self) -> None: + """Test correct size of the division.""" + divided_dataset = divide_dataset(self.dataset, self.divide) + if isinstance(divided_dataset, (list, tuple)): + lengths = [len(split) for split in divided_dataset] + else: + lengths = [len(split) for split in divided_dataset.values()] + + self.assertEqual(lengths, self.sizes) + + def test_correct_return_types(self) -> None: + """Test correct types of the divided dataset based on the config.""" + divided_dataset = divide_dataset(self.dataset, self.divide) + if isinstance(self.divide, (list, tuple)): + self.assertIsInstance(divided_dataset, list) + else: + self.assertIsInstance(divided_dataset, DatasetDict) + + +if __name__ == "__main__": + unittest.main() From 876dad5ba697cdcb979870b7cab5c34a14590177 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 15 Feb 2024 14:01:16 +0100 Subject: [PATCH 04/18] Update divsion utils --- datasets/flwr_datasets/utils.py | 119 ++++++++++++++++++++++---------- 1 file changed, 84 insertions(+), 35 deletions(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 7431a9a5d0b2..d09cf938d635 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,10 +16,9 @@ import warnings -from typing import Dict, Optional, Tuple, Union, cast, List +from typing import Dict, List, Optional, Tuple, Union, cast from datasets import Dataset, DatasetDict - from flwr_datasets.partitioner import IidPartitioner, Partitioner from flwr_datasets.resplitter import Resplitter from flwr_datasets.resplitter.merge_resplitter import MergeResplitter @@ -89,83 +88,133 @@ def _check_if_dataset_tested(dataset: str) -> None: ) -def divide_dataset(dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> Union[Dataset, List[Dataset], DatasetDict]: +def divide_dataset( + dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> Union[List[Dataset], DatasetDict]: + """Divide the dataset according to the `division`. + + The division support varying number of splits, which you can name. The splits are + created from the beginning of the dataset. + + Parameters + ---------- + dataset: Dataset + Dataset to be divided. + division: Union[List[float], Tuple[float, ...], Dict[str, float]] + Configuration specifying how the dataset is divided. Each fraction has to be + >0 and <1. They have to sum up to at most 1 (smaller sum is possible). + Returns + ------- + `List[Dataset]` is returned in case of `List` or `Tuple` of the division + `DatasetDict` in case of `Dict` specification + """ dataset_length = len(dataset) ranges = _create_division_indices_ranges(dataset_length, division) if isinstance(division, (list, tuple)): - split_partition = [] - for r in ranges: - split_partition.append(dataset.select(r)) - return split_partition - elif isinstance(division, dict): - split_partition = {} + split_partition: List[Dataset] = [] + for single_range in ranges: + split_partition.append(dataset.select(single_range)) + return split_partition + if isinstance(division, dict): + split_partition_dict: Dict[str, Dataset] = {} ranges = _create_division_indices_ranges(dataset_length, division) - for split_name, r in zip(division.keys(), ranges): - split_partition[split_name] = dataset.select(r) - return DatasetDict(split_partition) - else: - TypeError( - f"The type of the `division` should be dict, tuple or list but is {type(division)} instead.") - - -def _create_division_indices_ranges(dataset_length: int, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> List[range]: + for split_name, single_range in zip(division.keys(), ranges): + split_partition_dict[split_name] = dataset.select(single_range) + return DatasetDict(split_partition_dict) + raise TypeError( + f"The type of the `division` should be dict, " + f"tuple or list but is {type(division)} instead." + ) + + +def _create_division_indices_ranges( + dataset_length: int, + division: Union[List[float], Tuple[float, ...], Dict[str, float]], +) -> List[range]: ranges = [] if isinstance(division, (list, tuple)): start_idx = 0 end_idx = 0 for fraction in division: - end_idx = int(dataset_length * fraction) + end_idx += int(dataset_length * fraction) ranges.append(range(start_idx, end_idx)) - start_idx = end_idx + start_idx += end_idx elif isinstance(division, dict): ranges = [] start_idx = 0 end_idx = 0 for fraction in division.values(): - end_idx = int(dataset_length * fraction) + end_idx += int(dataset_length * fraction) ranges.append(range(start_idx, end_idx)) - start_idx = end_idx + start_idx += end_idx else: - TypeError("The type of the `partition_split` should be dict, tuple or list but is {type(self.partition_split)} instead. ") + TypeError( + f"The type of the `division` should be dict, " + f"tuple or list but is {type(division)} instead. " + ) return ranges -def _check_division_config_types_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: +def _check_division_config_types_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: if isinstance(division, (list, tuple)): if not all(isinstance(x, float) for x in division): raise TypeError( - "List or tuple values of `partition_split` must contain only floats, other types are not allowed.") + "List or tuple values of `division` must contain only floats, " + "other types are not allowed." + ) elif isinstance(division, dict): if not all(isinstance(x, float) for x in division.values()): raise TypeError( - "Dict values of `partition_split` must be only floats, other types are not allowed.") + "Dict values of `division` must be only floats, " + "other types are not allowed." + ) else: - raise TypeError("`partition_split` must be a list, tuple, or dict.") + raise TypeError("`division` must be a list, tuple, or dict.") + -def _check_division_config_values_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: +def _check_division_config_values_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: if isinstance(division, (list, tuple)): if not all(0 < x <= 1 for x in division): raise ValueError( - "All fractions for the division must be greater than 0 and smaller or equal to 1.") + "All fractions for the division must be greater than 0 and smaller or " + "equal to 1." + ) fraction_sum_from_list_tuple = sum(division) if fraction_sum_from_list_tuple > 1: raise ValueError("Sum of fractions for division must not exceed 1.") if fraction_sum_from_list_tuple < 1: - warnings.warn(f"Sum of fractions for division is {sum(division)}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.") + warnings.warn( + f"Sum of fractions for division is {sum(division)}, which is below 1. " + f"Make sure that's the desired behavior. Some data will not be used " + f"in the current specification.", + stacklevel=1, + ) elif isinstance(division, dict): values = list(division.values()) if not all(0 < x <= 1 for x in values): raise ValueError( - "All fractions must be greater than 0 and smaller or equal to 1.") + "All fractions must be greater than 0 and smaller or equal to 1." + ) if sum(values) > 1: raise ValueError("Sum of fractions must not exceed 1.") - if sum(division) < 1: + if sum(values) < 1: warnings.warn( - f"Sum of fractions in `partition_split` is {values}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.") + f"Sum of fractions in `division` is {values}, which is below 1. " + f"Make sure that's the desired behavior. Some data will not be used " + f"in the current specification.", + stacklevel=1, + ) else: - raise TypeError("`partition_split` must be a list, tuple, or dict.") + raise TypeError("`division` must be a list, tuple, or dict.") + -def _check_division_config_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: +def _check_division_config_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: _check_division_config_types_correctness(division) _check_division_config_values_correctness(division) From 1e77c06cdd80fe2a218c7e6f1398f34f7f7d1844 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Thu, 15 Feb 2024 14:01:24 +0100 Subject: [PATCH 05/18] Update tests --- .../flwr_datasets/federated_dataset_test.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index e02b6ed5add8..f768f6eaa923 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -17,7 +17,7 @@ import unittest -from typing import Dict, Union +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import Mock, patch import pytest @@ -67,6 +67,44 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: len(dataset_partition0), len(dataset["train"]) // train_num_partitions ) + @parameterized.expand( # type: ignore + [ + ((0.2, 0.8), None, 2), + ((0.2, 0.8), 0, None), + ([0.2, 0.8], None, 2), + ({"train": 0.2, "test": 0.8}, None, 2), + # Not full dataset + ([0.2, 0.1], None, 2), + ((0.2, 0.1), 1, None), + ({"train": 0.2, "test": 0.1}, None, 2), + ({"train": 0.2, "test": 0.1}, "test", None), + (None, None, None), + (None, "train", None), + ], + ) + def test_divide_partition_integration_size( + self, + partition_division: Optional[ + Union[List[float], Tuple[float, ...], Dict[str, float]] + ], + inner_split: Optional[str], + expected_length: Optional[int], + ): + """Test is the `partition_division` and `inner_split` create correct data.""" + dataset_fds = FederatedDataset( + dataset=self.dataset_name, + partitioners={"train": 100}, + partition_division=partition_division, + ) + partition = dataset_fds.load_partition(0, "train", inner_split=inner_split) + if partition_division is None: + self.assertEqual(expected_length, None) + else: + if inner_split is None: + self.assertEqual(len(partition), expected_length) + else: + self.assertIsInstance(partition, Dataset) + def test_load_full(self) -> None: """Test if the load_full works with the correct split name.""" dataset_fds = FederatedDataset( From ecf2b401fb8976ee6fbfb11739fb69ddce1a163b Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 10:20:33 +0100 Subject: [PATCH 06/18] Remove inner_split keyword from load_partition --- datasets/flwr_datasets/federated_dataset.py | 28 ++--------------- .../flwr_datasets/federated_dataset_test.py | 30 ++++++++----------- 2 files changed, 16 insertions(+), 42 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 11281f95b3ec..b0c67915cf87 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================== """FederatedDataset.""" -import warnings + + from typing import Dict, List, Optional, Tuple, Union import datasets @@ -117,7 +118,6 @@ def load_partition( self, node_id: int, split: Optional[str] = None, - inner_split: Optional[Union[int, str]] = None, ) -> Union[Dataset, List[Dataset], DatasetDict]: """Load the partition specified by the idx in the selected split. @@ -140,11 +140,6 @@ def load_partition( not need to provide this argument, but if `partitioners={"train": 10, "test": 100}`, you need to set it to differentiate which partitioner should be used. - inner_split: Optional[Union[int, str]] - In case of the `partition_split` specification you can identify the split - after the division that you want to return, otherwise all are returned. If - `partition_split` is list or tuple specify int, if it is dict specify the - str. Returns ------- @@ -169,28 +164,11 @@ def load_partition( self._assign_dataset_to_partitioner(split) partition = partitioner.load_partition(node_id) if self._partition_division is None: - if inner_split is not None: - warnings.warn( - "`inner_split` was specified but it does not have any " - "effect when the `partition_division` is None." - ) return partition divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset( partition, self._partition_division ) - if inner_split is None: - return divided_partition - if isinstance(divided_partition, list): - if isinstance(inner_split, int): - return divided_partition[inner_split] - raise ValueError( - "The type of `inner_split` = {type(inner_split)}, does not " - "work with the specified type `divide_partition` that " - "results in list." - ) - if isinstance(divided_partition, DatasetDict): - return divided_partition[inner_split] - raise ValueError("The types of divided_partition should be list or dict only.") + return divided_partition def load_full(self, split: str) -> Dataset: """Load the full split of the dataset. diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index f768f6eaa923..0500fb5b821c 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -69,17 +69,17 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: @parameterized.expand( # type: ignore [ - ((0.2, 0.8), None, 2), - ((0.2, 0.8), 0, None), - ([0.2, 0.8], None, 2), - ({"train": 0.2, "test": 0.8}, None, 2), + ((0.2, 0.8), 2), + ((0.2, 0.8), None), + ([0.2, 0.8], 2), + ({"train": 0.2, "test": 0.8}, 2), # Not full dataset - ([0.2, 0.1], None, 2), - ((0.2, 0.1), 1, None), - ({"train": 0.2, "test": 0.1}, None, 2), - ({"train": 0.2, "test": 0.1}, "test", None), - (None, None, None), - (None, "train", None), + ([0.2, 0.1], 2), + ((0.2, 0.1), None), + ({"train": 0.2, "test": 0.1}, 2), + ({"train": 0.2, "test": 0.1}, None), + (None, None), + (None, None), ], ) def test_divide_partition_integration_size( @@ -87,23 +87,19 @@ def test_divide_partition_integration_size( partition_division: Optional[ Union[List[float], Tuple[float, ...], Dict[str, float]] ], - inner_split: Optional[str], expected_length: Optional[int], ): - """Test is the `partition_division` and `inner_split` create correct data.""" + """Test is the `partition_division` create correct data.""" dataset_fds = FederatedDataset( dataset=self.dataset_name, partitioners={"train": 100}, partition_division=partition_division, ) - partition = dataset_fds.load_partition(0, "train", inner_split=inner_split) + partition = dataset_fds.load_partition(0, "train") if partition_division is None: self.assertEqual(expected_length, None) else: - if inner_split is None: - self.assertEqual(len(partition), expected_length) - else: - self.assertIsInstance(partition, Dataset) + self.assertEqual(len(partition), expected_length) def test_load_full(self) -> None: """Test if the load_full works with the correct split name.""" From d08b44bdf493e9f3d8b0d05f044631aa53063f20 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 11:06:49 +0100 Subject: [PATCH 07/18] Expose utils module --- datasets/flwr_datasets/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index 48d993037708..0b9a6685427b 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -16,6 +16,7 @@ from flwr_datasets import partitioner, resplitter +from flwr_datasets import utils as utils from flwr_datasets.common.version import package_version as _package_version from flwr_datasets.federated_dataset import FederatedDataset @@ -23,6 +24,7 @@ "FederatedDataset", "partitioner", "resplitter", + "utils", ] __version__ = _package_version From 0ed75856403104c17dfed4b630f2eefe0e365da3 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 11:35:01 +0100 Subject: [PATCH 08/18] Expose utils module --- datasets/flwr_datasets/federated_dataset_test.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 0500fb5b821c..66071af8e348 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -70,15 +70,10 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: @parameterized.expand( # type: ignore [ ((0.2, 0.8), 2), - ((0.2, 0.8), None), - ([0.2, 0.8], 2), ({"train": 0.2, "test": 0.8}, 2), # Not full dataset ([0.2, 0.1], 2), - ((0.2, 0.1), None), ({"train": 0.2, "test": 0.1}, 2), - ({"train": 0.2, "test": 0.1}, None), - (None, None), (None, None), ], ) From 494fdeb4125fac84f1b599067f21aca47e93ae4c Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 16:58:34 +0100 Subject: [PATCH 09/18] Add an example to FDS docs --- datasets/flwr_datasets/federated_dataset.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index b0c67915cf87..b75c1e41cc12 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -56,7 +56,7 @@ class FederatedDataset: but at most, one per split. partition_division : Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]] - Divide each partition into splits (e.g. into train and evaluation) and + Divide each partition into splits (e.g. into train and validation) and control the size of the splits - fractions of the data. You can also name the splits for verification. shuffle : bool @@ -72,14 +72,18 @@ class FederatedDataset: Use MNIST dataset for Federated Learning with 100 clients (edge devices): >>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) - - Load partition for client with ID 10. - + >>> # Load partition for client with ID 10. >>> partition = mnist_fds.load_partition(10, "train") - - Use test split for centralized evaluation. - + >>> # Use test split for centralized evaluation. >>> centralized = mnist_fds.load_full("test") + + Automatically divde the data returned from `load_partition` + >>> mnist_fds = FederatedDataset( + >>> dataset="mnist", + >>> partitioners={"train": 100}, + >>> partition_division=[0.8, 0.2], + >>> ) + >>> partition_train, partition_test = mnist_fds.load_partition(10, "train") """ # pylint: disable=too-many-instance-attributes From 2777491d1efd24add554351f384c718423c64c78 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 5 Mar 2024 16:59:12 +0100 Subject: [PATCH 10/18] Update datasets/flwr_datasets/utils.py Co-authored-by: Javier --- datasets/flwr_datasets/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index d09cf938d635..afc0da170b86 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -102,7 +102,7 @@ def divide_dataset( Dataset to be divided. division: Union[List[float], Tuple[float, ...], Dict[str, float]] Configuration specifying how the dataset is divided. Each fraction has to be - >0 and <1. They have to sum up to at most 1 (smaller sum is possible). + >0 and <=1. They have to sum up to at most 1 (smaller sum is possible). Returns ------- From 7c7412385f14eee30f37a0653ac6c5d5c94fd0a1 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 17:00:43 +0100 Subject: [PATCH 11/18] Remove redundant computation --- datasets/flwr_datasets/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index d09cf938d635..59ca8c7b1e91 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -118,7 +118,6 @@ def divide_dataset( return split_partition if isinstance(division, dict): split_partition_dict: Dict[str, Dataset] = {} - ranges = _create_division_indices_ranges(dataset_length, division) for split_name, single_range in zip(division.keys(), ranges): split_partition_dict[split_name] = dataset.select(single_range) return DatasetDict(split_partition_dict) From 57c7c63fd72c20f4b52c6fc4725fbf51cfa10e0e Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 17:06:21 +0100 Subject: [PATCH 12/18] Add an example to divide_dataset function --- datasets/flwr_datasets/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index eacb8ebacdca..194afc220338 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -108,6 +108,15 @@ def divide_dataset( ------- `List[Dataset]` is returned in case of `List` or `Tuple` of the division `DatasetDict` in case of `Dict` specification + + Examples + -------- + Use `divide_dataset` with division specified as a list. + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import divide_dataset + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> partition = fds.load_partition(0) + >>> train, test = divide_dataset(dataset=partition, division=[0.8, 0.2]) """ dataset_length = len(dataset) ranges = _create_division_indices_ranges(dataset_length, division) From 6486c8aca5eb3d7eaf882865523654100dc876e6 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 17:09:31 +0100 Subject: [PATCH 13/18] Add a second example to divide_dataset fnc --- datasets/flwr_datasets/utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 194afc220338..1b9971c65280 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -116,7 +116,17 @@ def divide_dataset( >>> from flwr_datasets.utils import divide_dataset >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> partition = fds.load_partition(0) - >>> train, test = divide_dataset(dataset=partition, division=[0.8, 0.2]) + >>> division = [0.8, 0.2] + >>> train, test = divide_dataset(dataset=partition, division=division) + + Use `divide_dataset` with division specified as a dict. + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import divide_dataset + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> partition = fds.load_partition(0) + >>> division = {"train": 0.8, "test": 0.2} + >>> train_test = divide_dataset(dataset=partition, division=division) + >>> train, test = train_test["train"], train_test["test"] """ dataset_length = len(dataset) ranges = _create_division_indices_ranges(dataset_length, division) From a8b0bfd5dbfb626a753c380ab702c957cb57d000 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Tue, 5 Mar 2024 17:14:48 +0100 Subject: [PATCH 14/18] Fix load_partition docs --- datasets/flwr_datasets/federated_dataset.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index b75c1e41cc12..176ded7cf431 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -125,11 +125,6 @@ def load_partition( ) -> Union[Dataset, List[Dataset], DatasetDict]: """Load the partition specified by the idx in the selected split. - If the partition is divided as specified in `partition_split` then you can - return only the specific split by specifying `inner_split` otherwise all split - are return. If the `partition_split` was specified as `Dict` give the string - name, otherwise give the index. - The dataset is downloaded only when the first call to `load_partition` or `load_full` is made. @@ -148,11 +143,11 @@ def load_partition( Returns ------- partition : Union[Dataset, List[Dataset], DatasetDict] - Single partition from the dataset split which can be further divided. - If `partition_split` specified and `inner_split` not given then - `List[Dataset]` is returned in case of `List` or `Tuple` specification of - the `partition_split` and `DatasetDict` in case of `Dict` specification. - If `partition_split` specified and the `inner_split` given, then Dataset is + Undivided or divided partition from the dataset split. + If `partition_division` is not specified then `Dataset` is returned. + If `partition_division` is specified as `List` or `Tuple` then + `List[Dataset]` is returned. + If `partition_division` is specified as `Dict` then `DatasetDict` is returned. """ if not self._dataset_prepared: From 6052e2a846854ce6bfd810b211acb76712e0b604 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 6 Mar 2024 09:12:41 +0100 Subject: [PATCH 15/18] Update year in the Copyright notice --- datasets/flwr_datasets/utils_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/utils_test.py b/datasets/flwr_datasets/utils_test.py index debb98976e37..26f24519eb76 100644 --- a/datasets/flwr_datasets/utils_test.py +++ b/datasets/flwr_datasets/utils_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 497a5cc5b0ba2337d2723458cbb7fcc09ab5dcf1 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Sun, 10 Mar 2024 14:28:06 +0100 Subject: [PATCH 16/18] Apply suggestions from code review Co-authored-by: Daniel J. Beutel --- datasets/flwr_datasets/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index 1b9971c65280..dcd016c516b5 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -98,7 +98,7 @@ def divide_dataset( Parameters ---------- - dataset: Dataset + dataset : Dataset Dataset to be divided. division: Union[List[float], Tuple[float, ...], Dict[str, float]] Configuration specifying how the dataset is divided. Each fraction has to be @@ -114,6 +114,7 @@ def divide_dataset( Use `divide_dataset` with division specified as a list. >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import divide_dataset + >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> partition = fds.load_partition(0) >>> division = [0.8, 0.2] @@ -122,6 +123,7 @@ def divide_dataset( Use `divide_dataset` with division specified as a dict. >>> from flwr_datasets import FederatedDataset >>> from flwr_datasets.utils import divide_dataset + >>> >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) >>> partition = fds.load_partition(0) >>> division = {"train": 0.8, "test": 0.2} From f0a906e6ad1c1a9194e5fcf0409cfeaacf51d37e Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Sun, 10 Mar 2024 15:56:15 +0100 Subject: [PATCH 17/18] Fix return type docs --- datasets/flwr_datasets/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index dcd016c516b5..38382508035c 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -106,8 +106,9 @@ def divide_dataset( Returns ------- - `List[Dataset]` is returned in case of `List` or `Tuple` of the division - `DatasetDict` in case of `Dict` specification + divided_dataset : Union[List[Dataset], DatasetDict] + If `division` is `List` or `Tuple` then `List[Dataset]` is returned else if + `division` is `Dict` then `DatasetDict` is returned. Examples -------- From 4a4de2608c0dfc705a8724030d54b9570ff52665 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Sun, 10 Mar 2024 18:05:32 +0100 Subject: [PATCH 18/18] Support partitioner-specific partition_division --- datasets/flwr_datasets/federated_dataset.py | 98 +++++++++++++++++-- .../flwr_datasets/federated_dataset_test.py | 27 +++-- 2 files changed, 109 insertions(+), 16 deletions(-) diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index 176ded7cf431..588d1ab40aec 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, cast import datasets from datasets import Dataset, DatasetDict @@ -54,11 +54,19 @@ class FederatedDataset: (representing the number of IID partitions that this split should be partitioned into). One or multiple `Partitioner` objects can be specified in that manner, but at most, one per split. - partition_division : Optional[Union[List[float], Tuple[float, ...], Dict[str, - float]]] - Divide each partition into splits (e.g. into train and validation) and - control the size of the splits - fractions of the data. You can also name the - splits for verification. + partition_division : Optional[Union[List[float], Tuple[float, ...], + Dict[str, float], Dict[str, Optional[Union[List[float], Tuple[float, ...], + Dict[str, float]]]]]] + Fractions specifing the division of the partition assiciated with certain split + (and partitioner) that enable returning already divided partition from the + `load_partition` method. You can think of this as on-edge division of the data + into multiple divisions (e.g. into train and validation). You can also name the + divisions by using the Dict or create specify it as a List/Tuple. If you + specified a single partitioner you can provide the simplified form e.g. + [0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2} but when multiple + partitioners are specified you need to indicate the result of which partitioner + are further divided e.g. {"train": [0.8, 0.2]} would result in dividing only the + partitions that are created from the "train" split. shuffle : bool Whether to randomize the order of samples. Applied prior to resplitting, speratelly to each of the present splits in the dataset. It uses the `seed` @@ -95,7 +103,15 @@ def __init__( resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], partition_division: Optional[ - Union[List[float], Tuple[float, ...], Dict[str, float]] + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] ] = None, shuffle: bool = True, seed: Optional[int] = 42, @@ -109,7 +125,9 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) - self._partition_division = partition_division + self._partition_division = self._initialize_partition_division( + partition_division + ) self._shuffle = shuffle self._seed = seed # _dataset is prepared lazily on the first call to `load_partition` @@ -164,8 +182,11 @@ def load_partition( partition = partitioner.load_partition(node_id) if self._partition_division is None: return partition + partition_division = self._partition_division.get(split) + if partition_division is None: + return partition divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset( - partition, self._partition_division + partition, partition_division ) return divided_partition @@ -261,3 +282,62 @@ def _check_if_no_split_keyword_possible(self) -> None: "Please set the `split` argument. You can only omit the split keyword " "if there is exactly one partitioner specified." ) + + def _initialize_partition_division( + self, + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ], + ) -> Optional[ + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ] + ]: + """Create the partition division in the full format. + + Reduced format (possible if only one partitioner exist): + + Union[List[float], Tuple[float, ...], Dict[str, float] + + Full format: Dict[str, Reduced format] + Full format represents the split to division mapping. + """ + # Check for simple dict, list, or tuple types directly + if isinstance(partition_division, (list, tuple)) or ( + isinstance(partition_division, dict) + and all(isinstance(value, float) for value in partition_division.values()) + ): + if len(self._partitioners) > 1: + raise ValueError( + f"The specified partition_division {partition_division} does not " + f"provide mapping to split but more than one partitioners is " + f"specified. Please adjust the partition_division specification to " + f"have the split names as the keys." + ) + return cast( + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + {list(self._partitioners.keys())[0]: partition_division}, + ) + if isinstance(partition_division, dict): + return cast( + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + partition_division, + ) + if partition_division is None: + return None + raise TypeError("Unsupported type for partition_division") diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index 66071af8e348..e01f56342954 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -69,25 +69,38 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: @parameterized.expand( # type: ignore [ - ((0.2, 0.8), 2), - ({"train": 0.2, "test": 0.8}, 2), + ((0.2, 0.8), 2, False), + ({"train": 0.2, "test": 0.8}, 2, False), + ({"train": {"train": 0.2, "test": 0.8}}, 2, True), # Not full dataset - ([0.2, 0.1], 2), - ({"train": 0.2, "test": 0.1}, 2), - (None, None), + ([0.2, 0.1], 2, False), + ({"train": 0.2, "test": 0.1}, 2, False), + (None, None, False), ], ) def test_divide_partition_integration_size( self, partition_division: Optional[ - Union[List[float], Tuple[float, ...], Dict[str, float]] + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] ], expected_length: Optional[int], + add_test_partitioner: bool, ): """Test is the `partition_division` create correct data.""" + partitioners: Dict[str, Union[Partitioner, int]] = {"train": 10} + if add_test_partitioner: + partitioners[self.test_split] = 10 dataset_fds = FederatedDataset( dataset=self.dataset_name, - partitioners={"train": 100}, + partitioners=partitioners, partition_division=partition_division, ) partition = dataset_fds.load_partition(0, "train")