1
1
from sympy_descent .helpers import *
2
2
from sympy_descent .data_loader import DataLoader
3
3
4
+
5
+ class NoThetaException (Exception ):
6
+ """No thetas found. Should start with `t`."""
7
+
8
+ class DivergentModelException (Exception ):
9
+ """Whoops"""
10
+
11
+
4
12
class Model :
5
13
6
14
# =========================================================================
@@ -37,7 +45,7 @@ def search(head):
37
45
search (self .model )
38
46
39
47
if len (thetas ) == 0 :
40
- raise Exception ( f'No Thetas found. Should start with { THETA_PREFIX } ' )
48
+ raise NoThetaException
41
49
42
50
self .thetas = sorted (tuple (thetas ), key = str )
43
51
@@ -66,29 +74,34 @@ def cost_partial(t):
66
74
67
75
# =========================================================================
68
76
def descend (self , initial_thetas ,
69
- alpha = .05 , momentum = 0.5 , threshold = .01 ):
77
+ alpha = .01 , momentum = 0.5 , threshold = .01 ):
78
+
79
+ if initial_thetas == 'rand' :
80
+ initial_thetas = list (np .random .rand (len (self .thetas )))
70
81
71
82
assert len (initial_thetas ) == len (self .thetas )
72
83
73
84
self .theta_vals = initial_thetas
74
85
75
86
deltas = np .zeros (len (self .thetas ))
76
87
77
- for i in range (999 ):
78
-
79
- self .save_weighted_model ()
88
+ for i in range (1 ,500 ):
80
89
81
- # print(i, '\t', *(f'{float(s):.4f}'.ljust(8) for s in (*self.theta_vals, self.error())))
90
+ # if not i%50:
91
+ # self.save_weighted_model()
92
+ # print(i, '\t', *(f'{float(s):.4f}'.ljust(8) for s in (*self.theta_vals, self.error())))
82
93
83
94
deltas = (momentum * deltas
84
95
+ np .array ([grad .subs (self .theta_dict ) for grad in self .grad ]))
85
96
86
- if sum (abs (deltas )) < threshold : break
97
+ if sum (abs (deltas )) < threshold :
98
+ break
87
99
88
100
self .theta_vals -= deltas * alpha
89
101
90
- else :
91
- raise Exception ('maximum iterations exceeded. something is probably wrong' )
102
+ if any (abs (self .theta_vals ) > 9999 ):
103
+ raise DivergentModelException ()
104
+
92
105
93
106
self .save_weighted_model ()
94
107
@@ -100,7 +113,6 @@ def theta_dict(self):
100
113
def save_weighted_model (self ):
101
114
self .weighted_model = simplify (self .model .subs (self .theta_dict ))
102
115
103
-
104
116
def error (self ):
105
117
return np .sum (self .weighted_model .subs (row )** 2 for row in self .data )
106
118
0 commit comments