Skip to content

Commit ecd8231

Browse files
committed
Update for yolov4-full
1 parent 087052f commit ecd8231

File tree

1 file changed

+183
-75
lines changed

1 file changed

+183
-75
lines changed

android/app/src/main/java/org/tensorflow/lite/examples/detection/tflite/YoloV4Classifier.java

+183-75
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -181,10 +178,11 @@ public float getObjThresh() {
181178
private static boolean isGPU = true;
182179

183180
// tiny or not
184-
private static boolean isTiny = true;
181+
private static boolean isTiny = false;
185182

186183
// config yolov4 tiny
187184
private static final int[] OUTPUT_WIDTH_TINY = new int[]{2535, 2535};
185+
private static final int[] OUTPUT_WIDTH_FULL = new int[]{10647, 10647};
188186
private static final int[][] MASKS_TINY = new int[][]{{3, 4, 5}, {1, 2, 3}};
189187
private static final int[] ANCHORS_TINY = new int[]{
190188
23, 27, 37, 58, 81, 82, 81, 82, 135, 169, 344, 319};
@@ -304,84 +302,127 @@ protected ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
304302
return byteBuffer;
305303
}
306304

307-
private ArrayList<Recognition> getDetections(ByteBuffer byteBuffer, Bitmap bitmap) {
308-
ArrayList<Recognition> detections = new ArrayList<Recognition>();
309-
Map<Integer, Object> outputMap = new HashMap<>();
310-
for (int i = 0; i < OUTPUT_WIDTH.length; i++) {
311-
float[][][][][] out = new float[1][OUTPUT_WIDTH[i]][OUTPUT_WIDTH[i]][3][5 + labels.size()];
312-
outputMap.put(i, out);
313-
}
305+
// private ArrayList<Recognition> getDetections(ByteBuffer byteBuffer, Bitmap bitmap) {
306+
// ArrayList<Recognition> detections = new ArrayList<Recognition>();
307+
// Map<Integer, Object> outputMap = new HashMap<>();
308+
// for (int i = 0; i < OUTPUT_WIDTH.length; i++) {
309+
// float[][][][][] out = new float[1][OUTPUT_WIDTH[i]][OUTPUT_WIDTH[i]][3][5 + labels.size()];
310+
// outputMap.put(i, out);
311+
// }
312+
//
313+
// Log.d("YoloV4Classifier", "mObjThresh: " + getObjThresh());
314+
//
315+
// Object[] inputArray = {byteBuffer};
316+
// tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
317+
//
318+
// for (int i = 0; i < OUTPUT_WIDTH.length; i++) {
319+
// int gridWidth = OUTPUT_WIDTH[i];
320+
// float[][][][][] out = (float[][][][][]) outputMap.get(i);
321+
//
322+
// Log.d("YoloV4Classifier", "out[" + i + "] detect start");
323+
// for (int y = 0; y < gridWidth; ++y) {
324+
// for (int x = 0; x < gridWidth; ++x) {
325+
// for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
326+
// final int offset =
327+
// (gridWidth * (NUM_BOXES_PER_BLOCK * (labels.size() + 5))) * y
328+
// + (NUM_BOXES_PER_BLOCK * (labels.size() + 5)) * x
329+
// + (labels.size() + 5) * b;
330+
//
331+
// final float confidence = expit(out[0][y][x][b][4]);
332+
// int detectedClass = -1;
333+
// float maxClass = 0;
334+
//
335+
// final float[] classes = new float[labels.size()];
336+
// for (int c = 0; c < labels.size(); ++c) {
337+
// classes[c] = out[0][y][x][b][5 + c];
338+
// }
339+
//
340+
// for (int c = 0; c < labels.size(); ++c) {
341+
// if (classes[c] > maxClass) {
342+
// detectedClass = c;
343+
// maxClass = classes[c];
344+
// }
345+
// }
346+
//
347+
// final float confidenceInClass = maxClass * confidence;
348+
// if (confidenceInClass > getObjThresh()) {
349+
//// final float xPos = (x + (expit(out[0][y][x][b][0]) * XYSCALE[i]) - (0.5f * (XYSCALE[i] - 1))) * (INPUT_SIZE / gridWidth);
350+
//// final float yPos = (y + (expit(out[0][y][x][b][1]) * XYSCALE[i]) - (0.5f * (XYSCALE[i] - 1))) * (INPUT_SIZE / gridWidth);
351+
//
352+
// final float xPos = (x + expit(out[0][y][x][b][0])) * (1.0f * INPUT_SIZE / gridWidth);
353+
// final float yPos = (y + expit(out[0][y][x][b][1])) * (1.0f * INPUT_SIZE / gridWidth);
354+
//
355+
// final float w = (float) (Math.exp(out[0][y][x][b][2]) * ANCHORS[2 * MASKS[i][b]]);
356+
// final float h = (float) (Math.exp(out[0][y][x][b][3]) * ANCHORS[2 * MASKS[i][b] + 1]);
357+
//
358+
// final RectF rect =
359+
// new RectF(
360+
// Math.max(0, xPos - w / 2),
361+
// Math.max(0, yPos - h / 2),
362+
// Math.min(bitmap.getWidth() - 1, xPos + w / 2),
363+
// Math.min(bitmap.getHeight() - 1, yPos + h / 2));
364+
// detections.add(new Recognition("" + offset, labels.get(detectedClass),
365+
// confidenceInClass, rect, detectedClass));
366+
// }
367+
// }
368+
// }
369+
// }
370+
// Log.d("YoloV4Classifier", "out[" + i + "] detect end");
371+
// }
372+
// return detections;
373+
// }
314374

315-
Log.d("YoloV4Classifier", "mObjThresh: " + getObjThresh());
375+
/**
376+
* For yolov4-tiny, the situation would be a little different from the yolov4, it only has two
377+
* output. Both has three dimenstion. The first one is a tensor with dimension [1, 2535,4], containing all the bounding boxes.
378+
* The second one is a tensor with dimension [1, 2535, class_num], containing all the classes score.
379+
* @param byteBuffer input ByteBuffer, which contains the image information
380+
* @param bitmap pixel disenty used to resize the output images
381+
* @return an array list containing the recognitions
382+
*/
316383

384+
private ArrayList<Recognition> getDetectionsForFull(ByteBuffer byteBuffer, Bitmap bitmap) {
385+
ArrayList<Recognition> detections = new ArrayList<Recognition>();
386+
Map<Integer, Object> outputMap = new HashMap<>();
387+
outputMap.put(0, new float[1][OUTPUT_WIDTH_FULL[0]][4]);
388+
outputMap.put(1, new float[1][OUTPUT_WIDTH_FULL[1]][labels.size()]);
317389
Object[] inputArray = {byteBuffer};
318390
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
319391

320-
for (int i = 0; i < OUTPUT_WIDTH.length; i++) {
321-
int gridWidth = OUTPUT_WIDTH[i];
322-
float[][][][][] out = (float[][][][][]) outputMap.get(i);
323-
324-
Log.d("YoloV4Classifier", "out[" + i + "] detect start");
325-
for (int y = 0; y < gridWidth; ++y) {
326-
for (int x = 0; x < gridWidth; ++x) {
327-
for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
328-
final int offset =
329-
(gridWidth * (NUM_BOXES_PER_BLOCK * (labels.size() + 5))) * y
330-
+ (NUM_BOXES_PER_BLOCK * (labels.size() + 5)) * x
331-
+ (labels.size() + 5) * b;
332-
333-
final float confidence = expit(out[0][y][x][b][4]);
334-
int detectedClass = -1;
335-
float maxClass = 0;
336-
337-
final float[] classes = new float[labels.size()];
338-
for (int c = 0; c < labels.size(); ++c) {
339-
classes[c] = out[0][y][x][b][5 + c];
340-
}
341-
342-
for (int c = 0; c < labels.size(); ++c) {
343-
if (classes[c] > maxClass) {
344-
detectedClass = c;
345-
maxClass = classes[c];
346-
}
347-
}
348-
349-
final float confidenceInClass = maxClass * confidence;
350-
if (confidenceInClass > getObjThresh()) {
351-
// final float xPos = (x + (expit(out[0][y][x][b][0]) * XYSCALE[i]) - (0.5f * (XYSCALE[i] - 1))) * (INPUT_SIZE / gridWidth);
352-
// final float yPos = (y + (expit(out[0][y][x][b][1]) * XYSCALE[i]) - (0.5f * (XYSCALE[i] - 1))) * (INPUT_SIZE / gridWidth);
353-
354-
final float xPos = (x + expit(out[0][y][x][b][0])) * (1.0f * INPUT_SIZE / gridWidth);
355-
final float yPos = (y + expit(out[0][y][x][b][1])) * (1.0f * INPUT_SIZE / gridWidth);
356-
357-
final float w = (float) (Math.exp(out[0][y][x][b][2]) * ANCHORS[2 * MASKS[i][b]]);
358-
final float h = (float) (Math.exp(out[0][y][x][b][3]) * ANCHORS[2 * MASKS[i][b] + 1]);
359-
360-
final RectF rect =
361-
new RectF(
362-
Math.max(0, xPos - w / 2),
363-
Math.max(0, yPos - h / 2),
364-
Math.min(bitmap.getWidth() - 1, xPos + w / 2),
365-
Math.min(bitmap.getHeight() - 1, yPos + h / 2));
366-
detections.add(new Recognition("" + offset, labels.get(detectedClass),
367-
confidenceInClass, rect, detectedClass));
368-
}
369-
}
392+
int gridWidth = OUTPUT_WIDTH_FULL[0];
393+
float[][][] bboxes = (float [][][]) outputMap.get(0);
394+
float[][][] out_score = (float[][][]) outputMap.get(1);
395+
396+
for (int i = 0; i < gridWidth;i++){
397+
float maxClass = 0;
398+
int detectedClass = -1;
399+
final float[] classes = new float[labels.size()];
400+
for (int c = 0;c< labels.size();c++){
401+
classes [c] = out_score[0][i][c];
402+
}
403+
for (int c = 0;c<labels.size();++c){
404+
if (classes[c] > maxClass){
405+
detectedClass = c;
406+
maxClass = classes[c];
370407
}
371408
}
372-
Log.d("YoloV4Classifier", "out[" + i + "] detect end");
409+
final float score = maxClass;
410+
if (score > getObjThresh()){
411+
final float xPos = bboxes[0][i][0];
412+
final float yPos = bboxes[0][i][1];
413+
final float w = bboxes[0][i][2];
414+
final float h = bboxes[0][i][3];
415+
final RectF rectF = new RectF(
416+
Math.max(0, xPos - w / 2),
417+
Math.max(0, yPos - h / 2),
418+
Math.min(bitmap.getWidth() - 1, xPos + w / 2),
419+
Math.min(bitmap.getHeight() - 1, yPos + h / 2));
420+
detections.add(new Recognition("" + i, labels.get(detectedClass),score,rectF,detectedClass ));
421+
}
373422
}
374423
return detections;
375424
}
376425

377-
/**
378-
* For yolov4-tiny, the situation would be a little different from the yolov4, it only has two
379-
* output. Both has three dimenstion. The first one is a tensor with dimension [1, 2535,4], containing all the bounding boxes.
380-
* The second one is a tensor with dimension [1, 2535, class_num], containing all the classes score.
381-
* @param byteBuffer input ByteBuffer, which contains the image information
382-
* @param bitmap pixel disenty used to resize the output images
383-
* @return an array list containing the recognitions
384-
*/
385426
private ArrayList<Recognition> getDetectionsForTiny(ByteBuffer byteBuffer, Bitmap bitmap) {
386427
ArrayList<Recognition> detections = new ArrayList<Recognition>();
387428
Map<Integer, Object> outputMap = new HashMap<>();
@@ -418,20 +459,87 @@ private ArrayList<Recognition> getDetectionsForTiny(ByteBuffer byteBuffer, Bitma
418459
Math.max(0, yPos - h / 2),
419460
Math.min(bitmap.getWidth() - 1, xPos + w / 2),
420461
Math.min(bitmap.getHeight() - 1, yPos + h / 2));
421-
detections.add(new Recognition("" + i, labels.get(detectedClass),score,rectF,detectedClass ));
462+
detections.add(new Recognition("" + i, labels.get(detectedClass),score,rectF,detectedClass ));
422463
}
423464
}
424465
return detections;
425466
}
426467

427468
public ArrayList<Recognition> recognizeImage(Bitmap bitmap) {
428469
ByteBuffer byteBuffer = convertBitmapToByteBuffer(bitmap);
470+
471+
// Map<Integer, Object> outputMap = new HashMap<>();
472+
// for (int i = 0; i < OUTPUT_WIDTH.length; i++) {
473+
// float[][][][][] out = new float[1][OUTPUT_WIDTH[i]][OUTPUT_WIDTH[i]][3][5 + labels.size()];
474+
// outputMap.put(i, out);
475+
// }
476+
//
477+
// Log.d("YoloV4Classifier", "mObjThresh: " + getObjThresh());
478+
//
479+
// Object[] inputArray = {byteBuffer};
480+
// tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
481+
//
482+
// ArrayList<Recognition> detections = new ArrayList<Recognition>();
483+
//
484+
// for (int i = 0; i < OUTPUT_WIDTH.length; i++) {
485+
// int gridWidth = OUTPUT_WIDTH[i];
486+
// float[][][][][] out = (float[][][][][]) outputMap.get(i);
487+
//
488+
// Log.d("YoloV4Classifier", "out[" + i + "] detect start");
489+
// for (int y = 0; y < gridWidth; ++y) {
490+
// for (int x = 0; x < gridWidth; ++x) {
491+
// for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) {
492+
// final int offset =
493+
// (gridWidth * (NUM_BOXES_PER_BLOCK * (labels.size() + 5))) * y
494+
// + (NUM_BOXES_PER_BLOCK * (labels.size() + 5)) * x
495+
// + (labels.size() + 5) * b;
496+
//
497+
// final float confidence = expit(out[0][y][x][b][4]);
498+
// int detectedClass = -1;
499+
// float maxClass = 0;
500+
//
501+
// final float[] classes = new float[labels.size()];
502+
// for (int c = 0; c < labels.size(); ++c) {
503+
// classes[c] = out[0][y][x][b][5 + c];
504+
// }
505+
//
506+
// for (int c = 0; c < labels.size(); ++c) {
507+
// if (classes[c] > maxClass) {
508+
// detectedClass = c;
509+
// maxClass = classes[c];
510+
// }
511+
// }
512+
//
513+
// final float confidenceInClass = maxClass * confidence;
514+
// if (confidenceInClass > getObjThresh()) {
515+
//// final float xPos = (x + (expit(out[0][y][x][b][0]) * XYSCALE[i]) - (0.5f * (XYSCALE[i] - 1))) * (INPUT_SIZE / gridWidth);
516+
//// final float yPos = (y + (expit(out[0][y][x][b][1]) * XYSCALE[i]) - (0.5f * (XYSCALE[i] - 1))) * (INPUT_SIZE / gridWidth);
517+
//
518+
// final float xPos = (x + expit(out[0][y][x][b][0])) * (1.0f * INPUT_SIZE / gridWidth);
519+
// final float yPos = (y + expit(out[0][y][x][b][1])) * (1.0f * INPUT_SIZE / gridWidth);
520+
//
521+
// final float w = (float) (Math.exp(out[0][y][x][b][2]) * ANCHORS[2 * MASKS[i][b]]);
522+
// final float h = (float) (Math.exp(out[0][y][x][b][3]) * ANCHORS[2 * MASKS[i][b] + 1]);
523+
//
524+
// final RectF rect =
525+
// new RectF(
526+
// Math.max(0, xPos - w / 2),
527+
// Math.max(0, yPos - h / 2),
528+
// Math.min(bitmap.getWidth() - 1, xPos + w / 2),
529+
// Math.min(bitmap.getHeight() - 1, yPos + h / 2));
530+
// detections.add(new Recognition("" + offset, labels.get(detectedClass),
531+
// confidenceInClass, rect, detectedClass));
532+
// }
533+
// }
534+
// }
535+
// }
536+
// Log.d("YoloV4Classifier", "out[" + i + "] detect end");
537+
// }
429538
ArrayList<Recognition> detections;
430-
//check whether the tiny version is specified
431539
if (isTiny) {
432540
detections = getDetectionsForTiny(byteBuffer, bitmap);
433541
} else {
434-
detections = getDetections(byteBuffer, bitmap);
542+
detections = getDetectionsForFull(byteBuffer, bitmap);
435543
}
436544
final ArrayList<Recognition> recognitions = nms(detections);
437545
return recognitions;
@@ -488,4 +596,4 @@ public boolean checkInvalidateBox(float x, float y, float width, float height, f
488596

489597
return true;
490598
}
491-
}
599+
}

0 commit comments

Comments
 (0)