Skip to content

Commit d157514

Browse files
author
Kangyi Zhang
authored
Add a new Node.js example: abalone (#328)
* add new example * switch to cpu * address comments * address comments * add test data * update * remove nit * update
1 parent 7977a6c commit d157514

11 files changed

+1036
-0
lines changed

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ to another project.
1919
<th>Inference</th>
2020
<th>API type</th>
2121
<th>Save-load operations</th>
22+
<tr>
23+
<td><a href="./abalone-node">abalone-node</a></td>
24+
<td></td>
25+
<td>Numeric</td>
26+
<td>Loading data from local file and training in Node.js</td>
27+
<td>Multilayer perceptron</td>
28+
<td>Node.js</td>
29+
<td>Node.js</td>
30+
<td>Layers</td>
31+
<td>Saving to filesystem and loading in Node.js</td>
32+
</tr>
2233
<tr>
2334
<td><a href="./addition-rnn">addition-rnn</a></td>
2435
<td><a href="https://storage.googleapis.com/tfjs-examples/addition-rnn/dist/index.html">🔗</a></td>

abalone-node/README.md

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# TensorFlow.js Example: Abalone Age
2+
3+
This example shows how to predicting the age of abalone from physical measurements under Node.js
4+
5+
The data set available at [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Abalone).
6+
7+
This example shows how to
8+
* load a `Dataset` from a local csv file
9+
* prepare the Dataset for training
10+
* create a `tf.LayersModel` from scratch
11+
* train the model through `model.fitDataset()`
12+
* save the trained model to a local folder.
13+
14+
To launch the demo, do
15+
16+
```sh
17+
yarn
18+
yarn train
19+
```
20+
21+
By default, the training uses tfjs-node, which runs on the CPU.
22+
If you have a CUDA-enabled GPU and have the CUDA and CuDNN libraries
23+
set up properly on your system, you can run the training on the GPU
24+
by replacing the tfjs-node package with tfjs-node-gpu.

abalone-node/data.js

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const tf = require('@tensorflow/tfjs-node');
19+
20+
/**
21+
* Load a local csv file and prepare the data for training. Data source:
22+
* https://archive.ics.uci.edu/ml/datasets/Abalone
23+
*
24+
* @param {string} csvPath The path to csv file.
25+
* @returns {tf.data.Dataset} The loaded and prepared Dataset.
26+
*/
27+
async function createDataset(csvPath) {
28+
const dataset = tf.data.csv(
29+
csvPath, {hasHeader: true, columnConfigs: {'rings': {isLabel: true}}});
30+
const numOfColumns = (await dataset.columnNames()).length - 1;
31+
// Convert features and labels.
32+
return {
33+
dataset: dataset.map(row => {
34+
const rawFeatures = row['xs'];
35+
const rawLabel = row['ys'];
36+
const convertedFeatures = Object.keys(rawFeatures).map(key => {
37+
switch (rawFeatures[key]) {
38+
case 'F':
39+
return 0;
40+
case 'M':
41+
return 1;
42+
case 'I':
43+
return 2;
44+
default:
45+
return Number(rawFeatures[key]);
46+
}
47+
});
48+
const convertedLabel = [rawLabel['rings']];
49+
return {xs: convertedFeatures, ys: convertedLabel};
50+
}),
51+
numOfColumns: numOfColumns
52+
};
53+
}
54+
55+
module.exports = createDataset;

abalone-node/data_test.js

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const tf = require('@tensorflow/tfjs-node');
19+
const createDataset = require('./data');
20+
21+
describe('Dataset', () => {
22+
it('Created dataset and numOfColumns', async () => {
23+
const csvPath = 'file://./test_data.csv';
24+
const datasetObj = await createDataset(csvPath);
25+
const dataset = datasetObj.dataset;
26+
const row = await dataset.take(1).toArray();
27+
const numOfColumns = datasetObj.numOfColumns;
28+
expect(numOfColumns).toBe(8);
29+
const features = row[0].xs;
30+
const label = row[0].ys;
31+
expect(features.length).toBe(8);
32+
expect(features[0] === 0 || features[0] === 1 || features[0] === 2)
33+
.toBeTruthy();
34+
for (let i = 1; i < 8; i++) {
35+
expect(features[i]).toBeLessThan(1);
36+
expect(features[i]).toBeGreaterThan(0);
37+
}
38+
expect(Number.isInteger(label[0])).toBeTruthy();
39+
expect(label[0]).toBeGreaterThanOrEqual(2);
40+
expect(label[0]).toBeLessThanOrEqual(16);
41+
});
42+
});

abalone-node/model.js

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const tf = require('@tensorflow/tfjs-node');
19+
20+
/**
21+
* Builds and returns Multi Layer Perceptron Regression Model.
22+
*
23+
* @param {number} inputShape The input shape of the model.
24+
* @returns {tf.Sequential} The multi layer perceptron regression mode l.
25+
*/
26+
function createModel(inputShape) {
27+
const model = tf.sequential();
28+
model.add(tf.layers.dense({
29+
inputShape: inputShape,
30+
activation: 'sigmoid',
31+
units: 50,
32+
}));
33+
model.add(tf.layers.dense({
34+
activation: 'sigmoid',
35+
units: 50,
36+
}));
37+
model.add(tf.layers.dense({
38+
units: 1,
39+
}));
40+
model.compile({optimizer: tf.train.sgd(0.01), loss: 'meanSquaredError'});
41+
return model;
42+
}
43+
44+
module.exports = createModel;

abalone-node/model_test.js

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const tf = require('@tensorflow/tfjs-node');
19+
const shelljs = require('shelljs');
20+
const tmp = require('tmp');
21+
const fs = require('fs');
22+
const createModel = require('./model');
23+
24+
let tempDir;
25+
26+
describe('Model', () => {
27+
beforeEach(() => {
28+
tempDir = tmp.dirSync();
29+
});
30+
31+
afterEach(() => {
32+
if (fs.existsSync(tempDir)) {
33+
shelljs.rm('-rf', tempDir);
34+
}
35+
});
36+
37+
it('Created model can train', async () => {
38+
const inputLength = 6;
39+
const outputLength = 1;
40+
const model = createModel([inputLength]);
41+
expect(model.inputs.length).toEqual(1);
42+
expect(model.inputs[0].shape).toEqual([null, inputLength]);
43+
expect(model.outputs.length).toEqual(1);
44+
expect(model.outputs[0].shape).toEqual([null, outputLength]);
45+
46+
const numExamples = 3;
47+
const inputFeature = tf.ones([numExamples, inputLength]);
48+
const inputLabel = tf.ones([numExamples, outputLength]);
49+
const history = await model.fit(inputFeature, inputLabel, {epochs: 2});
50+
expect(history.history.loss.length).toEqual(2);
51+
});
52+
53+
it('Model save-load roundtrip', async () => {
54+
const inputLength = 6;
55+
const model = createModel([inputLength]);
56+
const numExamples = 3;
57+
const feature = tf.ones([numExamples, inputLength]);
58+
const y = model.predict(feature);
59+
60+
await model.save(`file://${tempDir.name}`);
61+
const modelPrime =
62+
await tf.loadLayersModel(`file://${tempDir.name}/model.json`);
63+
const yPrime = modelPrime.predict([feature]);
64+
tf.test_util.expectArraysClose(yPrime, y);
65+
});
66+
});

abalone-node/package.json

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"name": "tfjs-abalone-node",
3+
"version": "1.0.0",
4+
"description": "",
5+
"main": "index.js",
6+
"license": "Apache-2.0",
7+
"scripts": {
8+
"train": "node train.js",
9+
"test": "node run_tests.js"
10+
},
11+
"dependencies": {
12+
"@tensorflow/tfjs-node": "1.2.8",
13+
"argparse": "^1.0.10"
14+
},
15+
"devDependencies": {
16+
"jasmine": "^3.2.0",
17+
"jasmine-core": "^3.2.1",
18+
"shelljs": "^0.8.3",
19+
"tmp": "^0.0.33"
20+
}
21+
}

