1
+ from unittest .mock import patch
2
+
1
3
import numpy as np
2
4
import pytest
3
5
4
6
from rocketpy .sensitivity import SensitivityModel
5
7
6
- # TODO: for some weird reason, these tests are not passing in the CI, but
7
- # passing locally. Need to investigate why.
8
-
9
8
10
- @pytest .mark .skip (reason = "legacy test" )
11
9
def test_initialization ():
12
10
parameters_names = ["param1" , "param2" ]
13
11
target_variables_names = ["target1" , "target2" ]
@@ -21,7 +19,6 @@ def test_initialization():
21
19
assert not model ._fitted
22
20
23
21
24
- @pytest .mark .skip (reason = "legacy test" )
25
22
def test_set_parameters_nominal ():
26
23
parameters_names = ["param1" , "param2" ]
27
24
target_variables_names = ["target1" , "target2" ]
@@ -35,8 +32,16 @@ def test_set_parameters_nominal():
35
32
assert model .parameters_info ["param1" ]["nominal_mean" ] == 1.0
36
33
assert model .parameters_info ["param2" ]["nominal_sd" ] == 0.2
37
34
35
+ # check dimensions mismatch error raise
36
+ incorrect_nominal_mean = np .array ([1.0 ])
37
+ with pytest .raises (ValueError ):
38
+ model .set_parameters_nominal (incorrect_nominal_mean , parameters_nominal_sd )
39
+
40
+ incorrect_nominal_sd = np .array ([0.1 ])
41
+ with pytest .raises (ValueError ):
42
+ model .set_parameters_nominal (parameters_nominal_mean , incorrect_nominal_sd )
43
+
38
44
39
- @pytest .mark .skip (reason = "legacy test" )
40
45
def test_set_target_variables_nominal ():
41
46
parameters_names = ["param1" , "param2" ]
42
47
target_variables_names = ["target1" , "target2" ]
@@ -49,9 +54,13 @@ def test_set_target_variables_nominal():
49
54
assert model .target_variables_info ["target1" ]["nominal_value" ] == 10.0
50
55
assert model .target_variables_info ["target2" ]["nominal_value" ] == 20.0
51
56
57
+ # check dimensions mismatch error raise
58
+ incorrect_nominal_value = np .array ([10.0 ])
59
+ with pytest .raises (ValueError ):
60
+ model .set_target_variables_nominal (incorrect_nominal_value )
61
+
52
62
53
- @pytest .mark .skip (reason = "legacy test" )
54
- def test_fit_method ():
63
+ def test_fit_method_one_target ():
55
64
parameters_names = ["param1" , "param2" ]
56
65
target_variables_names = ["target1" ]
57
66
model = SensitivityModel (parameters_names , target_variables_names )
@@ -65,7 +74,20 @@ def test_fit_method():
65
74
assert model .number_of_samples == 3
66
75
67
76
68
- @pytest .mark .skip (reason = "legacy test" )
77
+ def test_fit_method_multiple_target ():
78
+ parameters_names = ["param1" , "param2" ]
79
+ target_variables_names = ["target1" , "target2" ]
80
+ model = SensitivityModel (parameters_names , target_variables_names )
81
+
82
+ parameters_matrix = np .array ([[1.0 , 2.0 ], [2.0 , 3.0 ], [3.0 , 4.0 ]])
83
+ target_data = np .array ([[10.0 , 12.0 , 14.0 ], [11.0 , 13.0 , 17.0 ]]).T
84
+
85
+ model .fit (parameters_matrix , target_data )
86
+
87
+ assert model ._fitted
88
+ assert model .number_of_samples == 3
89
+
90
+
69
91
def test_fit_raises_error_on_mismatched_dimensions ():
70
92
parameters_names = ["param1" , "param2" ]
71
93
target_variables_names = ["target1" ]
@@ -78,7 +100,6 @@ def test_fit_raises_error_on_mismatched_dimensions():
78
100
model .fit (parameters_matrix , target_data )
79
101
80
102
81
- @pytest .mark .skip (reason = "legacy test" )
82
103
def test_check_conformity ():
83
104
parameters_names = ["param1" , "param2" ]
84
105
target_variables_names = ["target1" , "target2" ]
@@ -90,7 +111,6 @@ def test_check_conformity():
90
111
model ._SensitivityModel__check_conformity (parameters_matrix , target_data )
91
112
92
113
93
- @pytest .mark .skip (reason = "legacy test" )
94
114
def test_check_conformity_raises_error ():
95
115
parameters_names = ["param1" , "param2" ]
96
116
target_variables_names = ["target1" , "target2" ]
@@ -101,3 +121,42 @@ def test_check_conformity_raises_error():
101
121
102
122
with pytest .raises (ValueError ):
103
123
model ._SensitivityModel__check_conformity (parameters_matrix , target_data )
124
+
125
+ parameters_matrix2 = np .array ([[1.0 , 2.0 , 3.0 ], [2.0 , 3.0 , 4.0 ]])
126
+
127
+ with pytest .raises (ValueError ):
128
+ model ._SensitivityModel__check_conformity (parameters_matrix2 , target_data )
129
+
130
+ target_data2 = np .array ([10.0 , 12.0 ])
131
+
132
+ with pytest .raises (ValueError ):
133
+ model ._SensitivityModel__check_conformity (parameters_matrix , target_data2 )
134
+
135
+ target_variables_names = ["target1" ]
136
+ model = SensitivityModel (parameters_names , target_variables_names )
137
+
138
+ target_data = np .array ([[10.0 , 20.0 ], [12.0 , 22.0 ], [14.0 , 24.0 ]])
139
+
140
+ with pytest .raises (ValueError ):
141
+ model ._SensitivityModel__check_conformity (parameters_matrix , target_data )
142
+
143
+
144
+ @patch ("matplotlib.pyplot.show" )
145
+ def test_prints_and_plots (mock_show ): # pylint: disable=unused-argument
146
+ parameters_names = ["param1" , "param2" ]
147
+ target_variables_names = ["target1" ]
148
+ model = SensitivityModel (parameters_names , target_variables_names )
149
+
150
+ parameters_matrix = np .array ([[1.0 , 2.0 ], [2.0 , 3.0 ], [3.0 , 4.0 ]])
151
+ target_data = np .array ([10.0 , 12.0 , 14.0 ])
152
+
153
+ # tests if an error is raised if summary is called before print
154
+ with pytest .raises (ValueError ):
155
+ model .info ()
156
+
157
+ model .fit (parameters_matrix , target_data )
158
+ assert model .all_info () is None
159
+
160
+ nominal_target = np .array ([12.0 ])
161
+ model .set_target_variables_nominal (nominal_target )
162
+ assert model .all_info () is None
0 commit comments