-
Notifications
You must be signed in to change notification settings - Fork 942
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b909a55
commit 6ac5a7b
Showing
5 changed files
with
675 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
226 changes: 226 additions & 0 deletions
226
datasets/flwr_datasets/partitioner/vertical_even_partitioner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
# Copyright 2024 Flower Labs GmbH. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""VerticalEvenPartitioner class.""" | ||
# flake8: noqa: E501 | ||
from typing import Literal, Optional, Union | ||
|
||
import numpy as np | ||
|
||
import datasets | ||
from flwr_datasets.partitioner.partitioner import Partitioner | ||
from flwr_datasets.partitioner.vertical_partitioner_utils import ( | ||
_add_active_party_columns, | ||
_list_split, | ||
) | ||
|
||
|
||
class VerticalEvenPartitioner(Partitioner): | ||
"""Partitioner that splits features (columns) evenly into vertical partitions. | ||
Enables selection of "active party" column(s) and palcement into | ||
a specific partition or creation of a new partition just for it. | ||
Also enables droping columns and sharing specified columns across | ||
all partitions. | ||
The number and nature of partitions can be defined in various ways: | ||
- By specifying a simple integer for even splitting. | ||
- By providing ratios or absolute counts for each partition. | ||
- By explicitly listing the columns for each partition. | ||
(see `column_distribution` and `mode` parameters for more details) | ||
Parameters | ||
---------- | ||
num_partitions : int | ||
Number of partitions to create. | ||
active_party_columns : Optional[list[str]] | ||
Columns associated with the "active party" (which can be the server). | ||
active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int] | ||
Determines how to assign the active party columns: | ||
- "add_to_first": Append active party columns to the first partition. | ||
- "add_to_last": Append active party columns to the last partition. | ||
- int: Append active party columns to the specified partition index. | ||
- "create_as_first": Create a new partition at the start containing only | ||
these columns. | ||
- "create_as_last": Create a new partition at the end containing only | ||
these columns. | ||
- "add_to_all": Append active party columns to all partitions. | ||
drop_columns : Optional[list[str]] | ||
Columns to remove entirely from the dataset before partitioning. | ||
shared_columns : Optional[list[str]] | ||
Columns to duplicate into every partition after initial partitioning. | ||
shuffle : bool | ||
Whether to shuffle the order of columns before partitioning. | ||
seed : Optional[int] | ||
Random seed for shuffling columns. Has no effect if `shuffle=False`. | ||
Examples | ||
-------- | ||
>>> partitioner = VerticalEvenPartitioner( | ||
... num_partitions=3, | ||
... active_party_columns=["income"], | ||
... active_party_columns_mode="add_to_last", | ||
... shuffle=True, | ||
... seed=42 | ||
... ) | ||
>>> fds = FederatedDataset( | ||
... dataset="scikit-learn/adult-census-income", | ||
... partitioners={"train": partitioner} | ||
... ) | ||
>>> partitions = [fds.load_partition(i) for i in range(partitioner.num_partitions)] | ||
>>> print([partition.column_names for partition in partitions]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_partitions: int, | ||
active_party_columns: Optional[list[str]] = None, | ||
active_party_columns_mode: Union[ | ||
Literal[ | ||
"add_to_first", | ||
"add_to_last", | ||
"create_as_first", | ||
"create_as_last", | ||
"add_to_all", | ||
], | ||
int, | ||
] = "add_to_last", | ||
drop_columns: Optional[list[str]] = None, | ||
shared_columns: Optional[list[str]] = None, | ||
shuffle: bool = True, | ||
seed: Optional[int] = 42, | ||
) -> None: | ||
super().__init__() | ||
|
||
self._num_partitions = num_partitions | ||
self._active_party_columns = active_party_columns or [] | ||
self._active_party_columns_mode = active_party_columns_mode | ||
self._drop_columns = drop_columns or [] | ||
self._shared_columns = shared_columns or [] | ||
self._shuffle = shuffle | ||
self._seed = seed | ||
self._rng = np.random.default_rng(seed=self._seed) | ||
|
||
self._partition_columns: Optional[list[list[str]]] = None | ||
self._partitions_determined = False | ||
|
||
self._validate_parameters_in_init() | ||
|
||
def _determine_partitions_if_needed(self) -> None: | ||
if self._partitions_determined: | ||
return | ||
|
||
if self.dataset is None: | ||
raise ValueError("No dataset is set for this partitioner.") | ||
|
||
all_columns = list(self.dataset.column_names) | ||
self._validate_parameters_while_partitioning( | ||
all_columns, self._shared_columns, self._active_party_columns | ||
) | ||
columns = [column for column in all_columns if column not in self._drop_columns] | ||
columns = [column for column in columns if column not in self._shared_columns] | ||
columns = [ | ||
column for column in columns if column not in self._active_party_columns | ||
] | ||
|
||
if self._shuffle: | ||
self._rng.shuffle(columns) | ||
partition_columns = _list_split(columns, self._num_partitions) | ||
partition_columns = _add_active_party_columns( | ||
self._active_party_columns, | ||
self._active_party_columns_mode, | ||
partition_columns, | ||
) | ||
|
||
# Add shared columns to all partitions | ||
for partition in partition_columns: | ||
for column in self._shared_columns: | ||
partition.append(column) | ||
|
||
self._partition_columns = partition_columns | ||
self._partitions_determined = True | ||
|
||
def load_partition(self, partition_id: int) -> datasets.Dataset: | ||
"""Load a partition based on the partition index. | ||
Parameters | ||
---------- | ||
partition_id : int | ||
The index that corresponds to the requested partition. | ||
Returns | ||
------- | ||
dataset_partition : Dataset | ||
Single partition of a dataset. | ||
""" | ||
self._determine_partitions_if_needed() | ||
assert self._partition_columns is not None | ||
if partition_id < 0 or partition_id >= len(self._partition_columns): | ||
raise ValueError(f"Invalid partition_id {partition_id}.") | ||
columns = self._partition_columns[partition_id] | ||
return self.dataset.select_columns(columns) | ||
|
||
@property | ||
def num_partitions(self) -> int: | ||
"""Number of partitions.""" | ||
self._determine_partitions_if_needed() | ||
assert self._partition_columns is not None | ||
return len(self._partition_columns) | ||
|
||
def _validate_parameters_in_init(self) -> None: | ||
if self._num_partitions < 1: | ||
raise ValueError("column_distribution as int must be >= 1.") | ||
|
||
# Validate columns lists | ||
for parameter_name, parameter_list in [ | ||
("drop_columns", self._drop_columns), | ||
("shared_columns", self._shared_columns), | ||
("active_party_columns", self._active_party_columns), | ||
]: | ||
if not all(isinstance(column, str) for column in parameter_list): | ||
raise ValueError(f"All entries in {parameter_name} must be strings.") | ||
|
||
valid_modes = { | ||
"add_to_first", | ||
"add_to_last", | ||
"create_as_first", | ||
"create_as_last", | ||
"add_to_all", | ||
} | ||
if not ( | ||
isinstance(self._active_party_columns_mode, int) | ||
or self._active_party_columns_mode in valid_modes | ||
): | ||
raise ValueError( | ||
"active_party_columns_mode must be an int or one of " | ||
"'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', " | ||
"'add_to_all'." | ||
) | ||
|
||
def _validate_parameters_while_partitioning( | ||
self, | ||
all_columns: list[str], | ||
shared_columns: list[str], | ||
active_party_columns: list[str], | ||
) -> None: | ||
# Shared columns existance check | ||
for column in shared_columns: | ||
if column not in all_columns: | ||
raise ValueError(f"Shared column '{column}' not found in the dataset.") | ||
# Active party columns existence check | ||
for column in active_party_columns: | ||
if column not in all_columns: | ||
raise ValueError( | ||
f"Active party column '{column}' not found in the dataset." | ||
) |
Oops, something went wrong.