|
24 | 24 | # Helper functions |
25 | 25 | ######################################################################################################################## |
26 | 26 |
|
| 27 | +def vocs_data_to_arr(data: list | np.ndarray) -> np.ndarray: |
| 28 | + """Force data coming from VOCS object into 2D numpy array (or None) for compatibility with helper functions""" |
| 29 | + if isinstance(data, list): |
| 30 | + data = np.ndarray(list) |
| 31 | + if len(data.shape) == 0: |
| 32 | + return None |
| 33 | + if len(data.shape) == 1: |
| 34 | + return data[:, None] |
| 35 | + if len(data.shape) == 2: |
| 36 | + return data |
| 37 | + raise ValueError(f"Unrecognized shape from VOCS data: {data.shape}") |
| 38 | + |
27 | 39 |
|
28 | 40 | def get_crowding_distance(pop_f: np.ndarray) -> np.ndarray: |
29 | 41 | """ |
@@ -368,7 +380,7 @@ def _generate(self, n_candidates: int) -> list[dict]: |
368 | 380 | candidates = [] |
369 | 381 | pop_x = self.vocs.variable_data(self.pop).to_numpy() |
370 | 382 | pop_f = self.vocs.objective_data(self.pop).to_numpy() |
371 | | - pop_g = self.vocs.constraint_data(self.pop).to_numpy() if self.vocs.constraint_names else None |
| 383 | + pop_g = vocs_data_to_arr(self.vocs.constraint_data(self.pop).to_numpy()) |
372 | 384 | fitness = get_fitness(pop_f, pop_g) |
373 | 385 | for _ in range(n_candidates): |
374 | 386 | candidates.append( |
@@ -440,7 +452,7 @@ def add_data(self, new_data: pd.DataFrame): |
440 | 452 | idx = cull_population( |
441 | 453 | self.vocs.variable_data(self.pop).to_numpy(), |
442 | 454 | self.vocs.objective_data(self.pop).to_numpy(), |
443 | | - self.vocs.constraint_data(self.pop).to_numpy() if self.vocs.constraint_names else None, |
| 455 | + vocs_data_to_arr(self.vocs.constraint_data(self.pop).to_numpy()), |
444 | 456 | self.population_size, |
445 | 457 | ) |
446 | 458 | self.pop = [self.pop[i] for i in idx] |
|
0 commit comments