Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Jul 23, 2024
1 parent 145dfe6 commit 5ee5a13
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions datasets/flwr_datasets/partitioner/distribution_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing import Dict, List, Optional, Union

import numpy as np

import datasets
from flwr_datasets.common.typing import NDArray, NDArrayFloat, NDArrayInt
from flwr_datasets.partitioner.partitioner import Partitioner

import datasets


class DistributionPartitioner(Partitioner): # pylint: disable=R0902
"""Partitioner based on a distribution.
Expand Down Expand Up @@ -204,7 +204,6 @@ def load_partition(self, partition_id: int) -> datasets.Dataset:
self._check_num_unique_labels_per_partition_if_needed()
self._check_distribution_array_sum_if_needed()
self._check_num_partitions_correctness_if_needed()
self._check_num_partitions_greater_than_zero()
self._determine_partition_id_to_indices_if_needed()
return self.dataset.select(self._partition_id_to_indices[partition_id])

Expand Down Expand Up @@ -371,7 +370,8 @@ def _check_num_unique_labels_per_partition_if_needed(self) -> None:
raise ValueError(
"The specified `num_unique_labels_per_partition`"
f"={self._num_unique_labels_per_partition} is greater than the number "
f"of unique classes in the given dataset={self._num_unique_labels}. "
f"of unique labels in the given dataset={self._num_unique_labels} "
f"as specified by the label column `{self._partition_by}`."
"Reduce the `num_unique_labels_per_partition` or make use of a "
"different dataset to apply this partitioning."
)
Expand All @@ -380,17 +380,18 @@ def _check_distribution_array_sum_if_needed(self) -> None:
"""Test correctness of distribution array sum."""
if not self._partition_id_to_indices_determined and not self._rescale:
labels = self.dataset[self._partition_by]
distribution = sorted(Counter(labels).items())
distribution_vals = [v for _, v in distribution]
unique_labels_counter = sorted(Counter(labels).items())
unique_labels_counter_vals = [v for _, v in unique_labels_counter]

if any(self._distribution_array.sum(1) > distribution_vals):
if any(self._distribution_array.sum(1) > unique_labels_counter_vals):
raise ValueError(
"The sum of at least one label distribution array "
"exceeds the original class label distribution."
"The sum of at least one unique label distribution array "
"exceeds that of the unique labels counter in the given dataset= "
f"{dict(unique_labels_counter)}."
)

def _check_num_partitions_correctness_if_needed(self) -> None:
"""Test num_partitions when the dataset is given (in load_partition)."""
"""Test num_partitions when the dataset is given."""
if not self._partition_id_to_indices_determined:
if self._num_partitions > self.dataset.num_rows:
raise ValueError(
Expand All @@ -404,11 +405,10 @@ def _check_num_partitions_correctness_if_needed(self) -> None:
f"divisible by the number of unique labels "
f"{({self._num_unique_labels})}."
)

def _check_num_partitions_greater_than_zero(self) -> None:
"""Test num_partition left sides correctness."""
if not self._num_partitions > 0:
raise ValueError("The number of partitions needs to be greater than zero.")
if not self._num_partitions > 0:
raise ValueError(
"The number of partitions needs to be greater than zero."
)

def _check_total_preassigned_samples_within_limit(
self, label_distribution: NDArray, total_preassigned_samples: int
Expand Down

0 comments on commit 5ee5a13

Please sign in to comment.