diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py index 462a76a2e3f5..6326af0a0ea6 100644 --- a/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py +++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py @@ -51,7 +51,7 @@ class VerticalSizePartitioner(Partitioner): toward the partition sizes. In case fo list[int]: sum(partition_sizes) == len(columns) - len(drop_columns) - len(shared_columns) - len(active_party_columns) - active_party_column : Optional[Union[str, list[str]]] + active_party_columns : Optional[Union[str, list[str]]] Column(s) (typically representing labels) associated with the "active party" (which can be the server). active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int] @@ -93,7 +93,7 @@ class VerticalSizePartitioner(Partitioner): def __init__( self, partition_sizes: Union[list[int], list[float]], - active_party_column: Optional[Union[str, list[str]]] = None, + active_party_columns: Optional[Union[str, list[str]]] = None, active_party_columns_mode: Union[ Literal[ "add_to_first", @@ -112,7 +112,7 @@ def __init__( super().__init__() self._partition_sizes = partition_sizes - self._active_party_columns = self._init_active_party_column(active_party_column) + self._active_party_columns = _init_active_party_columns(active_party_columns) self._active_party_columns_mode = active_party_columns_mode self._drop_columns = drop_columns or [] self._shared_columns = shared_columns or [] @@ -201,8 +201,17 @@ def _validate_parameters_in_init(self) -> None: raise ValueError("partition_sizes must be a list.") if all(isinstance(fraction, float) for fraction in self._partition_sizes): fraction_sum = sum(self._partition_sizes) + # Tolerance 0.01 for the floating point numerical problems + if fraction_sum < 1.01 and fraction_sum > 0.99: + self._partition_sizes = self._partition_sizes[:-1] + [ + 1.0 - self._partition_sizes[-1] + ] + fraction_sum = 1.0 if fraction_sum != 1.0: - raise ValueError("Float ratios in `partition_sizes` must sum to 1.0.") + raise ValueError( + "Float ratios in `partition_sizes` must sum to 1.0. " + f"Instead got: {fraction_sum}." + ) if any( fraction < 0.0 or fraction > 1.0 for fraction in self._partition_sizes ): @@ -276,16 +285,17 @@ def _validate_parameters_while_partitioning( "active_party_columns are not included in the division." ) - def _init_active_party_column( - self, active_party_column: Optional[Union[str, list[str]]] - ) -> list[str]: - if active_party_column is None: - return [] - if isinstance(active_party_column, str): - return [active_party_column] - if isinstance(active_party_column, list): - return active_party_column - raise ValueError("active_party_column must be a string or a list of strings.") + +def _init_active_party_columns( + active_party_column: Optional[Union[str, list[str]]] +) -> list[str]: + if active_party_column is None: + return [] + if isinstance(active_party_column, str): + return [active_party_column] + if isinstance(active_party_column, list): + return active_party_column + raise ValueError("active_party_column must be a string or a list of strings.") def _count_split(columns: list[str], counts: list[int]) -> list[list[str]]: