diff --git a/datasets/flwr_datasets/partitioner/vertical_even_partitioner.py b/datasets/flwr_datasets/partitioner/vertical_even_partitioner.py index 54d70b7c8389..3a382d8e2a2e 100644 --- a/datasets/flwr_datasets/partitioner/vertical_even_partitioner.py +++ b/datasets/flwr_datasets/partitioner/vertical_even_partitioner.py @@ -39,9 +39,10 @@ class VerticalEvenPartitioner(Partitioner): ---------- num_partitions : int Number of partitions to create. - active_party_columns : Optional[list[str]] - Columns associated with the "active party" (which can be the server). - active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int] + active_party_column : Optional[Union[str, list[str]]] + Column(s) (typically representing labels) associated with the + "active party" (which can be the server). + active_party_column_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int] Determines how to assign the active party columns: - "add_to_first": Append active party columns to the first partition. - "add_to_last": Append active party columns to the last partition. @@ -64,8 +65,8 @@ class VerticalEvenPartitioner(Partitioner): -------- >>> partitioner = VerticalEvenPartitioner( ... num_partitions=3, - ... active_party_columns=["income"], - ... active_party_columns_mode="add_to_last", + ... active_party_column=["income"], + ... active_party_column_mode="add_to_last", ... shuffle=True, ... seed=42 ... ) @@ -80,8 +81,8 @@ class VerticalEvenPartitioner(Partitioner): def __init__( self, num_partitions: int, - active_party_columns: Optional[list[str]] = None, - active_party_columns_mode: Union[ + active_party_column: Optional[list[str]] = None, + active_party_column_mode: Union[ Literal[ "add_to_first", "add_to_last", @@ -99,8 +100,8 @@ def __init__( super().__init__() self._num_partitions = num_partitions - self._active_party_columns = active_party_columns or [] - self._active_party_columns_mode = active_party_columns_mode + self._active_party_column = active_party_column or [] + self._active_party_column_mode = active_party_column_mode self._drop_columns = drop_columns or [] self._shared_columns = shared_columns or [] self._shuffle = shuffle @@ -121,20 +122,20 @@ def _determine_partitions_if_needed(self) -> None: all_columns = list(self.dataset.column_names) self._validate_parameters_while_partitioning( - all_columns, self._shared_columns, self._active_party_columns + all_columns, self._shared_columns, self._active_party_column ) columns = [column for column in all_columns if column not in self._drop_columns] columns = [column for column in columns if column not in self._shared_columns] columns = [ - column for column in columns if column not in self._active_party_columns + column for column in columns if column not in self._active_party_column ] if self._shuffle: self._rng.shuffle(columns) partition_columns = _list_split(columns, self._num_partitions) partition_columns = _add_active_party_columns( - self._active_party_columns, - self._active_party_columns_mode, + self._active_party_column, + self._active_party_column_mode, partition_columns, ) @@ -181,7 +182,7 @@ def _validate_parameters_in_init(self) -> None: for parameter_name, parameter_list in [ ("drop_columns", self._drop_columns), ("shared_columns", self._shared_columns), - ("active_party_columns", self._active_party_columns), + ("active_party_column", self._active_party_column), ]: if not all(isinstance(column, str) for column in parameter_list): raise ValueError(f"All entries in {parameter_name} must be strings.") @@ -194,11 +195,11 @@ def _validate_parameters_in_init(self) -> None: "add_to_all", } if not ( - isinstance(self._active_party_columns_mode, int) - or self._active_party_columns_mode in valid_modes + isinstance(self._active_party_column_mode, int) + or self._active_party_column_mode in valid_modes ): raise ValueError( - "active_party_columns_mode must be an int or one of " + "active_party_column_mode must be an int or one of " "'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', " "'add_to_all'." ) @@ -207,14 +208,14 @@ def _validate_parameters_while_partitioning( self, all_columns: list[str], shared_columns: list[str], - active_party_columns: list[str], + active_party_column: list[str], ) -> None: # Shared columns existance check for column in shared_columns: if column not in all_columns: raise ValueError(f"Shared column '{column}' not found in the dataset.") # Active party columns existence check - for column in active_party_columns: + for column in active_party_column: if column not in all_columns: raise ValueError( f"Active party column '{column}' not found in the dataset." diff --git a/datasets/flwr_datasets/partitioner/vertical_even_partitioner_test.py b/datasets/flwr_datasets/partitioner/vertical_even_partitioner_test.py index 8e766617d609..b561fa11ce06 100644 --- a/datasets/flwr_datasets/partitioner/vertical_even_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/vertical_even_partitioner_test.py @@ -41,11 +41,11 @@ def test_init_with_invalid_num_partitions(self) -> None: VerticalEvenPartitioner(num_partitions=0) def test_init_with_invalid_active_party_mode(self) -> None: - """Test initialization with invalid active_party_columns_mode.""" + """Test initialization with invalid active_party_column_mode.""" with self.assertRaises(ValueError): VerticalEvenPartitioner( num_partitions=2, - active_party_columns_mode="invalid_mode", # type: ignore[arg-type] + active_party_column_mode="invalid_mode", # type: ignore[arg-type] ) def test_init_with_non_string_drop_columns(self) -> None: @@ -58,11 +58,11 @@ def test_init_with_non_string_shared_columns(self) -> None: with self.assertRaises(ValueError): VerticalEvenPartitioner(num_partitions=2, shared_columns=["col1", 123]) - def test_init_with_non_string_active_party_columns(self) -> None: - """Test initialization with non-string elements in active_party_columns.""" + def test_init_with_non_string_active_party_column(self) -> None: + """Test initialization with non-string elements in active_party_column.""" with self.assertRaises(ValueError): VerticalEvenPartitioner( - num_partitions=2, active_party_columns=["col1", None] + num_partitions=2, active_party_column=["col1", None] ) def test_partitioning_basic(self) -> None: @@ -120,14 +120,14 @@ def test_partitioning_with_shared_columns(self) -> None: self.assertIn("shared_col", p0.column_names) self.assertIn("shared_col", p1.column_names) - def test_partitioning_with_active_party_columns_add_to_last(self) -> None: + def test_partitioning_with_active_party_column_add_to_last(self) -> None: """Test active party columns are appended to the last partition.""" columns = ["f1", "f2", "f3", "f4", "income"] dataset = _create_dummy_dataset(columns, num_rows=50) partitioner = VerticalEvenPartitioner( num_partitions=2, - active_party_columns=["income"], - active_party_columns_mode="add_to_last", + active_party_column=["income"], + active_party_column_mode="add_to_last", shuffle=False, seed=42, ) @@ -140,14 +140,14 @@ def test_partitioning_with_active_party_columns_add_to_last(self) -> None: self.assertNotIn("income", p0.column_names) self.assertIn("income", p1.column_names) - def test_partitioning_with_active_party_columns_create_as_first(self) -> None: + def test_partitioning_with_active_party_column_create_as_first(self) -> None: """Test creating a new partition solely for active party columns.""" columns = ["f1", "f2", "f3", "f4", "income"] dataset = _create_dummy_dataset(columns, num_rows=50) partitioner = VerticalEvenPartitioner( num_partitions=2, - active_party_columns=["income"], - active_party_columns_mode="create_as_first", + active_party_column=["income"], + active_party_column_mode="create_as_first", shuffle=False, ) partitioner.dataset = dataset @@ -166,14 +166,14 @@ def test_partitioning_with_active_party_columns_create_as_first(self) -> None: self.assertIn("f3", p2.column_names) self.assertIn("f4", p2.column_names) - def test_partitioning_with_nonexistent_active_party_columns(self) -> None: + def test_partitioning_with_nonexistent_active_party_column(self) -> None: """Test that a ValueError is raised if active party column does not exist.""" columns = ["f1", "f2", "f3", "f4"] dataset = _create_dummy_dataset(columns, num_rows=50) partitioner = VerticalEvenPartitioner( num_partitions=2, - active_party_columns=["income"], # Not present in dataset - active_party_columns_mode="add_to_last", + active_party_column=["income"], # Not present in dataset + active_party_column_mode="add_to_last", shuffle=False, ) partitioner.dataset = dataset