abalone-node/run_tests.js

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const jasmine_util = require('@tensorflow/tfjs-core/dist/jasmine_util');
19+
const runTests = require('../test_util').runTests;
20+
21+
runTests(jasmine_util, ['./*test.js']);

abalone-node/test_data.csv

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
sex,length,diameter,height,weight.w,weight.s,weight.v,weight.sh,rings
2+
M,0.455,0.365,0.095,0.514,0.2245,0.101,0.15,15
3+
M,0.35,0.265,0.09,0.2255,0.0995,0.0485,0.07,7
4+
F,0.53,0.42,0.135,0.677,0.2565,0.1415,0.21,9

abalone-node/train.js

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
const tf = require('@tensorflow/tfjs-node');
19+
const argparse = require('argparse');
20+
const https = require('https');
21+
const fs = require('fs');
22+
const createModel = require('./model');
23+
const createDataset = require('./data');
24+
25+
26+
const csvUrl =
27+
'https://storage.googleapis.com/tfjs-examples/abalone-node/abalone.csv';
28+
const csvPath = './abalone.csv';
29+
30+
/**
31+
* Train a model with dataset, then save the model to a local folder.
32+
*/
33+
async function run(epochs, batchSize, savePath) {
34+
const datasetObj = await createDataset('file://' + csvPath);
35+
const model = createModel([datasetObj.numOfColumns]);
36+
// The dataset has 4177 rows. Split them into 2 groups, one for training and
37+
// one for validation. Take about 3500 rows as train dataset, and the rest as
38+
// validation dataset.
39+
const trainBatches = Math.floor(3500 / batchSize);
40+
const dataset = datasetObj.dataset.shuffle(1000).batch(batchSize);
41+
const trainDataset = dataset.take(trainBatches);
42+
const validationDataset = dataset.skip(trainBatches);
43+
44+
await model.fitDataset(
45+
trainDataset, {epochs: epochs, validationData: validationDataset});
46+
47+
await model.save(savePath);
48+
49+
const loadedModel = await tf.loadLayersModel(savePath + '/model.json');
50+
const result = loadedModel.predict(
51+
tf.tensor2d([[0, 0.625, 0.495, 0.165, 1.262, 0.507, 0.318, 0.39]]));
52+
console.log(
53+
'The actual test abalone age is 10, the inference result from the model is ' +
54+
result.dataSync());
55+
}
56+
57+
const parser = new argparse.ArgumentParser(
58+
{description: 'TensorFlow.js-Node Abalone Example.', addHelp: true});
59+
parser.addArgument('--epochs', {
60+
type: 'int',
61+
defaultValue: 100,
62+
help: 'Number of epochs to train the model for.'
63+
});
64+
parser.addArgument('--batch_size', {
65+
type: 'int',
66+
defaultValue: 500,
67+
help: 'Batch size to be used during model training.'
68+
})
69+
parser.addArgument(
70+
'--savePath',
71+
{type: 'string', defaultValue: 'file://trainedModel', help: 'Path.'})
72+
const args = parser.parseArgs();
73+
74+
75+
const file = fs.createWriteStream(csvPath);
76+
https.get(csvUrl, function(response) {
77+
response.pipe(file).on('close', async () => {
78+
run(args.epochs, args.batch_size, args.savePath);
79+
});
80+
});

0 commit comments

Comments
 (0)