Skip to content

Commit

Permalink
refactor: Refactor classes to match with native API (#689)
Browse files Browse the repository at this point in the history
  • Loading branch information
fbernaly authored Oct 7, 2024
1 parent 714b745 commit 9065223
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 295 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import 'package:google_mlkit_subject_segmentation/google_mlkit_subject_segmentat
import 'coordinates_translator.dart';

class SubjectSegmentationPainter extends CustomPainter {
final SubjectSegmenterMask mask;
final SubjectSegmentationResult mask;
final Size imageSize;
final Color color = Colors.red;
final InputImageRotation rotation;
Expand All @@ -19,18 +19,16 @@ class SubjectSegmentationPainter extends CustomPainter {

@override
void paint(Canvas canvas, Size size) {
final int width = mask.width;
final int height = mask.height;
final List<Subject> subjects = mask.subjects ?? [];
final List<Subject> subjects = mask.subjects;

final paint = Paint()..style = PaintingStyle.fill;

for (final Subject subject in subjects) {
final int startX = subject.startX;
final int startY = subject.startY;
final int subjectWidth = subject.subjectWidth;
final int subjectHeight = subject.subjectHeight;
final List<double> confidences = subject.confidences ?? [];
final int subjectWidth = subject.width;
final int subjectHeight = subject.height;
final List<double> confidences = subject.confidenceMask ?? [];

for (int y = 0; y < subjectHeight; y++) {
for (int x = 0; y < subjectWidth; x++) {
Expand All @@ -40,15 +38,15 @@ class SubjectSegmentationPainter extends CustomPainter {
final int tx = translateX(
absoluteX.toDouble(),
size,
Size(width.toDouble(), height.toDouble()),
Size(imageSize.width.toDouble(), imageSize.height.toDouble()),
rotation,
cameraLensDirection)
.round();

final int ty = translateY(
absoluteY.toDouble(),
size,
Size(width.toDouble(), height.toDouble()),
Size(imageSize.width.toDouble(), imageSize.height.toDouble()),
rotation,
cameraLensDirection)
.round();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ class SubjectSegmenterView extends StatefulWidget {

class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
final SubjectSegmenter _segmenter = SubjectSegmenter(
options: SubjectSegmenterOptions(enableForegroundConfidenceMask: true));
options: SubjectSegmenterOptions(
enableForegroundConfidenceMask: false,
enableForegroundBitmap: false,
enableMultipleSubjects: SubjectResultOptions(
enableConfidenceMask: true,
enableSubjectBitmap: true,
),
),
);
bool _canProcess = true;
bool _isBusy = false;
CustomPaint? _customPaint;
Expand Down Expand Up @@ -45,7 +53,8 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
setState(() {
_text = '';
});
final SubjectSegmenterMask mask = await _segmenter.processImage(inputImage);
final SubjectSegmentationResult mask =
await _segmenter.processImage(inputImage);
if (inputImage.metadata?.size != null &&
inputImage.metadata?.rotation != null) {
final painter = SubjectSegmentationPainter(
Expand All @@ -57,7 +66,7 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
_customPaint = CustomPaint(painter: painter);
} else {
// TODO: set _customPaint to draw on top of image
_text = 'There is a mask with ${mask.subjects?.length} subjects';
_text = 'There is a mask with ${mask.subjects.length} subjects';
_customPaint = null;
}
_isBusy = false;
Expand Down
2 changes: 1 addition & 1 deletion packages/google_mlkit_subject_segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ final segmenter = SubjectSegmenter(options: options);
#### Process image

```dart
final mask = await segmenter.processImage(inputImage);
final result = await segmenter.processImage(inputImage);
```

#### Release resources with `close()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import io.flutter.plugin.common.MethodChannel;

public class GoogleMlKitSubjectSegmentationPlugin implements FlutterPlugin {
private MethodChannel channel;
private static final String channelName = "google_mlkit_subject_segmentation";
private MethodChannel channel;
private static final String channelName = "google_mlkit_subject_segmentation";

@Override
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName);
channel.setMethodCallHandler(new SubjectSegmenterProcess(flutterPluginBinding.getApplicationContext()));
}
@Override
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName);
channel.setMethodCallHandler(new SubjectSegmenter(flutterPluginBinding.getApplicationContext()));
}

@Override
public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) {
channel.setMethodCallHandler(null);
}
@Override
public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) {
channel.setMethodCallHandler(null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package com.google_mlkit_subject_segmentation;

import android.content.Context;
import android.graphics.Bitmap;

import androidx.annotation.NonNull;

import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.segmentation.subject.Subject;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult;

import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;

import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;

import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions;
import com.google_mlkit_commons.InputImageConverter;

public class SubjectSegmenter implements MethodChannel.MethodCallHandler {
private static final String START = "vision#startSubjectSegmenter";
private static final String CLOSE = "vision#closeSubjectSegmenter";

private final Context context;

private final Map<String, com.google.mlkit.vision.segmentation.subject.SubjectSegmenter> instances = new HashMap<>();

public SubjectSegmenter(Context context) {
this.context = context;
}

@Override
public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) {
String method = call.method;
switch (method) {
case START:
handleDetection(call, result);
break;
case CLOSE:
closeDetector(call);
result.success(null);
break;
default:
result.notImplemented();
break;
}
}

private void handleDetection(MethodCall call, MethodChannel.Result result) {
Map<String, Object> imageData = (Map<String, Object>) call.argument("imageData");
InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result);
if (inputImage == null) return;

String id = call.argument("id");
com.google.mlkit.vision.segmentation.subject.SubjectSegmenter subjectSegmenter = getOrCreateSegmenter(id, call);
subjectSegmenter.process(inputImage).addOnSuccessListener(subjectSegmentationResult -> processResult(subjectSegmentationResult, result)).addOnFailureListener(e -> result.error("Subject segmentation failure!", e.getMessage(), e));
}

private com.google.mlkit.vision.segmentation.subject.SubjectSegmenter getOrCreateSegmenter(String id, MethodCall call) {
return instances.computeIfAbsent(id, k -> initialize(call));
}

private com.google.mlkit.vision.segmentation.subject.SubjectSegmenter initialize(MethodCall call) {
Map<String, Object> options = call.argument("options");
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder();
assert options != null;
configureBuilder(builder, options);
return SubjectSegmentation.getClient(builder.build());
}

private void configureBuilder(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
if (Boolean.TRUE.equals(options.get("enableForegroundBitmap"))) {
builder.enableForegroundBitmap();
}
if (Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))) {
builder.enableForegroundConfidenceMask();
}
configureMultipleSubjects(builder, (Map<String, Object>) options.get("enableMultiSubjectBitmap"));
}

private void configureMultipleSubjects(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
boolean enableConfidenceMask = Boolean.TRUE.equals(options.get("enableConfidenceMask"));
boolean enableSubjectBitmap = Boolean.TRUE.equals(options.get("enableSubjectBitmap"));
SubjectSegmenterOptions.SubjectResultOptions.Builder subjectResultOptionsBuilder = new SubjectSegmenterOptions.SubjectResultOptions.Builder();
if (enableConfidenceMask) subjectResultOptionsBuilder.enableConfidenceMask();
if (enableSubjectBitmap) subjectResultOptionsBuilder.enableSubjectBitmap();
if (enableConfidenceMask || enableSubjectBitmap) {
builder.enableMultipleSubjects(subjectResultOptionsBuilder.build());
}
}

private void processResult(SubjectSegmentationResult subjectSegmentationResult, MethodChannel.Result result) {
Map<String, Object> resultMap = new HashMap<>();
FloatBuffer foregroundConfidenceMask = subjectSegmentationResult.getForegroundConfidenceMask();
if (foregroundConfidenceMask != null) {
resultMap.put("foregroundConfidenceMask", getConfidenceMask(foregroundConfidenceMask));
}
Bitmap foregroundBitmap = subjectSegmentationResult.getForegroundBitmap();
if (foregroundBitmap != null) {
resultMap.put("foregroundBitmap", getBitmapBytes(foregroundBitmap));
}
List<Map<String, Object>> subjectsData = new ArrayList<>();
for (Subject subject : subjectSegmentationResult.getSubjects()) {
Map<String, Object> subjectData = getStringObjectMap(subject);
subjectsData.add(subjectData);
}
resultMap.put("subjects", subjectsData);
result.success(resultMap);
}

private static float[] getConfidenceMask(FloatBuffer floatBuffer) {
float[] mask = new float[floatBuffer.remaining()];
floatBuffer.get(mask);
return mask;
}

private static byte[] getBitmapBytes(Bitmap bitmap) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream);
return outputStream.toByteArray();
}

@NonNull
private static Map<String, Object> getStringObjectMap(Subject subject) {
Map<String, Object> subjectData = new HashMap<>();
subjectData.put("startX", subject.getStartX());
subjectData.put("startY", subject.getStartY());
subjectData.put("width", subject.getWidth());
subjectData.put("height", subject.getHeight());
FloatBuffer confidenceMask = subject.getConfidenceMask();
if (confidenceMask != null) {
subjectData.put("confidenceMask", getConfidenceMask(confidenceMask));
}
Bitmap bitmap = subject.getBitmap();
if (bitmap != null) {
subjectData.put("bitmap", getBitmapBytes(bitmap));
}
return subjectData;
}

private void closeDetector(MethodCall call) {
String id = call.argument("id");
com.google.mlkit.vision.segmentation.subject.SubjectSegmenter subjectSegmenter = instances.get(id);
if (subjectSegmenter == null) return;
subjectSegmenter.close();
instances.remove(id);
}
}
Loading

0 comments on commit 9065223

Please sign in to comment.