From 670783911166eb94908ffb4c439210f6b9f0a125 Mon Sep 17 00:00:00 2001 From: Adam Narozniak Date: Wed, 14 Feb 2024 16:23:02 +0100 Subject: [PATCH] wip: partition division utils --- datasets/flwr_datasets/utils.py | 86 ++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index e3d0fdfffa63..7431a9a5d0b2 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,7 +16,9 @@ import warnings -from typing import Dict, Optional, Tuple, Union, cast +from typing import Dict, Optional, Tuple, Union, cast, List + +from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import IidPartitioner, Partitioner from flwr_datasets.resplitter import Resplitter @@ -85,3 +87,85 @@ def _check_if_dataset_tested(dataset: str) -> None: f"The currently tested dataset are {tested_datasets}. Given: {dataset}.", stacklevel=1, ) + + +def divide_dataset(dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> Union[Dataset, List[Dataset], DatasetDict]: + + dataset_length = len(dataset) + ranges = _create_division_indices_ranges(dataset_length, division) + if isinstance(division, (list, tuple)): + split_partition = [] + for r in ranges: + split_partition.append(dataset.select(r)) + return split_partition + elif isinstance(division, dict): + split_partition = {} + ranges = _create_division_indices_ranges(dataset_length, division) + for split_name, r in zip(division.keys(), ranges): + split_partition[split_name] = dataset.select(r) + return DatasetDict(split_partition) + else: + TypeError( + f"The type of the `division` should be dict, tuple or list but is {type(division)} instead.") + + +def _create_division_indices_ranges(dataset_length: int, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> List[range]: + ranges = [] + if isinstance(division, (list, tuple)): + start_idx = 0 + end_idx = 0 + for fraction in division: + end_idx = int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx = end_idx + elif isinstance(division, dict): + ranges = [] + start_idx = 0 + end_idx = 0 + for fraction in division.values(): + end_idx = int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx = end_idx + else: + TypeError("The type of the `partition_split` should be dict, tuple or list but is {type(self.partition_split)} instead. ") + return ranges + + +def _check_division_config_types_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: + if isinstance(division, (list, tuple)): + if not all(isinstance(x, float) for x in division): + raise TypeError( + "List or tuple values of `partition_split` must contain only floats, other types are not allowed.") + elif isinstance(division, dict): + if not all(isinstance(x, float) for x in division.values()): + raise TypeError( + "Dict values of `partition_split` must be only floats, other types are not allowed.") + else: + raise TypeError("`partition_split` must be a list, tuple, or dict.") + +def _check_division_config_values_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: + if isinstance(division, (list, tuple)): + if not all(0 < x <= 1 for x in division): + raise ValueError( + "All fractions for the division must be greater than 0 and smaller or equal to 1.") + fraction_sum_from_list_tuple = sum(division) + if fraction_sum_from_list_tuple > 1: + raise ValueError("Sum of fractions for division must not exceed 1.") + if fraction_sum_from_list_tuple < 1: + warnings.warn(f"Sum of fractions for division is {sum(division)}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.") + elif isinstance(division, dict): + values = list(division.values()) + if not all(0 < x <= 1 for x in values): + raise ValueError( + "All fractions must be greater than 0 and smaller or equal to 1.") + if sum(values) > 1: + raise ValueError("Sum of fractions must not exceed 1.") + if sum(division) < 1: + warnings.warn( + f"Sum of fractions in `partition_split` is {values}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.") + else: + raise TypeError("`partition_split` must be a list, tuple, or dict.") + +def _check_division_config_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None: + _check_division_config_types_correctness(division) + _check_division_config_values_correctness(division)