Skip to content
This repository was archived by the owner on Nov 13, 2025. It is now read-only.

Commit f2d992b

Browse files
authored
Cluster estimation filter (#253)
* min_points_per_cluster * config * format * fix * format * update * update * 'format' * 'format' * 'update' * debug * integration * format * main_edit
1 parent 030287d commit f2d992b

File tree

7 files changed

+56
-5
lines changed

7 files changed

+56
-5
lines changed

config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ cluster_estimation:
7676
min_new_points_to_run: 3
7777
max_num_components: 10
7878
random_state: 0
79+
min_points_per_cluster: 3
7980

8081
communications:
8182
timeout: 60.0 # seconds

main_2025.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def main() -> int:
154154
MIN_NEW_POINTS_TO_RUN = config["cluster_estimation"]["min_new_points_to_run"]
155155
MAX_NUM_COMPONENTS = config["cluster_estimation"]["max_num_components"]
156156
RANDOM_STATE = config["cluster_estimation"]["random_state"]
157+
MIN_POINTS_PER_CLUSTER = config["cluster_estimation"]["min_points_per_cluster"]
157158

158159
COMMUNICATIONS_TIMEOUT = config["communications"]["timeout"]
159160
COMMUNICATIONS_WORKER_PERIOD = config["communications"]["worker_period"]
@@ -354,6 +355,7 @@ def main() -> int:
354355
MIN_NEW_POINTS_TO_RUN,
355356
MAX_NUM_COMPONENTS,
356357
RANDOM_STATE,
358+
MIN_POINTS_PER_CLUSTER,
357359
),
358360
input_queues=[geolocation_to_cluster_estimation_queue],
359361
output_queues=[cluster_estimation_to_communications_queue],

