Skip to content

Commit

Permalink
Make naming consistent in VerticalSizePartitioner
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Dec 19, 2024
1 parent 4a821c0 commit fa0aa34
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions datasets/flwr_datasets/partitioner/vertical_size_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class VerticalSizePartitioner(Partitioner):
toward the partition sizes.
In case fo list[int]: sum(partition_sizes) == len(columns) - len(drop_columns) -
len(shared_columns) - len(active_party_columns)
active_party_column : Optional[Union[str, list[str]]]
active_party_columns : Optional[Union[str, list[str]]]
Column(s) (typically representing labels) associated with the
"active party" (which can be the server).
active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
Expand Down Expand Up @@ -93,7 +93,7 @@ class VerticalSizePartitioner(Partitioner):
def __init__(
self,
partition_sizes: Union[list[int], list[float]],
active_party_column: Optional[Union[str, list[str]]] = None,
active_party_columns: Optional[Union[str, list[str]]] = None,
active_party_columns_mode: Union[
Literal[
"add_to_first",
Expand All @@ -112,7 +112,7 @@ def __init__(
super().__init__()

self._partition_sizes = partition_sizes
self._active_party_columns = self._init_active_party_column(active_party_column)
self._active_party_columns = _init_active_party_columns(active_party_columns)
self._active_party_columns_mode = active_party_columns_mode
self._drop_columns = drop_columns or []
self._shared_columns = shared_columns or []
Expand Down Expand Up @@ -201,8 +201,17 @@ def _validate_parameters_in_init(self) -> None:
raise ValueError("partition_sizes must be a list.")
if all(isinstance(fraction, float) for fraction in self._partition_sizes):
fraction_sum = sum(self._partition_sizes)
# Tolerance 0.01 for the floating point numerical problems
if fraction_sum < 1.01 and fraction_sum > 0.99:
self._partition_sizes = self._partition_sizes[:-1] + [
1.0 - self._partition_sizes[-1]
]
fraction_sum = 1.0
if fraction_sum != 1.0:
raise ValueError("Float ratios in `partition_sizes` must sum to 1.0.")
raise ValueError(
"Float ratios in `partition_sizes` must sum to 1.0. "
f"Instead got: {fraction_sum}."
)
if any(
fraction < 0.0 or fraction > 1.0 for fraction in self._partition_sizes
):
Expand Down Expand Up @@ -276,16 +285,17 @@ def _validate_parameters_while_partitioning(
"active_party_columns are not included in the division."
)

def _init_active_party_column(
self, active_party_column: Optional[Union[str, list[str]]]
) -> list[str]:
if active_party_column is None:
return []
if isinstance(active_party_column, str):
return [active_party_column]
if isinstance(active_party_column, list):
return active_party_column
raise ValueError("active_party_column must be a string or a list of strings.")

def _init_active_party_columns(
active_party_column: Optional[Union[str, list[str]]]
) -> list[str]:
if active_party_column is None:
return []
if isinstance(active_party_column, str):
return [active_party_column]
if isinstance(active_party_column, list):
return active_party_column
raise ValueError("active_party_column must be a string or a list of strings.")


def _count_split(columns: list[str], counts: list[int]) -> list[list[str]]:
Expand Down

0 comments on commit fa0aa34

Please sign in to comment.