Skip to content

Commit 6b46322

Browse files
committedJan 19, 2025
TST: increase test coverage (#756)
* TST: add flight fixture for solid propulsion equations of motion * DEV: add pragma comments to exclude specific lines from coverage * TST: adds more unit tests to the codebase MNT: linters TST: complementing tests for sensitivity analysis and removing duplicate piece of code. DEV: add pragma comments to exclude specific lines from coverage MNT: fix pylint error * TST: add fixture for solid propulsion equations of motion in flight tests * TST: fix tests not passing
1 parent 48efc87 commit 6b46322

27 files changed

+544
-119
lines changed
 

‎.vscode/settings.json

+1
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@
262262
"rtol",
263263
"rtype",
264264
"rucsoundings",
265+
"runslow",
265266
"rwork",
266267
"savetxt",
267268
"savgol",

‎rocketpy/environment/environment.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,7 @@ def __initialize_utm_coordinates(self):
459459
flattening=self.ellipsoid.flattening,
460460
semi_major_axis=self.ellipsoid.semi_major_axis,
461461
)
462-
else:
463-
# pragma: no cover
462+
else: # pragma: no cover
464463
warnings.warn(
465464
"UTM coordinates are not available for latitudes "
466465
"above 84 or below -80 degrees. The UTM conversions will fail."
@@ -715,8 +714,8 @@ def set_location(self, latitude, longitude):
715714

716715
if not isinstance(latitude, NUMERICAL_TYPES) and isinstance(
717716
longitude, NUMERICAL_TYPES
718-
):
719-
# pragma: no cover
717+
): # pragma: no cover
718+
720719
raise TypeError("Latitude and Longitude must be numbers!")
721720

722721
# Store latitude and longitude
@@ -812,8 +811,8 @@ def max_expected_height(self):
812811