modules/cluster_estimation/cluster_estimation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""
22
Take in bounding box coordinates from Geolocation and use to estimate landing pad locations.
3-
Returns an array of classes, each containing the x coordinate, y coordinate, and spherical
3+
Returns an array of classes, each containing the x coordinate, y coordinate, and spherical
44
covariance of each landing pad estimation.
55
"""
66

@@ -14,6 +14,7 @@
1414
from ..common.modules.logger import logger
1515

1616

17+
# pylint: disable=too-many-instance-attributes
1718
class ClusterEstimation:
1819
"""
1920
Estimate landing pad locations based on landing pad ground detection. Estimation
@@ -63,6 +64,7 @@ def create(
6364
max_num_components: int,
6465
random_state: int,
6566
local_logger: logger.Logger,
67+
min_points_per_cluster: int,
6668
) -> "tuple[bool, ClusterEstimation | None]":
6769
"""
6870
Data requirement conditions for estimation model to run.
@@ -83,6 +85,9 @@ def create(
8385
local_logger: logger.Logger
8486
The local logger to log this object's information.
8587
88+
min_points_per_cluster: int
89+
Minimum number of points that must be assigned to a cluster for it to be considered valid.
90+
8691
RETURNS: The ClusterEstimation object if all conditions pass, otherwise False, None
8792
"""
8893
if min_activation_threshold < max_num_components:
@@ -97,13 +102,17 @@ def create(
97102
if random_state < 0:
98103
return False, None
99104

105+
if min_points_per_cluster < 1:
106+
return False, None
107+
100108
return True, ClusterEstimation(
101109
cls.__create_key,
102110
min_activation_threshold,
103111
min_new_points_to_run,
104112
max_num_components,
105113
random_state,
106114
local_logger,
115+
min_points_per_cluster,
107116
)
108117

109118
def __init__(
@@ -114,6 +123,7 @@ def __init__(
114123
max_num_components: int,
115124
random_state: int,
116125
local_logger: logger.Logger,
126+
min_points_per_cluster: int,
117127
) -> None:
118128
"""
119129
Private constructor, use create() method.
@@ -140,6 +150,7 @@ def __init__(
140150
self.__min_new_points_to_run = min_new_points_to_run
141151
self.__has_ran_once = False
142152
self.__logger = local_logger
153+
self.__min_points_per_cluster = min_points_per_cluster
143154

144155
def run(
145156
self, detections: "list[detection_in_world.DetectionInWorld]", run_override: bool
@@ -337,15 +348,16 @@ def __filter_by_points_ownership(
337348
# List of each point's cluster index
338349
cluster_assignment = self.__vgmm.predict(self.__all_points) # type: ignore
339350

340-
# Find which cluster indices have points
341-
clusters_with_points = np.unique(cluster_assignment)
351+
# Check each cluster has enough points associated to it by index
352+
unique, counts = np.unique(cluster_assignment, return_counts=True)
353+
cluster_counts = dict(zip(unique, counts))
342354

343355
# Remove empty clusters
344356
filtered_output: "list[tuple[np.ndarray, float, float]]" = []
345357
# By cluster index
346358
# pylint: disable-next=consider-using-enumerate
347359
for i in range(len(model_output)):
348-
if i in clusters_with_points:
360+
if cluster_counts.get(i, 0) >= self.__min_points_per_cluster:
349361
filtered_output.append(model_output[i])
350362

351363
return filtered_output

modules/cluster_estimation/cluster_estimation_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def cluster_estimation_worker(
2121
input_queue: queue_proxy_wrapper.QueueProxyWrapper,
2222
output_queue: queue_proxy_wrapper.QueueProxyWrapper,
2323
controller: worker_controller.WorkerController,
24+
min_points_per_cluster: int,
2425
) -> None:
2526
"""
2627
Estimation worker process.
@@ -67,6 +68,7 @@ def cluster_estimation_worker(
6768
max_num_components,
6869
random_state,
6970
local_logger,
71+
min_points_per_cluster,
7072
)
7173
if not result:
7274
local_logger.error("Worker failed to create class object", True)

tests/integration/test_cluster_estimation_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
MIN_NEW_POINTS_TO_RUN = 0
1818
MAX_NUM_COMPONENTS = 3
1919
RANDOM_STATE = 0
20+
MIN_POINTS_PER_CLUSTER = 3
2021

2122

2223
def check_output_results(output_queue: queue_proxy_wrapper.QueueProxyWrapper) -> None:
@@ -49,6 +50,7 @@ def test_cluster_estimation_worker() -> int:
4950
MIN_NEW_POINTS_TO_RUN,
5051
MAX_NUM_COMPONENTS,
5152
RANDOM_STATE,
53+
MIN_POINTS_PER_CLUSTER,
5254
input_queue,
5355
output_queue,
5456
controller,

tests/integration/test_communications_to_ground_station.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Test MAVLink integration test
2+
Test MAVLink integration test
33
"""
44

55
import multiprocessing as mp

tests/unit/test_cluster_detection.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
MAX_NUM_COMPONENTS = 10
1717
RNG_SEED = 0
1818
CENTRE_BOX_SIZE = 500
19+
MIN_POINTS_PER_CLUSTER = 3
1920

2021
# Test functions use test fixture signature names and access class privates
2122
# No enable
@@ -37,6 +38,7 @@ def cluster_model() -> cluster_estimation.ClusterEstimation: # type: ignore
3738
MAX_NUM_COMPONENTS,
3839
RNG_SEED,
3940
test_logger,
41+
MIN_POINTS_PER_CLUSTER,
4042
)
4143
assert result
4244
assert model is not None
@@ -489,3 +491,33 @@ def test_position_regular_data(
489491
break
490492

491493
assert is_match
494+
495+
496+
class TestMinimumPointsPerCluster:
497+
"""
498+
Tests that clusters with fewer than the minimum required points are filtered out.
499+
"""
500+
501+
__STD_DEV_REG = 1
502+
503+
def test_outlier_is_filtered(self, cluster_model: cluster_estimation.ClusterEstimation) -> None:
504+
"""
505+
Verify that a single outlier (cluster with only one point) is filtered out,
506+
while a valid cluster with enough points is retained.
507+
"""
508+
# Setup
509+
valid_detections, valid_cluster_positions = generate_cluster_data([100], self.__STD_DEV_REG)
510+
outlier_detections = generate_points_away_from_cluster(
511+
num_points_to_generate=1,
512+
minimum_distance_from_cluster=20,
513+
cluster_positions=valid_cluster_positions,
514+
)
515+
generated_detections = valid_detections + outlier_detections
516+
517+
# Run
518+
result, detections_in_world = cluster_model.run(generated_detections, False)
519+
520+
# Test
521+
assert result
522+
assert detections_in_world is not None
523+
assert len(detections_in_world) == 1

0 commit comments

Comments
 (0)