Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fds add concatenate divisions #3103

Merged
merged 37 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
507065b
Add partition_division to FederatedDataset
adam-narozniak Feb 14, 2024
6707839
wip: partition division utils
adam-narozniak Feb 14, 2024
9f4b310
Update load_partition
adam-narozniak Feb 15, 2024
876dad5
Update divsion utils
adam-narozniak Feb 15, 2024
1e77c06
Update tests
adam-narozniak Feb 15, 2024
ecf2b40
Remove inner_split keyword from load_partition
adam-narozniak Mar 5, 2024
d08b44b
Expose utils module
adam-narozniak Mar 5, 2024
0ed7585
Expose utils module
adam-narozniak Mar 5, 2024
494fdeb
Add an example to FDS docs
adam-narozniak Mar 5, 2024
2777491
Update datasets/flwr_datasets/utils.py
adam-narozniak Mar 5, 2024
7c74123
Remove redundant computation
adam-narozniak Mar 5, 2024
ecef48b
Merge remote-tracking branch 'origin/fds-add-partition-splitting' int…
adam-narozniak Mar 5, 2024
57c7c63
Add an example to divide_dataset function
adam-narozniak Mar 5, 2024
6486c8a
Add a second example to divide_dataset fnc
adam-narozniak Mar 5, 2024
a8b0bfd
Fix load_partition docs
adam-narozniak Mar 5, 2024
6052e2a
Update year in the Copyright notice
adam-narozniak Mar 6, 2024
5f13155
Merge branch 'main' into fds-add-partition-splitting
danieljanes Mar 6, 2024
497a5cc
Apply suggestions from code review
adam-narozniak Mar 10, 2024
f0a906e
Fix return type docs
adam-narozniak Mar 10, 2024
4a4de26
Support partitioner-specific partition_division
adam-narozniak Mar 10, 2024
4211941
Add num_partition property
adam-narozniak Mar 11, 2024
02411f7
Trigger the partitioning in the num_partitions
adam-narozniak Mar 11, 2024
ebc486b
Merge branch 'fds-add-num-partitions' into fds-add-concatenate-divisions
adam-narozniak Mar 11, 2024
d651c10
Add concatenate_divisions
adam-narozniak Mar 11, 2024
bd04e37
Merge branch 'main' into fds-add-concatenate-divisions
adam-narozniak Mar 12, 2024
b9392aa
Apply suggestions from code review
jafermarq Mar 13, 2024
61efc4f
Apply suggestions from code review
jafermarq Mar 13, 2024
63be596
Remove concatenate division form FederatedDataset
adam-narozniak Mar 13, 2024
07faad0
Remove concatenate division form FederatedDataset tests
adam-narozniak Mar 13, 2024
e5448b8
Add concatenate division to utils
adam-narozniak Mar 13, 2024
eec8dac
Add concatenate division to utils tests
adam-narozniak Mar 13, 2024
708cff6
Fix formatting
adam-narozniak Mar 13, 2024
e55e1e8
Merge remote-tracking branch 'origin/main' into fds-add-concatenate-d…
adam-narozniak Mar 14, 2024
caae36a
Merge branch 'main' into fds-add-concatenate-divisions
adam-narozniak Mar 14, 2024
0b4e761
Apply suggestions from code review
adam-narozniak Mar 14, 2024
4775e3e
Fix formatting
adam-narozniak Mar 14, 2024
297cf52
Merge branch 'main' into fds-add-concatenate-divisions
jafermarq Mar 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 5 additions & 114 deletions datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""FederatedDataset."""


from typing import Dict, List, Optional, Tuple, Union, cast
from typing import Dict, Optional, Tuple, Union

import datasets
from datasets import Dataset, DatasetDict
Expand All @@ -25,7 +25,6 @@
_check_if_dataset_tested,
_instantiate_partitioners,
_instantiate_resplitter_if_needed,
divide_dataset,
)


Expand Down Expand Up @@ -54,19 +53,6 @@ class FederatedDataset:
(representing the number of IID partitions that this split should be partitioned
into). One or multiple `Partitioner` objects can be specified in that manner,
but at most, one per split.
partition_division : Optional[Union[List[float], Tuple[float, ...],
Dict[str, float], Dict[str, Optional[Union[List[float], Tuple[float, ...],
Dict[str, float]]]]]]
Fractions specifing the division of the partition assiciated with certain split
(and partitioner) that enable returning already divided partition from the
`load_partition` method. You can think of this as on-edge division of the data
into multiple divisions (e.g. into train and validation). You can also name the
divisions by using the Dict or create specify it as a List/Tuple. If you
specified a single partitioner you can provide the simplified form e.g.
[0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2} but when multiple
partitioners are specified you need to indicate the result of which partitioner
are further divided e.g. {"train": [0.8, 0.2]} would result in dividing only the
partitions that are created from the "train" split.
shuffle : bool
Whether to randomize the order of samples. Applied prior to resplitting,
speratelly to each of the present splits in the dataset. It uses the `seed`
Expand All @@ -84,14 +70,6 @@ class FederatedDataset:
>>> partition = mnist_fds.load_partition(10, "train")
>>> # Use test split for centralized evaluation.
>>> centralized = mnist_fds.load_split("test")

Automatically divde the data returned from `load_partition`
>>> mnist_fds = FederatedDataset(
>>> dataset="mnist",
>>> partitioners={"train": 100},
>>> partition_division=[0.8, 0.2],
>>> )
>>> partition_train, partition_test = mnist_fds.load_partition(10, "train")
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -102,17 +80,6 @@ def __init__(
subset: Optional[str] = None,
resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None,
partitioners: Dict[str, Union[Partitioner, int]],
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
] = None,
shuffle: bool = True,
seed: Optional[int] = 42,
) -> None:
Expand All @@ -125,9 +92,6 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
self._partition_division = self._initialize_partition_division(
partition_division
)
self._shuffle = shuffle
self._seed = seed
# _dataset is prepared lazily on the first call to `load_partition`
Expand All @@ -140,7 +104,7 @@ def load_partition(
self,
partition_id: int,
split: Optional[str] = None,
) -> Union[Dataset, List[Dataset], DatasetDict]:
) -> Dataset:
"""Load the partition specified by the idx in the selected split.

