27
27
import ai .djl .translate .Translator ;
28
28
import ai .djl .translate .TranslatorContext ;
29
29
import java .io .IOException ;
30
- import javax .visrec .ml .ClassifierCreationException ;
31
30
import javax .visrec .ml .classification .BinaryClassifier ;
32
31
import javax .visrec .ml .classification .NeuralNetBinaryClassifier ;
32
+ import javax .visrec .ml .model .ModelCreationException ;
33
33
import javax .visrec .spi .BinaryClassifierFactory ;
34
34
35
35
public class DjlBinaryClassifierFactory implements BinaryClassifierFactory <float []> {
@@ -41,7 +41,7 @@ public Class<float[]> getTargetClass() {
41
41
42
42
@ Override
43
43
public BinaryClassifier <float []> create (NeuralNetBinaryClassifier .BuildingBlock <float []> block )
44
- throws ClassifierCreationException {
44
+ throws ModelCreationException {
45
45
int inputSize = block .getInputsNum ();
46
46
int [] hiddenLayers = block .getHiddenLayers ();
47
47
int epochs = block .getMaxEpochs ();
@@ -62,12 +62,12 @@ public BinaryClassifier<float[]> create(NeuralNetBinaryClassifier.BuildingBlock<
62
62
try {
63
63
CsvDataset csv =
64
64
CsvDataset .builder ()
65
- .setCsvFile (block .getTrainingFile ())
65
+ .setCsvFile (block .getTrainingPath ())
66
66
.setSampling (batchSize , true )
67
67
.build ();
68
68
dataset = csv .randomSplit (8 , 2 );
69
69
} catch (IOException | TranslateException e ) {
70
- throw new ClassifierCreationException ("Failed to load dataset." , e );
70
+ throw new ModelCreationException ("Failed to load dataset." , e );
71
71
}
72
72
73
73
// setup training configuration
@@ -97,7 +97,7 @@ public BinaryClassifier<float[]> create(NeuralNetBinaryClassifier.BuildingBlock<
97
97
trainer .notifyListeners (listener -> listener .onEpoch (trainer ));
98
98
}
99
99
} catch (IOException | TranslateException e ) {
100
- throw new ClassifierCreationException ("Failed to process dataset." , e );
100
+ throw new ModelCreationException ("Failed to process dataset." , e );
101
101
}
102
102
103
103
return new SimpleBinaryClassifier (new ZooModel <>(model , new BinaryClassifierTranslator ()));
0 commit comments