Skip to content

Commit c1fa49e

Browse files
authored
Update decision_tree.py
1 parent c0ad5bb commit c1fa49e

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

machine_learning/decision_tree.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,14 @@ def train(self, x, y):
8787
if y.ndim != 1:
8888
raise ValueError("Data set labels must be one-dimensional")
8989

90-
if len(x) < 2 * self.min_leaf_size:
91-
self.prediction = np.mean(y)
92-
return
90+
mean_y = np.mean(y)
9391

94-
if self.depth == 1:
95-
self.prediction = np.mean(y)
92+
if len(x) < 2 * self.min_leaf_size or self.depth == 1:
93+
self.prediction = mean_y
9694
return
97-
95+
9896
best_split = 0
99-
min_error = self.mean_squared_error(x, np.mean(y)) * 2
97+
min_error = self.mean_squared_error(x, mean_y) * 2
10098

10199
"""
102100
loop over all possible splits for the decision tree. find the best split.
@@ -105,17 +103,21 @@ def train(self, x, y):
105103
the predictor
106104
"""
107105
for i in range(len(x)):
108-
if len(x[:i]) < self.min_leaf_size: # noqa: SIM114
109-
continue
110-
elif len(x[i:]) < self.min_leaf_size:
106+
if len(x[:i]) < self.min_leaf_size or len(x[i:]) < self.min_leaf_size:
111107
continue
112-
else:
113-
error_left = self.mean_squared_error(x[:i], np.mean(y[:i]))
114-
error_right = self.mean_squared_error(x[i:], np.mean(y[i:]))
115-
error = error_left + error_right
116-
if error < min_error:
117-
best_split = i
118-
min_error = error
108+
109+
left_y = y[:i]
110+
right_y = y[i:]
111+
mean_left = np.mean(left_y)
112+
mean_right = np.mean(right_y)
113+
114+
error_left = self.mean_squared_error(left_y, mean_left)
115+
error_right = self.mean_squared_error(right_y, mean_right)
116+
error = error_left + error_right
117+
118+
if error < min_error:
119+
best_split = i
120+
min_error = error
119121

120122
if best_split != 0:
121123
left_x = x[:best_split]
@@ -184,7 +186,7 @@ def main():
184186
x = np.arange(-1.0, 1.0, 0.005)
185187
y = np.sin(x)
186188

187-
tree = DecisionTree(depth=10, min_leaf_size=10)
189+
tree = DecisionTree(depth=6, min_leaf_size=10)
188190
tree.train(x, y)
189191

190192
rng = np.random.default_rng()
@@ -201,4 +203,4 @@ def main():
201203
main()
202204
import doctest
203205

204-
doctest.testmod(name="mean_squarred_error", verbose=True)
206+
doctest.testmod(name="mean_squared_error", verbose=True)

0 commit comments

Comments
 (0)