The dataset is downloaded only when the first call to `load_partition` or
Expand All @@ -160,13 +124,8 @@ def load_partition(

Returns
-------
partition : Union[Dataset, List[Dataset], DatasetDict]
Undivided or divided partition from the dataset split.
If `partition_division` is not specified then `Dataset` is returned.
If `partition_division` is specified as `List` or `Tuple` then
`List[Dataset]` is returned.
If `partition_division` is specified as `Dict` then `DatasetDict` is
returned.
partition : Dataset
Single partition from the dataset split.
"""
if not self._dataset_prepared:
self._prepare_dataset()
Expand All @@ -179,16 +138,7 @@ def load_partition(
self._check_if_split_possible_to_federate(split)
partitioner: Partitioner = self._partitioners[split]
self._assign_dataset_to_partitioner(split)
partition = partitioner.load_partition(partition_id)
if self._partition_division is None:
return partition
partition_division = self._partition_division.get(split)
if partition_division is None:
return partition
divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset(
partition, partition_division
)
return divided_partition
return partitioner.load_partition(partition_id)

def load_split(self, split: str) -> Dataset:
"""Load the full split of the dataset.
Expand Down Expand Up @@ -301,62 +251,3 @@ def _check_if_no_split_keyword_possible(self) -> None:
"Please set the `split` argument. You can only omit the split keyword "
"if there is exactly one partitioner specified."
)

def _initialize_partition_division(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
) -> Optional[
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
]
]:
"""Create the partition division in the full format.

Reduced format (possible if only one partitioner exist):

Union[List[float], Tuple[float, ...], Dict[str, float]

Full format: Dict[str, Reduced format]
Full format represents the split to division mapping.
"""
# Check for simple dict, list, or tuple types directly
if isinstance(partition_division, (list, tuple)) or (
isinstance(partition_division, dict)
and all(isinstance(value, float) for value in partition_division.values())
):
if len(self._partitioners) > 1:
raise ValueError(
f"The specified partition_division {partition_division} does not "
f"provide mapping to split but more than one partitioners is "
f"specified. Please adjust the partition_division specification to "
f"have the split names as the keys."
)
return cast(
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
{list(self._partitioners.keys())[0]: partition_division},
)
if isinstance(partition_division, dict):
return cast(
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
partition_division,
)
if partition_division is None:
return None
raise TypeError("Unsupported type for partition_division")
44 changes: 1 addition & 43 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import unittest
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Union
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -67,48 +67,6 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None:
len(dataset_partition0), len(dataset["train"]) // train_num_partitions
)

