This repository has been archived by the owner on Mar 11, 2021. It is now read-only.
forked from shiffman/Tensorflow-JS-Examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.js
65 lines (60 loc) · 1.77 KB
/
classifier.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// Based on: https://github.com/tensorflow/tfjs-examples/tree/master/mnist
class Classifier {
constructor() {
this.model = tf.sequential();
this.model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
this.model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2]
}));
this.model.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'VarianceScaling'
}));
this.model.add(tf.layers.maxPooling2d({
poolSize: [2, 2],
strides: [2, 2]
}));
this.model.add(tf.layers.flatten());
this.model.add(tf.layers.dense({
units: CLASSES,
kernelInitializer: 'VarianceScaling',
activation: 'softmax'
}));
const LEARNING_RATE = 0.15;
const optimizer = tf.train.sgd(LEARNING_RATE);
this.model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
}
async train(data) {
const batchSize = 100;
const iterations = data.total / batchSize;
for (let i = 0; i < iterations; i++) {
const batch = data.getTrainBatch(batchSize, i * batchSize);
const batchData = batch.data.reshape([batchSize, 28, 28, 1]);
const batchLabels = batch.labels;
const options = {
batchSize: batchSize,
validationData: null,
epochs: 1
}
const history = await this.model.fit(batchData, batchLabels, options);
const loss = history.history.loss[0];
const accuracy = history.history.acc[0];
console.log(`batch: ${i} loss: ${nf(loss, 2, 2)} accuracy: ${nf(accuracy, 2, 2)}`);
}
}
}