Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Refactor classes to match with native API #689

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import 'package:google_mlkit_subject_segmentation/google_mlkit_subject_segmentat
import 'coordinates_translator.dart';

class SubjectSegmentationPainter extends CustomPainter {
final SubjectSegmenterMask mask;
final SubjectSegmentationResult mask;
final Size imageSize;
final Color color = Colors.red;
final InputImageRotation rotation;
Expand All @@ -19,18 +19,16 @@ class SubjectSegmentationPainter extends CustomPainter {

@override
void paint(Canvas canvas, Size size) {
final int width = mask.width;
final int height = mask.height;
final List<Subject> subjects = mask.subjects ?? [];
final List<Subject> subjects = mask.subjects;

final paint = Paint()..style = PaintingStyle.fill;

for (final Subject subject in subjects) {
final int startX = subject.startX;
final int startY = subject.startY;
final int subjectWidth = subject.subjectWidth;
final int subjectHeight = subject.subjectHeight;
final List<double> confidences = subject.confidences ?? [];
final int subjectWidth = subject.width;
final int subjectHeight = subject.height;
final List<double> confidences = subject.confidenceMask ?? [];

for (int y = 0; y < subjectHeight; y++) {
for (int x = 0; y < subjectWidth; x++) {
Expand All @@ -40,15 +38,15 @@ class SubjectSegmentationPainter extends CustomPainter {
final int tx = translateX(
absoluteX.toDouble(),
size,
Size(width.toDouble(), height.toDouble()),
Size(imageSize.width.toDouble(), imageSize.height.toDouble()),
rotation,
cameraLensDirection)
.round();

final int ty = translateY(
absoluteY.toDouble(),
size,
Size(width.toDouble(), height.toDouble()),
Size(imageSize.width.toDouble(), imageSize.height.toDouble()),
rotation,
cameraLensDirection)
.round();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ class SubjectSegmenterView extends StatefulWidget {

class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
final SubjectSegmenter _segmenter = SubjectSegmenter(
options: SubjectSegmenterOptions(enableForegroundConfidenceMask: true));
options: SubjectSegmenterOptions(
enableForegroundConfidenceMask: false,
enableForegroundBitmap: false,
enableMultipleSubjects: SubjectResultOptions(
enableConfidenceMask: true,
enableSubjectBitmap: true,
),
),
);
bool _canProcess = true;
bool _isBusy = false;
CustomPaint? _customPaint;
Expand Down Expand Up @@ -45,7 +53,8 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
setState(() {
_text = '';
});
final SubjectSegmenterMask mask = await _segmenter.processImage(inputImage);
final SubjectSegmentationResult mask =
await _segmenter.processImage(inputImage);
if (inputImage.metadata?.size != null &&
inputImage.metadata?.rotation != null) {
final painter = SubjectSegmentationPainter(
Expand All @@ -57,7 +66,7 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
_customPaint = CustomPaint(painter: painter);
} else {
// TODO: set _customPaint to draw on top of image
_text = 'There is a mask with ${mask.subjects?.length} subjects';
_text = 'There is a mask with ${mask.subjects.length} subjects';
_customPaint = null;
}
_isBusy = false;
Expand Down
2 changes: 1 addition & 1 deletion packages/google_mlkit_subject_segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ final segmenter = SubjectSegmenter(options: options);
#### Process image

```dart
final mask = await segmenter.processImage(inputImage);
final result = await segmenter.processImage(inputImage);
```

#### Release resources with `close()`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
import io.flutter.plugin.common.MethodChannel;

public class GoogleMlKitSubjectSegmentationPlugin implements FlutterPlugin {
private MethodChannel channel;
private static final String channelName = "google_mlkit_subject_segmentation";
private MethodChannel channel;
private static final String channelName = "google_mlkit_subject_segmentation";

@Override
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName);
channel.setMethodCallHandler(new SubjectSegmenterProcess(flutterPluginBinding.getApplicationContext()));
}
@Override
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), channelName);
channel.setMethodCallHandler(new SubjectSegmenter(flutterPluginBinding.getApplicationContext()));
}

@Override
public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) {
channel.setMethodCallHandler(null);
}
@Override
public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) {
channel.setMethodCallHandler(null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package com.google_mlkit_subject_segmentation;

import android.content.Context;
import android.graphics.Bitmap;

import androidx.annotation.NonNull;

import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.segmentation.subject.Subject;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentationResult;

import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Map;

import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;

import com.google.mlkit.vision.segmentation.subject.SubjectSegmenterOptions;
import com.google_mlkit_commons.InputImageConverter;

public class SubjectSegmenter implements MethodChannel.MethodCallHandler {
private static final String START = "vision#startSubjectSegmenter";
private static final String CLOSE = "vision#closeSubjectSegmenter";

private final Context context;

private final Map<String, com.google.mlkit.vision.segmentation.subject.SubjectSegmenter> instances = new HashMap<>();

public SubjectSegmenter(Context context) {
this.context = context;
}

@Override
public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) {
String method = call.method;
switch (method) {
case START:
handleDetection(call, result);
break;
case CLOSE:
closeDetector(call);
result.success(null);
break;
default:
result.notImplemented();
break;
}
}

private void handleDetection(MethodCall call, MethodChannel.Result result) {
Map<String, Object> imageData = (Map<String, Object>) call.argument("imageData");
InputImage inputImage = InputImageConverter.getInputImageFromData(imageData, context, result);
if (inputImage == null) return;

String id = call.argument("id");
com.google.mlkit.vision.segmentation.subject.SubjectSegmenter subjectSegmenter = getOrCreateSegmenter(id, call);
subjectSegmenter.process(inputImage).addOnSuccessListener(subjectSegmentationResult -> processResult(subjectSegmentationResult, result)).addOnFailureListener(e -> result.error("Subject segmentation failure!", e.getMessage(), e));
}

private com.google.mlkit.vision.segmentation.subject.SubjectSegmenter getOrCreateSegmenter(String id, MethodCall call) {
return instances.computeIfAbsent(id, k -> initialize(call));
}

private com.google.mlkit.vision.segmentation.subject.SubjectSegmenter initialize(MethodCall call) {
Map<String, Object> options = call.argument("options");
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder();
assert options != null;
configureBuilder(builder, options);
return SubjectSegmentation.getClient(builder.build());
}

private void configureBuilder(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
if (Boolean.TRUE.equals(options.get("enableForegroundBitmap"))) {
builder.enableForegroundBitmap();
}
if (Boolean.TRUE.equals(options.get("enableForegroundConfidenceMask"))) {
builder.enableForegroundConfidenceMask();
}
configureMultipleSubjects(builder, (Map<String, Object>) options.get("enableMultiSubjectBitmap"));
}

private void configureMultipleSubjects(SubjectSegmenterOptions.Builder builder, Map<String, Object> options) {
boolean enableConfidenceMask = Boolean.TRUE.equals(options.get("enableConfidenceMask"));
boolean enableSubjectBitmap = Boolean.TRUE.equals(options.get("enableSubjectBitmap"));
SubjectSegmenterOptions.SubjectResultOptions.Builder subjectResultOptionsBuilder = new SubjectSegmenterOptions.SubjectResultOptions.Builder();
if (enableConfidenceMask) subjectResultOptionsBuilder.enableConfidenceMask();
if (enableSubjectBitmap) subjectResultOptionsBuilder.enableSubjectBitmap();
if (enableConfidenceMask || enableSubjectBitmap) {
builder.enableMultipleSubjects(subjectResultOptionsBuilder.build());
}
}

private void processResult(SubjectSegmentationResult subjectSegmentationResult, MethodChannel.Result result) {
Map<String, Object> resultMap = new HashMap<>();
FloatBuffer foregroundConfidenceMask = subjectSegmentationResult.getForegroundConfidenceMask();
if (foregroundConfidenceMask != null) {
resultMap.put("foregroundConfidenceMask", getConfidenceMask(foregroundConfidenceMask));
}
Bitmap foregroundBitmap = subjectSegmentationResult.getForegroundBitmap();
if (foregroundBitmap != null) {
resultMap.put("foregroundBitmap", getBitmapBytes(foregroundBitmap));
}
List<Map<String, Object>> subjectsData = new ArrayList<>();
for (Subject subject : subjectSegmentationResult.getSubjects()) {
Map<String, Object> subjectData = getStringObjectMap(subject);
subjectsData.add(subjectData);
}
resultMap.put("subjects", subjectsData);
result.success(resultMap);
}

private static float[] getConfidenceMask(FloatBuffer floatBuffer) {
float[] mask = new float[floatBuffer.remaining()];
floatBuffer.get(mask);
return mask;
}

private static byte[] getBitmapBytes(Bitmap bitmap) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
bitmap.compress(Bitmap.CompressFormat.PNG, 100, outputStream);
return outputStream.toByteArray();
}

@NonNull
private static Map<String, Object> getStringObjectMap(Subject subject) {
Map<String, Object> subjectData = new HashMap<>();
subjectData.put("startX", subject.getStartX());
subjectData.put("startY", subject.getStartY());
subjectData.put("width", subject.getWidth());
subjectData.put("height", subject.getHeight());
FloatBuffer confidenceMask = subject.getConfidenceMask();
if (confidenceMask != null) {
subjectData.put("confidenceMask", getConfidenceMask(confidenceMask));
}
Bitmap bitmap = subject.getBitmap();
if (bitmap != null) {
subjectData.put("bitmap", getBitmapBytes(bitmap));
}
return subjectData;
}

private void closeDetector(MethodCall call) {
String id = call.argument("id");
com.google.mlkit.vision.segmentation.subject.SubjectSegmenter subjectSegmenter = instances.get(id);
if (subjectSegmenter == null) return;
subjectSegmenter.close();
instances.remove(id);
}
}
Loading