@parameterized.expand( # type: ignore
[
((0.2, 0.8), 2, False),
({"train": 0.2, "test": 0.8}, 2, False),
({"train": {"train": 0.2, "test": 0.8}}, 2, True),
# Not full dataset
([0.2, 0.1], 2, False),
({"train": 0.2, "test": 0.1}, 2, False),
(None, None, False),
],
)
def test_divide_partition_integration_size(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
expected_length: Optional[int],
add_test_partitioner: bool,
):
"""Test is the `partition_division` create correct data."""
partitioners: Dict[str, Union[Partitioner, int]] = {"train": 10}
if add_test_partitioner:
partitioners[self.test_split] = 10
dataset_fds = FederatedDataset(
dataset=self.dataset_name,
partitioners=partitioners,
partition_division=partition_division,
)
partition = dataset_fds.load_partition(0, "train")
if partition_division is None:
self.assertEqual(expected_length, None)
else:
self.assertEqual(len(partition), expected_length)

def test_load_split(self) -> None:
"""Test if the load_split works with the correct split name."""
dataset_fds = FederatedDataset(
Expand Down
69 changes: 68 additions & 1 deletion datasets/flwr_datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from typing import Dict, List, Optional, Tuple, Union, cast

from datasets import Dataset, DatasetDict
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import IidPartitioner, Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.resplitter.merge_resplitter import MergeResplitter
Expand Down Expand Up @@ -239,3 +239,70 @@ def _check_division_config_correctness(
) -> None:
_check_division_config_types_correctness(division)
_check_division_config_values_correctness(division)


def concatenate_divisions(
partitioner: Partitioner,
partition_division: Union[List[float], Tuple[float, ...], Dict[str, float]],
division_id: Union[int, str],
) -> Dataset:
"""Create a dataset by concatenation of all partitions in the same division.

The divisions are created based on the `partition_division` and accessed based
on the `division_id`. It can be used to create e.g. centralized dataset from
federated on-edge test sets.

Parameters
----------
partitioner : Partitioner
Partitioner object with assigned dataset.
partition_division : Union[List[float], Tuple[float, ...], Dict[str, float]]
Fractions specifying the division of the partitions of a `partitioner`. You can
think of this as on-edge division of the data into multiple divisions
(e.g. into train and validation). E.g. [0.8, 0.2] or
{"partition_train": 0.8, "partition_test": 0.2}.
division_id : Union[int, str]
The way to access the division (from a List or DatasetDict). If your
`partition_division` is specified as a list, then `division_id` represents an
index to an element in that list. If `partition_division` is passed as a
`Dict`, then `division_id` is a key of such dictionary.

Returns
-------
concatenated_divisions : Dataset
A dataset created as concatenation of the divisions from all partitions.
"""
divisions = []
zero_len_divisions = 0
for partition_id in range(partitioner.num_partitions):
partition = partitioner.load_partition(partition_id)
if isinstance(partition_division, (list, tuple)):
if not isinstance(division_id, int):
raise TypeError(
"The `division_id` needs to be an int in case of "
"`partition_division` specification as List."
)
partition = divide_dataset(partition, partition_division)
division = partition[division_id]
elif isinstance(partition_division, Dict):
partition = divide_dataset(partition, partition_division)
division = partition[division_id]
else:
raise TypeError(
"The type of partition needs to be List of DatasetDict in this "
"context."
)
if len(division) == 0:
zero_len_divisions += 1
divisions.append(division)

if zero_len_divisions == partitioner.num_partitions:
raise ValueError(
"The concatenated dataset is of length 0. Please change the "
"`partition_division` parameter to change this behavior."
)
if zero_len_divisions != 0:
warnings.warn(
f"{zero_len_divisions} division(s) have length zero.", stacklevel=1
)
return concatenate_datasets(divisions)
Loading