2
2
3
3
import ai .djl .MalformedModelException ;
4
4
import ai .djl .Model ;
5
+ import ai .djl .basicdataset .ImageFolder ;
6
+ import ai .djl .basicmodelzoo .cv .classification .ResNetV1 ;
5
7
import ai .djl .jsr381 .classification .SimpleImageClassifier ;
8
+ import ai .djl .metric .Metrics ;
6
9
import ai .djl .modality .Classifications ;
7
10
import ai .djl .modality .cv .Image ;
8
11
import ai .djl .modality .cv .Image .Flag ;
9
12
import ai .djl .modality .cv .transform .CenterCrop ;
10
13
import ai .djl .modality .cv .transform .Resize ;
11
14
import ai .djl .modality .cv .transform .ToTensor ;
12
15
import ai .djl .modality .cv .translator .ImageClassificationTranslator ;
16
+ import ai .djl .ndarray .NDArray ;
17
+ import ai .djl .ndarray .types .Shape ;
18
+ import ai .djl .nn .Block ;
13
19
import ai .djl .repository .zoo .ZooModel ;
14
- import ai .djl .translate .Pipeline ;
20
+ import ai .djl .training .DefaultTrainingConfig ;
21
+ import ai .djl .training .EasyTrain ;
22
+ import ai .djl .training .Trainer ;
23
+ import ai .djl .training .dataset .Batch ;
24
+ import ai .djl .training .dataset .RandomAccessDataset ;
25
+ import ai .djl .training .evaluator .Accuracy ;
26
+ import ai .djl .training .listener .TrainingListener ;
27
+ import ai .djl .training .loss .Loss ;
28
+ import ai .djl .translate .TranslateException ;
15
29
import ai .djl .translate .Translator ;
16
30
import java .awt .image .BufferedImage ;
31
+ import java .io .File ;
17
32
import java .io .IOException ;
18
33
import java .nio .file .Path ;
34
+ import java .util .List ;
19
35
import javax .visrec .ml .ClassifierCreationException ;
20
36
import javax .visrec .ml .classification .ImageClassifier ;
21
37
import javax .visrec .ml .classification .NeuralNetImageClassifier ;
@@ -38,34 +54,103 @@ public ImageClassifier<BufferedImage> create(
38
54
throws ClassifierCreationException {
39
55
int width = block .getImageWidth ();
40
56
int height = block .getImageHeight ();
41
- Flag flag = width < 50 ? Flag .GRAYSCALE : Flag .COLOR ;
57
+
58
+ Model model = Model .newInstance ("imageClassifier" ); // TODO generate better model name
59
+ ZooModel <Image , Classifications > zooModel ;
42
60
43
61
Path modelPath = block .getImportPath ();
44
62
if (modelPath != null ) {
45
63
// load pre-trained model from model zoo
46
64
logger .info ("Loading pre-trained model ..." );
47
65
48
66
try {
49
- Pipeline pipeline = new Pipeline ( );
50
- pipeline . add ( new CenterCrop ()). add ( new Resize ( width , height )). add ( new ToTensor ()) ;
67
+ model . load ( modelPath );
68
+ Flag flag = width < 50 ? Flag . GRAYSCALE : Flag . COLOR ;
51
69
Translator <Image , Classifications > translator =
52
70
ImageClassificationTranslator .builder ()
53
71
.optFlag (flag )
54
- .setPipeline (pipeline )
72
+ .addTransform (new CenterCrop ())
73
+ .addTransform (new Resize (width , height ))
74
+ .addTransform (new ToTensor ())
55
75
.optSynsetArtifactName ("synset.txt" )
56
76
.optApplySoftmax (true )
57
77
.build ();
58
-
59
- Model model =
60
- Model .newInstance ("imageClassifier" ); // TODO generate better model name
61
- model .load (modelPath );
62
- ZooModel <Image , Classifications > zooModel = new ZooModel <>(model , translator );
63
- return new SimpleImageClassifier (zooModel , 5 );
78
+ zooModel = new ZooModel <>(model , translator );
64
79
} catch (MalformedModelException | IOException e ) {
65
80
throw new ClassifierCreationException ("Failed load model from model zoo." , e );
66
81
}
82
+ } else {
83
+ try {
84
+ zooModel = trainWithResnet (model , block );
85
+ } catch (IOException | TranslateException e ) {
86
+ throw new ClassifierCreationException ("Failed train model." , e );
87
+ }
67
88
}
89
+ return new SimpleImageClassifier (zooModel , 5 );
90
+ }
91
+
92
+ private ZooModel <Image , Classifications > trainWithResnet (
93
+ Model model , NeuralNetImageClassifier .BuildingBlock <BufferedImage > block )
94
+ throws IOException , TranslateException {
95
+ int width = block .getImageWidth ();
96
+ int height = block .getImageHeight ();
97
+ int epochs = block .getMaxEpochs ();
98
+ int batch = 1 ;
99
+
100
+ File trainingFile = block .getTrainingFile ();
101
+ if (trainingFile == null ) {
102
+ throw new IllegalArgumentException ("TrainingFile is required." );
103
+ }
104
+ ImageFolder dataset =
105
+ ImageFolder .builder ()
106
+ .setSampling (batch , true )
107
+ .setRepositoryPath (trainingFile .toPath ())
108
+ .addTransform (new CenterCrop (width , height ))
109
+ .addTransform (new Resize (width , height ))
110
+ .addTransform (new ToTensor ())
111
+ .build ();
112
+
113
+ RandomAccessDataset [] set = dataset .randomSplit (9 , 1 );
114
+
115
+ List <String > synset = dataset .getSynset ();
116
+
117
+ Block resNet18 =
118
+ ResNetV1 .builder ()
119
+ .setImageShape (new Shape (3 , width , height ))
120
+ .setNumLayers (18 )
121
+ .setOutSize (synset .size ())
122
+ .build ();
123
+ model .setBlock (resNet18 );
124
+
125
+ Path exportDir = block .getExportPath ();
126
+ // setup training configuration
127
+ DefaultTrainingConfig config =
128
+ new DefaultTrainingConfig (Loss .softmaxCrossEntropyLoss ())
129
+ .addEvaluator (new Accuracy ())
130
+ .addTrainingListeners (TrainingListener .Defaults .logging ());
131
+
132
+ try (Trainer trainer = model .newTrainer (config )) {
133
+ trainer .setMetrics (new Metrics ());
134
+ // initialize trainer with proper input shape
135
+ trainer .initialize (new Shape (1 , 3 , width , height ));
136
+ EasyTrain .fit (trainer , epochs , set [0 ], set [1 ]);
137
+ }
138
+
139
+ if (exportDir != null ) {
140
+ model .save (exportDir , model .getName ());
141
+ }
142
+
143
+ Batch b = dataset .getData (model .getNDManager ()).iterator ().next ();
144
+ NDArray array = b .getData ().singletonOrThrow ();
68
145
69
- return null ;
146
+ Translator <Image , Classifications > translator =
147
+ ImageClassificationTranslator .builder ()
148
+ .addTransform (new CenterCrop (width , height ))
149
+ .addTransform (new Resize (width , height ))
150
+ .addTransform (new ToTensor ())
151
+ .optSynset (synset )
152
+ .optApplySoftmax (true )
153
+ .build ();
154
+ return new ZooModel <>(model , translator );
70
155
}
71
156
}
0 commit comments