-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainModel.js
79 lines (76 loc) · 2.24 KB
/
trainModel.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
let zeros = [],
ones = [],
twos = [],
threes = [],
fours = [],
fives = [],
sixes = [],
sevens = [],
eights = [],
nines = [];
function preload() {
console.log("Loading Data...");
for (let i = 1; i <= 140; i++) {
zeros[i - 1] = loadImage(`data/zeros/zero (${i}).jpg`);
ones[i - 1] = loadImage(`data/ones/one (${i}).jpg`);
twos[i - 1] = loadImage(`data/twos/two (${i}).jpg`);
threes[i - 1] = loadImage(`data/threes/three (${i}).jpg`);
fours[i - 1] = loadImage(`data/fours/four (${i}).jpg`);
fives[i - 1] = loadImage(`data/fives/five (${i}).jpg`);
sixes[i - 1] = loadImage(`data/sixes/six (${i}).jpg`);
sevens[i - 1] = loadImage(`data/sevens/seven (${i}).jpg`);
eights[i - 1] = loadImage(`data/eights/eight (${i}).jpg`);
nines[i - 1] = loadImage(`data/nines/nine (${i}).jpg`);
}
console.log("Data Loaded...");
}
let canvas;
let textP;
let nn;
function setup() {
console.log("Init NN");
nn = ml5.neuralNetwork({
inputs: [56, 56, 4],
task: "imageClassification",
debug: true,
});
// const modelDetails = {
// model: "model/model.json",
// metadata: "model/model_meta.json",
// weights: "model/model.weights.bin",
// };
// nn.load(modelDetails, () => {
// console.log("Model Loaded!");
// trainModel();
// });
console.log("Init NN, done!");
// function trainModel() {
for (let i = 0; i < 140; i++) {
nn.addData({ image: zeros[i] }, { target: "zero" });
nn.addData({ image: ones[i] }, { target: "one" });
nn.addData({ image: twos[i] }, { target: "two" });
nn.addData({ image: threes[i] }, { target: "three" });
nn.addData({ image: fours[i] }, { target: "four" });
nn.addData({ image: fives[i] }, { target: "five" });
nn.addData({ image: sixes[i] }, { target: "six" });
nn.addData({ image: sevens[i] }, { target: "seven" });
nn.addData({ image: eights[i] }, { target: "eight" });
nn.addData({ image: nines[i] }, { target: "nine" });
}
console.log("added data");
nn.normalizeData();
console.log("normalized");
console.log("Training...");
nn.train(
{
epochs: 200,
batchSize: 70,
},
finishedTraining
);
function finishedTraining() {
nn.save("model");
console.log("Model saved!");
}
// }
}