Skip to content

Commit e813403

Browse files
committedOct 27, 2020
Implement ImageClassifier training
1 parent 1317233 commit e813403

File tree

4 files changed

+204
-12
lines changed

4 files changed

+204
-12
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ libs
1313
*.dylib
1414
*.dll
1515
*.class
16+
datasets/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
5+
* with the License. A copy of the License is located at
6+
*
7+
* http://aws.amazon.com/apache2.0/
8+
*
9+
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
10+
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
11+
* and limitations under the License.
12+
*/
13+
package ai.djl.jsr381.classification;
14+
15+
import ai.djl.util.ZipUtils;
16+
import java.awt.image.BufferedImage;
17+
import java.io.File;
18+
import java.io.IOException;
19+
import java.io.InputStream;
20+
import java.net.URL;
21+
import java.nio.file.Files;
22+
import java.nio.file.Path;
23+
import java.nio.file.Paths;
24+
import java.util.Map;
25+
import javax.visrec.ml.ClassifierCreationException;
26+
import javax.visrec.ml.classification.ImageClassifier;
27+
import javax.visrec.ml.classification.NeuralNetImageClassifier;
28+
29+
public class CatDogRecognition {
30+
31+
public static void main(String[] args) throws IOException, ClassifierCreationException {
32+
File trainingFile = downloadTrainingData();
33+
Path modelDir = Paths.get("build/model");
34+
35+
ImageClassifier<BufferedImage> classifier =
36+
NeuralNetImageClassifier.builder()
37+
.inputClass(BufferedImage.class)
38+
.imageHeight(128)
39+
.imageWidth(128)
40+
.trainingFile(trainingFile)
41+
.exportModel(modelDir)
42+
.maxEpochs(20)
43+
.build();
44+
45+
File input = new File(trainingFile, "cat/cat_1.png");
46+
Map<String, Float> result = classifier.classify(input);
47+
for (Map.Entry<String, Float> entry : result.entrySet()) {
48+
System.out.println(entry.getKey() + ": " + entry.getValue());
49+
}
50+
}
51+
52+
private static File downloadTrainingData() throws IOException {
53+
String link =
54+
"https://github.com/JavaVisRec/jsr381-examples-datasets/raw/master/cats_and_dogs_training_data_png.zip";
55+
URL url = new URL(link);
56+
Path dir = Paths.get("datasets", "cats_and_dogs");
57+
if (!Files.exists(dir)) {
58+
Files.createDirectories(dir);
59+
try (InputStream is = url.openStream()) {
60+
ZipUtils.unzip(is, dir);
61+
}
62+
}
63+
return dir.resolve("training").toFile();
64+
}
65+
}

‎jsr381/src/main/java/ai/djl/jsr381/spi/DjlImageClassifierFactory.java

+97-12
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,36 @@
22

