Skip to content

Commit 21b3ad1

Browse files
committed
correct two small bugs
1 parent f0b1d32 commit 21b3ad1

File tree

3 files changed

+3
-6
lines changed

3 files changed

+3
-6
lines changed

src/main/java/ai/nets/samj/models/EfficientSamJ.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
420420
code = String.format(code, size);
421421
code += "])" + System.lineSeparator();
422422
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
423-
code += "num_features -= 1" + System.lineSeparator();
424423
}
425424
code += ""
426425
+ "contours_x = []" + System.lineSeparator()
@@ -442,7 +441,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
442441
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
443442
+ " for pp in range(n_points):" + System.lineSeparator()
444443
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
445-
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
444+
+ " extracted_point_labels += [1]" + System.lineSeparator()
446445
+ " ip = torch.reshape(torch.tensor(np.array(extracted_point_prompts).reshape(len(extracted_point_prompts), 2)), [1, 1, -1, 2])" + System.lineSeparator()
447446
+ " il = torch.reshape(torch.tensor(np.array(extracted_point_labels)), [1, 1, -1])" + System.lineSeparator()
448447
+ " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()

src/main/java/ai/nets/samj/models/EfficientViTSamJ.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
501501
code = String.format(code, size);
502502
code += "])" + System.lineSeparator();
503503
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
504-
code += "num_features -= 1" + System.lineSeparator();
505504
}
506505
code += ""
507506
+ "contours_x = []" + System.lineSeparator()
@@ -523,7 +522,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
523522
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
524523
+ " for pp in range(n_points):" + System.lineSeparator()
525524
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
526-
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
525+
+ " extracted_point_labels += [1]" + System.lineSeparator()
527526
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
528527
+ " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator()
529528
+ " point_labels=np.array(extracted_point_labels)," + System.lineSeparator()

src/main/java/ai/nets/samj/models/Sam2.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
475475
code = String.format(code, size);
476476
code += "])" + System.lineSeparator();
477477
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
478-
code += "num_features -= 1" + System.lineSeparator();
479478
}
480479
code += ""
481480
+ "contours_x = []" + System.lineSeparator()
@@ -497,7 +496,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
497496
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
498497
+ " for pp in range(n_points):" + System.lineSeparator()
499498
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
500-
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
499+
+ " extracted_point_labels += [1]" + System.lineSeparator()
501500
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
502501
+ " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator()
503502
+ " point_labels=np.array(extracted_point_labels)," + System.lineSeparator()

0 commit comments

Comments
 (0)