Skip to content

Commit

Permalink
wip: partition division utils
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Feb 14, 2024
1 parent 507065b commit 6707839
Showing 1 changed file with 85 additions and 1 deletion.
86 changes: 85 additions & 1 deletion datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@


import warnings
from typing import Dict, Optional, Tuple, Union, cast
from typing import Dict, Optional, Tuple, Union, cast, List

from datasets import Dataset, DatasetDict

from flwr_datasets.partitioner import IidPartitioner, Partitioner
from flwr_datasets.resplitter import Resplitter
Expand Down Expand Up @@ -85,3 +87,85 @@ def _check_if_dataset_tested(dataset: str) -> None:
f"The currently tested dataset are {tested_datasets}. Given: {dataset}.",
stacklevel=1,
)


def divide_dataset(dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> Union[Dataset, List[Dataset], DatasetDict]:

dataset_length = len(dataset)
ranges = _create_division_indices_ranges(dataset_length, division)
if isinstance(division, (list, tuple)):
split_partition = []
for r in ranges:
split_partition.append(dataset.select(r))
return split_partition
elif isinstance(division, dict):
split_partition = {}
ranges = _create_division_indices_ranges(dataset_length, division)
for split_name, r in zip(division.keys(), ranges):
split_partition[split_name] = dataset.select(r)
return DatasetDict(split_partition)
else:
TypeError(
f"The type of the `division` should be dict, tuple or list but is {type(division)} instead.")


def _create_division_indices_ranges(dataset_length: int, division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> List[range]:
ranges = []
if isinstance(division, (list, tuple)):
start_idx = 0
end_idx = 0
for fraction in division:
end_idx = int(dataset_length * fraction)
ranges.append(range(start_idx, end_idx))
start_idx = end_idx
elif isinstance(division, dict):
ranges = []
start_idx = 0
end_idx = 0
for fraction in division.values():
end_idx = int(dataset_length * fraction)
ranges.append(range(start_idx, end_idx))
start_idx = end_idx
else:
TypeError("The type of the `partition_split` should be dict, tuple or list but is {type(self.partition_split)} instead. ")
return ranges


def _check_division_config_types_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None:
if isinstance(division, (list, tuple)):
if not all(isinstance(x, float) for x in division):
raise TypeError(
"List or tuple values of `partition_split` must contain only floats, other types are not allowed.")
elif isinstance(division, dict):
if not all(isinstance(x, float) for x in division.values()):
raise TypeError(
"Dict values of `partition_split` must be only floats, other types are not allowed.")
else:
raise TypeError("`partition_split` must be a list, tuple, or dict.")

def _check_division_config_values_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None:
if isinstance(division, (list, tuple)):
if not all(0 < x <= 1 for x in division):
raise ValueError(
"All fractions for the division must be greater than 0 and smaller or equal to 1.")
fraction_sum_from_list_tuple = sum(division)
if fraction_sum_from_list_tuple > 1:
raise ValueError("Sum of fractions for division must not exceed 1.")
if fraction_sum_from_list_tuple < 1:
warnings.warn(f"Sum of fractions for division is {sum(division)}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.")
elif isinstance(division, dict):
values = list(division.values())
if not all(0 < x <= 1 for x in values):
raise ValueError(
"All fractions must be greater than 0 and smaller or equal to 1.")
if sum(values) > 1:
raise ValueError("Sum of fractions must not exceed 1.")
if sum(division) < 1:
warnings.warn(
f"Sum of fractions in `partition_split` is {values}, which is below 1. Make sure that's the desired behavior. Some data will not be used in the current specification.")
else:
raise TypeError("`partition_split` must be a list, tuple, or dict.")

def _check_division_config_correctness(division: Union[List[float], Tuple[float, ...], Dict[str, float]]) -> None:
_check_division_config_types_correctness(division)
_check_division_config_values_correctness(division)

0 comments on commit 6707839

Please sign in to comment.