Skip to content

Commit

Permalink
Add concatenate_divisions
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Mar 11, 2024
1 parent ebc486b commit d651c10
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 1 deletion.
70 changes: 69 additions & 1 deletion datasets/flwr_datasets/federated_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
"""FederatedDataset."""


import warnings
from typing import Dict, List, Optional, Tuple, Union, cast

import datasets
from datasets import Dataset, DatasetDict
from datasets import Dataset, DatasetDict, concatenate_datasets
from flwr_datasets.partitioner import Partitioner
from flwr_datasets.resplitter import Resplitter
from flwr_datasets.utils import (
Expand Down Expand Up @@ -125,6 +126,7 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
self._partition_type = None
self._partition_division = self._initialize_partition_division(
partition_division
)
Expand Down Expand Up @@ -213,6 +215,72 @@ def load_full(self, split: str) -> Dataset:
self._check_if_split_present(split)
return self._dataset[split]

def concatenate_divisions(
self, division_id: Union[int, str], split: Optional[str] = None
) -> Dataset:
"""Concatenate the divisions of the partitions.
The divisions are created based on the `partition_division` and accessed based
on the `division_id`. If you specified the divisions for more than 1 partitioner
you need to specify the split associated with a partitioner that creates the
divisions. It can be used to create e.g. centralized dataset from federated
on-edge test sets.
sets
Parameters
----------
division_id: Union[int, str]
The way to access the division (from a List or DatasetDict).
split: Optional[str]
Split associated with a partitioner that creates the division. It needs to
be specified if `partition_division` specifies divisions for more than
one partitioner.
Returns
-------
concatenated_divisions: Dataset
A dataset created as concatenation of the divisions from all partitions.
"""
if self._partition_division is None:
raise ValueError(
"The division concatenation is possible only if the partition_division "
"is specified. To access all of the undivided partitions use "
"load_split."
)
if split is None:
self._check_if_no_split_keyword_possible()
split = list(self._partitioners.keys())[0]
divisions = []
zero_len_divisions = 0
for partition_id in range(self._partitioners[split].num_partitions):
partition = self.load_partition(partition_id, split)
if isinstance(partition, List):
if not isinstance(division_id, int):
raise TypeError(
"The division_id needs to be an int in case of "
"partition_division specification as List."
)
division = partition[division_id]
elif isinstance(partition, DatasetDict):
division = partition[division_id]
else:
raise TypeError(
"The type of partition needs to be List of DatasetDict in this "
"context."
)
if len(division) == 0:
zero_len_divisions += 1
divisions.append(division)

if zero_len_divisions == self._partitioners[split].num_partitions:
raise ValueError(
"The concatenated dataset is of length 0. Please change the "
"partition_division parameter to change this behavior."
)
if zero_len_divisions != 0:
warnings.warn(f"{zero_len_divisions} division(s) have length zero.")
return concatenate_datasets(divisions)

def _check_if_split_present(self, split: str) -> None:
"""Check if the split (for partitioning or full return) is in the dataset."""
if self._dataset is None:
Expand Down
74 changes: 74 additions & 0 deletions datasets/flwr_datasets/federated_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,59 @@ def resplit(dataset: DatasetDict) -> DatasetDict:
dataset_length = sum([len(ds) for ds in dataset.values()])
self.assertEqual(len(full), dataset_length)

@parameterized.expand( # type: ignore
[
(
[0.8, 0.2],
1,
),
({"train": 0.8, "test": 0.2}, "test"),
]
)
def test_concatenate_divisions(
self,
partition_division: Optional[
Union[
List[float],
Tuple[float, ...],
Dict[str, float],
Dict[
str,
Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]],
],
]
],
division_id: Union[int, str],
) -> None:
"""Test if the length of the divisions match the concatenated dataset."""
dataset_fds = FederatedDataset(
dataset=self.dataset_name,
partitioners={"train": 10},
partition_division=partition_division,
)
centralized_from_federated_test = dataset_fds.concatenate_divisions(division_id)

lengths = []
for partition_id in range(dataset_fds._partitioners["train"].num_partitions):
partition = dataset_fds.load_partition(partition_id, "train")
if isinstance(partition, List):
if not isinstance(division_id, int):
raise TypeError(
"The division_id needs to be an int in case of "
"partition_division specification as List."
)
division = partition[division_id]
elif isinstance(partition, DatasetDict):
division = partition[division_id]
else:
raise TypeError(
"The type of partition needs to be List of DatasetDict in this "
"context."
)
lengths.append(len(division))

self.assertEqual(len(centralized_from_federated_test), sum(lengths))


class ArtificialDatasetTest(unittest.TestCase):
"""Test using small artificial dataset, mocked load_dataset."""
Expand Down Expand Up @@ -393,6 +446,27 @@ def test_cannot_use_the_old_split_names(self) -> None:
with self.assertRaises(ValueError):
fds.load_partition(0, "train")

def test_concatenate_division_without_partition_division_param(self) -> None:
"""Test raises when no specification of partition_division and using concat."""
dataset_fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 10},
)
division_id = 1
with self.assertRaises(ValueError):
_ = dataset_fds.concatenate_divisions(division_id)

def test_all_divisions_to_concat_size_zero(self) -> None:
"""Test raises when all divisions for concatenations are zero."""
dataset_fds = FederatedDataset(
dataset="mnist",
partitioners={"train": 10},
partition_division=[0.8, 0.0],
)
division_id = 1
with self.assertRaises(ValueError):
_ = dataset_fds.concatenate_divisions(division_id)


def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
Expand Down

0 comments on commit d651c10

Please sign in to comment.