33
import ai.djl.MalformedModelException;
44
import ai.djl.Model;
5+
import ai.djl.basicdataset.ImageFolder;
6+
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
57
import ai.djl.jsr381.classification.SimpleImageClassifier;
8+
import ai.djl.metric.Metrics;
69
import ai.djl.modality.Classifications;
710
import ai.djl.modality.cv.Image;
811
import ai.djl.modality.cv.Image.Flag;
912
import ai.djl.modality.cv.transform.CenterCrop;
1013
import ai.djl.modality.cv.transform.Resize;
1114
import ai.djl.modality.cv.transform.ToTensor;
1215
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
16+
import ai.djl.ndarray.NDArray;
17+
import ai.djl.ndarray.types.Shape;
18+
import ai.djl.nn.Block;
1319
import ai.djl.repository.zoo.ZooModel;
14-
import ai.djl.translate.Pipeline;
20+
import ai.djl.training.DefaultTrainingConfig;
21+
import ai.djl.training.EasyTrain;
22+
import ai.djl.training.Trainer;
23+
import ai.djl.training.dataset.Batch;
24+
import ai.djl.training.dataset.RandomAccessDataset;
25+
import ai.djl.training.evaluator.Accuracy;
26+
import ai.djl.training.listener.TrainingListener;
27+
import ai.djl.training.loss.Loss;
28+
import ai.djl.translate.TranslateException;
1529
import ai.djl.translate.Translator;
1630
import java.awt.image.BufferedImage;
31+
import java.io.File;
1732
import java.io.IOException;
1833
import java.nio.file.Path;
34+
import java.util.List;
1935
import javax.visrec.ml.ClassifierCreationException;
2036
import javax.visrec.ml.classification.ImageClassifier;
2137
import javax.visrec.ml.classification.NeuralNetImageClassifier;
@@ -38,34 +54,103 @@ public ImageClassifier<BufferedImage> create(
3854
throws ClassifierCreationException {
3955
int width = block.getImageWidth();
4056
int height = block.getImageHeight();
41-
Flag flag = width < 50 ? Flag.GRAYSCALE : Flag.COLOR;
57+
58+
Model model = Model.newInstance("imageClassifier"); // TODO generate better model name
59+
ZooModel<Image, Classifications> zooModel;
4260

4361
Path modelPath = block.getImportPath();
4462
if (modelPath != null) {
4563
// load pre-trained model from model zoo
4664
logger.info("Loading pre-trained model ...");
4765

4866
try {
49-
Pipeline pipeline = new Pipeline();
50-
pipeline.add(new CenterCrop()).add(new Resize(width, height)).add(new ToTensor());
67+
model.load(modelPath);
68+
Flag flag = width < 50 ? Flag.GRAYSCALE : Flag.COLOR;
5169
Translator<Image, Classifications> translator =
5270
ImageClassificationTranslator.builder()
5371
.optFlag(flag)
54-
.setPipeline(pipeline)
72+
.addTransform(new CenterCrop())
73+
.addTransform(new Resize(width, height))
74+
.addTransform(new ToTensor())
5575
.optSynsetArtifactName("synset.txt")
5676
.optApplySoftmax(true)
5777
.build();
58-
59-
Model model =
60-
Model.newInstance("imageClassifier"); // TODO generate better model name
61-
model.load(modelPath);
62-
ZooModel<Image, Classifications> zooModel = new ZooModel<>(model, translator);
63-
return new SimpleImageClassifier(zooModel, 5);
78+
zooModel = new ZooModel<>(model, translator);
6479
} catch (MalformedModelException | IOException e) {
6580
throw new ClassifierCreationException("Failed load model from model zoo.", e);
6681
}
82+
} else {
83+
try {
84+
zooModel = trainWithResnet(model, block);
85+
} catch (IOException | TranslateException e) {
86+
throw new ClassifierCreationException("Failed train model.", e);
87+
}
6788
}
89+
return new SimpleImageClassifier(zooModel, 5);
90+
}
91+
92+
private ZooModel<Image, Classifications> trainWithResnet(
93+
Model model, NeuralNetImageClassifier.BuildingBlock<BufferedImage> block)
94+
throws IOException, TranslateException {
95+
int width = block.getImageWidth();
96+
int height = block.getImageHeight();
97+
int epochs = block.getMaxEpochs();
98+
int batch = 1;
99+
100+
File trainingFile = block.getTrainingFile();
101+
if (trainingFile == null) {
102+
throw new IllegalArgumentException("TrainingFile is required.");
103+
}
104+
ImageFolder dataset =
105+
ImageFolder.builder()
106+
.setSampling(batch, true)
107+
.setRepositoryPath(trainingFile.toPath())
108+
.addTransform(new CenterCrop(width, height))
109+
.addTransform(new Resize(width, height))
110+
.addTransform(new ToTensor())
111+
.build();
112+
113+
RandomAccessDataset[] set = dataset.randomSplit(9, 1);
114+
115+
List<String> synset = dataset.getSynset();
116+
117+
Block resNet18 =
118+
ResNetV1.builder()
119+
.setImageShape(new Shape(3, width, height))
120+
.setNumLayers(18)
121+
.setOutSize(synset.size())
122+
.build();
123+
model.setBlock(resNet18);
124+
125+
Path exportDir = block.getExportPath();
126+
// setup training configuration
127+
DefaultTrainingConfig config =
128+
new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
129+
.addEvaluator(new Accuracy())
130+
.addTrainingListeners(TrainingListener.Defaults.logging());
131+
132+
try (Trainer trainer = model.newTrainer(config)) {
133+
trainer.setMetrics(new Metrics());
134+
// initialize trainer with proper input shape
135+
trainer.initialize(new Shape(1, 3, width, height));
136+
EasyTrain.fit(trainer, epochs, set[0], set[1]);
137+
}
138+
139+
if (exportDir != null) {
140+
model.save(exportDir, model.getName());
141+
}
142+
143+
Batch b = dataset.getData(model.getNDManager()).iterator().next();
144+
NDArray array = b.getData().singletonOrThrow();
68145

69-
return null;
146+
Translator<Image, Classifications> translator =
147+
ImageClassificationTranslator.builder()
148+
.addTransform(new CenterCrop(width, height))
149+
.addTransform(new Resize(width, height))
150+
.addTransform(new ToTensor())
151+
.optSynset(synset)
152+
.optApplySoftmax(true)
153+
.build();
154+
return new ZooModel<>(model, translator);
70155
}
71156
}

