Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partition division #2951

Merged
merged 21 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@


from flwr_datasets import partitioner, resplitter
from flwr_datasets import utils as utils
from flwr_datasets.common.version import package_version as _package_version
from flwr_datasets.federated_dataset import FederatedDataset

__all__ = [
"FederatedDataset",
"partitioner",
"resplitter",
"utils",
]

__version__ = _package_version
53 changes: 42 additions & 11 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import datasets
from datasets import Dataset, DatasetDict
Expand All @@ -25,9 +25,12 @@
_check_if_dataset_tested,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
divide_dataset,
)


# flake8: noqa: E501
# pylint: disable=line-too-long
class FederatedDataset:
"""Representation of a dataset for federated learning/evaluation/analytics.

Expand All @@ -51,6 +54,11 @@ class FederatedDataset:
(representing the number of IID partitions that this split should be partitioned
into). One or multiple `Partitioner` objects can be specified in that manner,
but at most, one per split.
partition_division : Optional[Union[List[float], Tuple[float, ...], Dict[str,
float]]]
Divide each partition into splits (e.g. into train and validation) and
control the size of the splits - fractions of the data. You can also name the
splits for verification.
shuffle : bool
Whether to randomize the order of samples. Applied prior to resplitting,
speratelly to each of the present splits in the dataset. It uses the `seed`
Expand All @@ -64,14 +72,18 @@ class FederatedDataset:
Use MNIST dataset for Federated Learning with 100 clients (edge devices):

>>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})

Load partition for client with ID 10.

>>> # Load partition for client with ID 10.
>>> partition = mnist_fds.load_partition(10, "train")

Use test split for centralized evaluation.

>>> # Use test split for centralized evaluation.
>>> centralized = mnist_fds.load_full("test")

Automatically divde the data returned from `load_partition`
>>> mnist_fds = FederatedDataset(
>>> dataset="mnist",
>>> partitioners={"train": 100},
>>> partition_division=[0.8, 0.2],
>>> )
>>> partition_train, partition_test = mnist_fds.load_partition(10, "train")
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -82,6 +94,9 @@ def __init__(
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
partition_division: Optional[
Union[List[float], Tuple[float, ...], Dict[str, float]]
] = None,
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
Expand All @@ -94,6 +109,7 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
self._partition_division = partition_division
self._shuffle = shuffle
self._seed = seed
# _dataset is prepared lazily on the first call to `load_partition`
Expand All @@ -102,7 +118,11 @@ def __init__(
# Indicate if the dataset is prepared for `load_partition` or `load_full`
self._dataset_prepared: bool = False

def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
def load_partition(
self,
node_id: int,
split: Optional[str] = None,
) -> Union[Dataset, List[Dataset], DatasetDict]:
"""Load the partition specified by the idx in the selected split.

The dataset is downloaded only when the first call to `load_partition` or
Expand All @@ -122,8 +142,13 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:

Returns
-------
partition : Dataset
Single partition from the dataset split.
partition : Union[Dataset, List[Dataset], DatasetDict]
Undivided or divided partition from the dataset split.
If `partition_division` is not specified then `Dataset` is returned.
If `partition_division` is specified as `List` or `Tuple` then
`List[Dataset]` is returned.
If `partition_division` is specified as `Dict` then `DatasetDict` is
returned.
"""
if not self._dataset_prepared:
self._prepare_dataset()
Expand All @@ -136,7 +161,13 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
self._assign_dataset_to_partitioner(split)
return partitioner.load_partition(node_id)
partition = partitioner.load_partition(node_id)
if self._partition_division is None:
return partition
divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset(
partition, self._partition_division
)
return divided_partition

def load_full(self, split: str) -> Dataset:
"""Load the full split of the dataset.
Expand Down
31 changes: 30 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import unittest
from typing import Dict, Union
from typing import Dict, List, Optional, Tuple, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -67,6 +67,35 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)

@parameterized.expand( # type: ignore
[
((0.2, 0.8), 2),
({"train": 0.2, "test": 0.8}, 2),
# Not full dataset
([0.2, 0.1], 2),
({"train": 0.2, "test": 0.1}, 2),
(None, None),
],
)
def test_divide_partition_integration_size(
self,
partition_division: Optional[
Union[List[float], Tuple[float, ...], Dict[str, float]]
],
expected_length: Optional[int],
):
"""Test is the `partition_division` create correct data."""
dataset_fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 100},
partition_division=partition_division,
)
partition = dataset_fds.load_partition(0, "train")
if partition_division is None:
self.assertEqual(expected_length, None)
else:
self.assertEqual(len(partition), expected_length)

