|
| 1 | +/** |
| 2 | + * @license |
| 3 | + * Copyright 2019 Google LLC. All Rights Reserved. |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ============================================================================= |
| 16 | + */ |
| 17 | + |
| 18 | +const tf = require('@tensorflow/tfjs-node'); |
| 19 | +const argparse = require('argparse'); |
| 20 | +const https = require('https'); |
| 21 | +const fs = require('fs'); |
| 22 | +const createModel = require('./model'); |
| 23 | +const createDataset = require('./data'); |
| 24 | + |
| 25 | + |
| 26 | +const csvUrl = |
| 27 | + 'https://storage.googleapis.com/tfjs-examples/abalone-node/abalone.csv'; |
| 28 | +const csvPath = './abalone.csv'; |
| 29 | + |
| 30 | +/** |
| 31 | + * Train a model with dataset, then save the model to a local folder. |
| 32 | + */ |
| 33 | +async function run(epochs, batchSize, savePath) { |
| 34 | + const datasetObj = await createDataset('file://' + csvPath); |
| 35 | + const model = createModel([datasetObj.numOfColumns]); |
| 36 | + // The dataset has 4177 rows. Split them into 2 groups, one for training and |
| 37 | + // one for validation. Take about 3500 rows as train dataset, and the rest as |
| 38 | + // validation dataset. |
| 39 | + const trainBatches = Math.floor(3500 / batchSize); |
| 40 | + const dataset = datasetObj.dataset.shuffle(1000).batch(batchSize); |
| 41 | + const trainDataset = dataset.take(trainBatches); |
| 42 | + const validationDataset = dataset.skip(trainBatches); |
| 43 | + |
| 44 | + await model.fitDataset( |
| 45 | + trainDataset, {epochs: epochs, validationData: validationDataset}); |
| 46 | + |
| 47 | + await model.save(savePath); |
| 48 | + |
| 49 | + const loadedModel = await tf.loadLayersModel(savePath + '/model.json'); |
| 50 | + const result = loadedModel.predict( |
| 51 | + tf.tensor2d([[0, 0.625, 0.495, 0.165, 1.262, 0.507, 0.318, 0.39]])); |
| 52 | + console.log( |
| 53 | + 'The actual test abalone age is 10, the inference result from the model is ' + |
| 54 | + result.dataSync()); |
| 55 | +} |
| 56 | + |
| 57 | +const parser = new argparse.ArgumentParser( |
| 58 | + {description: 'TensorFlow.js-Node Abalone Example.', addHelp: true}); |
| 59 | +parser.addArgument('--epochs', { |
| 60 | + type: 'int', |
| 61 | + defaultValue: 100, |
| 62 | + help: 'Number of epochs to train the model for.' |
| 63 | +}); |
| 64 | +parser.addArgument('--batch_size', { |
| 65 | + type: 'int', |
| 66 | + defaultValue: 500, |
| 67 | + help: 'Batch size to be used during model training.' |
| 68 | +}) |
| 69 | +parser.addArgument( |
| 70 | + '--savePath', |
| 71 | + {type: 'string', defaultValue: 'file://trainedModel', help: 'Path.'}) |
| 72 | +const args = parser.parseArgs(); |
| 73 | + |
| 74 | + |
| 75 | +const file = fs.createWriteStream(csvPath); |
| 76 | +https.get(csvUrl, function(response) { |
| 77 | + response.pipe(file).on('close', async () => { |
| 78 | + run(args.epochs, args.batch_size, args.savePath); |
| 79 | + }); |
| 80 | +}); |
0 commit comments