@@ -87,16 +87,14 @@ def train(self, x, y):
87
87
if y .ndim != 1 :
88
88
raise ValueError ("Data set labels must be one-dimensional" )
89
89
90
- if len (x ) < 2 * self .min_leaf_size :
91
- self .prediction = np .mean (y )
92
- return
90
+ mean_y = np .mean (y )
93
91
94
- if self .depth == 1 :
95
- self .prediction = np . mean ( y )
92
+ if len ( x ) < 2 * self . min_leaf_size or self .depth == 1 :
93
+ self .prediction = mean_y
96
94
return
97
-
95
+
98
96
best_split = 0
99
- min_error = self .mean_squared_error (x , np . mean ( y ) ) * 2
97
+ min_error = self .mean_squared_error (x , mean_y ) * 2
100
98
101
99
"""
102
100
loop over all possible splits for the decision tree. find the best split.
@@ -105,17 +103,21 @@ def train(self, x, y):
105
103
the predictor
106
104
"""
107
105
for i in range (len (x )):
108
- if len (x [:i ]) < self .min_leaf_size : # noqa: SIM114
109
- continue
110
- elif len (x [i :]) < self .min_leaf_size :
106
+ if len (x [:i ]) < self .min_leaf_size or len (x [i :]) < self .min_leaf_size :
111
107
continue
112
- else :
113
- error_left = self .mean_squared_error (x [:i ], np .mean (y [:i ]))
114
- error_right = self .mean_squared_error (x [i :], np .mean (y [i :]))
115
- error = error_left + error_right
116
- if error < min_error :
117
- best_split = i
118
- min_error = error
108
+
109
+ left_y = y [:i ]
110
+ right_y = y [i :]
111
+ mean_left = np .mean (left_y )
112
+ mean_right = np .mean (right_y )
113
+
114
+ error_left = self .mean_squared_error (left_y , mean_left )
115
+ error_right = self .mean_squared_error (right_y , mean_right )
116
+ error = error_left + error_right
117
+
118
+ if error < min_error :
119
+ best_split = i
120
+ min_error = error
119
121
120
122
if best_split != 0 :
121
123
left_x = x [:best_split ]
@@ -184,7 +186,7 @@ def main():
184
186
x = np .arange (- 1.0 , 1.0 , 0.005 )
185
187
y = np .sin (x )
186
188
187
- tree = DecisionTree (depth = 10 , min_leaf_size = 10 )
189
+ tree = DecisionTree (depth = 6 , min_leaf_size = 10 )
188
190
tree .train (x , y )
189
191
190
192
rng = np .random .default_rng ()
@@ -201,4 +203,4 @@ def main():
201
203
main ()
202
204
import doctest
203
205
204
- doctest .testmod (name = "mean_squarred_error " , verbose = True )
206
+ doctest .testmod (name = "mean_squared_error " , verbose = True )
0 commit comments