Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partition division #2951

Merged
merged 21 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datasets/flwr_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@


from flwr_datasets import partitioner, resplitter
from flwr_datasets import utils as utils
from flwr_datasets.common.version import package_version as _package_version
from flwr_datasets.federated_dataset import FederatedDataset

__all__ = [
"FederatedDataset",
"partitioner",
"resplitter",
"utils",
]

__version__ = _package_version
133 changes: 122 additions & 11 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, cast

import datasets
from datasets import Dataset, DatasetDict
Expand All @@ -25,9 +25,12 @@
_check_if_dataset_tested,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
divide_dataset,
)


# flake8: noqa: E501
# pylint: disable=line-too-long
class FederatedDataset:
"""Representation of a dataset for federated learning/evaluation/analytics.

Expand All @@ -51,6 +54,19 @@ class FederatedDataset:
(representing the number of IID partitions that this split should be partitioned
into). One or multiple `Partitioner` objects can be specified in that manner,
but at most, one per split.
partition_division : Optional[Union[List[float], Tuple[float, ...],
Dict[str, float], Dict[str, Optional[Union[List[float], Tuple[float, ...],
Dict[str, float]]]]]]
Fractions specifing the division of the partition assiciated with certain split
(and partitioner) that enable returning already divided partition from the
`load_partition` method. You can think of this as on-edge division of the data
into multiple divisions (e.g. into train and validation). You can also name the
divisions by using the Dict or create specify it as a List/Tuple. If you
specified a single partitioner you can provide the simplified form e.g.
[0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2} but when multiple
partitioners are specified you need to indicate the result of which partitioner
are further divided e.g. {"train": [0.8, 0.2]} would result in dividing only the
partitions that are created from the "train" split.
shuffle : bool
Whether to randomize the order of samples. Applied prior to resplitting,
speratelly to each of the present splits in the dataset. It uses the `seed`
Expand All @@ -64,14 +80,18 @@ class FederatedDataset:
Use MNIST dataset for Federated Learning with 100 clients (edge devices):

>>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100})

Load partition for client with ID 10.

>>> # Load partition for client with ID 10.
>>> partition = mnist_fds.load_partition(10, "train")

Use test split for centralized evaluation.

>>> # Use test split for centralized evaluation.
>>> centralized = mnist_fds.load_full("test")

Automatically divde the data returned from `load_partition`
>>> mnist_fds = FederatedDataset(
>>> dataset="mnist",
>>> partitioners={"train": 100},
>>> partition_division=[0.8, 0.2],
>>> )
>>> partition_train, partition_test = mnist_fds.load_partition(10, "train")
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -82,6 +102,17 @@ def __init__(
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
] = None,
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
Expand All @@ -94,6 +125,9 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
self._partition_division = self._initialize_partition_division(
partition_division
)
self._shuffle = shuffle
self._seed = seed
# _dataset is prepared lazily on the first call to `load_partition`
Expand All @@ -102,7 +136,11 @@ def __init__(
# Indicate if the dataset is prepared for `load_partition` or `load_full`
self._dataset_prepared: bool = False

def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
def load_partition(
self,
node_id: int,
split: Optional[str] = None,
) -> Union[Dataset, List[Dataset], DatasetDict]:
"""Load the partition specified by the idx in the selected split.

The dataset is downloaded only when the first call to `load_partition` or
Expand All @@ -122,8 +160,13 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:

Returns
-------
partition : Dataset
Single partition from the dataset split.
partition : Union[Dataset, List[Dataset], DatasetDict]
Undivided or divided partition from the dataset split.
If `partition_division` is not specified then `Dataset` is returned.
If `partition_division` is specified as `List` or `Tuple` then
`List[Dataset]` is returned.
If `partition_division` is specified as `Dict` then `DatasetDict` is
returned.
"""
if not self._dataset_prepared:
self._prepare_dataset()
Expand All @@ -136,7 +179,16 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset:
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
self._assign_dataset_to_partitioner(split)
return partitioner.load_partition(node_id)
partition = partitioner.load_partition(node_id)
if self._partition_division is None:
return partition
partition_division = self._partition_division.get(split)
if partition_division is None:
return partition
divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset(
partition, partition_division
)
return divided_partition

def load_full(self, split: str) -> Dataset:
"""Load the full split of the dataset.
Expand Down Expand Up @@ -230,3 +282,62 @@ def _check_if_no_split_keyword_possible(self) -> None:
"Please set the `split` argument. You can only omit the split keyword "
"if there is exactly one partitioner specified."
)

def _initialize_partition_division(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
) -> Optional[
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
]
]:
"""Create the partition division in the full format.

Reduced format (possible if only one partitioner exist):

Union[List[float], Tuple[float, ...], Dict[str, float]

Full format: Dict[str, Reduced format]
Full format represents the split to division mapping.
"""
# Check for simple dict, list, or tuple types directly
if isinstance(partition_division, (list, tuple)) or (
isinstance(partition_division, dict)
and all(isinstance(value, float) for value in partition_division.values())
):
if len(self._partitioners) > 1:
raise ValueError(
f"The specified partition_division {partition_division} does not "
f"provide mapping to split but more than one partitioners is "
f"specified. Please adjust the partition_division specification to "
f"have the split names as the keys."
)
return cast(
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
{list(self._partitioners.keys())[0]: partition_division},
)
if isinstance(partition_division, dict):
return cast(
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
partition_division,
)
if partition_division is None:
return None
raise TypeError("Unsupported type for partition_division")
44 changes: 43 additions & 1 deletion datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import unittest
from typing import Dict, Union
from typing import Dict, List, Optional, Tuple, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -67,6 +67,48 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)

@parameterized.expand( # type: ignore
[
((0.2, 0.8), 2, False),
({"train": 0.2, "test": 0.8}, 2, False),
({"train": {"train": 0.2, "test": 0.8}}, 2, True),
# Not full dataset
([0.2, 0.1], 2, False),
({"train": 0.2, "test": 0.1}, 2, False),
(None, None, False),
],
)
def test_divide_partition_integration_size(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
expected_length: Optional[int],
add_test_partitioner: bool,
):
"""Test is the `partition_division` create correct data."""
partitioners: Dict[str, Union[Partitioner, int]] = {"train": 10}
if add_test_partitioner:
partitioners[self.test_split] = 10
dataset_fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
partition_division=partition_division,
)
partition = dataset_fds.load_partition(0, "train")
if partition_division is None:
self.assertEqual(expected_length, None)
else:
self.assertEqual(len(partition), expected_length)

def test_load_full(self) -> None:
"""Test if the load_full works with the correct split name."""
dataset_fds = FederatedDataset(
Expand Down
Loading