Skip to content

Commit 12dc959

Browse files
committed
Add API for specifying activation functions for neural network layers
1 parent b40d096 commit 12dc959

File tree

3 files changed

+96
-58
lines changed

3 files changed

+96
-58
lines changed

src/lambda_ml/core.clj

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
[coll]
2929
(first (apply max-key second (frequencies coll))))
3030

31-
(defn sigmoid
32-
[z]
33-
(/ 1 (+ 1 (expt Math/E (- z)))))
34-
3531
(defn random-partition
3632
"Returns n partitions of elements randomly selected from coll."
3733
[n coll]
@@ -66,3 +62,27 @@
6662
(sample-without-replacement (subvec (assoc coll index (first coll)) 1)
6763
(dec n)
6864
(conj s (nth coll index)))))))
65+
66+
;; Common functions
67+
68+
(defn relu
69+
[z]
70+
(max 0 z))
71+
72+
(defn relu'
73+
[z]
74+
(if (> z 0) 1 0))
75+
76+
(defn sigmoid
77+
[z]
78+
(/ 1 (+ 1 (expt Math/E (- z)))))
79+
80+
(defn sigmoid'
81+
[z]
82+
(* z (- 1 z)))
83+
84+
(defn derivative
85+
[f]
86+
(cond
87+
(= f relu) relu'
88+
(= f sigmoid) sigmoid'))

src/lambda_ml/neural_network.clj

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
```
66
(def data [[0 0 [0]] [0 1 [1]] [1 0 [1]] [1 1 [0]]])
77
(def fit
8-
(let [hidden-layers [3]
9-
alpha 0.5
10-
lambda 0.001]
11-
(-> #(neural-network-fit % data)
12-
(iterate (make-neural-network hidden-layers alpha lambda))
8+
(let [alpha 0.5
9+
lambda 0.001
10+
model (-> (make-neural-network alpha lambda)
11+
(add-neural-network-layer 2 sigmoid) ;; input layer
12+
(add-neural-network-layer 3 sigmoid) ;; hidden layer
13+
(add-neural-network-layer 1 sigmoid))] ;; output layer
14+
(-> (iterate #(neural-network-fit % data) model)
1315
(nth 5000))))
1416
(neural-network-predict fit (map butlast data))
1517
;;=> [[0.04262340225834812] [0.9582632706756758] [0.9581124103456861] [0.04103544440312673]]
@@ -29,38 +31,41 @@
2931
(defn feed-forward
3032
"Returns the activation values for nodes in a neural network after forward
3133
propagating the values of a single input example x through the network."
32-
[x theta]
33-
(reduce (fn [activations weights]
34+
[x theta fns]
35+
(reduce (fn [activations [weights f]]
3436
(let [inputs (if (empty? activations) (m/matrix x) (last activations))
3537
inputs+bias (m/join bias inputs)
36-
outputs (m/emap c/sigmoid (m/mmul weights inputs+bias))]
38+
outputs (m/emap f (m/mmul weights inputs+bias))]
3739
(conj activations outputs)))
3840
[]
39-
theta))
41+
(map vector theta fns)))
4042

4143
(defn feed-forward-batch
4244
"Returns the activation values for nodes in a neural network after forward
4345
propagating a collection of input examples x through the network."
44-
[x theta]
45-
(-> (reduce (fn [inputs weights]
46+
[x theta fns]
47+
(-> (reduce (fn [inputs [weights f]]
4648
(let [bias (m/broadcast 1.0 [1 (m/column-count inputs)])
4749
inputs+bias (m/join bias inputs)
48-
outputs (m/emap c/sigmoid (m/mmul weights inputs+bias))]
50+
outputs (m/emap f (m/mmul weights inputs+bias))]
4951
outputs))
5052
(m/transpose (m/matrix x))
51-
theta)
53+
(map vector theta fns))
5254
(m/transpose)))
5355

5456
(defn back-propagate
5557
"Returns the errors of each node in a neural network after propagating the
5658
the errors at the output nodes, computed against a single target value y,
5759
backwards through the network."
58-
[y theta activations output-error]
59-
(->> (map vector (reverse (rest theta)) (reverse (butlast activations)))
60-
(reduce (fn [errors [w a]]
61-
(cons (m/mul a (m/sub 1 a) (m/mmul (first errors) (drop-bias w)))
60+
[y theta fns' activations output-error]
61+
(->> (map vector
62+
(reverse (rest theta))
63+
(reverse (butlast activations))
64+
(reverse (butlast fns')))
65+
(reduce (fn [errors [w a f]]
66+
(cons (m/mul (m/emap f a) (m/mmul (first errors) (drop-bias w)))
6267
errors))
63-
(list (output-error y (last activations))))
68+
(list (output-error y (last activations) (last fns'))))
6469
(vec)))
6570

6671
(defn compute-gradients
@@ -77,39 +82,40 @@
7782
"Returns the numeric approximations of the gradients for each weight given the
7883
input values of a single example x and label y. Used for debugging by checking
7984
against the computed gradients during backpropagation."
80-
[x y theta cost]
85+
[x y theta fns cost]
8186
(mapv (fn [k weights]
8287
(m/matrix (for [i (range (m/row-count weights))]
8388
(for [j (range (m/column-count weights))]
8489
(let [w (m/select weights i j)
8590
theta+ (assoc theta k (m/set-selection weights i j (+ w epsilon)))
8691
theta- (assoc theta k (m/set-selection weights i j (- w epsilon)))]
87-
(/ (- (cost (list x) (list y) theta+)
88-
(cost (list x) (list y) theta-))
92+
(/ (- (cost (list x) (list y) theta+ fns)
93+
(cost (list x) (list y) theta- fns))
8994
(* 2 epsilon)))))))
9095
(range)
9196
theta))
9297

9398
(defn gradient-descent-step
9499
"Performs a single gradient step on the input and target values of a single
95100
example x and label y, and returns the updated weights."
96-
[x y theta alpha lambda cost output-error]
97-
(let [activations (feed-forward x theta)
98-
errors (back-propagate y theta activations output-error)
101+
[x y theta fns alpha lambda cost output-error]
102+
(let [activations (feed-forward x theta fns)
103+
errors (back-propagate y theta (map c/derivative fns) activations output-error)
99104
gradients (compute-gradients x activations errors)
100105
regularization (map (fn [w]
101106
(-> (m/mul alpha lambda w)
102107
(m/set-column 0 (m/matrix (repeat (m/row-count w) 0)))))
103108
theta)]
104109
;; Numeric gradient checking
105-
;;(println (map (comp #(/ (m/esum %) (m/ecount %)) m/abs m/sub) gradients (numeric-gradients x y theta cost)))
110+
;;(println (map (comp #(/ (m/esum %) (m/ecount %)) m/abs m/sub) gradients (numeric-gradients x y theta fns cost)))
106111
(mapv m/sub theta (map #(m/mul % alpha) gradients) regularization)))
107112

108113
(defn gradient-descent
109114
"Performs gradient descent on input and target values of all examples x and
110115
y, and returns the updated weights."
111116
[model x y]
112-
(let [{alpha :alpha lambda :lambda theta :parameters cost :cost output-error :output-error} model]
117+
(let [{alpha :alpha lambda :lambda theta :parameters cost :cost
118+
fns :activation-fns output-error :output-error} model]
113119
(loop [inputs x
114120
targets y
115121
weights theta]
@@ -120,6 +126,7 @@
120126
(gradient-descent-step (first inputs)
121127
(first targets)
122128
weights
129+
fns
123130
alpha
124131
lambda
125132
cost
@@ -139,24 +146,25 @@
139146
;; Cost functions
140147

141148
(defn cross-entropy-cost
142-
[x y theta]
143-
(let [a (feed-forward-batch x theta)]
149+
[x y theta fns]
150+
(let [a (feed-forward-batch x theta fns)]
144151
(/ (m/esum (m/add (m/mul y (m/log a))
145152
(m/mul (m/sub 1 y) (m/log (m/sub 1 a)))))
146153
(- (count x)))))
147154

148155
(defn cross-entropy-output-error
149-
[y activations]
156+
[y activations f']
157+
;; Cross entropy error is independent of the derivative of output activation
150158
(m/sub activations y))
151159

152160
(defn quadratic-cost
153-
[x y theta]
154-
(/ (m/esum (m/square (m/sub (feed-forward-batch x theta) y)))
161+
[x y theta fns]
162+
(/ (m/esum (m/square (m/sub (feed-forward-batch x theta fns) y)))
155163
2))
156164

157165
(defn quadratic-output-error
158-
[y activations]
159-
(m/mul (m/sub activations y) activations (m/sub 1 activations)))
166+
[y activations f']
167+
(m/mul (m/sub activations y) (m/emap f' activations)))
160168

161169
;; API
162170

@@ -166,30 +174,25 @@
166174
([model data]
167175
(neural-network-fit model (map (comp vec butlast) data) (map (comp vec last) data)))
168176
([model x y]
169-
(let [{hidden :hidden layers :layers theta :parameters} model
170-
layers (or layers
171-
(concat [(count (first x))] ;; number of input nodes
172-
hidden ;; number of nodes at each hidden layer
173-
[(count (first y))])) ;; number of output nodes
177+
(let [{layers :layers theta :parameters} model
174178
model (-> model
175-
(assoc :layers layers)
176179
(assoc :parameters (or theta (init-parameters layers))))]
177180
(assoc model :parameters (gradient-descent model x y)))))
178181

179182
(defn neural-network-predict
180183
"Predicts the values of example data using a neural network model."
181184
[model x]
182-
(let [{theta :parameters} model]
185+
(let [{theta :parameters fns :activation-fns} model]
183186
(when (not (nil? theta))
184-
(mapv vec (feed-forward-batch x theta)))))
187+
(mapv vec (feed-forward-batch x theta fns)))))
185188

186189
(defn neural-network-cost
187190
([model data]
188191
(neural-network-cost model (map (comp vec butlast) data) (map (comp vec last) data)))
189192
([model x y]
190-
(let [{theta :parameters cost :cost} model]
193+
(let [{theta :parameters fns :activation-fns cost :cost} model]
191194
(when (not (nil? theta))
192-
(cost x y theta)))))
195+
(cost x y theta fns)))))
193196

194197
(defn print-neural-network
195198
"Prints information about a given neural network."
@@ -202,16 +205,23 @@
202205
(str (dec (count (first thetai))) " x " (count thetai))))))))
203206

204207
(defn make-neural-network
205-
"Returns a neural network model where alpha is the learning rate and hidden is
206-
a sequence of numbers where the ith element is the number of nodes in the ith
207-
hidden layer."
208-
([hidden alpha lambda]
209-
(make-neural-network hidden alpha lambda cross-entropy-cost))
210-
([hidden alpha lambda cost]
208+
"Returns a neural network model where alpha is the learning rate."
209+
([alpha lambda]
210+
(make-neural-network alpha lambda cross-entropy-cost))
211+
([alpha lambda cost]
211212
{:alpha alpha
212213
:lambda lambda
213-
:hidden hidden
214+
:layers []
215+
:activation-fns []
214216
:cost cost
215217
:output-error (cond
216218
(= cost cross-entropy-cost) cross-entropy-output-error
217219
(= cost quadratic-cost) quadratic-output-error)}))
220+
221+
(defn add-neural-network-layer
222+
"Adds a layer to a neural network model with n nodes and an activation
223+
function f."
224+
[model n f]
225+
(-> model
226+
(update :layers #(conj % n))
227+
(update :activation-fns #(conj % f))))

test/lambda_ml/neural_network_test.clj

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
(ns lambda-ml.neural-network-test
22
(:require [clojure.test :refer :all]
3+
[lambda-ml.core :refer :all]
34
[lambda-ml.neural-network :refer :all]))
45

56
(deftest test-feed-forward
67
(let [weights [[[0.35 0.15 0.20]
78
[0.35 0.25 0.30]]
89
[[0.60 0.40 0.45]
910
[0.60 0.50 0.55]]]
11+
fs [sigmoid sigmoid]
1012
x [0.05 0.1]
11-
[hidden output] (feed-forward x weights)]
13+
[hidden output] (feed-forward x weights fs)]
1214
(is (< (Math/abs (- 0.593269920 (first hidden))) 1E-6))
1315
(is (< (Math/abs (- 0.596884378 (second hidden))) 1E-6))
1416
(is (< (Math/abs (- 0.751365070 (first output))) 1E-6))
@@ -20,8 +22,9 @@
2022
[ 0.5 0.3 -0.4]]
2123
[[-0.1 -0.4 0.1 0.6]
2224
[ 0.6 0.2 -0.1 -0.2]]]
25+
fs [sigmoid sigmoid]
2326
x [0.6 0.1]
24-
[hidden output] (feed-forward x weights)]
27+
[hidden output] (feed-forward x weights fs)]
2528
(is (< (Math/abs (- 0.53494294 (nth hidden 0))) 1E-6))
2629
(is (< (Math/abs (- 0.55477923 (nth hidden 1))) 1E-6))
2730
(is (< (Math/abs (- 0.65475346 (nth hidden 2))) 1E-6))
@@ -33,11 +36,13 @@
3336
[0.35 0.25 0.30]]
3437
[[0.60 0.40 0.45]
3538
[0.60 0.50 0.55]]]
39+
fs [sigmoid sigmoid]
3640
x [0.05 0.1]
3741
y [0.01 0.99]
3842
alpha 0.5
3943
lambda 0
40-
[w0 w1] (gradient-descent-step x y weights alpha lambda
44+
[w0 w1] (gradient-descent-step x y weights fs
45+
alpha lambda
4146
quadratic-cost
4247
quadratic-output-error)]
4348
(is (< (Math/abs (- 0.149780716 (nth (nth w0 0) 1))) 1E-6))
@@ -54,7 +59,10 @@
5459
[0 1 [1]]
5560
[1 0 [1]]
5661
[1 1 [0]]]
57-
model (make-neural-network [3] 0.5 0.0)
62+
model (-> (make-neural-network 0.5 0.0)
63+
(add-neural-network-layer 2 sigmoid)
64+
(add-neural-network-layer 3 sigmoid)
65+
(add-neural-network-layer 1 sigmoid))
5866
fit (nth (iterate #(neural-network-fit % data) model) 5000)
5967
predictions (map first (neural-network-predict fit (map butlast data)))]
6068
(is (> 0.1 (nth predictions 0)))

0 commit comments

Comments
 (0)