Skip to content

Commit e52f92b

Browse files
lint :(
1 parent e025d92 commit e52f92b

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

xopt/generators/ga/nsga2.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# Helper functions
2525
########################################################################################################################
2626

27+
2728
def vocs_data_to_arr(data: list | np.ndarray) -> np.ndarray:
2829
"""Force data coming from VOCS object into 2D numpy array (or None) for compatibility with helper functions"""
2930
if isinstance(data, list):
@@ -106,33 +107,33 @@ def crowded_comparison_argsort(
106107
has_nan = has_nan | np.any(~np.isfinite(pop_g), axis=1)
107108
nan_indices = np.where(has_nan)[0]
108109
finite_indices = np.where(~has_nan)[0]
109-
110+
110111
# If all values are non-finite, return the original indices
111112
if len(finite_indices) == 0:
112113
return np.arange(pop_f.shape[0])
113-
114+
114115
# Extract only finite values for processing
115116
pop_f_finite = pop_f[finite_indices, :]
116-
117+
117118
# Handle constraints if provided
118119
pop_g_finite = None
119120
if pop_g is not None:
120121
pop_g_finite = pop_g[finite_indices, :]
121-
122+
122123
# Apply domination ranking
123124
ranks = fast_dominated_argsort(pop_f_finite, pop_g_finite)
124-
125+
125126
# Calculate crowding distance and sort within each rank
126127
sorted_finite_indices = []
127128
for rank in ranks:
128129
dist = get_crowding_distance(pop_f_finite[rank, :])
129130
sorted_rank = np.array(rank)[np.argsort(dist)[::-1]]
130131
sorted_finite_indices.extend(sorted_rank)
131-
132+
132133
# Map back to original indices and put nans at end
133134
sorted_original_indices = finite_indices[sorted_finite_indices]
134135
final_sorted_indices = np.concatenate([sorted_original_indices, nan_indices])
135-
136+
136137
return final_sorted_indices[::-1]
137138

138139

xopt/tests/generators/ga/test_nsga2.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -405,28 +405,24 @@ def compare(val_a, val_b):
405405
"pop_f, pop_g, expected_indices_options",
406406
[
407407
# Two individuals in different ranks
408-
(
409-
np.array([[1.0, 2.0], [2.0, 3.0]]),
410-
None,
411-
[np.array([1, 0])]
412-
),
408+
(np.array([[1.0, 2.0], [2.0, 3.0]]), None, [np.array([1, 0])]),
413409
# Non-dominated, different crowding distances
414410
(
415411
np.array([[1.0, 3.0], [2.0, 2.0], [3.0, 1.0]]),
416412
None,
417-
[np.array([1, 2, 0]), np.array([1, 0, 2])]
413+
[np.array([1, 2, 0]), np.array([1, 0, 2])],
418414
),
419415
# NaN values
420416
(
421417
np.array([[1.0, 2.0], [np.nan, 3.0], [2.0, 1.0]]),
422418
None,
423-
[np.array([1, 2, 0]), np.array([1, 0, 2])]
419+
[np.array([1, 2, 0]), np.array([1, 0, 2])],
424420
),
425421
# Constrained
426422
(
427423
np.array([[2.0, 2.0], [1.0, 1.0]]),
428424
np.array([[-1.0, -1.0], [1.0, -1.0]]),
429-
[np.array([1, 0])]
425+
[np.array([1, 0])],
430426
),
431427
# Multiple individuals with same rank but potentially same crowding distances
432428
(
@@ -437,26 +433,26 @@ def compare(val_a, val_b):
437433
np.array([1, 3, 2, 0]),
438434
np.array([3, 1, 0, 2]),
439435
np.array([1, 3, 0, 2]),
440-
]
436+
],
441437
),
442438
# NaN values and constraints
443439
(
444440
np.array([[1.0, 2.0], [np.nan, 3.0], [2.0, 1.0]]),
445441
np.array([[-1.0, -1.0, -1.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]),
446-
[np.array([1, 2, 0])]
442+
[np.array([1, 2, 0])],
447443
),
448444
# All NaN
449445
(
450446
np.array([[np.nan, np.nan], [np.nan, np.nan]]),
451447
None,
452-
[np.array([0, 1]), np.array([1, 0])]
448+
[np.array([0, 1]), np.array([1, 0])],
453449
),
454450
],
455451
)
456452
def test_crowded_comparison_argsort(pop_f, pop_g, expected_indices_options):
457453
"""
458454
Test the crowded_comparison_argsort function with various explicit input.
459-
455+
460456
Parameters
461457
----------
462458
pop_f : numpy.ndarray
@@ -468,10 +464,12 @@ def test_crowded_comparison_argsort(pop_f, pop_g, expected_indices_options):
468464
"""
469465
# Call the function
470466
result = crowded_comparison_argsort(pop_f, pop_g)
471-
467+
472468
# Check if the result matches any of the expected options
473-
matches_any = any(np.array_equal(result, expected) for expected in expected_indices_options)
474-
469+
matches_any = any(
470+
np.array_equal(result, expected) for expected in expected_indices_options
471+
)
472+
475473
if not matches_any:
476474
message = (
477475
f"Result {result} doesn't match any expected ordering.\n"

0 commit comments

Comments
 (0)