Skip to content

Commit b3b37a7

Browse files
committed
Improve code documentation and readability
1 parent 22692b1 commit b3b37a7

File tree

11 files changed

+243
-261
lines changed

11 files changed

+243
-261
lines changed

android/app/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ android {
3434

3535
lintOptions {
3636
disable 'InvalidPackage'
37+
checkReleaseBuilds false // Adding this line to avoid release build failure
3738
}
3839

3940
defaultConfig {

lib/tflite/classifier.dart

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
99

1010
import 'stats.dart';
1111

12+
/// Classifier
1213
class Classifier {
1314
/// Instance of Interpreter
1415
Interpreter _interpreter;
1516

16-
/// Instance of loaded labels
17+
/// Labels file loaded as list
1718
List<String> _labels;
1819

1920
static const String MODEL_FILE_NAME = "detect.tflite";
@@ -79,6 +80,7 @@ class Classifier {
7980
}
8081
}
8182

83+
/// Pre-process the image
8284
TensorImage getProcessedImage(TensorImage inputImage) {
8385
padSize = max(inputImage.height, inputImage.width);
8486
if (imageProcessor == null) {
@@ -91,7 +93,8 @@ class Classifier {
9193
return inputImage;
9294
}
9395

94-
List predict(imageLib.Image image) {
96+
/// Runs object detection on the input image
97+
Map<String, dynamic> predict(imageLib.Image image) {
9598
var predictStartTime = DateTime.now().millisecondsSinceEpoch;
9699

97100
if (_interpreter == null) {
@@ -110,16 +113,16 @@ class Classifier {
110113
var preProcessElapsedTime =
111114
DateTime.now().millisecondsSinceEpoch - preProcessStart;
112115

113-
// Inputs object for runForMultipleInputs
114-
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
115-
List<Object> inputs = [inputImage.buffer];
116-
117116
// TensorBuffers for output tensors
118117
TensorBuffer outputLocations = TensorBufferFloat(_outputShapes[0]);
119118
TensorBuffer outputClasses = TensorBufferFloat(_outputShapes[1]);
120119
TensorBuffer outputScores = TensorBufferFloat(_outputShapes[2]);
121120
TensorBuffer numLocations = TensorBufferFloat(_outputShapes[3]);
122121

122+
// Inputs object for runForMultipleInputs
123+
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
124+
List<Object> inputs = [inputImage.buffer];
125+
123126
// Outputs map
124127
Map<int, Object> outputs = {
125128
0: outputLocations.buffer,
@@ -129,6 +132,8 @@ class Classifier {
129132
};
130133

131134
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
135+
136+
// run inference
132137
_interpreter.runForMultipleInputs(inputs, outputs);
133138

134139
var inferenceTimeElapsed =
@@ -137,8 +142,6 @@ class Classifier {
137142
// Maximum number of results to show
138143
int resultsCount = min(NUM_RESULTS, numLocations.getIntValue(0));
139144

140-
List<Recognition> recognitions = [];
141-
142145
// Using labelOffset = 1 as ??? at index 0
143146
int labelOffset = 1;
144147

@@ -153,15 +156,23 @@ class Classifier {
153156
width: INPUT_SIZE,
154157
);
155158

159+
List<Recognition> recognitions = [];
160+
156161
for (int i = 0; i < resultsCount; i++) {
157-
var label = _labels.elementAt(outputClasses.getIntValue(i) + labelOffset);
162+
// Prediction score
158163
var score = outputScores.getDoubleValue(i);
159164

160-
// inverse of rect
161-
Rect transformedRect = imageProcessor.inverseTransformRect(
162-
locations[i], image.height, image.width);
165+
// Label string
166+
var labelIndex = outputClasses.getIntValue(i) + labelOffset;
167+
var label = _labels.elementAt(labelIndex);
163168

164169
if (score > THRESHOLD) {
170+
// inverse of rect
171+
// [locations] corresponds to the image size 300 X 300
172+
// inverseTransformRect transforms it our [inputImage]
173+
Rect transformedRect = imageProcessor.inverseTransformRect(
174+
locations[i], image.height, image.width);
175+
165176
recognitions.add(
166177
Recognition(i, label, score, transformedRect),
167178
);
@@ -171,16 +182,13 @@ class Classifier {
171182
var predictElapsedTime =
172183
DateTime.now().millisecondsSinceEpoch - predictStartTime;
173184

174-
// print(
175-
// 'Classifier.predict | Pre-process: $preProcessElapsedTime ms | Inference: $inferenceTimeElapsed ms | Total: $predictElapsedTime');
176-
177-
return [
178-
recognitions,
179-
Stats(
185+
return {
186+
"recognitions": recognitions,
187+
"stats": Stats(
180188
totalPredictTime: predictElapsedTime,
181189
inferenceTime: inferenceTimeElapsed,
182190
preProcessingTime: preProcessElapsedTime)
183-
];
191+
};
184192
}
185193

186194
/// Gets the interpreter instance

lib/tflite/recognition.dart

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,21 @@ import 'dart:math';
33
import 'package:flutter/cupertino.dart';
44
import 'package:object_detection/ui/camera_view_singleton.dart';
55

6+
/// Represents the recognition output from the model
67
class Recognition {
8+
/// Index of the result
79
int _id;
10+
11+
/// Label of the result
812
String _label;
13+
14+
/// Confidence [0.0, 1.0]
915
double _score;
16+
17+
/// Location of bounding box rect
18+
///
19+
/// The rectangle corresponds to the raw input image
20+
/// passed for inference
1021
Rect _location;
1122

1223
Recognition(this._id, this._label, this._score, [this._location]);
@@ -19,6 +30,11 @@ class Recognition {
1930

2031
Rect get location => _location;
2132

33+
/// Returns bounding box rectangle corresponding to the
34+
/// displayed image on screen
35+
///
36+
/// This is the actual location where rectangle is rendered on
37+
/// the screen
2238
Rect get renderLocation {
2339
// ratioX = screenWidth / imageInputWidth
2440
// ratioY = ratioX if image fits screenWidth with aspectRatio = constant

lib/tflite/stats.dart

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1+
/// Bundles different elapsed times
12
class Stats {
2-
int totalElapsedTime;
3+
/// Total time taken in the isolate where the inference runs
34
int totalPredictTime;
5+
6+
/// [totalPredictTime] + communication overhead time
7+
/// between main isolate and another isolate
8+
int totalElapsedTime;
9+
10+
/// Time for which inference runs
411
int inferenceTime;
12+
13+
/// Time taken to pre-process the image
514
int preProcessingTime;
615

7-
Stats({this.totalElapsedTime, this.totalPredictTime, this.inferenceTime,
16+
Stats(
17+
{this.totalPredictTime,
18+
this.totalElapsedTime,
19+
this.inferenceTime,
820
this.preProcessingTime});
921

1022
@override
1123
String toString() {
12-
return 'Stats{totalElapsedTime: $totalElapsedTime, inferenceTime: $inferenceTime, preProcessingTime: $preProcessingTime}';
24+
return 'Stats{totalPredictTime: $totalPredictTime, totalElapsedTime: $totalElapsedTime, inferenceTime: $inferenceTime, preProcessingTime: $preProcessingTime}';
1325
}
1426
}

lib/ui/box_widget.dart

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class BoxWidget extends StatelessWidget {
88
const BoxWidget({Key key, this.result}) : super(key: key);
99
@override
1010
Widget build(BuildContext context) {
11-
11+
// Color for bounding box
1212
Color color = Colors.primaries[
1313
(result.label.length + result.label.codeUnitAt(0) + result.id) %
1414
Colors.primaries.length];
@@ -22,9 +22,8 @@ class BoxWidget extends StatelessWidget {
2222
width: result.renderLocation.width,
2323
height: result.renderLocation.height,
2424
decoration: BoxDecoration(
25-
border: Border.all(color: color, width: 3),
26-
borderRadius: BorderRadius.all(Radius.circular(2))
27-
),
25+
border: Border.all(color: color, width: 3),
26+
borderRadius: BorderRadius.all(Radius.circular(2))),
2827
child: Align(
2928
alignment: Alignment.topLeft,
3029
child: FittedBox(

0 commit comments

Comments
 (0)