11"""
22Take in bounding box coordinates from Geolocation and use to estimate landing pad locations.
3- Returns an array of classes, each containing the x coordinate, y coordinate, and spherical
3+ Returns an array of classes, each containing the x coordinate, y coordinate, and spherical
44covariance of each landing pad estimation.
55"""
66
1414from ..common .modules .logger import logger
1515
1616
17+ # pylint: disable=too-many-instance-attributes
1718class ClusterEstimation :
1819 """
1920 Estimate landing pad locations based on landing pad ground detection. Estimation
@@ -63,6 +64,7 @@ def create(
6364 max_num_components : int ,
6465 random_state : int ,
6566 local_logger : logger .Logger ,
67+ min_points_per_cluster : int ,
6668 ) -> "tuple[bool, ClusterEstimation | None]" :
6769 """
6870 Data requirement conditions for estimation model to run.
@@ -83,6 +85,9 @@ def create(
8385 local_logger: logger.Logger
8486 The local logger to log this object's information.
8587
88+ min_points_per_cluster: int
89+ Minimum number of points that must be assigned to a cluster for it to be considered valid.
90+
8691 RETURNS: The ClusterEstimation object if all conditions pass, otherwise False, None
8792 """
8893 if min_activation_threshold < max_num_components :
@@ -97,13 +102,17 @@ def create(
97102 if random_state < 0 :
98103 return False , None
99104
105+ if min_points_per_cluster < 1 :
106+ return False , None
107+
100108 return True , ClusterEstimation (
101109 cls .__create_key ,
102110 min_activation_threshold ,
103111 min_new_points_to_run ,
104112 max_num_components ,
105113 random_state ,
106114 local_logger ,
115+ min_points_per_cluster ,
107116 )
108117
109118 def __init__ (
@@ -114,6 +123,7 @@ def __init__(
114123 max_num_components : int ,
115124 random_state : int ,
116125 local_logger : logger .Logger ,
126+ min_points_per_cluster : int ,
117127 ) -> None :
118128 """
119129 Private constructor, use create() method.
@@ -140,6 +150,7 @@ def __init__(
140150 self .__min_new_points_to_run = min_new_points_to_run
141151 self .__has_ran_once = False
142152 self .__logger = local_logger
153+ self .__min_points_per_cluster = min_points_per_cluster
143154
144155 def run (
145156 self , detections : "list[detection_in_world.DetectionInWorld]" , run_override : bool
@@ -337,15 +348,16 @@ def __filter_by_points_ownership(
337348 # List of each point's cluster index
338349 cluster_assignment = self .__vgmm .predict (self .__all_points ) # type: ignore
339350
340- # Find which cluster indices have points
341- clusters_with_points = np .unique (cluster_assignment )
351+ # Check each cluster has enough points associated to it by index
352+ unique , counts = np .unique (cluster_assignment , return_counts = True )
353+ cluster_counts = dict (zip (unique , counts ))
342354
343355 # Remove empty clusters
344356 filtered_output : "list[tuple[np.ndarray, float, float]]" = []
345357 # By cluster index
346358 # pylint: disable-next=consider-using-enumerate
347359 for i in range (len (model_output )):
348- if i in clusters_with_points :
360+ if cluster_counts . get ( i , 0 ) >= self . __min_points_per_cluster :
349361 filtered_output .append (model_output [i ])
350362
351363 return filtered_output
0 commit comments