Skip to content

Commit 9af0480

Browse files
authored
Merge pull request #310 from electronsandstuff/pierce/crowded-comparison-argsort-issues
NSGA2 Generator NaN Handling in `crowded_comparison_argsort` Bugfix
2 parents 75d37ef + 647d7cd commit 9af0480

File tree

3 files changed

+160
-48
lines changed

3 files changed

+160
-48
lines changed

xopt/generators/ga/nsga2.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from pydantic import Field, Discriminator
3-
from typing import Dict, List, Optional, Annotated, Union
3+
from typing import Annotated
44
import pandas as pd
55
import os
66
from datetime import datetime
@@ -25,6 +25,19 @@
2525
########################################################################################################################
2626

2727

28+
def vocs_data_to_arr(data: list | np.ndarray) -> np.ndarray:
29+
"""Force data coming from VOCS object into 2D numpy array (or None) for compatibility with helper functions"""
30+
if isinstance(data, list):
31+
data = np.ndarray(list)
32+
if data.size == 0:
33+
return None
34+
if len(data.shape) == 1:
35+
return data[:, None]
36+
if len(data.shape) == 2:
37+
return data
38+
raise ValueError(f"Unrecognized shape from VOCS data: {data.shape}")
39+
40+
2841
def get_crowding_distance(pop_f: np.ndarray) -> np.ndarray:
2942
"""
3043
Calculates NSGA-II style crowding distance as described in [1].
@@ -66,16 +79,20 @@ def get_crowding_distance(pop_f: np.ndarray) -> np.ndarray:
6679

6780

6881
def crowded_comparison_argsort(
69-
pop_f: np.ndarray, pop_g: Optional[np.ndarray] = None
82+
pop_f: np.ndarray, pop_g: np.ndarray | None = None
7083
) -> np.ndarray:
7184
"""
7285
Sorts the objective functions by domination rank and then by crowding distance (crowded comparison operator).
73-
Indices to individuals are returned in order of increasing value by crowded comparison operator.
86+
Indices to individuals are returned in order of increasing value of fitness by crowded comparison operator.
87+
That is, the least fit individuals are returned first.
88+
89+
Notes: NaN values are removed from the comparison and added back at the beginning (least fit direction) of
90+
the sorted indices.
7491
7592
Parameters
7693
----------
7794
pop_f : np.ndarray
78-
(M, N) numpy array where N is the number of individuals and M is the number of objectives
95+
(N, M) numpy array where N is the number of individuals and M is the number of objectives
7996
pop_g : np.ndarray, optional
8097
The constraints, by default None
8198
@@ -84,51 +101,69 @@ def crowded_comparison_argsort(
84101
np.ndarray
85102
Numpy array of indices to sorted individuals
86103
"""
87-
# Deal with NaNs
88-
pop_f = np.copy(pop_f)
89-
pop_f[~np.isfinite(pop_g)] = 1e300
104+
# Check for non-finite values in both pop_f and pop_g
105+
has_nan = np.any(~np.isfinite(pop_f), axis=1)
90106
if pop_g is not None:
91-
pop_g = np.copy(pop_g)
92-
pop_g[~np.isfinite(pop_g)] = 1e300
107+
has_nan = has_nan | np.any(~np.isfinite(pop_g), axis=1)
108+
nan_indices = np.where(has_nan)[0]
109+
finite_indices = np.where(~has_nan)[0]
93110

94-
ranks = fast_dominated_argsort(pop_f, pop_g)
95-
inds = []
111+
# If all values are non-finite, return the original indices
112+
if len(finite_indices) == 0:
113+
return np.arange(pop_f.shape[0])
114+
115+
# Extract only finite values for processing
116+
pop_f_finite = pop_f[finite_indices, :]
117+
118+
# Handle constraints if provided
119+
pop_g_finite = None
120+
if pop_g is not None:
121+
pop_g_finite = pop_g[finite_indices, :]
122+
123+
# Apply domination ranking
124+
ranks = fast_dominated_argsort(pop_f_finite, pop_g_finite)
125+
126+
# Calculate crowding distance and sort within each rank
127+
sorted_finite_indices = []
96128
for rank in ranks:
97-
dist = get_crowding_distance(pop_f[rank, :])
98-
inds.extend(np.array(rank)[np.argsort(dist)[::-1]])
129+
dist = get_crowding_distance(pop_f_finite[rank, :])
130+
sorted_rank = np.array(rank)[np.argsort(dist)[::-1]]
131+
sorted_finite_indices.extend(sorted_rank)
132+
133+
# Map back to original indices and put nans at end
134+
sorted_original_indices = finite_indices[sorted_finite_indices]
135+
final_sorted_indices = np.concatenate([sorted_original_indices, nan_indices])
99136

100-
return np.array(inds)[::-1]
137+
return final_sorted_indices[::-1]
101138

102139

103-
def get_fitness(pop_f: np.ndarray, pop_g: np.ndarray) -> np.ndarray:
140+
def get_fitness(pop_f: np.ndarray, pop_g: np.ndarray | None = None) -> np.ndarray:
104141
"""
105142
Get the "fitness" of each individual according to domination and crowding distance.
106143
107144
Parameters
108145
----------
109146
pop_f : np.ndarray
110147
The objectives
111-
pop_g : np.ndarray
112-
The constraints
148+
pop_g : np.ndarray / None
149+
The constraints, or None of no constraints
113150
114151
Returns
115152
-------
116153
np.ndarray
117154
The fitness of each individual
118155
"""
119-
sort_ind = crowded_comparison_argsort(pop_f, pop_g)
120-
fitness = np.argsort(sort_ind)
121-
return fitness
156+
return np.argsort(crowded_comparison_argsort(pop_f, pop_g))
122157

123158

124159
def generate_child_binary_tournament(
125160
pop_x: np.ndarray,
126161
pop_f: np.ndarray,
127-
pop_g: np.ndarray,
162+
pop_g: np.ndarray | None,
128163
bounds: np.ndarray,
129164
mutate: MutationOperator,
130165
crossover: CrossoverOperator,
131-
fitness: Optional[np.ndarray] = None,
166+
fitness: np.ndarray | None = None,
132167
) -> np.ndarray:
133168
"""
134169
Creates a single child from the population using binary tournament selection, crossover, and mutation.
@@ -143,8 +178,9 @@ def generate_child_binary_tournament(
143178
Decision variables of the population, shape (n_individuals, n_variables).
144179
pop_f : numpy.ndarray
145180
Objective function values of the population, shape (n_individuals, n_objectives).
146-
pop_g : numpy.ndarray
181+
pop_g : numpy.ndarray / None
147182
Constraint violation values of the population, shape (n_individuals, n_constraints).
183+
None if no constraints.
148184
bounds : numpy.ndarray
149185
Bounds for decision variables, shape (2, n_variables) where bounds[0] are lower bounds
150186
and bounds[1] are upper bounds.
@@ -186,7 +222,7 @@ def generate_child_binary_tournament(
186222

187223

188224
def cull_population(
189-
pop_x: np.ndarray, pop_f: np.ndarray, pop_g: np.ndarray, population_size: int
225+
pop_x: np.ndarray, pop_f: np.ndarray, pop_g: np.ndarray | None, population_size: int
190226
) -> np.ndarray:
191227
"""
192228
Reduce population size by selecting the best individuals based on crowded comparison.
@@ -198,8 +234,8 @@ def cull_population(
198234
----------
199235
pop_x : numpy.ndarray
200236
Decision variables of the population, shape (n_individuals, n_variables).
201-
pop_f : numpy.ndarray
202-
Objective function values of the population, shape (n_individuals, n_objectives).
237+
pop_f : numpy.ndarray / None
238+
Objective function values of the population, shape (n_individuals, n_objectives), None if no constraints.
203239
pop_g : numpy.ndarray
204240
Constraint violation values of the population, shape (n_individuals, n_constraints).
205241
population_size : int
@@ -210,9 +246,7 @@ def cull_population(
210246
numpy.ndarray
211247
Indices of selected individuals, shape (population_size,).
212248
"""
213-
inds = crowded_comparison_argsort(pop_f, pop_g)[::-1]
214-
inds = inds[:population_size]
215-
return inds
249+
return crowded_comparison_argsort(pop_f, pop_g)[-population_size:]
216250

217251

218252
########################################################################################################################
@@ -278,20 +312,20 @@ class NSGA2Generator(DeduplicatedGeneratorBase, StateOwner):
278312

279313
population_size: int = Field(50, description="Population size")
280314
crossover_operator: Annotated[
281-
Union[
282-
SimulatedBinaryCrossover, DummyCrossover
283-
], # Dummy placeholder to keep discriminator code from failing
315+
(
316+
SimulatedBinaryCrossover | DummyCrossover
317+
), # Dummy placeholder to keep discriminator code from failing
284318
Discriminator("name"),
285319
] = SimulatedBinaryCrossover()
286320
mutation_operator: Annotated[
287-
Union[
288-
PolynomialMutation, DummyMutation
289-
], # Dummy placeholder to keep discriminator code from failing
321+
(
322+
PolynomialMutation | DummyMutation
323+
), # Dummy placeholder to keep discriminator code from failing
290324
Discriminator("name"),
291325
] = PolynomialMutation()
292326

293327
# Output options
294-
output_dir: Optional[str] = None
328+
output_dir: str | None = None
295329
checkpoint_freq: int = Field(
296330
-1,
297331
description="How often (in generations) to save checkpoints (set to -1 to disable)",
@@ -302,7 +336,7 @@ class NSGA2Generator(DeduplicatedGeneratorBase, StateOwner):
302336
_output_dir_setup: bool = (
303337
False # Used in initializing the directory. PLEASE DO NOT CHANGE
304338
)
305-
_logger: Optional[logging.Logger] = None
339+
_logger: logging.Logger | None = None
306340

307341
# Metadata
308342
fevals: int = Field(
@@ -315,7 +349,7 @@ class NSGA2Generator(DeduplicatedGeneratorBase, StateOwner):
315349
n_candidates: int = Field(
316350
0, description="The number of candidate solutions generated so far"
317351
)
318-
history_idx: List[List[int]] = Field(
352+
history_idx: list[list[int]] = Field(
319353
default=[],
320354
description="Xopt indices of the individuals in each population",
321355
)
@@ -326,15 +360,15 @@ class NSGA2Generator(DeduplicatedGeneratorBase, StateOwner):
326360
)
327361

328362
# The population and returned children
329-
pop: List[Dict] = Field(default=[])
330-
child: List[Dict] = Field(default=[])
363+
pop: list[dict] = Field(default=[])
364+
child: list[dict] = Field(default=[])
331365

332366
def model_post_init(self, context):
333367
# Get a unique logger per object
334368
self._logger = logging.getLogger(f"{__name__}.NSGA2Generator.{id(self)}")
335369
self._logger.setLevel(self.log_level)
336370

337-
def _generate(self, n_candidates: int) -> List[Dict]:
371+
def _generate(self, n_candidates: int) -> list[dict]:
338372
self.ensure_output_dir_setup()
339373
start_t = time.perf_counter()
340374

@@ -347,7 +381,7 @@ def _generate(self, n_candidates: int) -> List[Dict]:
347381
candidates = []
348382
pop_x = self.vocs.variable_data(self.pop).to_numpy()
349383
pop_f = self.vocs.objective_data(self.pop).to_numpy()
350-
pop_g = self.vocs.constraint_data(self.pop).to_numpy()
384+
pop_g = vocs_data_to_arr(self.vocs.constraint_data(self.pop).to_numpy())
351385
fitness = get_fitness(pop_f, pop_g)
352386
for _ in range(n_candidates):
353387
candidates.append(
@@ -419,7 +453,7 @@ def add_data(self, new_data: pd.DataFrame):
419453
idx = cull_population(
420454
self.vocs.variable_data(self.pop).to_numpy(),
421455
self.vocs.objective_data(self.pop).to_numpy(),
422-
self.vocs.constraint_data(self.pop).to_numpy(),
456+
vocs_data_to_arr(self.vocs.constraint_data(self.pop).to_numpy()),
423457
self.population_size,
424458
)
425459
self.pop = [self.pop[i] for i in idx]

xopt/generators/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# from xopt.generator import Generator
22
import numpy as np
3-
from typing import List, Optional
43

54

6-
def get_domination(pop_f: np.ndarray, pop_g: Optional[np.ndarray] = None) -> np.ndarray:
5+
def get_domination(pop_f: np.ndarray, pop_g: np.ndarray | None = None) -> np.ndarray:
76
"""
87
Compute domination matrix for a population based on objective values and constraints. Determines domination
98
relationships between all pairs of individuals in a population.
@@ -46,7 +45,7 @@ def get_domination(pop_f: np.ndarray, pop_g: Optional[np.ndarray] = None) -> np.
4645
return dom
4746

4847

49-
def fast_dominated_argsort_internal(dom: np.ndarray) -> List[np.ndarray]:
48+
def fast_dominated_argsort_internal(dom: np.ndarray) -> list[np.ndarray]:
5049
"""
5150
Used inside of `fast_dominated_argsort`. Call that function instead.
5251
@@ -83,8 +82,8 @@ def fast_dominated_argsort_internal(dom: np.ndarray) -> List[np.ndarray]:
8382

8483

8584
def fast_dominated_argsort(
86-
pop_f: np.ndarray, pop_g: Optional[np.ndarray] = None
87-
) -> List[np.ndarray]:
85+
pop_f: np.ndarray, pop_g: np.ndarray | None = None
86+
) -> list[np.ndarray]:
8887
"""
8988
Performs a dominated sort on matrix of objective function values O. This is a numpy implementation of the algorithm
9089
described in [1].

xopt/tests/generators/ga/test_nsga2.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
import numpy as np
77
import pandas as pd
8+
import pytest
89

910
from xopt.base import Xopt
1011
from xopt.evaluator import Evaluator
1112
from xopt.generators.ga.nsga2 import (
1213
NSGA2Generator,
1314
generate_child_binary_tournament,
15+
crowded_comparison_argsort,
1416
)
1517
from xopt.generators.ga.operators import PolynomialMutation, SimulatedBinaryCrossover
1618
from xopt.resources.test_functions.tnk import evaluate_TNK, tnk_vocs
@@ -397,3 +399,80 @@ def compare(val_a, val_b):
397399
np.random.seed(42)
398400
samples = X.generator.generate(1)
399401
X.evaluate_data(samples)
402+
403+
404+
@pytest.mark.parametrize(
405+
"pop_f, pop_g, expected_indices_options",
406+
[
407+
# Two individuals in different ranks
408+
(np.array([[1.0, 2.0], [2.0, 3.0]]), None, [np.array([1, 0])]),
409+
# Non-dominated, different crowding distances
410+
(
411+
np.array([[1.0, 3.0], [2.0, 2.0], [3.0, 1.0]]),
412+
None,
413+
[np.array([1, 2, 0]), np.array([1, 0, 2])],
414+
),
415+
# NaN values
416+
(
417+
np.array([[1.0, 2.0], [np.nan, 3.0], [2.0, 1.0]]),
418+
None,
419+
[np.array([1, 2, 0]), np.array([1, 0, 2])],
420+
),
421+
# Constrained
422+
(
423+
np.array([[2.0, 2.0], [1.0, 1.0]]),
424+
np.array([[-1.0, -1.0], [1.0, -1.0]]),
425+
[np.array([1, 0])],
426+
),
427+
# Multiple individuals with same rank but potentially same crowding distances
428+
(
429+
np.array([[1.0, 3.0], [2.0, 2.0], [3.0, 1.0], [1.5, 2.5]]),
430+
None,
431+
[
432+
np.array([3, 1, 2, 0]),
433+
np.array([1, 3, 2, 0]),
434+
np.array([3, 1, 0, 2]),
435+
np.array([1, 3, 0, 2]),
436+
],
437+
),
438+
# NaN values and constraints
439+
(
440+
np.array([[1.0, 2.0], [np.nan, 3.0], [2.0, 1.0]]),
441+
np.array([[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]),
442+
[np.array([1, 2, 0])],
443+
),
444+
# All NaN
445+
(
446+
np.array([[np.nan, np.nan], [np.nan, np.nan]]),
447+
None,
448+
[np.array([0, 1]), np.array([1, 0])],
449+
),
450+
],
451+
)
452+
def test_crowded_comparison_argsort(pop_f, pop_g, expected_indices_options):
453+
"""
454+
Test the crowded_comparison_argsort function with various explicit input.
455+
456+
Parameters
457+
----------
458+
pop_f : numpy.ndarray
459+
Objective values
460+
pop_g : numpy.ndarray or None
461+
Constraint values
462+
expected_indices_options : list of numpy.ndarray
463+
List of valid expected sorted indices
464+
"""
465+
# Call the function
466+
result = crowded_comparison_argsort(pop_f, pop_g)
467+
468+
# Check if the result matches any of the expected options
469+
matches_any = any(
470+
np.array_equal(result, expected) for expected in expected_indices_options
471+
)
472+
473+
if not matches_any:
474+
message = (
475+
f"Result {result} doesn't match any expected ordering.\n"
476+
f"Expected one of: {expected_indices_options}"
477+
)
478+
assert False, message

0 commit comments

Comments
 (0)