Skip to content

Commit

Permalink
Update naming convention for active_party_column
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-narozniak committed Dec 13, 2024
1 parent 31e1020 commit 8e14318
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 33 deletions.
39 changes: 20 additions & 19 deletions datasets/flwr_datasets/partitioner/vertical_even_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ class VerticalEvenPartitioner(Partitioner):
----------
num_partitions : int
Number of partitions to create.
active_party_columns : Optional[list[str]]
Columns associated with the "active party" (which can be the server).
active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
active_party_column : Optional[Union[str, list[str]]]
Column(s) (typically representing labels) associated with the
"active party" (which can be the server).
active_party_column_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
Determines how to assign the active party columns:
- "add_to_first": Append active party columns to the first partition.
- "add_to_last": Append active party columns to the last partition.
Expand All @@ -64,8 +65,8 @@ class VerticalEvenPartitioner(Partitioner):
--------
>>> partitioner = VerticalEvenPartitioner(
... num_partitions=3,
... active_party_columns=["income"],
... active_party_columns_mode="add_to_last",
... active_party_column=["income"],
... active_party_column_mode="add_to_last",
... shuffle=True,
... seed=42
... )
Expand All @@ -80,8 +81,8 @@ class VerticalEvenPartitioner(Partitioner):
def __init__(
self,
num_partitions: int,
active_party_columns: Optional[list[str]] = None,
active_party_columns_mode: Union[
active_party_column: Optional[list[str]] = None,
active_party_column_mode: Union[
Literal[
"add_to_first",
"add_to_last",
Expand All @@ -99,8 +100,8 @@ def __init__(
super().__init__()

self._num_partitions = num_partitions
self._active_party_columns = active_party_columns or []
self._active_party_columns_mode = active_party_columns_mode
self._active_party_column = active_party_column or []
self._active_party_column_mode = active_party_column_mode
self._drop_columns = drop_columns or []
self._shared_columns = shared_columns or []
self._shuffle = shuffle
Expand All @@ -121,20 +122,20 @@ def _determine_partitions_if_needed(self) -> None:

all_columns = list(self.dataset.column_names)
self._validate_parameters_while_partitioning(
all_columns, self._shared_columns, self._active_party_columns
all_columns, self._shared_columns, self._active_party_column
)
columns = [column for column in all_columns if column not in self._drop_columns]
columns = [column for column in columns if column not in self._shared_columns]
columns = [
column for column in columns if column not in self._active_party_columns
column for column in columns if column not in self._active_party_column
]

if self._shuffle:
self._rng.shuffle(columns)
partition_columns = _list_split(columns, self._num_partitions)
partition_columns = _add_active_party_columns(
self._active_party_columns,
self._active_party_columns_mode,
self._active_party_column,
self._active_party_column_mode,
partition_columns,
)

Expand Down Expand Up @@ -181,7 +182,7 @@ def _validate_parameters_in_init(self) -> None:
for parameter_name, parameter_list in [
("drop_columns", self._drop_columns),
("shared_columns", self._shared_columns),
("active_party_columns", self._active_party_columns),
("active_party_column", self._active_party_column),
]:
if not all(isinstance(column, str) for column in parameter_list):
raise ValueError(f"All entries in {parameter_name} must be strings.")
Expand All @@ -194,11 +195,11 @@ def _validate_parameters_in_init(self) -> None:
"add_to_all",
}
if not (
isinstance(self._active_party_columns_mode, int)
or self._active_party_columns_mode in valid_modes
isinstance(self._active_party_column_mode, int)
or self._active_party_column_mode in valid_modes
):
raise ValueError(
"active_party_columns_mode must be an int or one of "
"active_party_column_mode must be an int or one of "
"'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', "
"'add_to_all'."
)
Expand All @@ -207,14 +208,14 @@ def _validate_parameters_while_partitioning(
self,
all_columns: list[str],
shared_columns: list[str],
active_party_columns: list[str],
active_party_column: list[str],
) -> None:
# Shared columns existance check
for column in shared_columns:
if column not in all_columns:
raise ValueError(f"Shared column '{column}' not found in the dataset.")
# Active party columns existence check
for column in active_party_columns:
for column in active_party_column:
if column not in all_columns:
raise ValueError(
f"Active party column '{column}' not found in the dataset."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_init_with_invalid_num_partitions(self) -> None:
VerticalEvenPartitioner(num_partitions=0)

def test_init_with_invalid_active_party_mode(self) -> None:
"""Test initialization with invalid active_party_columns_mode."""
"""Test initialization with invalid active_party_column_mode."""
with self.assertRaises(ValueError):
VerticalEvenPartitioner(
num_partitions=2,
active_party_columns_mode="invalid_mode", # type: ignore[arg-type]
active_party_column_mode="invalid_mode", # type: ignore[arg-type]
)

def test_init_with_non_string_drop_columns(self) -> None:
Expand All @@ -58,11 +58,11 @@ def test_init_with_non_string_shared_columns(self) -> None:
with self.assertRaises(ValueError):
VerticalEvenPartitioner(num_partitions=2, shared_columns=["col1", 123])

def test_init_with_non_string_active_party_columns(self) -> None:
"""Test initialization with non-string elements in active_party_columns."""
def test_init_with_non_string_active_party_column(self) -> None:
"""Test initialization with non-string elements in active_party_column."""
with self.assertRaises(ValueError):
VerticalEvenPartitioner(
num_partitions=2, active_party_columns=["col1", None]
num_partitions=2, active_party_column=["col1", None]
)

def test_partitioning_basic(self) -> None:
Expand Down Expand Up @@ -120,14 +120,14 @@ def test_partitioning_with_shared_columns(self) -> None:
self.assertIn("shared_col", p0.column_names)
self.assertIn("shared_col", p1.column_names)

def test_partitioning_with_active_party_columns_add_to_last(self) -> None:
def test_partitioning_with_active_party_column_add_to_last(self) -> None:
"""Test active party columns are appended to the last partition."""
columns = ["f1", "f2", "f3", "f4", "income"]
dataset = _create_dummy_dataset(columns, num_rows=50)
partitioner = VerticalEvenPartitioner(
num_partitions=2,
active_party_columns=["income"],
active_party_columns_mode="add_to_last",
active_party_column=["income"],
active_party_column_mode="add_to_last",
shuffle=False,
seed=42,
)
Expand All @@ -140,14 +140,14 @@ def test_partitioning_with_active_party_columns_add_to_last(self) -> None:
self.assertNotIn("income", p0.column_names)
self.assertIn("income", p1.column_names)

def test_partitioning_with_active_party_columns_create_as_first(self) -> None:
def test_partitioning_with_active_party_column_create_as_first(self) -> None:
"""Test creating a new partition solely for active party columns."""
columns = ["f1", "f2", "f3", "f4", "income"]
dataset = _create_dummy_dataset(columns, num_rows=50)
partitioner = VerticalEvenPartitioner(
num_partitions=2,
active_party_columns=["income"],
active_party_columns_mode="create_as_first",
active_party_column=["income"],
active_party_column_mode="create_as_first",
shuffle=False,
)
partitioner.dataset = dataset
Expand All @@ -166,14 +166,14 @@ def test_partitioning_with_active_party_columns_create_as_first(self) -> None:
self.assertIn("f3", p2.column_names)
self.assertIn("f4", p2.column_names)

def test_partitioning_with_nonexistent_active_party_columns(self) -> None:
def test_partitioning_with_nonexistent_active_party_column(self) -> None:
"""Test that a ValueError is raised if active party column does not exist."""
columns = ["f1", "f2", "f3", "f4"]
dataset = _create_dummy_dataset(columns, num_rows=50)
partitioner = VerticalEvenPartitioner(
num_partitions=2,
active_party_columns=["income"], # Not present in dataset
active_party_columns_mode="add_to_last",
active_party_column=["income"], # Not present in dataset
active_party_column_mode="add_to_last",
shuffle=False,
)
partitioner.dataset = dataset
Expand Down

0 comments on commit 8e14318

Please sign in to comment.