813812
@max_expected_height.setter
814813
def max_expected_height(self, value):
815-
if value < self.elevation:
816-
raise ValueError( # pragma: no cover
814+
if value < self.elevation: # pragma: no cover
815+
raise ValueError(
817816
"Max expected height cannot be lower than the surface elevation"
818817
)
819818
self._max_expected_height = value
@@ -952,8 +951,8 @@ def get_elevation_from_topographic_profile(self, lat, lon):
952951
Elevation provided by the topographic data, in meters.
953952
"""
954953
# TODO: refactor this method. pylint: disable=too-many-statements
955-
if self.topographic_profile_activated is False:
956-
raise ValueError( # pragma: no cover
954+
if self.topographic_profile_activated is False: # pragma: no cover
955+
raise ValueError(
957956
"You must define a Topographic profile first, please use the "
958957
"Environment.set_topographic_profile() method first."
959958
)
@@ -1285,8 +1284,8 @@ def set_atmospheric_model( # pylint: disable=too-many-statements
12851284
self.process_forecast_reanalysis(dataset, dictionary)
12861285
else:
12871286
self.process_ensemble(dataset, dictionary)
1288-
else:
1289-
raise ValueError(f"Unknown model type '{type}'.") # pragma: no cover
1287+
else: # pragma: no cover
1288+
raise ValueError(f"Unknown model type '{type}'.")
12901289

12911290
if type not in ["ensemble"]:
12921291
# Ensemble already computed these values
@@ -2578,7 +2577,7 @@ def set_earth_geometry(self, datum):
25782577
}
25792578
try:
25802579
return ellipsoid[datum]
2581-
except KeyError as e:
2580+
except KeyError as e: # pragma: no cover
25822581
available_datums = ', '.join(ellipsoid.keys())
25832582
raise AttributeError(
25842583
f"The reference system '{datum}' is not recognized. Please use one of "
@@ -2845,7 +2844,7 @@ def from_dict(cls, data): # pylint: disable=too-many-statements
28452844
return env
28462845

28472846

2848-
if __name__ == "__main__":
2847+
if __name__ == "__main__": # pragma: no cover
28492848
import doctest
28502849

28512850
results = doctest.testmod()

‎rocketpy/environment/fetchers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def fetch_atmospheric_data_from_windy(lat, lon, model):
7979

8080
try:
8181
response = requests.get(url).json()
82-
if "data" not in response.keys():
82+
if "data" not in response.keys(): # pragma: no cover
8383
raise ValueError(
8484
f"Could not get a valid response for '{model}' from Windy. "
8585
"Check if the coordinates are set inside the model's domain."
8686
)
87-
except requests.exceptions.RequestException as e:
87+
except requests.exceptions.RequestException as e: # pragma: no cover
8888
if model == "iconEu":
8989
raise ValueError(
9090
"Could not get a valid response for Icon-EU from Windy. "
@@ -315,8 +315,8 @@ def fetch_wyoming_sounding(file):
315315
If the response indicates the output format is invalid.
316316
"""
317317
response = requests.get(file)
318-
if response.status_code != 200:
319-
raise ImportError(f"Unable to load {file}.") # pragma: no cover
318+
if response.status_code != 200: # pragma: no cover
319+
raise ImportError(f"Unable to load {file}.")
320320
if len(re.findall("Can't get .+ Observations at", response.text)):
321321
raise ValueError(
322322
re.findall("Can't get .+ Observations at .+", response.text)[0]
@@ -330,7 +330,7 @@ def fetch_wyoming_sounding(file):
330330

331331

332332
@exponential_backoff(max_attempts=5, base_delay=2, max_delay=60)
333-
def fetch_noaaruc_sounding(file):
333+
def fetch_noaaruc_sounding(file): # pragma: no cover
334334
"""Fetches sounding data from a specified file using the NOAA RUC soundings.
335335
336336
Parameters

‎rocketpy/environment/tools.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def utm_to_geodesic( # pylint: disable=too-many-locals,too-many-statements
590590
return lat, lon
591591

592592

593-
if __name__ == "__main__":
593+
if __name__ == "__main__": # pragma: no cover
594594
import doctest
595595

596596
results = doctest.testmod()

‎rocketpy/mathutils/function.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,7 @@ def plot(self, *args, **kwargs):
14811481
else:
14821482
print("Error: Only functions with 1D or 2D domains can be plotted.")
14831483

1484-
def plot1D(self, *args, **kwargs):
1484+
def plot1D(self, *args, **kwargs): # pragma: no cover
14851485
"""Deprecated method, use Function.plot_1d instead."""
14861486
warnings.warn(
14871487
"The `Function.plot1D` method is set to be deprecated and fully "
@@ -1581,7 +1581,7 @@ def plot_1d( # pylint: disable=too-many-statements
15811581
if return_object:
15821582
return fig, ax
15831583

1584-
def plot2D(self, *args, **kwargs):
1584+
def plot2D(self, *args, **kwargs): # pragma: no cover
15851585
"""Deprecated method, use Function.plot_2d instead."""
15861586
warnings.warn(
15871587
"The `Function.plot2D` method is set to be deprecated and fully "
@@ -2772,7 +2772,7 @@ def differentiate_complex_step(self, x, dx=1e-200, order=1):
27722772
"""
27732773
if order == 1:
27742774
return float(self.get_value_opt(x + dx * 1j).imag / dx)
2775-
else:
2775+
else: # pragma: no cover
27762776
raise NotImplementedError(
27772777
"Only 1st order derivatives are supported yet. Set order=1."
27782778
)
@@ -3119,12 +3119,12 @@ def compose(self, func, extrapolate=False):
31193119
The result of inputting the function into the function.
31203120
"""
31213121
# Check if the input is a function
3122-
if not isinstance(func, Function):
3122+
if not isinstance(func, Function): # pragma: no cover
31233123
raise TypeError("Input must be a Function object.")
31243124

31253125
if isinstance(self.source, np.ndarray) and isinstance(func.source, np.ndarray):
31263126
# Perform bounds check for composition
3127-
if not extrapolate:
3127+
if not extrapolate: # pragma: no cover
31283128
if func.min < self.x_initial or func.max > self.x_final:
31293129
raise ValueError(
31303130
f"Input Function image {func.min, func.max} must be within "
@@ -3197,7 +3197,7 @@ def savetxt(
31973197

31983198
# create the datapoints
31993199
if callable(self.source):
3200-
if lower is None or upper is None or samples is None:
3200+
if lower is None or upper is None or samples is None: # pragma: no cover
32013201
raise ValueError(
32023202
"If the source is a callable, lower, upper and samples"
32033203
+ " must be provided."
@@ -3264,7 +3264,7 @@ def __validate_source(self, source): # pylint: disable=too-many-statements
32643264
self.__inputs__ = header[:-1]
32653265
if self.__outputs__ is None:
32663266
self.__outputs__ = [header[-1]]
3267-
except Exception as e:
3267+
except Exception as e: # pragma: no cover
32683268
raise ValueError(
32693269
"Could not read the csv or txt file to create Function source."
32703270
) from e
@@ -3323,6 +3323,7 @@ def __validate_inputs(self, inputs):
33233323
if isinstance(inputs, (list, tuple)):
33243324
if len(inputs) == 1:
33253325
return inputs
3326+
# pragma: no cover
33263327
raise ValueError(
33273328
"Inputs must be a string or a list of strings with "
33283329
"the length of the domain dimension."
@@ -3335,6 +3336,7 @@ def __validate_inputs(self, inputs):
33353336
isinstance(i, str) for i in inputs
33363337
):
33373338
return inputs
3339+
# pragma: no cover
33383340
raise ValueError(
33393341
"Inputs must be a list of strings with "
33403342
"the length of the domain dimension."
@@ -3611,7 +3613,7 @@ def reset_funcified_methods(instance):
36113613
instance.__dict__.pop(key)
36123614

36133615

3614-
if __name__ == "__main__":
3616+
if __name__ == "__main__": # pragma: no cover
36153617
import doctest
36163618

36173619
results = doctest.testmod()

‎rocketpy/mathutils/vector_matrix.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ def from_dict(cls, data):
11031103
return cls(data)
11041104

11051105

1106-
if __name__ == "__main__":
1106+
if __name__ == "__main__": # pragma: no cover
11071107
import doctest
11081108

11091109
results = doctest.testmod()

‎rocketpy/motors/fluid.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ def __post_init__(self):
3030
If the density is not a positive number.
3131
"""
3232

33-
if not isinstance(self.name, str):
33+
if not isinstance(self.name, str): # pragma: no cover
3434
raise ValueError("The name must be a string.")
35-
if self.density < 0:
35+
if self.density < 0: # pragma: no cover
3636
raise ValueError("The density must be a positive number.")
3737

3838
# Initialize plots and prints object

‎rocketpy/motors/motor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class Function. Thrust units are Newtons.
247247
self._csys = 1
248248
elif coordinate_system_orientation == "combustion_chamber_to_nozzle":
249249
self._csys = -1
250-
else:
250+
else: # pragma: no cover
251251
raise ValueError(
252252
"Invalid coordinate system orientation. Options are "
253253
"'nozzle_to_combustion_chamber' and 'combustion_chamber_to_nozzle'."
@@ -346,7 +346,7 @@ def burn_time(self, burn_time):
346346
else:
347347
if not callable(self.thrust.source):
348348
self._burn_time = (self.thrust.x_array[0], self.thrust.x_array[-1])
349-
else:
349+
else: # pragma: no cover
350350
raise ValueError(
351351
"When using a float or callable as thrust source, a burn_time"
352352
" argument must be specified."

‎rocketpy/plots/rocket_plots.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ def _draw_generic_surface(
385385
x_pos = position[2]
386386
# y position of the surface is the y position in the plot
387387
y_pos = position[1]
388-
else:
388+
else: # pragma: no cover
389389
raise ValueError("Plane must be 'xz' or 'yz'.")
390390

391391
ax.scatter(
@@ -633,7 +633,7 @@ def _draw_sensors(self, ax, sensors, plane):
633633
# y position of the sensor is the y position in the plot
634634
y_pos = pos[1]
635635
normal_y = sensor.normal_vector.y
636-
else:
636+
else: # pragma: no cover
637637
raise ValueError("Plane must be 'xz' or 'yz'.")
638638

639639
# line length is 2/5 of the rocket radius

‎rocketpy/prints/compare_prints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
class _ComparePrints:
1+
class _ComparePrints: # pragma: no cover
22
def __init__(self) -> None:
33
pass

‎rocketpy/rocket/aero_surface/nose_cone.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def __init__( # pylint: disable=too-many-statements
135135
self._base_radius = base_radius
136136
self._length = length
137137
if bluffness is not None:
138-
if bluffness > 1 or bluffness < 0:
138+
if bluffness > 1 or bluffness < 0: # pragma: no cover
139139
raise ValueError(
140140
f"Bluffness ratio of {bluffness} is out of range. "
141141
"It must be between 0 and 1."
@@ -286,7 +286,7 @@ def theta(x):
286286
self.y_nosecone = Function(
287287
lambda x: self.base_radius * np.power(x / self.length, self.power)
288288
)
289-
else:
289+
else: # pragma: no cover
290290
raise ValueError(
291291
f"Nose Cone kind '{self.kind}' not found, "
292292
+ "please use one of the following Nose Cone kinds:"
@@ -317,12 +317,11 @@ def bluffness(self, value):
317317
raise ValueError(
318318
"Parameter 'bluffness' must be None or 0 when using a nose cone kind 'powerseries'."
319319
)
320-
if value is not None:
321-
if value > 1 or value < 0:
322-
raise ValueError(
323-
f"Bluffness ratio of {value} is out of range. "
324-
"It must be between 0 and 1."
325-
)
320+
if value is not None and not 0 <= value <= 1: # pragma: no cover
321+
raise ValueError(
322+
f"Bluffness ratio of {value} is out of range. "
323+
"It must be between 0 and 1."
324+
)
326325
self._bluffness = value
327326
self.evaluate_nose_shape()
328327

‎rocketpy/rocket/rocket.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def __init__( # pylint: disable=too-many-statements
278278
self._csys = 1
279279
elif coordinate_system_orientation == "nose_to_tail":
280280
self._csys = -1
281-
else:
281+
else: # pragma: no cover
282282
raise TypeError(
283283
"Invalid coordinate system orientation. Please choose between "
284284
+ '"tail_to_nose" and "nose_to_tail".'
@@ -1173,7 +1173,7 @@ def add_nose(
11731173
self.add_surfaces(nose, position)
11741174
return nose
11751175

1176-
def add_fins(self, *args, **kwargs):
1176+
def add_fins(self, *args, **kwargs): # pragma: no cover
11771177
"""See Rocket.add_trapezoidal_fins for documentation.
11781178
This method is set to be deprecated in version 1.0.0 and fully removed
11791179
by version 2.0.0. Use Rocket.add_trapezoidal_fins instead. It keeps the

‎rocketpy/sensitivity/sensitivity_model.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,6 @@ def set_target_variables_nominal(self, target_variables_nominal_value):
140140
self.target_variables_info[target_variable]["nominal_value"] = (
141141
target_variables_nominal_value[i]
142142
)
143-
for i, target_variable in enumerate(self.target_variables_names):
144-
self.target_variables_info[target_variable]["nominal_value"] = (
145-
target_variables_nominal_value[i]
146-
)
147143

148144
self._nominal_target_passed = True
149145

@@ -356,12 +352,12 @@ def __check_requirements(self):
356352
version = ">=0" if not version else version
357353
try:
358354
check_requirement_version(module_name, version)
359-
except (ValueError, ImportError) as e:
355+
except (ValueError, ImportError) as e: # pragma: no cover
360356
has_error = True
361357
print(
362358
f"The following error occurred while importing {module_name}: {e}"
363359
)
364-
if has_error:
360+
if has_error: # pragma: no cover
365361
print(
366362
"Given the above errors, some methods may not work. Please run "
367363
+ "'pip install rocketpy[sensitivity]' to install extra requirements."

‎rocketpy/simulation/flight.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
615615
self.env = environment
616616
self.rocket = rocket
617617
self.rail_length = rail_length
618-
if self.rail_length <= 0:
618+
if self.rail_length <= 0: # pragma: no cover
619619
raise ValueError("Rail length must be a positive value.")
620620
self.parachutes = self.rocket.parachutes[:]
621621
self.inclination = inclination
@@ -872,11 +872,11 @@ def __simulate(self, verbose):
872872
for t_root in t_roots
873873
if 0 < t_root.real < t1 and abs(t_root.imag) < 0.001
874874
]
875-
if len(valid_t_root) > 1:
875+
if len(valid_t_root) > 1: # pragma: no cover
876876
raise ValueError(
877877
"Multiple roots found when solving for rail exit time."
878878
)
879-
if len(valid_t_root) == 0:
879+
if len(valid_t_root) == 0: # pragma: no cover
880880
raise ValueError(
881881
"No valid roots found when solving for rail exit time."
882882
)
@@ -951,7 +951,7 @@ def __simulate(self, verbose):
951951
for t_root in t_roots
952952
if abs(t_root.imag) < 0.001 and 0 < t_root.real < t1
953953
]
954-
if len(valid_t_root) > 1:
954+
if len(valid_t_root) > 1: # pragma: no cover
955955
raise ValueError(
956956
"Multiple roots found when solving for impact time."
957957
)
@@ -1226,7 +1226,7 @@ def __init_controllers(self):
12261226
self._controllers = self.rocket._controllers[:]
12271227
self.sensors = self.rocket.sensors.get_components()
12281228
if self._controllers or self.sensors:
1229-
if self.time_overshoot:
1229+
if self.time_overshoot: # pragma: no cover
12301230
self.time_overshoot = False
12311231
warnings.warn(
12321232
"time_overshoot has been set to False due to the presence "
@@ -1266,7 +1266,7 @@ def __set_ode_solver(self, solver):
12661266
else:
12671267
try:
12681268
self._solver = ODE_SOLVER_MAP[solver]
1269-
except KeyError as e:
1269+
except KeyError as e: # pragma: no cover
12701270
raise ValueError(
12711271
f"Invalid ``ode_solver`` input: {solver}. "
12721272
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
@@ -1398,7 +1398,7 @@ def udot_rail1(self, t, u, post_processing=False):
13981398

13991399
return [vx, vy, vz, ax, ay, az, 0, 0, 0, 0, 0, 0, 0]
14001400

1401-
def udot_rail2(self, t, u, post_processing=False):
1401+
def udot_rail2(self, t, u, post_processing=False): # pragma: no cover
14021402
"""[Still not implemented] Calculates derivative of u state vector with
14031403
respect to time when rocket is flying in 3 DOF motion in the rail.
14041404
@@ -3531,7 +3531,7 @@ def __len__(self):
35313531
def __repr__(self):
35323532
return str(self.list)
35333533

3534-
def display_warning(self, *messages):
3534+
def display_warning(self, *messages): # pragma: no cover
35353535
"""A simple function to print a warning message."""
35363536
print("WARNING:", *messages)
35373537

‎rocketpy/tools.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@
2929

3030

3131
def tuple_handler(value):
32-
"""Transforms the input value into a tuple that
33-
represents a range. If the input is an int or float,
34-
the output is a tuple from zero to the input value. If
35-
the input is a tuple or list, the output is a tuple with
36-
the same range.
32+
"""Transforms the input value into a tuple that represents a range. If the
33+
input is an int or float, the output is a tuple from zero to the input
34+
value. If the input is a tuple or list, the output is a tuple with the same
35+
range.
3736
3837
Parameters
3938
----------
@@ -265,7 +264,7 @@ def get_distribution(distribution_function_name):
265264
}
266265
try:
267266
return distributions[distribution_function_name]
268-
except KeyError as e:
267+
except KeyError as e: # pragma: no cover
269268
raise ValueError(
270269
f"Distribution function '{distribution_function_name}' not found, "
271270
+ "please use one of the following np.random distribution function:"
@@ -915,7 +914,7 @@ def import_optional_dependency(name):
915914
"""
916915
try:
917916
module = importlib.import_module(name)
918-
except ImportError as exc:
917+
except ImportError as exc: # pragma: no cover
919918
module_name = name.split(".")[0]
920919
package_name = INSTALL_MAPPING.get(module_name, module_name)
921920
raise ImportError(
@@ -979,7 +978,8 @@ def wrapper(*args, **kwargs):
979978
for i in range(max_attempts):
980979
try:
981980
return func(*args, **kwargs)
982-
except Exception as e: # pylint: disable=broad-except
981+
# pylint: disable=broad-except
982+
except Exception as e: # pragma: no cover
983983
if i == max_attempts - 1:
984984
raise e from None
985985
delay = min(delay * 2, max_delay)
@@ -1205,7 +1205,7 @@ def from_hex_decode(obj_bytes, decoder=base64.b85decode):
12051205
return dill.loads(decoder(bytes.fromhex(obj_bytes)))
12061206

12071207

1208-
if __name__ == "__main__":
1208+
if __name__ == "__main__": # pragma: no cover
12091209
import doctest
12101210

12111211
res = doctest.testmod()

‎rocketpy/utilities.py

+30-34
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from .simulation.flight import Flight
1515

1616

17-
# TODO: Needs tests
1817
def compute_cd_s_from_drop_test(
1918
terminal_velocity, rocket_mass, air_density=1.225, g=9.80665
2019
):
@@ -39,13 +38,34 @@ def compute_cd_s_from_drop_test(
3938
-------
4039
cd_s : float
4140
Number equal to drag coefficient times reference area for parachute.
41+
"""
42+
return 2 * rocket_mass * g / ((terminal_velocity**2) * air_density)
43+
4244

45+
def check_constant(f, eps):
4346
"""
47+
Check for three consecutive elements in the list that are approximately
48+
equal within a tolerance.
4449
45-
return 2 * rocket_mass * g / ((terminal_velocity**2) * air_density)
50+
Parameters
51+
----------
52+
f : list or array
53+
A list or array of numerical values.
54+
eps : float
55+
The tolerance level for comparing the elements.
56+
57+
Returns
58+
-------
59+
int or None
60+
The index of the first element in the first sequence of three
61+
consecutive elements that are approximately equal within the tolerance.
62+
Returns None if no such sequence is found.
63+
"""
64+
for i in range(len(f) - 2):
65+
if abs(f[i + 2] - f[i + 1]) < eps and abs(f[i + 1] - f[i]) < eps:
66+
return i
4667

4768

48-
# TODO: Needs tests
4969
def calculate_equilibrium_altitude(
5070
rocket_mass,
5171
cd_s,
@@ -90,7 +110,6 @@ def calculate_equilibrium_altitude(
90110
affect the final result if the value is not high enough. Increase the
91111
estimative in case the final solution is not founded.
92112
93-
94113
Returns
95114
-------
96115
altitude_function: Function
@@ -103,30 +122,8 @@ def calculate_equilibrium_altitude(
103122
"""
104123
final_sol = {}
105124

106-
if v0 >= 0:
107-
print("Please set a valid negative value for v0")
108-
return None
109-
110-
# TODO: Improve docs
111-
def check_constant(f, eps):
112-
"""_summary_
113-
114-
Parameters
115-
----------
116-
f : array, list
117-
118-
_description_
119-
eps : float
120-
_description_
121-
122-
Returns
123-
-------
124-
int, None
125-
_description_
126-
"""
127-
for i in range(len(f) - 2):
128-
if abs(f[i + 2] - f[i + 1]) < eps and abs(f[i + 1] - f[i]) < eps:
129-
return i
125+
if v0 >= 0: # pragma: no cover
126+
raise ValueError("Please set a valid negative value for v0")
130127

131128
if env is None:
132129
environment = Environment(
@@ -138,21 +135,20 @@ def check_constant(f, eps):
138135
else:
139136
environment = env
140137

141-
# TODO: Improve docs
142138
def du(z, u):
143-
"""_summary_
139+
"""Returns the derivative of the velocity at a given altitude.
144140
145141
Parameters
146142
----------
147143
z : float
148-
_description_
144+
altitude, in meters, at a given time
149145
u : float
150146
velocity, in m/s, at a given z altitude
151147
152148
Returns
153149
-------
154150
float
155-
_description_
151+
velocity at a given altitude
156152
"""
157153
return (
158154
u[1],
@@ -258,7 +254,7 @@ def fin_flutter_analysis(
258254
found_fin = True
259255
else:
260256
warnings.warn("More than one fin set found. The last one will be used.")
261-
if not found_fin:
257+
if not found_fin: # pragma: no cover
262258
raise AttributeError(
263259
"There is no TrapezoidalFins in the rocket, can't run Flutter Analysis."
264260
)
@@ -442,7 +438,7 @@ def _flutter_prints(
442438

443439

444440
# TODO: deprecate and delete this function. Never used and now we have Monte Carlo.
445-
def create_dispersion_dictionary(filename):
441+
def create_dispersion_dictionary(filename): # pragma: no cover
446442
"""Creates a dictionary with the rocket data provided by a .csv file.
447443
File should be organized in four columns: attribute_class, parameter_name,
448444
mean_value, standard_deviation. The first row should be the header.

‎tests/fixtures/flight/flight_fixtures.py

+16
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ def flight_calisto_robust(calisto_robust, example_spaceport_env):
9393
)
9494

9595

96+
@pytest.fixture
97+
def flight_calisto_robust_solid_eom(calisto_robust, example_spaceport_env):
98+
"""Similar to flight_calisto_robust, but with the equations of motion set to
99+
"solid_propulsion".
100+
"""
101+
return Flight(
102+
environment=example_spaceport_env,
103+
rocket=calisto_robust,
104+
rail_length=5.2,
105+
inclination=85,
106+
heading=0,
107+
terminate_on_apogee=False,
108+
equations_of_motion="solid_propulsion",
109+
)
110+
111+
96112
@pytest.fixture
97113
def flight_calisto_liquid_modded(calisto_liquid_modded, example_plain_env):
98114
"""A rocketpy.Flight object of the Calisto rocket modded for a liquid

‎tests/fixtures/surfaces/surface_fixtures.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import pytest
22

3-
from rocketpy import NoseCone, RailButtons, Tail, TrapezoidalFins
4-
from rocketpy.rocket.aero_surface.fins.free_form_fins import FreeFormFins
3+
from rocketpy.rocket.aero_surface import (
4+
EllipticalFins,
5+
FreeFormFins,
6+
NoseCone,
7+
RailButtons,
8+
Tail,
9+
TrapezoidalFins,
10+
)
511

612

713
@pytest.fixture
@@ -94,3 +100,16 @@ def calisto_rail_buttons():
94100
angular_position=45,
95101
name="Rail Buttons",
96102
)
103+
104+
105+
@pytest.fixture
106+
def elliptical_fin_set():
107+
return EllipticalFins(
108+
n=4,
109+
span=0.100,
110+
root_chord=0.120,
111+
rocket_radius=0.0635,
112+
cant_angle=0,
113+
airfoil=None,
114+
name="Test Elliptical Fins",
115+
)

‎tests/integration/test_flight.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
plt.rcParams.update({"figure.max_open_warning": 0})
1111

1212

13+
@pytest.mark.parametrize(
14+
"flight_fixture", ["flight_calisto_robust", "flight_calisto_robust_solid_eom"]
15+
)
1316
@patch("matplotlib.pyplot.show")
1417
# pylint: disable=unused-argument
15-
def test_all_info(mock_show, flight_calisto_robust):
18+
def test_all_info(mock_show, request, flight_fixture):
1619
"""Test that the flight class is working as intended. This basically calls
1720
the all_info() method and checks if it returns None. It is not testing if
1821
the values are correct, but whether the method is working without errors.
@@ -21,11 +24,13 @@ def test_all_info(mock_show, flight_calisto_robust):
2124
----------
2225
mock_show : unittest.mock.MagicMock
2326
Mock object to replace matplotlib.pyplot.show
24-
flight_calisto_robust : rocketpy.Flight
25-
Flight object to be tested. See the conftest.py file for more info
26-
regarding this pytest fixture.
27+
request : _pytest.fixtures.FixtureRequest
28+
Request object to access the fixture dynamically.
29+
flight_fixture : str
30+
Name of the flight fixture to be tested.
2731
"""
28-
assert flight_calisto_robust.all_info() is None
32+
flight = request.getfixturevalue(flight_fixture)
33+
assert flight.all_info() is None
2934

3035

3136
@pytest.mark.slow

‎tests/unit/test_aero_surfaces.py

+68
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from unittest.mock import patch
2+
13
import pytest
24

35
from rocketpy import NoseCone
@@ -71,3 +73,69 @@ def test_powerseries_nosecones_setters(power, invalid_power, new_power):
7173
expected_k = (2 * new_power) / ((2 * new_power) + 1)
7274

7375
assert pytest.approx(test_nosecone.k) == expected_k
76+
77+
78+
@patch("matplotlib.pyplot.show")
79+
def test_elliptical_fins_draw(
80+
mock_show, elliptical_fin_set
81+
): # pylint: disable=unused-argument
82+
assert elliptical_fin_set.plots.draw(filename=None) is None
83+
84+
85+
def test_nose_cone_info(calisto_nose_cone):
86+
assert calisto_nose_cone.info() is None
87+
88+
89+
@patch("matplotlib.pyplot.show")
90+
def test_nose_cone_draw(
91+
mock_show, calisto_nose_cone
92+
): # pylint: disable=unused-argument
93+
assert calisto_nose_cone.draw(filename=None) is None
94+
95+
96+
def test_trapezoidal_fins_info(calisto_trapezoidal_fins):
97+
assert calisto_trapezoidal_fins.info() is None
98+
99+
100+
def test_trapezoidal_fins_tip_chord_setter(calisto_trapezoidal_fins):
101+
calisto_trapezoidal_fins.tip_chord = 0.1
102+
assert calisto_trapezoidal_fins.tip_chord == 0.1
103+
104+
105+
def test_trapezoidal_fins_root_chord_setter(calisto_trapezoidal_fins):
106+
calisto_trapezoidal_fins.root_chord = 0.1
107+
assert calisto_trapezoidal_fins.root_chord == 0.1
108+
109+
110+
def test_trapezoidal_fins_sweep_angle_setter(calisto_trapezoidal_fins):
111+
calisto_trapezoidal_fins.sweep_angle = 0.1
112+
assert calisto_trapezoidal_fins.sweep_angle == 0.1
113+
114+
115+
def test_trapezoidal_fins_sweep_length_setter(calisto_trapezoidal_fins):
116+
calisto_trapezoidal_fins.sweep_length = 0.1
117+
assert calisto_trapezoidal_fins.sweep_length == 0.1
118+
119+
120+
def test_tail_info(calisto_tail):
121+
assert calisto_tail.info() is None
122+
123+
124+
def test_tail_length_setter(calisto_tail):
125+
calisto_tail.length = 0.1
126+
assert calisto_tail.length == 0.1
127+
128+
129+
def test_tail_rocket_radius_setter(calisto_tail):
130+
calisto_tail.rocket_radius = 0.1
131+
assert calisto_tail.rocket_radius == 0.1
132+
133+
134+
def test_tail_bottom_radius_setter(calisto_tail):
135+
calisto_tail.bottom_radius = 0.1
136+
assert calisto_tail.bottom_radius == 0.1
137+
138+
139+
def test_tail_top_radius_setter(calisto_tail):
140+
calisto_tail.top_radius = 0.1
141+
assert calisto_tail.top_radius == 0.1

‎tests/unit/test_flight_time_nodes.py

+10
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,13 @@ def test_time_node_lt(flight_calisto):
9999
node2 = flight_calisto.TimeNodes.TimeNode(2.0, [], [], [])
100100
assert node1 < node2
101101
assert not node2 < node1
102+
103+
104+
def test_time_node_repr(flight_calisto):
105+
node = flight_calisto.TimeNodes.TimeNode(1.0, [], [], [])
106+
assert isinstance(repr(node), str)
107+
108+
109+
def test_time_nodes_repr(flight_calisto):
110+
time_nodes = flight_calisto.TimeNodes()
111+
assert isinstance(repr(time_nodes), str)

‎tests/unit/test_function.py

+104
Original file line numberDiff line numberDiff line change
@@ -787,3 +787,107 @@ def test_low_pass_filter(alpha):
787787
f"The filtered value at index {i} is not the expected value. "
788788
f"Expected: {expected}, Actual: {filtered_func.source[i][1]}"
789789
)
790+
791+
792+
def test_average_function_ndarray():
793+
794+
dummy_function = Function(
795+
source=[
796+
[0, 0],
797+
[1, 1],
798+
[2, 0],
799+
[3, 1],
800+
[4, 0],
801+
[5, 1],
802+
[6, 0],
803+
[7, 1],
804+
[8, 0],
805+
[9, 1],
806+
],
807+
inputs=["x"],
808+
outputs=["y"],
809+
)
810+
avg_function = dummy_function.average_function()
811+
812+
assert isinstance(avg_function, Function)
813+
assert np.isclose(avg_function(0), 0)
814+
assert np.isclose(avg_function(9), 0.5)
815+
816+
817+
def test_average_function_callable():
818+
819+
dummy_function = Function(lambda x: 2)
820+
avg_function = dummy_function.average_function(lower=0)
821+
822+
assert isinstance(avg_function, Function)
823+
assert np.isclose(avg_function(1), 2)
824+
assert np.isclose(avg_function(9), 2)
825+
826+
827+
@pytest.mark.parametrize(
828+
"lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive",
829+
[
830+
(0, 10, 100, 1, 0.5, True, True),
831+
(0, 10, 100, 1, 0.5, True, False),
832+
(0, 10, 100, 1, 0.5, False, True),
833+
(0, 10, 100, 1, 0.5, False, False),
834+
(0, 20, 200, 2, 1, True, True),
835+
],
836+
)
837+
def test_short_time_fft(
838+
lower, upper, sampling_frequency, window_size, step_size, remove_dc, only_positive
839+
):
840+
"""Test the short_time_fft method of the Function class.
841+
842+
Parameters
843+
----------
844+
lower : float
845+
Lower bound of the time range.
846+
upper : float
847+
Upper bound of the time range.
848+
sampling_frequency : float
849+
Sampling frequency at which to perform the Fourier transform.
850+
window_size : float
851+
Size of the window for the STFT, in seconds.
852+
step_size : float
853+
Step size for the window, in seconds.
854+
remove_dc : bool
855+
If True, the DC component is removed from each window before
856+
computing the Fourier transform.
857+
only_positive: bool
858+
If True, only the positive frequencies are returned.
859+
"""
860+
# Generate a test signal
861+
t = np.linspace(lower, upper, int((upper - lower) * sampling_frequency))
862+
signal = np.sin(2 * np.pi * 5 * t) # 5 Hz sine wave
863+
func = Function(np.column_stack((t, signal)))
864+
865+
# Perform STFT
866+
stft_results = func.short_time_fft(
867+
lower=lower,
868+
upper=upper,
869+
sampling_frequency=sampling_frequency,
870+
window_size=window_size,
871+
step_size=step_size,
872+
remove_dc=remove_dc,
873+
only_positive=only_positive,
874+
)
875+
876+
# Check the results
877+
assert isinstance(stft_results, list)
878+
assert all(isinstance(f, Function) for f in stft_results)
879+
880+
for f in stft_results:
881+
assert f.get_inputs() == ["Frequency (Hz)"]
882+
assert f.get_outputs() == ["Amplitude"]
883+
assert f.get_interpolation_method() == "linear"
884+
assert f.get_extrapolation_method() == "zero"
885+
886+
frequencies = f.source[:, 0]
887+
# amplitudes = f.source[:, 1]
888+
889+
if only_positive:
890+
assert np.all(frequencies >= 0)
891+
else:
892+
assert np.all(frequencies >= -sampling_frequency / 2)
893+
assert np.all(frequencies <= sampling_frequency / 2)

‎tests/unit/test_sensitivity.py

+70-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
from unittest.mock import patch
2+
13
import numpy as np
24
import pytest
35

46
from rocketpy.sensitivity import SensitivityModel
57

6-
# TODO: for some weird reason, these tests are not passing in the CI, but
7-
# passing locally. Need to investigate why.
8-
98

10-
@pytest.mark.skip(reason="legacy test")
119
def test_initialization():
1210
parameters_names = ["param1", "param2"]
1311
target_variables_names = ["target1", "target2"]
@@ -21,7 +19,6 @@ def test_initialization():
2119
assert not model._fitted
2220

2321

24-
@pytest.mark.skip(reason="legacy test")
2522
def test_set_parameters_nominal():
2623
parameters_names = ["param1", "param2"]
2724
target_variables_names = ["target1", "target2"]
@@ -35,8 +32,16 @@ def test_set_parameters_nominal():
3532
assert model.parameters_info["param1"]["nominal_mean"] == 1.0
3633
assert model.parameters_info["param2"]["nominal_sd"] == 0.2
3734

35+
# check dimensions mismatch error raise
36+
incorrect_nominal_mean = np.array([1.0])
37+
with pytest.raises(ValueError):
38+
model.set_parameters_nominal(incorrect_nominal_mean, parameters_nominal_sd)
39+
40+
incorrect_nominal_sd = np.array([0.1])
41+
with pytest.raises(ValueError):
42+
model.set_parameters_nominal(parameters_nominal_mean, incorrect_nominal_sd)
43+
3844

39-
@pytest.mark.skip(reason="legacy test")
4045
def test_set_target_variables_nominal():
4146
parameters_names = ["param1", "param2"]
4247
target_variables_names = ["target1", "target2"]
@@ -49,9 +54,13 @@ def test_set_target_variables_nominal():
4954
assert model.target_variables_info["target1"]["nominal_value"] == 10.0
5055
assert model.target_variables_info["target2"]["nominal_value"] == 20.0
5156

57+
# check dimensions mismatch error raise
58+
incorrect_nominal_value = np.array([10.0])
59+
with pytest.raises(ValueError):
60+
model.set_target_variables_nominal(incorrect_nominal_value)
61+
5262

53-
@pytest.mark.skip(reason="legacy test")
54-
def test_fit_method():
63+
def test_fit_method_one_target():
5564
parameters_names = ["param1", "param2"]
5665
target_variables_names = ["target1"]
5766
model = SensitivityModel(parameters_names, target_variables_names)
@@ -65,7 +74,20 @@ def test_fit_method():
6574
assert model.number_of_samples == 3
6675

6776

68-
@pytest.mark.skip(reason="legacy test")
77+
def test_fit_method_multiple_target():
78+
parameters_names = ["param1", "param2"]
79+
target_variables_names = ["target1", "target2"]
80+
model = SensitivityModel(parameters_names, target_variables_names)
81+
82+
parameters_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
83+
target_data = np.array([[10.0, 12.0, 14.0], [11.0, 13.0, 17.0]]).T
84+
85+
model.fit(parameters_matrix, target_data)
86+
87+
assert model._fitted
88+
assert model.number_of_samples == 3
89+
90+
6991
def test_fit_raises_error_on_mismatched_dimensions():
7092
parameters_names = ["param1", "param2"]
7193
target_variables_names = ["target1"]
@@ -78,7 +100,6 @@ def test_fit_raises_error_on_mismatched_dimensions():
78100
model.fit(parameters_matrix, target_data)
79101

80102

81-
@pytest.mark.skip(reason="legacy test")
82103
def test_check_conformity():
83104
parameters_names = ["param1", "param2"]
84105
target_variables_names = ["target1", "target2"]
@@ -90,7 +111,6 @@ def test_check_conformity():
90111
model._SensitivityModel__check_conformity(parameters_matrix, target_data)
91112

92113

93-
@pytest.mark.skip(reason="legacy test")
94114
def test_check_conformity_raises_error():
95115
parameters_names = ["param1", "param2"]
96116
target_variables_names = ["target1", "target2"]
@@ -101,3 +121,42 @@ def test_check_conformity_raises_error():
101121

102122
with pytest.raises(ValueError):
103123
model._SensitivityModel__check_conformity(parameters_matrix, target_data)
124+
125+
parameters_matrix2 = np.array([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])
126+
127+
with pytest.raises(ValueError):
128+
model._SensitivityModel__check_conformity(parameters_matrix2, target_data)
129+
130+
target_data2 = np.array([10.0, 12.0])
131+
132+
with pytest.raises(ValueError):
133+
model._SensitivityModel__check_conformity(parameters_matrix, target_data2)
134+
135+
target_variables_names = ["target1"]
136+
model = SensitivityModel(parameters_names, target_variables_names)
137+
138+
target_data = np.array([[10.0, 20.0], [12.0, 22.0], [14.0, 24.0]])
139+
140+
with pytest.raises(ValueError):
141+
model._SensitivityModel__check_conformity(parameters_matrix, target_data)
142+
143+
144+
@patch("matplotlib.pyplot.show")
145+
def test_prints_and_plots(mock_show): # pylint: disable=unused-argument
146+
parameters_names = ["param1", "param2"]
147+
target_variables_names = ["target1"]
148+
model = SensitivityModel(parameters_names, target_variables_names)
149+
150+
parameters_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
151+
target_data = np.array([10.0, 12.0, 14.0])
152+
153+
# tests if an error is raised if summary is called before print
154+
with pytest.raises(ValueError):
155+
model.info()
156+
157+
model.fit(parameters_matrix, target_data)
158+
assert model.all_info() is None
159+
160+
nominal_target = np.array([12.0])
161+
model.set_target_variables_nominal(nominal_target)
162+
assert model.all_info() is None

‎tests/unit/test_tank.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from math import isclose
22
from pathlib import Path
3+
from unittest.mock import patch
34

45
import numpy as np
56
import pytest
67
import scipy.integrate as spi
78

9+
from rocketpy.motors import TankGeometry
10+
811
BASE_PATH = Path("./data/rockets/berkeley/")
912

1013

@@ -355,3 +358,8 @@ def expected_gas_inertia(t):
355358
atol=1e-3,
356359
rtol=1e-2,
357360
)
361+
362+
363+
@patch("matplotlib.pyplot.show")
364+
def test_tank_geometry_plots_info(mock_show): # pylint: disable=unused-argument
365+
assert TankGeometry({(0, 5): 1}).plots.all() is None

‎tests/unit/test_tools.py

+28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
euler313_to_quaternions,
77
find_roots_cubic_function,
88
haversine,
9+
tuple_handler,
910
)
1011

1112

@@ -72,3 +73,30 @@ def test_cardanos_root_finding():
7273
def test_haversine(lat0, lon0, lat1, lon1, expected_distance):
7374
distance = haversine(lat0, lon0, lat1, lon1)
7475
assert np.isclose(distance, expected_distance, rtol=1e-2)
76+
77+
78+
@pytest.mark.parametrize(
79+
"input_value, expected_output",
80+
[
81+
(5, (0, 5)),
82+
(3.5, (0, 3.5)),
83+
([7], (0, 7)),
84+
((8,), (0, 8)),
85+
([2, 4], (2, 4)),
86+
((1, 3), (1, 3)),
87+
],
88+
)
89+
def test_tuple_handler(input_value, expected_output):
90+
assert tuple_handler(input_value) == expected_output
91+
92+
93+
@pytest.mark.parametrize(
94+
"input_value, expected_exception",
95+
[
96+
([1, 2, 3], ValueError),
97+
((4, 5, 6), ValueError),
98+
],
99+
)
100+
def test_tuple_handler_exceptions(input_value, expected_exception):
101+
with pytest.raises(expected_exception):
102+
tuple_handler(input_value)

‎tests/unit/test_units.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
3+
from rocketpy.units import conversion_factor, convert_temperature, convert_units
4+
5+
6+
class TestConvertTemperature:
7+
"""Tests for the convert_temperature function."""
8+
9+
def test_convert_temperature_same_unit(self):
10+
assert convert_temperature(300, "K", "K") == 300
11+
assert convert_temperature(27, "degC", "degC") == 27
12+
assert convert_temperature(80, "degF", "degF") == 80
13+
14+
def test_convert_temperature_kelvin_to_celsius(self):
15+
assert convert_temperature(300, "K", "degC") == pytest.approx(26.85, rel=1e-2)
16+
17+
def test_convert_temperature_kelvin_to_fahrenheit(self):
18+
assert convert_temperature(300, "K", "degF") == pytest.approx(80.33, rel=1e-2)
19+
20+
def test_convert_temperature_celsius_to_kelvin(self):
21+
assert convert_temperature(27, "degC", "K") == pytest.approx(300.15, rel=1e-2)
22+
23+
def test_convert_temperature_celsius_to_fahrenheit(self):
24+
assert convert_temperature(27, "degC", "degF") == pytest.approx(80.6, rel=1e-2)
25+
26+
def test_convert_temperature_fahrenheit_to_kelvin(self):
27+
assert convert_temperature(80, "degF", "K") == pytest.approx(299.817, rel=1e-2)
28+
29+
def test_convert_temperature_fahrenheit_to_celsius(self):
30+
assert convert_temperature(80, "degF", "degC") == pytest.approx(26.67, rel=1e-2)
31+
32+
def test_convert_temperature_invalid_conversion(self):
33+
with pytest.raises(ValueError):
34+
convert_temperature(300, "K", "invalid_unit")
35+
with pytest.raises(ValueError):
36+
convert_temperature(300, "invalid_unit", "K")
37+
38+
39+
class TestConversionFactor:
40+
"""Tests for the conversion_factor function."""
41+
42+
def test_conversion_factor_same_unit(self):
43+
assert conversion_factor("m", "m") == 1
44+
assert conversion_factor("ft", "ft") == 1
45+
assert conversion_factor("s", "s") == 1
46+
47+
def test_conversion_factor_m_to_ft(self):
48+
assert conversion_factor("m", "ft") == pytest.approx(3.28084, rel=1e-2)
49+
50+
def test_conversion_factor_ft_to_m(self):
51+
assert conversion_factor("ft", "m") == pytest.approx(0.3048, rel=1e-2)
52+
53+
def test_conversion_factor_s_to_min(self):
54+
assert conversion_factor("s", "min") == pytest.approx(1 / 60, rel=1e-2)
55+
56+
def test_conversion_factor_min_to_s(self):
57+
assert conversion_factor("min", "s") == pytest.approx(60, rel=1e-2)
58+
59+
def test_conversion_factor_invalid_conversion(self):
60+
with pytest.raises(ValueError):
61+
conversion_factor("m", "invalid_unit")
62+
with pytest.raises(ValueError):
63+
conversion_factor("invalid_unit", "m")
64+
65+
66+
class TestConvertUnits:
67+
"""Tests for the convert_units function."""
68+
69+
def test_convert_units_same_unit(self):
70+
assert convert_units(300, "K", "K") == 300
71+
assert convert_units(27, "degC", "degC") == 27
72+
assert convert_units(80, "degF", "degF") == 80
73+
74+
def test_convert_units_kelvin_to_celsius(self):
75+
assert convert_units(300, "K", "degC") == pytest.approx(26.85, rel=1e-2)
76+
77+
def test_convert_units_kelvin_to_fahrenheit(self):
78+
assert convert_units(300, "K", "degF") == pytest.approx(80.33, rel=1e-2)
79+
80+
def test_convert_units_kilogram_to_pound(self):
81+
assert convert_units(1, "kg", "lb") == pytest.approx(2.20462, rel=1e-2)
82+
83+
def test_convert_units_kilometer_to_mile(self):
84+
assert convert_units(1, "km", "mi") == pytest.approx(0.621371, rel=1e-2)

‎tests/unit/test_utilities.py

+31
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,34 @@ def test_get_instance_attributes(flight_calisto_robust):
178178
assert np.allclose(attr, value)
179179
else:
180180
assert attr == value
181+
182+
183+
@pytest.mark.parametrize(
184+
"f, eps, expected",
185+
[
186+
([1.0, 1.0, 1.0, 2.0, 3.0], 1e-6, 0),
187+
([1.0, 1.0, 1.0, 2.0, 3.0], 1e-1, 0),
188+
([1.0, 1.1, 1.2, 2.0, 3.0], 1e-1, None),
189+
([1.0, 1.0, 1.0, 1.0, 1.0], 1e-6, 0),
190+
([1.0, 1.0, 1.0, 1.0, 1.0], 1e-1, 0),
191+
([1.0, 1.0, 1.0], 1e-6, 0),
192+
([1.0, 1.0], 1e-6, None),
193+
([1.0], 1e-6, None),
194+
([], 1e-6, None),
195+
],
196+
)
197+
def test_check_constant(f, eps, expected):
198+
"""Test if the function `check_constant` returns the correct index or None
199+
for different scenarios.
200+
201+
Parameters
202+
----------
203+
f : list or array
204+
A list or array of numerical values.
205+
eps : float
206+
The tolerance level for comparing the elements.
207+
expected : int or None
208+
The expected result of the function.
209+
"""
210+
result = utilities.check_constant(f, eps)
211+
assert result == expected

0 commit comments

Comments
 (0)
Please sign in to comment.