diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py index bc6b8324ac52..d2c483c2be88 100644 --- a/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py +++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py @@ -181,6 +181,26 @@ def test_sum_of_int_partition_sizes_exceeds_num_columns(self) -> None: with self.assertRaises(ValueError): partitioner.load_partition(0) + def test_sum_of_int_partition_sizes_indirectly_exceeds_num_columns(self) -> None: + """Check ValueError if sum of int sizes > total columns.""" + columns = ["f1", "f2", "f3"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner( + partition_sizes=[1, 1], drop_columns=["f3", "f2"], shuffle=False + ) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + + def test_sum_of_int_partition_sizes_is_smaller_than_num_columns(self) -> None: + """Check ValueError if sum of int sizes < total columns.""" + columns = ["f1", "f2", "f3"] + dataset = _create_dummy_dataset(columns) + partitioner = VerticalSizePartitioner(partition_sizes=[2], shuffle=False) + partitioner.dataset = dataset + with self.assertRaises(ValueError): + partitioner.load_partition(0) + if __name__ == "__main__": unittest.main()