Skip to content

Commit e025d92

Browse files
cleaner handling of VOCS data
1 parent 4333d41 commit e025d92

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

xopt/generators/ga/nsga2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@
2424
# Helper functions
2525
########################################################################################################################
2626

27+
def vocs_data_to_arr(data: list | np.ndarray) -> np.ndarray:
28+
"""Force data coming from VOCS object into 2D numpy array (or None) for compatibility with helper functions"""
29+
if isinstance(data, list):
30+
data = np.ndarray(list)
31+
if len(data.shape) == 0:
32+
return None
33+
if len(data.shape) == 1:
34+
return data[:, None]
35+
if len(data.shape) == 2:
36+
return data
37+
raise ValueError(f"Unrecognized shape from VOCS data: {data.shape}")
38+
2739

2840
def get_crowding_distance(pop_f: np.ndarray) -> np.ndarray:
2941
"""
@@ -368,7 +380,7 @@ def _generate(self, n_candidates: int) -> list[dict]:
368380
candidates = []
369381
pop_x = self.vocs.variable_data(self.pop).to_numpy()
370382
pop_f = self.vocs.objective_data(self.pop).to_numpy()
371-
pop_g = self.vocs.constraint_data(self.pop).to_numpy() if self.vocs.constraint_names else None
383+
pop_g = vocs_data_to_arr(self.vocs.constraint_data(self.pop).to_numpy())
372384
fitness = get_fitness(pop_f, pop_g)
373385
for _ in range(n_candidates):
374386
candidates.append(
@@ -440,7 +452,7 @@ def add_data(self, new_data: pd.DataFrame):
440452
idx = cull_population(
441453
self.vocs.variable_data(self.pop).to_numpy(),
442454
self.vocs.objective_data(self.pop).to_numpy(),
443-
self.vocs.constraint_data(self.pop).to_numpy() if self.vocs.constraint_names else None,
455+
vocs_data_to_arr(self.vocs.constraint_data(self.pop).to_numpy()),
444456
self.population_size,
445457
)
446458
self.pop = [self.pop[i] for i in idx]

0 commit comments

Comments
 (0)