Skip to content

Commit

Permalink
Update subject segmentation painter
Browse files Browse the repository at this point in the history
  • Loading branch information
bensonarafat committed Aug 10, 2024
1 parent 7036418 commit 3bb7882
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 99 deletions.
2 changes: 1 addition & 1 deletion packages/example/android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ android {

defaultConfig {
applicationId "com.google.ml.kit.flutter.example"
minSdkVersion 21
minSdkVersion 24
targetSdkVersion 33
versionCode flutterVersionCode.toInteger()
versionName flutterVersionName
Expand Down
Binary file added packages/example/assets/images/child_dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added packages/example/assets/images/friends.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions packages/example/lib/main.dart
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import 'dart:io';

import 'package:flutter/material.dart';
import 'package:google_ml_kit_example/vision_detector_views/subject_segmenter_view.dart';

import 'nlp_detector_views/entity_extraction_view.dart';
import 'nlp_detector_views/language_identifier_view.dart';
import 'nlp_detector_views/language_translator_view.dart';
Expand All @@ -16,6 +14,7 @@ import 'vision_detector_views/label_detector_view.dart';
import 'vision_detector_views/object_detector_view.dart';
import 'vision_detector_views/pose_detector_view.dart';
import 'vision_detector_views/selfie_segmenter_view.dart';
import 'vision_detector_views/subject_segmenter_view.dart';
import 'vision_detector_views/text_detector_view.dart';

Future<void> main() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,44 @@ class SubjectSegmentationPainter extends CustomPainter {

@override
void paint(Canvas canvas, Size size) {
final width = mask.width;
final height = mask.height;
final confidences = mask.confidences;
final int width = mask.width;
final int height = mask.height;
final List<Subject> subjects = mask.subjects;

final paint = Paint()..style = PaintingStyle.fill;

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
final int tx = translateX(
x.toDouble(),
size,
Size(mask.width.toDouble(), mask.height.toDouble()),
rotation,
cameraLensDirection,
).round();
final int ty = translateY(
y.toDouble(),
size,
Size(mask.width.toDouble(), mask.height.toDouble()),
rotation,
cameraLensDirection,
).round();

final double opacity = confidences[(y * width) + x] * 0.5;
paint.color = color.withOpacity(opacity);
canvas.drawCircle(Offset(tx.toDouble(), ty.toDouble()), 2, paint);
for (final Subject subject in subjects) {
final int startX = subject.startX;
final int startY = subject.startY;
final int subjectWidth = subject.subjectWidth;
final int subjectHeight = subject.subjectHeight;
final List<double> confidences = subject.confidences;

for (int y = 0; y < subjectHeight; y++) {
for (int x = 0; y < subjectWidth; x++) {
final int absoluteX = startX;
final int absoluteY = startY;

final int tx = translateX(
absoluteX.toDouble(),
size,
Size(width.toDouble(), height.toDouble()),
rotation,
cameraLensDirection)
.round();

final int ty = translateY(
absoluteY.toDouble(),
size,
Size(width.toDouble(), height.toDouble()),
rotation,
cameraLensDirection)
.round();

final double opacity = confidences[(y * subjectWidth) + x] * 0.5;
paint.color = color.withOpacity(opacity);
canvas.drawCircle(Offset(tx.toDouble(), ty.toDouble()), 2, paint);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
bool _isBusy = false;
CustomPaint? _customPaint;
String? _text;
var _cameraLensDirection = CameraLensDirection.front;
var _cameraLensDirection = CameraLensDirection.back;

@override
void dispose() async {
Expand Down Expand Up @@ -44,10 +44,9 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
setState(() {
_text = '';
});
final mask = await _segmenter.processImage(inputImage);
final SubjectSegmenterMask mask = await _segmenter.processImage(inputImage);
if (inputImage.metadata?.size != null &&
inputImage.metadata?.rotation != null &&
mask != null) {
inputImage.metadata?.rotation != null) {
final painter = SubjectSegmentationPainter(
mask,
inputImage.metadata!.size,
Expand All @@ -57,8 +56,8 @@ class _SubjectSegmenterViewState extends State<SubjectSegmenterView> {
_customPaint = CustomPaint(painter: painter);
} else {
// TODO: set _customPaint to draw on top of image
_text =
'There is a mask with ${(mask?.confidences ?? []).where((element) => element > 0.8).length} elements';
_text = 'There is a mask with ${mask.subjects.length} subjects';

_customPaint = null;
}
_isBusy = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ android {
}

defaultConfig {
minSdkVersion 21
minSdkVersion 24
}
}

dependencies {
implementation 'com.google.android.gms:play-services-mlkit-subject-segmentation:16.0.0-beta1'
implementation files('/Users/bensonarafat/development/flutter/bin/cache/artifacts/engine/android-arm/flutter.jar')
//implementation files('/Users/bensonarafat/development/flutter/bin/cache/artifacts/engine/android-arm/flutter.jar')
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

public class GoogleMlKitSubjectSegmentationPlugin implements FlutterPlugin {
private MethodChannel channel;
private static final String channelName = "google_mlkit_subject_segmenter";
private static final String channelName = "google_mlkit_subject_segmentation";

@Override
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import com.google.mlkit.vision.segmentation.subject.SubjectSegmentation;
import com.google.mlkit.vision.segmentation.subject.SubjectSegmenter;

import java.util.ArrayList;
import java.util.List;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import io.flutter.Log;
import io.flutter.plugin.common.MethodCall;
import io.flutter.plugin.common.MethodChannel;

Expand All @@ -26,6 +28,8 @@ public class SubjectSegmenterProcess implements MethodChannel.MethodCallHandler

private final Context context;

private static final String TAG = "Logger";

private int imageWidth;
private int imageHeight;

Expand All @@ -37,8 +41,6 @@ public SubjectSegmenterProcess(Context context) {
@Override
public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result result) {
String method = call.method;


switch (method) {
case START:
handleDetection(call, result);
Expand All @@ -54,17 +56,9 @@ public void onMethodCall(@NonNull MethodCall call, @NonNull MethodChannel.Result
}

private SubjectSegmenter initialize(MethodCall call) {
Boolean enableForegroundConfidenceMask = call.argument("enableForegroundConfidenceMask");
Boolean enableForegroundBitmap = call.argument("enableForegroundBitmap");
SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder();

if(Boolean.TRUE.equals(enableForegroundConfidenceMask)){
builder.enableForegroundConfidenceMask();
}
if(Boolean.TRUE.equals(enableForegroundBitmap)){
builder.enableForegroundBitmap();
}

SubjectSegmenterOptions.Builder builder = new SubjectSegmenterOptions.Builder()
.enableMultipleSubjects(new SubjectSegmenterOptions.SubjectResultOptions.Builder()
.enableConfidenceMask().build());
SubjectSegmenterOptions options = builder.build();
return SubjectSegmentation.getClient(options);
}
Expand All @@ -84,27 +78,35 @@ private void handleDetection(MethodCall call, MethodChannel.Result result){

subjectSegmenter.process(inputImage)
.addOnSuccessListener( subjectSegmentationResult -> {
Map<String, Object> map = new HashMap<>();
map.put("maxWidth", imageWidth);
map.put("maxHeight", imageHeight);
List<Subject> subjects = subjectSegmentationResult.getSubjects();
final float[] confidences = new float[imageWidth * imageHeight];
for (int k =0; k < subjects.size(); k++){
Subject subject = subjects.get(k);
FloatBuffer mask = subject.getConfidenceMask();
for(int j = 0; j < subject.getHeight(); j++){
for (int i = 0; j < subject.getWidth(); i++){
if (mask.get() > 0.5f) {
confidences[ (subject.getStartY() + j) * imageWidth + subject.getStartX() + i] = mask.get();
}
}
}
}
map.put("confidences", confidences);
result.success(map);
List<Map<String, Object>> subjectsData = new ArrayList<>();
for(Subject subject : subjectSegmentationResult.getSubjects()){
Map<String, Object> subjectData = getStringObjectMap(subject);
subjectsData.add(subjectData);
}
Map<String, Object> map = new HashMap<>();
map.put("subjects", subjectsData);
map.put("width", imageWidth);
map.put("height", imageHeight);
result.success(map);
}).addOnFailureListener( e -> result.error("Subject segmentation failed!", e.getMessage(), e) );
}

@NonNull
private static Map<String, Object> getStringObjectMap(Subject subject) {
Map<String, Object> subjectData = new HashMap<>();
subjectData.put("startX", subject.getStartX());
subjectData.put("startY", subject.getStartY());
subjectData.put("width", subject.getWidth());
subjectData.put("height", subject.getHeight());

FloatBuffer confidenceMask = subject.getConfidenceMask();
assert confidenceMask != null;
float[] confidences = new float[confidenceMask.remaining()];
confidenceMask.get(confidences);
subjectData.put("confidences", confidences);
return subjectData;
}

private void closeDetector(MethodCall call) {
String id = call.argument("id");
SubjectSegmenter subjectSegmenter = instances.get(id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,72 +6,82 @@ class SubjectSegmenter {
static const MethodChannel _channel =
MethodChannel('google_mlkit_subject_segmentation');

///
/// TODO: Comment here
///
final bool enableForegroundConfidenceMask;

///
/// TODO: comment here
///
final bool enableForegroundBitmap;

SubjectSegmenter({
this.enableForegroundConfidenceMask = true,
this.enableForegroundBitmap = false,
});

/// Instance id.
final id = DateTime.now().microsecondsSinceEpoch.toString();

/// Processes the given [InputImage] for segmentation.
/// Returns the segmentation mask in the given image or nil if there was an error.
Future<SubjectSegmenterMask> processImage(InputImage inputImage) async {
final results = await _channel
.invokeMethod('vision#startSubjectSegmentation', <String, dynamic>{
.invokeMethod('vision#startSubjectSegmenter', <String, dynamic>{
'id': id,
'imageData': inputImage.toJson(),
'enableForegroundConfidenceMask': enableForegroundBitmap,
'enableForegroundBitmap': enableForegroundBitmap,
});

SubjectSegmenterMask masks = SubjectSegmenterMask.fromJson(results);
return masks;
}

/// Closes the detector and releases its resources.
Future<void> close() =>
_channel.invokeMethod('vision#closeSubjectSegmentation', {'id': id});
_channel.invokeMethod('vision#closeSubjectSegmenter', {'id': id});
}

class SubjectSegmenterMask {
/// The width of the mask.
final int width;

/// The height of the mask.
final int height;

/// The confidence of the pixel in the mask being in the foreground.
final List<double> confidences;
final List<Subject> subjects;

/// Constructir to create a instance of [SubjectSegmenterMask].
SubjectSegmenterMask({
required this.width,
required this.height,
required this.confidences,
required this.subjects,
});

/// Returns an instance of [SubjectSegmenterMask] from json
factory SubjectSegmenterMask.fromJson(Map<String, dynamic> json) {
final values = json['confidences'];
final List<double> confidences = [];
for (final item in values) {
confidences.add(double.parse(item.toString()));
}
factory SubjectSegmenterMask.fromJson(Map<dynamic, dynamic> json) {
List<dynamic> list = json['subjects'];
List<Subject> subjects = list.map((e) => Subject.fromJson(e)).toList();
return SubjectSegmenterMask(
width: json['width'] as int,
height: json['height'] as int,
confidences: confidences,
subjects: subjects,
);
}
}

class Subject {
final int startX;
final int startY;
final int subjectWidth;
final int subjectHeight;
final List<double> confidences;

Subject(
{required this.startX,
required this.startY,
required this.subjectWidth,
required this.subjectHeight,
required this.confidences});

factory Subject.fromJson(Map<dynamic, dynamic> json) {
return Subject(
startX: json['startX'] as int,
startY: json['startY'] as int,
subjectWidth: json['width'] as int,
subjectHeight: json['height'] as int,
confidences: json['confidences']);
}

Map<dynamic, dynamic> toJson() {
return {
"startX": startX,
"startY": startY,
"subjectWidth": subjectWidth,
"subjectHeight": subjectHeight,
"confidences": confidences,
};
}
}

0 comments on commit 3bb7882

Please sign in to comment.