Skip to content

Commit 9f9880f

Browse files
committed
Upgrade to DJL 0.17.0 and JSR-381 1.0.5
1 parent e813403 commit 9f9880f

19 files changed

+83
-98
lines changed

build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
allprojects {
22
repositories {
3-
jcenter()
3+
mavenCentral()
44
maven {
55
url "https://oss.sonatype.org/content/repositories/snapshots/"
66
}

examples/build.gradle

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ plugins {
55

66
dependencies {
77
implementation project(":jsr381")
8-
implementation "org.slf4j:slf4j-simple:1.7.30"
8+
implementation "org.slf4j:slf4j-simple:1.7.36"
99

10-
testImplementation("org.testng:testng:6.14.3") {
10+
testImplementation('org.testng:testng:7.6.0') {
1111
exclude group: "junit", module: "junit"
1212
}
1313
}
@@ -18,7 +18,7 @@ test {
1818
}
1919

2020
application {
21-
mainClassName = System.getProperty("main", "ai.djl.jsr381.detection.ObjectDetectorExample")
21+
mainClass = System.getProperty("main", "ai.djl.jsr381.detection.ObjectDetectorExample")
2222
}
2323

2424
run {

examples/src/main/java/ai/djl/jsr381/classification/BinaryClassifierExample.java

+6-7
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,23 @@
1212
*/
1313
package ai.djl.jsr381.classification;
1414

15-
import java.io.File;
16-
import javax.visrec.ml.ClassificationException;
17-
import javax.visrec.ml.ClassifierCreationException;
15+
import java.nio.file.Path;
16+
import java.nio.file.Paths;
1817
import javax.visrec.ml.classification.BinaryClassifier;
1918
import javax.visrec.ml.classification.NeuralNetBinaryClassifier;
19+
import javax.visrec.ml.model.ModelCreationException;
2020

2121
public class BinaryClassifierExample {
2222

23-
public static void main(String[] args)
24-
throws ClassificationException, ClassifierCreationException {
25-
File trainingFile = new File("../jsr381/src/test/resources/spam.csv");
23+
public static void main(String[] args) throws ModelCreationException {
24+
Path trainingFile = Paths.get("../jsr381/src/test/resources/spam.csv");
2625
BinaryClassifier<float[]> spamClassifier =
2726
NeuralNetBinaryClassifier.builder()
2827
.inputClass(float[].class)
2928
.inputsNum(57)
3029
.hiddenLayers(5)
3130
.maxEpochs(2)
32-
.trainingFile(trainingFile)
31+
.trainingPath(trainingFile)
3332
.build();
3433

3534
// create test email feature

examples/src/main/java/ai/djl/jsr381/classification/CatDogRecognition.java

+6-7
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,21 @@
1414

1515
import ai.djl.util.ZipUtils;
1616
import java.awt.image.BufferedImage;
17-
import java.io.File;
1817
import java.io.IOException;
1918
import java.io.InputStream;
2019
import java.net.URL;
2120
import java.nio.file.Files;
2221
import java.nio.file.Path;
2322
import java.nio.file.Paths;
2423
import java.util.Map;
25-
import javax.visrec.ml.ClassifierCreationException;
2624
import javax.visrec.ml.classification.ImageClassifier;
2725
import javax.visrec.ml.classification.NeuralNetImageClassifier;
26+
import javax.visrec.ml.model.ModelCreationException;
2827

2928
public class CatDogRecognition {
3029

31-
public static void main(String[] args) throws IOException, ClassifierCreationException {
32-
File trainingFile = downloadTrainingData();
30+
public static void main(String[] args) throws IOException, ModelCreationException {
31+
Path trainingFile = downloadTrainingData();
3332
Path modelDir = Paths.get("build/model");
3433

3534
ImageClassifier<BufferedImage> classifier =
@@ -42,14 +41,14 @@ public static void main(String[] args) throws IOException, ClassifierCreationExc
4241
.maxEpochs(20)
4342
.build();
4443

45-
File input = new File(trainingFile, "cat/cat_1.png");
44+
Path input = trainingFile.resolve("cat/cat_1.png");
4645
Map<String, Float> result = classifier.classify(input);
4746
for (Map.Entry<String, Float> entry : result.entrySet()) {
4847
System.out.println(entry.getKey() + ": " + entry.getValue());
4948
}
5049
}
5150

52-
private static File downloadTrainingData() throws IOException {
51+
private static Path downloadTrainingData() throws IOException {
5352
String link =
5453
"https://github.com/JavaVisRec/jsr381-examples-datasets/raw/master/cats_and_dogs_training_data_png.zip";
5554
URL url = new URL(link);
@@ -60,6 +59,6 @@ private static File downloadTrainingData() throws IOException {
6059
ZipUtils.unzip(is, dir);
6160
}
6261
}
63-
return dir.resolve("training").toFile();
62+
return dir.resolve("training");
6463
}
6564
}

examples/src/main/java/ai/djl/jsr381/classification/ImageClassifierExample.java

+3-6
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,17 @@
1313
package ai.djl.jsr381.classification;
1414

1515
import java.awt.image.BufferedImage;
16-
import java.io.File;
1716
import java.nio.file.Path;
1817
import java.nio.file.Paths;
1918
import java.util.Map;
20-
import javax.visrec.ml.ClassificationException;
21-
import javax.visrec.ml.ClassifierCreationException;
2219
import javax.visrec.ml.classification.ImageClassifier;
2320
import javax.visrec.ml.classification.NeuralNetImageClassifier;
21+
import javax.visrec.ml.model.ModelCreationException;
2422

2523
public class ImageClassifierExample {
2624

27-
public static void main(String[] args)
28-
throws ClassifierCreationException, ClassificationException {
29-
File input = new File("../jsr381/src/test/resources/0.png");
25+
public static void main(String[] args) throws ModelCreationException {
26+
Path input = Paths.get("../jsr381/src/test/resources/0.png");
3027

3128
// use pre-trained mlp model
3229
Path modelDir = Paths.get("../jsr381/src/test/resources/mlp");

examples/src/main/java/ai/djl/jsr381/detection/ObjectDetectorExample.java

+3-6
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,26 @@
1717
import ai.djl.modality.cv.Image;
1818
import ai.djl.modality.cv.output.DetectedObjects;
1919
import ai.djl.repository.zoo.Criteria;
20-
import ai.djl.repository.zoo.ModelZoo;
2120
import ai.djl.repository.zoo.ZooModel;
2221
import java.awt.image.BufferedImage;
2322
import java.io.IOException;
2423
import java.net.URL;
2524
import java.util.List;
2625
import java.util.Map;
2726
import javax.imageio.ImageIO;
28-
import javax.visrec.ml.ClassificationException;
29-
import javax.visrec.util.BoundingBox;
27+
import javax.visrec.ml.detection.BoundingBox;
3028

3129
public class ObjectDetectorExample {
3230

33-
public static void main(String[] args)
34-
throws ClassificationException, IOException, ModelException {
31+
public static void main(String[] args) throws IOException, ModelException {
3532
Criteria<Image, DetectedObjects> criteria =
3633
Criteria.builder()
3734
.setTypes(Image.class, DetectedObjects.class)
3835
.optApplication(Application.CV.OBJECT_DETECTION)
3936
.optArtifactId("yolo")
4037
.optArgument("threshold", 0.3)
4138
.build();
42-
try (ZooModel<Image, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
39+
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
4340
SimpleObjectDetector objectDetector = new SimpleObjectDetector(model);
4441
URL imageUrl =
4542
new URL("https://djl-ai.s3.amazonaws.com/resources/images/dog_bike_car.jpg");

gradle/wrapper/gradle-wrapper.properties

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
33
distributionPath=wrapper/dists
44
zipStoreBase=GRADLE_USER_HOME
55
zipStorePath=wrapper/dists
6-
distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip
6+
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4.2-bin.zip

jsr381/build.gradle

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,25 @@ plugins {
44
id "signing"
55
}
66

7+
repositories {
8+
mavenCentral()
9+
}
10+
711
group "ai.djl.jsr381"
812
boolean isRelease = project.hasProperty("release") || project.hasProperty("staging")
913
version = "0.8.0" + (isRelease ? "" : "-SNAPSHOT")
1014

1115
dependencies {
12-
api "javax.visrec:visrec-api:1.0.1"
13-
api platform("ai.djl:bom:0.8.0")
16+
api 'javax.visrec:visrec-api:1.0.5'
17+
api platform("ai.djl:bom:0.17.0")
1418
api "ai.djl:api"
1519
api "ai.djl:basicdataset"
1620
api "ai.djl:model-zoo"
1721
api "ai.djl.mxnet:mxnet-model-zoo"
1822
api 'ai.djl.mxnet:mxnet-engine'
19-
api 'ai.djl.mxnet:mxnet-native-auto'
20-
api "org.apache.commons:commons-csv:1.7"
2123

22-
testImplementation "org.slf4j:slf4j-simple:1.7.30"
23-
testImplementation("org.testng:testng:6.14.3") {
24+
testImplementation 'org.slf4j:slf4j-simple:1.7.36'
25+
testImplementation('org.testng:testng:7.6.0') {
2426
exclude group: "junit", module: "junit"
2527
}
2628
}

jsr381/src/main/java/ai/djl/jsr381/classification/SimpleBinaryClassifier.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import ai.djl.inference.Predictor;
44
import ai.djl.repository.zoo.ZooModel;
55
import ai.djl.translate.TranslateException;
6-
import javax.visrec.ml.ClassificationException;
76
import javax.visrec.ml.classification.BinaryClassifier;
7+
import javax.visrec.ml.classification.ClassificationException;
88

99
/** Implementation of a {@link BinaryClassifier} with DJL. */
1010
public class SimpleBinaryClassifier implements BinaryClassifier<float[]> {

jsr381/src/main/java/ai/djl/jsr381/classification/SimpleImageClassifier.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import ai.djl.repository.zoo.ZooModel;
88
import ai.djl.translate.TranslateException;
99
import java.awt.image.BufferedImage;
10-
import java.io.File;
1110
import java.io.IOException;
1211
import java.io.InputStream;
12+
import java.nio.file.Path;
1313
import java.util.List;
1414
import java.util.Map;
1515
import java.util.stream.Collectors;
1616
import javax.imageio.ImageIO;
17-
import javax.visrec.ml.ClassificationException;
17+
import javax.visrec.ml.classification.ClassificationException;
1818
import javax.visrec.ml.classification.ImageClassifier;
1919

2020
/**
@@ -33,9 +33,9 @@ public SimpleImageClassifier(ZooModel<Image, Classifications> model, int topK) {
3333
}
3434

3535
@Override
36-
public Map<String, Float> classify(File input) throws ClassificationException {
36+
public Map<String, Float> classify(Path input) throws ClassificationException {
3737
try {
38-
return classify(ImageIO.read(input));
38+
return classify(ImageIO.read(input.toFile()));
3939
} catch (IOException e) {
4040
throw new ClassificationException("Couldn't transform input into a BufferedImage", e);
4141
}

jsr381/src/main/java/ai/djl/jsr381/dataset/CsvDataset.java

+10-7
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import ai.djl.training.dataset.RandomAccessDataset;
77
import ai.djl.training.dataset.Record;
88
import ai.djl.util.Progress;
9-
import java.io.File;
109
import java.io.IOException;
1110
import java.io.Reader;
1211
import java.nio.file.Files;
12+
import java.nio.file.Path;
1313
import java.util.List;
1414
import org.apache.commons.csv.CSVFormat;
1515
import org.apache.commons.csv.CSVParser;
@@ -53,27 +53,30 @@ public static final class Builder extends BaseBuilder<Builder> {
5353

5454
List<CSVRecord> records;
5555

56-
private File file;
56+
private Path file;
5757

5858
@Override
5959
protected Builder self() {
6060
return this;
6161
}
6262

63-
public Builder setCsvFile(File file) {
63+
public Builder setCsvFile(Path file) {
6464
this.file = file;
6565
return this;
6666
}
6767

6868
public CsvDataset build() throws IOException {
69-
try (Reader reader = Files.newBufferedReader(file.toPath());
69+
try (Reader reader = Files.newBufferedReader(file);
7070
CSVParser csvParser =
7171
new CSVParser(
7272
reader,
7373
CSVFormat.DEFAULT
74-
.withFirstRecordAsHeader()
75-
.withIgnoreHeaderCase()
76-
.withTrim())) {
74+
.builder()
75+
.setHeader()
76+
.setSkipHeaderRecord(true)
77+
.setIgnoreHeaderCase(true)
78+
.setTrim(true)
79+
.build())) {
7780
records = csvParser.getRecords();
7881
}
7982
return new CsvDataset(this);

jsr381/src/main/java/ai/djl/jsr381/detection/SimpleObjectDetector.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
import java.util.List;
1313
import java.util.Map;
1414
import java.util.concurrent.ConcurrentHashMap;
15-
import javax.visrec.ml.ClassificationException;
15+
import javax.visrec.ml.classification.ClassificationException;
16+
import javax.visrec.ml.detection.BoundingBox;
1617
import javax.visrec.ml.detection.ObjectDetector;
17-
import javax.visrec.util.BoundingBox;
1818

1919
/** A simple object detector implemented with DJL. */
2020
public class SimpleObjectDetector implements ObjectDetector<BufferedImage> {

jsr381/src/main/java/ai/djl/jsr381/spi/DjlBinaryClassifierFactory.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import ai.djl.translate.Translator;
2828
import ai.djl.translate.TranslatorContext;
2929
import java.io.IOException;
30-
import javax.visrec.ml.ClassifierCreationException;
3130
import javax.visrec.ml.classification.BinaryClassifier;
3231
import javax.visrec.ml.classification.NeuralNetBinaryClassifier;
32+
import javax.visrec.ml.model.ModelCreationException;
3333
import javax.visrec.spi.BinaryClassifierFactory;
3434

3535
public class DjlBinaryClassifierFactory implements BinaryClassifierFactory<float[]> {
@@ -41,7 +41,7 @@ public Class<float[]> getTargetClass() {
4141

4242
@Override
4343
public BinaryClassifier<float[]> create(NeuralNetBinaryClassifier.BuildingBlock<float[]> block)
44-
throws ClassifierCreationException {
44+
throws ModelCreationException {
4545
int inputSize = block.getInputsNum();
4646
int[] hiddenLayers = block.getHiddenLayers();
4747
int epochs = block.getMaxEpochs();
@@ -62,12 +62,12 @@ public BinaryClassifier<float[]> create(NeuralNetBinaryClassifier.BuildingBlock<
6262
try {
6363
CsvDataset csv =
6464
CsvDataset.builder()
65-
.setCsvFile(block.getTrainingFile())
65+
.setCsvFile(block.getTrainingPath())
6666
.setSampling(batchSize, true)
6767
.build();
6868
dataset = csv.randomSplit(8, 2);
6969
} catch (IOException | TranslateException e) {
70-
throw new ClassifierCreationException("Failed to load dataset.", e);
70+
throw new ModelCreationException("Failed to load dataset.", e);
7171
}
7272

7373
// setup training configuration
@@ -97,7 +97,7 @@ public BinaryClassifier<float[]> create(NeuralNetBinaryClassifier.BuildingBlock<
9797
trainer.notifyListeners(listener -> listener.onEpoch(trainer));
9898
}
9999
} catch (IOException | TranslateException e) {
100-
throw new ClassifierCreationException("Failed to process dataset.", e);
100+
throw new ModelCreationException("Failed to process dataset.", e);
101101
}
102102

103103
return new SimpleBinaryClassifier(new ZooModel<>(model, new BinaryClassifierTranslator()));

0 commit comments

Comments
 (0)