-
Notifications
You must be signed in to change notification settings - Fork 942
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0269569
commit fd0cc0a
Showing
2 changed files
with
253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""SizePartitioner class.""" | ||
|
||
|
||
import warnings | ||
from typing import Dict, List, Sequence | ||
|
||
import datasets | ||
from flwr_datasets.partitioner.partitioner import Partitioner | ||
|
||
|
||
class SizePartitioner(Partitioner): | ||
"""Partitioner that creates each partition with the size specified by a user. | ||
Parameters | ||
---------- | ||
partition_sizes : Sequence[int] | ||
The size of each partition. partition_id 0 will have partition_sizes[0] | ||
samples, partition_id 1 will have partition_sizes[1] samples, etc. | ||
Examples | ||
-------- | ||
>>> from flwr_datasets import FederatedDataset | ||
>>> from flwr_datasets.partitioner import SizePartitioner | ||
>>> | ||
>>> partition_sizes = [20_000, 10_000 30_000] | ||
>>> partitioner = SizePartitioner(partition_sizes) | ||
>>> fds = FederatedDataset(dataset="cifar10", partitioners={"train": partitioner}) | ||
""" | ||
|
||
def __init__(self, partition_sizes: Sequence[int]) -> None: | ||
super().__init__() | ||
self._pre_ds_validate_partition_sizes(partition_sizes) | ||
self._partition_sizes = partition_sizes | ||
self._partition_id_to_indices: Dict[int, List[int]] = {} | ||
self._partition_id_to_indices_determined = False | ||
|
||
def load_partition(self, partition_id: int) -> datasets.Dataset: | ||
"""Load a single partition of the size of partition_sizes[partition_id]. | ||
For example if given partition_sizes=[20_000, 10_000, 30_000], | ||
then partition_id=0 will return a partition of size 20_000, | ||
partition_id=1 will return a partition of size 10_000, etc. | ||
Parameters | ||
---------- | ||
partition_id : int | ||
The index that corresponds to the requested partition. | ||
Returns | ||
------- | ||
dataset_partition : Dataset | ||
Single dataset partition. | ||
""" | ||
self._determine_partition_id_to_indices_if_needed() | ||
return self.dataset.select(self._partition_id_to_indices[partition_id]) | ||
|
||
@property | ||
def num_partitions(self) -> int: | ||
"""Total number of partitions.""" | ||
self._determine_partition_id_to_indices_if_needed() | ||
return len(self._partition_sizes) | ||
|
||
@property | ||
def partition_id_to_indices(self) -> Dict[int, List[int]]: | ||
"""Partition id to indices (the result of partitioning).""" | ||
self._determine_partition_id_to_indices_if_needed() | ||
return self._partition_id_to_indices | ||
|
||
def _determine_partition_id_to_indices_if_needed( | ||
self, | ||
) -> None: | ||
"""Create an assignment of indices to the partition indices.""" | ||
if self._partition_id_to_indices_determined: | ||
return | ||
self._post_ds_validate_partition_sizes() | ||
start = 0 | ||
end = 0 | ||
for partition_id, partition_size in enumerate(self._partition_sizes): | ||
end += partition_size | ||
indices = list(range(start, end)) | ||
self._partition_id_to_indices[partition_id] = indices | ||
start = end | ||
self._partition_id_to_indices_determined = True | ||
|
||
def _pre_ds_validate_partition_sizes(self, partition_sizes: Sequence[int]) -> None: | ||
"""Check if the partition sizes are valid (no information about the dataset).""" | ||
if not isinstance(partition_sizes, Sequence): | ||
raise ValueError("Partition sizes must be a sequence.") | ||
if len(partition_sizes) == 0: | ||
raise ValueError("Partition sizes must not be empty.") | ||
if not all( | ||
isinstance(partition_size, int) for partition_size in partition_sizes | ||
): | ||
raise ValueError("All partition sizes must be integers.") | ||
if not all(partition_size > 0 for partition_size in partition_sizes): | ||
raise ValueError("All partition sizes must be greater than zero.") | ||
|
||
def _post_ds_validate_partition_sizes(self) -> None: | ||
"""Validate the partition sizes against the dataset size.""" | ||
desired_partition_sizes = sum(self._partition_sizes) | ||
dataset_size = len(self.dataset) | ||
if desired_partition_sizes > dataset_size: | ||
raise ValueError( | ||
f"The sum of partition sizes sum({self._partition_sizes})" | ||
f"= {desired_partition_sizes} is greater than the size of" | ||
f" the dataset {dataset_size}." | ||
) | ||
if desired_partition_sizes < dataset_size: | ||
warnings.warn( | ||
f"The sum of partition sizes is {desired_partition_sizes}, which is" | ||
f"smaller than the size of the dataset: {dataset_size}. " | ||
f"Ignore this warning if it is the desired behavior.", | ||
stacklevel=1, | ||
) |
125 changes: 125 additions & 0 deletions
125
datasets/flwr_datasets/partitioner/size_partitioner_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Test the SizePartitioner class.""" | ||
|
||
# pylint: disable=W0212 | ||
import unittest | ||
from typing import Sequence | ||
|
||
from parameterized import parameterized | ||
|
||
from datasets import Dataset | ||
from flwr_datasets.partitioner.size_partitioner import SizePartitioner | ||
|
||
|
||
def _dummy_setup_size(partition_sizes: Sequence[int], num_rows: int) -> SizePartitioner: | ||
"""Create a dummy dataset and SizePartitioner for testing.""" | ||
data = { | ||
"features": list(range(num_rows)), | ||
} | ||
dataset = Dataset.from_dict(data) | ||
partitioner = SizePartitioner(partition_sizes=partition_sizes) | ||
partitioner.dataset = dataset | ||
return partitioner | ||
|
||
|
||
tested_valid_intits = [ | ||
((10, 20, 30), 60), | ||
# Non growing order | ||
((20, 40, 10), 70), | ||
# Different lengths | ||
((10, 10), 20), | ||
# Single partition | ||
((10,), 10), | ||
] | ||
|
||
|
||
class TestSizePartitionerSuccess(unittest.TestCase): | ||
"""Test SizePartitioner used with no exceptions.""" | ||
|
||
@parameterized.expand(tested_valid_intits) # type: ignore | ||
def test_valid_initialization( | ||
self, partition_sizes: Sequence[int], dataset_size: int | ||
) -> None: | ||
"""Test that the SizePartitioner initializes correctly with valid sizes.""" | ||
partitioner = _dummy_setup_size(partition_sizes, dataset_size) | ||
self.assertEqual(partitioner.num_partitions, len(partition_sizes)) | ||
|
||
@parameterized.expand(tested_valid_intits) # type: ignore | ||
def test_partition_size_assignment( | ||
self, partition_sizes: Sequence[int], dataset_size: int | ||
) -> None: | ||
"""Test that partitions are assigned the correct size.""" | ||
partitioner = _dummy_setup_size(partition_sizes, dataset_size) | ||
partitioner._determine_partition_id_to_indices_if_needed() | ||
self.assertEqual( | ||
{ | ||
pid: len(indices) | ||
for pid, indices in partitioner.partition_id_to_indices.items() | ||
}, | ||
dict(enumerate(partition_sizes)), | ||
) | ||
|
||
def test_correct_partition_loading(self) -> None: | ||
"""Test that partitions are loaded correctly.""" | ||
partition_sizes = [10, 20, 30] | ||
partitioner = _dummy_setup_size(partition_sizes, 60) | ||
partition = partitioner.load_partition(1) | ||
self.assertEqual(len(partition), 20) | ||
|
||
def test_warning_for_smaller_partition_sizes(self) -> None: | ||
"""Test a warning is raised if sum of partition sizes < len(ds).""" | ||
partition_sizes = [10, 5, 20] | ||
partitioner = _dummy_setup_size(partition_sizes, 50) | ||
with self.assertWarns(Warning): | ||
partitioner._determine_partition_id_to_indices_if_needed() | ||
|
||
def test_no_exception_for_exact_size(self) -> None: | ||
"""Test no exception is raised when len(ds) == sum(patition_sizes).""" | ||
partition_sizes = [10, 20, 30] | ||
partitioner = _dummy_setup_size(partition_sizes, 60) | ||
partitioner._determine_partition_id_to_indices_if_needed() | ||
|
||
|
||
class TestSizePartitionerFailure(unittest.TestCase): | ||
"""Test SizePartitioner failures (exceptions) by incorrect usage.""" | ||
|
||
def test_invalid_partition_size(self) -> None: | ||
"""Test if raises ValueError when partition sizes are non-positive.""" | ||
with self.assertRaises(ValueError): | ||
SizePartitioner(partition_sizes=[-1, 10, 20]) | ||
|
||
def test_invalid_partition_type(self) -> None: | ||
"""Test if raises ValueError when partition sizes are non-positive.""" | ||
with self.assertRaises(ValueError): | ||
SizePartitioner(partition_sizes=[0.2, 0.3]) # type: ignore[list-item] | ||
|
||
def test_partition_size_exceeds_dataset(self) -> None: | ||
"""Test if raises ValueError when partition sizes exceed dataset size.""" | ||
partition_sizes = [10, 20, 30] | ||
partitioner = _dummy_setup_size(partition_sizes, 40) | ||
with self.assertRaises(ValueError): | ||
partitioner._determine_partition_id_to_indices_if_needed() | ||
|
||
def test_load_invalid_partition_index(self) -> None: | ||
"""Test if raises KeyError when an invalid partition index is loaded.""" | ||
partition_sizes = [10, 20, 30] | ||
partitioner = _dummy_setup_size(partition_sizes, 60) | ||
with self.assertRaises(KeyError): | ||
partitioner.load_partition(3) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |