diff --git a/config.yaml b/config.yaml index d110c0ca..da379748 100644 --- a/config.yaml +++ b/config.yaml @@ -74,6 +74,7 @@ cluster_estimation: min_new_points_to_run: 3 max_num_components: 10 random_state: 0 + min_points_per_cluster: 3 communications: timeout: 60.0 # seconds diff --git a/main_2025.py b/main_2025.py index 81383697..642a524c 100644 --- a/main_2025.py +++ b/main_2025.py @@ -154,6 +154,7 @@ def main() -> int: MIN_NEW_POINTS_TO_RUN = config["cluster_estimation"]["min_new_points_to_run"] MAX_NUM_COMPONENTS = config["cluster_estimation"]["max_num_components"] RANDOM_STATE = config["cluster_estimation"]["random_state"] + MIN_POINTS_PER_CLUSTER = config["cluster_estimation"]["min_points_per_cluster"] COMMUNICATIONS_TIMEOUT = config["communications"]["timeout"] COMMUNICATIONS_WORKER_PERIOD = config["communications"]["worker_period"] @@ -354,6 +355,7 @@ def main() -> int: MIN_NEW_POINTS_TO_RUN, MAX_NUM_COMPONENTS, RANDOM_STATE, + MIN_POINTS_PER_CLUSTER, ), input_queues=[geolocation_to_cluster_estimation_queue], output_queues=[cluster_estimation_to_communications_queue], diff --git a/modules/cluster_estimation/cluster_estimation.py b/modules/cluster_estimation/cluster_estimation.py index 10c7bb91..c3d6c082 100644 --- a/modules/cluster_estimation/cluster_estimation.py +++ b/modules/cluster_estimation/cluster_estimation.py @@ -1,6 +1,6 @@ """ Take in bounding box coordinates from Geolocation and use to estimate landing pad locations. -Returns an array of classes, each containing the x coordinate, y coordinate, and spherical +Returns an array of classes, each containing the x coordinate, y coordinate, and spherical covariance of each landing pad estimation. """ @@ -14,6 +14,7 @@ from ..common.modules.logger import logger +# pylint: disable=too-many-instance-attributes class ClusterEstimation: """ Estimate landing pad locations based on landing pad ground detection. Estimation @@ -63,6 +64,7 @@ def create( max_num_components: int, random_state: int, local_logger: logger.Logger, + min_points_per_cluster: int, ) -> "tuple[bool, ClusterEstimation | None]": """ Data requirement conditions for estimation model to run. @@ -83,6 +85,9 @@ def create( local_logger: logger.Logger The local logger to log this object's information. + min_points_per_cluster: int + Minimum number of points that must be assigned to a cluster for it to be considered valid. + RETURNS: The ClusterEstimation object if all conditions pass, otherwise False, None """ if min_activation_threshold < max_num_components: @@ -97,6 +102,9 @@ def create( if random_state < 0: return False, None + if min_points_per_cluster < 1: + return False, None + return True, ClusterEstimation( cls.__create_key, min_activation_threshold, @@ -104,6 +112,7 @@ def create( max_num_components, random_state, local_logger, + min_points_per_cluster, ) def __init__( @@ -114,6 +123,7 @@ def __init__( max_num_components: int, random_state: int, local_logger: logger.Logger, + min_points_per_cluster: int, ) -> None: """ Private constructor, use create() method. @@ -140,6 +150,7 @@ def __init__( self.__min_new_points_to_run = min_new_points_to_run self.__has_ran_once = False self.__logger = local_logger + self.__min_points_per_cluster = min_points_per_cluster def run( self, detections: "list[detection_in_world.DetectionInWorld]", run_override: bool @@ -337,15 +348,16 @@ def __filter_by_points_ownership( # List of each point's cluster index cluster_assignment = self.__vgmm.predict(self.__all_points) # type: ignore - # Find which cluster indices have points - clusters_with_points = np.unique(cluster_assignment) + # Check each cluster has enough points associated to it by index + unique, counts = np.unique(cluster_assignment, return_counts=True) + cluster_counts = dict(zip(unique, counts)) # Remove empty clusters filtered_output: "list[tuple[np.ndarray, float, float]]" = [] # By cluster index # pylint: disable-next=consider-using-enumerate for i in range(len(model_output)): - if i in clusters_with_points: + if cluster_counts.get(i, 0) >= self.__min_points_per_cluster: filtered_output.append(model_output[i]) return filtered_output diff --git a/modules/cluster_estimation/cluster_estimation_worker.py b/modules/cluster_estimation/cluster_estimation_worker.py index 0f378625..48a74bd5 100644 --- a/modules/cluster_estimation/cluster_estimation_worker.py +++ b/modules/cluster_estimation/cluster_estimation_worker.py @@ -20,6 +20,7 @@ def cluster_estimation_worker( input_queue: queue_proxy_wrapper.QueueProxyWrapper, output_queue: queue_proxy_wrapper.QueueProxyWrapper, controller: worker_controller.WorkerController, + min_points_per_cluster: int, ) -> None: """ Estimation worker process. @@ -64,6 +65,7 @@ def cluster_estimation_worker( max_num_components, random_state, local_logger, + min_points_per_cluster, ) if not result: local_logger.error("Worker failed to create class object", True) diff --git a/tests/integration/test_cluster_estimation_worker.py b/tests/integration/test_cluster_estimation_worker.py index de3392d1..595cda3c 100644 --- a/tests/integration/test_cluster_estimation_worker.py +++ b/tests/integration/test_cluster_estimation_worker.py @@ -17,6 +17,7 @@ MIN_NEW_POINTS_TO_RUN = 0 MAX_NUM_COMPONENTS = 3 RANDOM_STATE = 0 +MIN_POINTS_PER_CLUSTER = 3 def check_output_results(output_queue: queue_proxy_wrapper.QueueProxyWrapper) -> None: @@ -49,6 +50,7 @@ def test_cluster_estimation_worker() -> int: MIN_NEW_POINTS_TO_RUN, MAX_NUM_COMPONENTS, RANDOM_STATE, + MIN_POINTS_PER_CLUSTER, input_queue, output_queue, controller, diff --git a/tests/integration/test_communications_to_ground_station.py b/tests/integration/test_communications_to_ground_station.py index e9856d5b..d9132102 100644 --- a/tests/integration/test_communications_to_ground_station.py +++ b/tests/integration/test_communications_to_ground_station.py @@ -1,5 +1,5 @@ """ -Test MAVLink integration test +Test MAVLink integration test """ import multiprocessing as mp diff --git a/tests/unit/test_cluster_detection.py b/tests/unit/test_cluster_detection.py index 155cca06..372c5e28 100644 --- a/tests/unit/test_cluster_detection.py +++ b/tests/unit/test_cluster_detection.py @@ -16,6 +16,7 @@ MAX_NUM_COMPONENTS = 10 RNG_SEED = 0 CENTRE_BOX_SIZE = 500 +MIN_POINTS_PER_CLUSTER = 3 # Test functions use test fixture signature names and access class privates # No enable @@ -37,6 +38,7 @@ def cluster_model() -> cluster_estimation.ClusterEstimation: # type: ignore MAX_NUM_COMPONENTS, RNG_SEED, test_logger, + MIN_POINTS_PER_CLUSTER, ) assert result assert model is not None @@ -489,3 +491,33 @@ def test_position_regular_data( break assert is_match + + +class TestMinimumPointsPerCluster: + """ + Tests that clusters with fewer than the minimum required points are filtered out. + """ + + __STD_DEV_REG = 1 + + def test_outlier_is_filtered(self, cluster_model: cluster_estimation.ClusterEstimation) -> None: + """ + Verify that a single outlier (cluster with only one point) is filtered out, + while a valid cluster with enough points is retained. + """ + # Setup + valid_detections, valid_cluster_positions = generate_cluster_data([100], self.__STD_DEV_REG) + outlier_detections = generate_points_away_from_cluster( + num_points_to_generate=1, + minimum_distance_from_cluster=20, + cluster_positions=valid_cluster_positions, + ) + generated_detections = valid_detections + outlier_detections + + # Run + result, detections_in_world = cluster_model.run(generated_detections, False) + + # Test + assert result + assert detections_in_world is not None + assert len(detections_in_world) == 1