Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;

import org.tensorflow.EagerSession;
Expand Down Expand Up @@ -59,25 +58,29 @@ public static void main(String[] args) {
String modelDir = args[0];
String imageFile = args[1];

Graph g = new Graph();
byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
g.importGraphDef(graphDef);
Session s = new Session(g);

List<String> labels =
readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));

Tensor<Float> image = normalizeImage(imageBytes);
float[] labelProbabilities = executeInceptionGraph(graphDef, image);
TFloat image = normalizeImage(imageBytes);
Tensor<Float> labelProbabilities = executeInceptionGraph(graphDef, image);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(
String.format("BEST MATCH: %s (%.2f%% likely)",
labels.get(bestLabelIdx),
labelProbabilities[bestLabelIdx] * 100f));
labelProbabilities.get(bestLabelIdx) * 100f));
}

private static Tensor<Float> normalizeImage(byte[] imageBytes) {
private static TFloat normalizeImage(byte[] imageBytes) {
// Normalize image eagerly
try (EagerSession session = EagerSession.create()) {
try (EagerSession session = EagerSession.create()) {
Ops tf = Ops.create(session);

// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
//
Expand All @@ -89,46 +92,37 @@ private static Tensor<Float> normalizeImage(byte[] imageBytes) {
final float mean = 117f;
final float scale = 1f;

final Operand<Float> decodedImage =
tf.dtypes.cast(tf.image.decodeJpeg(tf.constant(imageBytes), DecodeJpeg.channels(3L)), Float.class);

final Operand<Float> resizedImage =
tf.image.resizeBilinear(tf.expandDims(decodedImage, tf.constant(0)), tf.constant(new int[] {H, W}));
final Operand<TFloat> decodedImage =
tf.dtypes.cast(tf.image.decodeJpeg(tf.constant(imageBytes), DecodeJpeg.channels(3L)), TFloat.DTYPE);
final Operand<TFloat> resizedImage =
tf.image.resizeBilinear(tf.expandDims(decodedImage, tf.constant(0)), tf.constant(H, W));

final Operand<Float> normalizedImage =
final Operand<TFloat> normalizedImage =
tf.math.div(tf.math.sub(resizedImage, tf.constant(mean)), tf.constant(scale));

return normalizedImage.asOutput().tensor();
return normalizedImage.tensor();
}
}

private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
try (Session s = new Session(g);
// Generally, there may be multiple output tensors, all of them must be closed to prevent resource leaks.
Tensor<Float> result =
s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
final long[] rshape = result.shape();
if (result.numDimensions() != 2 || rshape[0] != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
Arrays.toString(rshape)));
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
private static Tensor<Float> executeInceptionGraph(Session graphSession, TFloat image) {
TFloat result = graphSession.runner().feed("input", image).fetch("output").run().get(0).expect(TFloat.class);
if (result.rank() != 2 || result.shape().numElements(0) != 1) {
throw new RuntimeException(
String.format(
"Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
result.shape()));
}
return result.get(0);
}

private static int maxIndex(float[] probabilities) {
private static int maxIndex(Tensor<Float> probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
probabilities.values().forEachRemaining(p -> {
if (p > probabilities[best]) {
best = p;
}
}
});
return best;
}

Expand Down