def test_load_full(self) -> None:
"""Test if the load_full works with the correct split name."""
dataset_fds = FederatedDataset(
Expand Down
153 changes: 152 additions & 1 deletion datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@


import warnings
from typing import Dict, Optional, Tuple, Union, cast
from typing import Dict, List, Optional, Tuple, Union, cast

from datasets import Dataset, DatasetDict
from flwr_datasets.partitioner import IidPartitioner, Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.resplitter.merge_resplitter import MergeResplitter
Expand Down Expand Up @@ -85,3 +86,153 @@ def _check_if_dataset_tested(dataset: str) -> None:
f"The currently tested dataset are {tested_datasets}. Given: {dataset}.",
stacklevel=1,
)


def divide_dataset(
dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> Union[List[Dataset], DatasetDict]:
"""Divide the dataset according to the `division`.

The division support varying number of splits, which you can name. The splits are
created from the beginning of the dataset.

Parameters
----------
dataset: Dataset
Dataset to be divided.
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
Configuration specifying how the dataset is divided. Each fraction has to be
>0 and <=1. They have to sum up to at most 1 (smaller sum is possible).

Returns
-------
`List[Dataset]` is returned in case of `List` or `Tuple` of the division
`DatasetDict` in case of `Dict` specification

Examples
--------
Use `divide_dataset` with division specified as a list.
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.utils import divide_dataset
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
>>> partition = fds.load_partition(0)
>>> division = [0.8, 0.2]
>>> train, test = divide_dataset(dataset=partition, division=division)

Use `divide_dataset` with division specified as a dict.
>>> from flwr_datasets import FederatedDataset
>>> from flwr_datasets.utils import divide_dataset
>>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})
>>> partition = fds.load_partition(0)
>>> division = {"train": 0.8, "test": 0.2}
>>> train_test = divide_dataset(dataset=partition, division=division)
>>> train, test = train_test["train"], train_test["test"]
"""
dataset_length = len(dataset)
ranges = _create_division_indices_ranges(dataset_length, division)
if isinstance(division, (list, tuple)):
split_partition: List[Dataset] = []
for single_range in ranges:
split_partition.append(dataset.select(single_range))
return split_partition
if isinstance(division, dict):
split_partition_dict: Dict[str, Dataset] = {}
for split_name, single_range in zip(division.keys(), ranges):
split_partition_dict[split_name] = dataset.select(single_range)
return DatasetDict(split_partition_dict)
raise TypeError(
f"The type of the `division` should be dict, "
f"tuple or list but is {type(division)} instead."
)


def _create_division_indices_ranges(
dataset_length: int,
division: Union[List[float], Tuple[float, ...], Dict[str, float]],
) -> List[range]:
ranges = []
if isinstance(division, (list, tuple)):
start_idx = 0
end_idx = 0
for fraction in division:
end_idx += int(dataset_length * fraction)
ranges.append(range(start_idx, end_idx))
start_idx += end_idx
elif isinstance(division, dict):
ranges = []
start_idx = 0
end_idx = 0
for fraction in division.values():
end_idx += int(dataset_length * fraction)
ranges.append(range(start_idx, end_idx))
start_idx += end_idx
else:
TypeError(
f"The type of the `division` should be dict, "
f"tuple or list but is {type(division)} instead. "
)
return ranges


def _check_division_config_types_correctness(
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> None:
if isinstance(division, (list, tuple)):
if not all(isinstance(x, float) for x in division):
raise TypeError(
"List or tuple values of `division` must contain only floats, "
"other types are not allowed."
)
elif isinstance(division, dict):
if not all(isinstance(x, float) for x in division.values()):
raise TypeError(
"Dict values of `division` must be only floats, "
"other types are not allowed."
)
else:
raise TypeError("`division` must be a list, tuple, or dict.")


def _check_division_config_values_correctness(
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> None:
if isinstance(division, (list, tuple)):
if not all(0 < x <= 1 for x in division):
raise ValueError(
"All fractions for the division must be greater than 0 and smaller or "
"equal to 1."
)
fraction_sum_from_list_tuple = sum(division)
if fraction_sum_from_list_tuple > 1:
raise ValueError("Sum of fractions for division must not exceed 1.")
if fraction_sum_from_list_tuple < 1:
warnings.warn(
f"Sum of fractions for division is {sum(division)}, which is below 1. "
f"Make sure that's the desired behavior. Some data will not be used "
f"in the current specification.",
stacklevel=1,
)
elif isinstance(division, dict):
values = list(division.values())
if not all(0 < x <= 1 for x in values):
raise ValueError(
"All fractions must be greater than 0 and smaller or equal to 1."
)
if sum(values) > 1:
raise ValueError("Sum of fractions must not exceed 1.")
if sum(values) < 1:
warnings.warn(
f"Sum of fractions in `division` is {values}, which is below 1. "
f"Make sure that's the desired behavior. Some data will not be used "
f"in the current specification.",
stacklevel=1,
)
else:
raise TypeError("`division` must be a list, tuple, or dict.")


def _check_division_config_correctness(
division: Union[List[float], Tuple[float, ...], Dict[str, float]]
) -> None:
_check_division_config_types_correctness(division)
_check_division_config_values_correctness(division)
Loading