‎jsr381/src/test/java/ai/djl/jsr381/classification/ImageClassifierTest.java

+41
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
*/
1313
package ai.djl.jsr381.classification;
1414

15+
import ai.djl.util.ZipUtils;
1516
import java.awt.image.BufferedImage;
1617
import java.io.File;
18+
import java.io.IOException;
19+
import java.io.InputStream;
1720
import java.net.URL;
21+
import java.nio.file.Files;
1822
import java.nio.file.Path;
1923
import java.nio.file.Paths;
2024
import java.util.Map;
@@ -46,4 +50,41 @@ public void testImageClassifier() throws ClassifierCreationException, Classifica
4650
System.out.println(entry.getKey() + ": " + entry.getValue());
4751
}
4852
}
53+
54+
@Test
55+
public void testImageClassifierTraining()
56+
throws ClassifierCreationException, ClassificationException, IOException {
57+
File trainingFile = downloadTrainingData();
58+
Path modelDir = Paths.get("build/model");
59+
60+
ImageClassifier<BufferedImage> classifier =
61+
NeuralNetImageClassifier.builder()
62+
.inputClass(BufferedImage.class)
63+
.imageHeight(128)
64+
.imageWidth(128)
65+
.trainingFile(trainingFile)
66+
.exportModel(modelDir)
67+
.maxEpochs(2)
68+
.build();
69+
70+
File input = new File(trainingFile, "cat/cat_1.png");
71+
Map<String, Float> result = classifier.classify(input);
72+
for (Map.Entry<String, Float> entry : result.entrySet()) {
73+
System.out.println(entry.getKey() + ": " + entry.getValue());
74+
}
75+
}
76+
77+
private File downloadTrainingData() throws IOException {
78+
String link =
79+
"https://github.com/JavaVisRec/jsr381-examples-datasets/raw/master/cats_and_dogs_training_data_png.zip";
80+
URL url = new URL(link);
81+
Path dir = Paths.get("datasets", "cats_and_dogs");
82+
if (!Files.exists(dir)) {
83+
Files.createDirectories(dir);
84+
try (InputStream is = url.openStream()) {
85+
ZipUtils.unzip(is, dir);
86+
}
87+
}
88+
return dir.resolve("training").toFile();
89+
}
4990
}

0 commit comments

Comments
 (0)
Please sign in to comment.