diff --git a/config.yaml b/config.yaml index ce511130..2fcedd66 100644 --- a/config.yaml +++ b/config.yaml @@ -53,6 +53,7 @@ geolocation: cluster_estimation: min_activation_threshold: 25 min_new_points_to_run: 5 + max_num_components: 10 random_state: 0 communications: diff --git a/main_2024.py b/main_2024.py index 7ce84c27..f927f220 100644 --- a/main_2024.py +++ b/main_2024.py @@ -137,6 +137,7 @@ def main() -> int: MIN_ACTIVATION_THRESHOLD = config["cluster_estimation"]["min_activation_threshold"] MIN_NEW_POINTS_TO_RUN = config["cluster_estimation"]["min_new_points_to_run"] + MAX_NUM_COMPONENTS = config["cluster_estimation"]["max_num_components"] RANDOM_STATE = config["cluster_estimation"]["random_state"] COMMUNICATIONS_TIMEOUT = config["communications"]["timeout"] @@ -327,7 +328,12 @@ def main() -> int: result, cluster_estimation_worker_properties = worker_manager.WorkerProperties.create( count=1, target=cluster_estimation_worker.cluster_estimation_worker, - work_arguments=(MIN_ACTIVATION_THRESHOLD, MIN_NEW_POINTS_TO_RUN, RANDOM_STATE), + work_arguments=( + MIN_ACTIVATION_THRESHOLD, + MIN_NEW_POINTS_TO_RUN, + MAX_NUM_COMPONENTS, + RANDOM_STATE, + ), input_queues=[geolocation_to_cluster_estimation_queue], output_queues=[cluster_estimation_to_communications_queue], controller=controller, diff --git a/modules/cluster_estimation/cluster_estimation.py b/modules/cluster_estimation/cluster_estimation.py index ae7ff0b3..ba5e80fb 100644 --- a/modules/cluster_estimation/cluster_estimation.py +++ b/modules/cluster_estimation/cluster_estimation.py @@ -28,6 +28,9 @@ class ClusterEstimation: min_new_points_to_run: int Minimum number of new data points that must be collected before running model. + max_num_components: int + Max number of real landing pads. + random_state: int Seed for randomizer, to get consistent results. @@ -62,9 +65,6 @@ class ClusterEstimation: __MEAN_PRECISION_PRIOR = 1e-6 __MAX_MODEL_ITERATIONS = 1000 - # Real-world scenario Hyperparameters - __MAX_NUM_COMPONENTS = 10 # assumed maximum number of real landing pads - # Hyperparameters to clean up model outputs __WEIGHT_DROP_THRESHOLD = 0.1 __MAX_COVARIANCE_THRESHOLD = 10 @@ -74,6 +74,7 @@ def create( cls, min_activation_threshold: int, min_new_points_to_run: int, + max_num_components: int, random_state: int, local_logger: logger.Logger, ) -> "tuple[bool, ClusterEstimation | None]": @@ -88,10 +89,15 @@ def create( if min_activation_threshold < 1: return False, None + # This must be greater than 0 + if max_num_components < 0: + return False, None + return True, ClusterEstimation( cls.__create_key, min_activation_threshold, min_new_points_to_run, + max_num_components, random_state, local_logger, ) @@ -101,6 +107,7 @@ def __init__( class_private_create_key: object, min_activation_threshold: int, min_new_points_to_run: int, + max_num_components: int, random_state: int, local_logger: logger.Logger, ) -> None: @@ -112,7 +119,7 @@ def __init__( # Initializes VGMM self.__vgmm = sklearn.mixture.BayesianGaussianMixture( covariance_type=self.__COVAR_TYPE, - n_components=self.__MAX_NUM_COMPONENTS, + n_components=max_num_components, init_params=self.__MODEL_INIT_PARAM, weight_concentration_prior=self.__WEIGHT_CONCENTRATION_PRIOR, mean_precision_prior=self.__MEAN_PRECISION_PRIOR, diff --git a/modules/cluster_estimation/cluster_estimation_worker.py b/modules/cluster_estimation/cluster_estimation_worker.py index f10c8313..17c58765 100644 --- a/modules/cluster_estimation/cluster_estimation_worker.py +++ b/modules/cluster_estimation/cluster_estimation_worker.py @@ -14,6 +14,7 @@ def cluster_estimation_worker( min_activation_threshold: int, min_new_points_to_run: int, + max_num_components: int, random_state: int, input_queue: queue_proxy_wrapper.QueueProxyWrapper, output_queue: queue_proxy_wrapper.QueueProxyWrapper, @@ -30,6 +31,9 @@ def cluster_estimation_worker( min_new_points_to_run: int Minimum number of new data points that must be collected before running model. + max_num_components: int + Max number of real landing pads. + random_state: int Seed for randomizer, to get consistent results. @@ -56,6 +60,7 @@ def cluster_estimation_worker( result, estimator = cluster_estimation.ClusterEstimation.create( min_activation_threshold, min_new_points_to_run, + max_num_components, random_state, local_logger, ) diff --git a/tests/unit/test_cluster_detection.py b/tests/unit/test_cluster_detection.py index 6f6da0f7..155cca06 100644 --- a/tests/unit/test_cluster_detection.py +++ b/tests/unit/test_cluster_detection.py @@ -13,10 +13,10 @@ MIN_TOTAL_POINTS_THRESHOLD = 100 MIN_NEW_POINTS_TO_RUN = 10 +MAX_NUM_COMPONENTS = 10 RNG_SEED = 0 CENTRE_BOX_SIZE = 500 - # Test functions use test fixture signature names and access class privates # No enable # pylint: disable=protected-access,redefined-outer-name @@ -34,6 +34,7 @@ def cluster_model() -> cluster_estimation.ClusterEstimation: # type: ignore result, model = cluster_estimation.ClusterEstimation.create( MIN_TOTAL_POINTS_THRESHOLD, MIN_NEW_POINTS_TO_RUN, + MAX_NUM_COMPONENTS, RNG_SEED, test_logger, )