Skip to content

Commit 6b67f01

Browse files
committed
fix(tests): handle error case in clustering tests
1 parent 369cd77 commit 6b67f01

File tree

1 file changed

+78
-2
lines changed

1 file changed

+78
-2
lines changed

test/unit_tests/dimred/svd/test_clustering_betas.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,87 @@ def test_group_betas(self):
4747

4848
for cluster_type in ClusterType.get_cluster_type_name():
4949
for detector_type in DetectorType.get_detector_type_name():
50-
beta_clusters, name_clusters = group_betas(
51-
fake_names, fake_betas, cluster=cluster_type, detector=detector_type
50+
51+
# As some clustering algorithms require parameters, we need to provide them
52+
# to avoid errors
53+
if cluster_type == "KMeans" or cluster_type == "SpectralClustering":
54+
cluster_params = {"n_clusters": 2}
55+
else:
56+
cluster_params = {}
57+
58+
result = group_betas(
59+
fake_names,
60+
fake_betas,
61+
cluster=cluster_type,
62+
detector=detector_type,
63+
cluster_params=cluster_params,
5264
)
5365

66+
if isinstance(result, str):
67+
self.fail(f"group_betas returned an error: {result}")
68+
69+
beta_clusters, name_clusters = result
70+
5471
# verify correct output
5572
self.assertIsInstance(beta_clusters, list)
5673
self.assertIsInstance(name_clusters, list)
5774
self.assertEqual(len(beta_clusters), len(name_clusters))
75+
76+
77+
def test_group_betas_no_detector(self):
78+
"""tests correct function of the group_betas function
79+
in clustering_betas.py without an outlier detector"""
80+
81+
fake_names = np.array([f"betas_{i}" for i in range(25)])
82+
fake_cluster_0 = np.random.rand(12, 3) + 5
83+
fake_cluster_1 = np.random.rand(12, 3) - 5
84+
fake_betas = np.stack([*fake_cluster_0, *fake_cluster_1, np.array([0, 0, 0])])
85+
expected_clusters = 2
86+
87+
# test with recommended settings
88+
beta_clusters, name_clusters = group_betas(
89+
fake_names,
90+
fake_betas,
91+
cluster=ClusterType.KMeans,
92+
cluster_params={"n_clusters": expected_clusters},
93+
)
94+
95+
# verify correct type of output
96+
self.assertIsInstance(beta_clusters, list)
97+
self.assertIsInstance(name_clusters, list)
98+
99+
# verify that beta_clusters and name_clusters correspond to each other
100+
self.assertEqual(len(beta_clusters), len(name_clusters))
101+
# verify that beta_clusters contains as many clusters as searched for
102+
self.assertEqual(len(beta_clusters), expected_clusters)
103+
104+
# verify that entries correspond to each other
105+
for c, cluster in enumerate(name_clusters):
106+
for e, entry in enumerate(cluster):
107+
index = np.where(fake_names == entry)[0]
108+
self.assertTrue((fake_betas[index] - beta_clusters[c][e]).max() == 0)
109+
110+
# verify different keyword combinations
111+
112+
for cluster_type in ClusterType.get_cluster_type_name():
113+
114+
# As some clustering algorithms require parameters, we need to provide them
115+
# to avoid errors
116+
if cluster_type == "KMeans" or cluster_type == "SpectralClustering":
117+
cluster_params = {"n_clusters": 2}
118+
else:
119+
cluster_params = {}
120+
121+
result = group_betas(
122+
fake_names, fake_betas, cluster=cluster_type, cluster_params=cluster_params
123+
)
124+
125+
if isinstance(result, str):
126+
self.fail(f"group_betas returned an error: {result}")
127+
128+
beta_clusters, name_clusters = result
129+
130+
# verify correct output
131+
self.assertIsInstance(beta_clusters, list)
132+
self.assertIsInstance(name_clusters, list)
133+
self.assertEqual(len(beta_clusters), len(name_clusters))

0 commit comments

Comments
 (0)