From 6f3fef46340c2d4d4da427d3c9f30269c878aec2 Mon Sep 17 00:00:00 2001 From: Evelyn Fu Date: Wed, 9 Apr 2025 11:46:19 -0400 Subject: [PATCH 01/14] formatting --- .gitignore | 2 + .../segment_moving_object_data.py | 71 +++++++++++++------ 2 files changed, 50 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index d41c1aa..e57d4e6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ .venv_nerfstudio data checkpoints +configs __pycache__ outputs .vscode +~/ \ No newline at end of file diff --git a/scalable_real2sim/segmentation/segment_moving_object_data.py b/scalable_real2sim/segmentation/segment_moving_object_data.py index 67ece8c..13866f0 100644 --- a/scalable_real2sim/segmentation/segment_moving_object_data.py +++ b/scalable_real2sim/segmentation/segment_moving_object_data.py @@ -14,9 +14,11 @@ import numpy as np import torch +from mmdet.apis import inference_detector, init_detector from PIL import Image from sam2.build_sam import build_sam2, build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor +from torch.cuda.amp import autocast from tqdm import tqdm from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor @@ -112,7 +114,12 @@ def segment_moving_obj_data( "Text prompt must be provided if GUI frames are not specified." ) - sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" + if txt_prompt == "gripper": + sam2_checkpoint = ( + "./checkpoints/checkpoint_gripper_finetune_sam2_200epoch_4_1.pt" + ) + else: + sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" # Download checkpoint if not exist. @@ -142,12 +149,19 @@ def segment_moving_obj_data( # Build Grounding DINO from Hugging Face (used only if not using GUI) if gui_frames is None: - model_id = "IDEA-Research/grounding-dino-tiny" - device = "cuda" if torch.cuda.is_available() else "cpu" - processor = AutoProcessor.from_pretrained(model_id) - grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained( - model_id - ).to(device) + if txt_prompt == "gripper": + config_file = "./configs/grounding_dino_swin-t_finetune_gripper.py" + checkpoint_file = "./checkpoints/best_coco_bbox_mAP_epoch_8.pth" + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = init_detector(config_file, checkpoint_file, device=device) + else: + model_id = "IDEA-Research/grounding-dino-tiny" + model_id = "./checkpoints/best_coco_bbox_mAP_epoch_8.pth" + device = "cuda" if torch.cuda.is_available() else "cpu" + processor = AutoProcessor.from_pretrained(model_id) + grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained( + model_id + ).to(device) # Convert PNG to JPG as required for video predictor. jpg_dir = os.path.join(rgb_dir, "jpg") @@ -425,23 +439,34 @@ def mouse_callback(event, x, y, flags, param): # Function to get DINO boxes def get_dino_boxes(text, frame_idx): img_path = os.path.join(jpg_dir, frame_names[frame_idx]) - image = Image.open(img_path) - - inputs = processor(images=image, text=text, return_tensors="pt").to(device) - with torch.no_grad(): - outputs = grounding_model(**inputs) - - results = processor.post_process_grounded_object_detection( - outputs, - inputs.input_ids, - box_threshold=0.4, - text_threshold=0.3, - target_sizes=[image.size[::-1]], - ) + if txt_prompt == "gripper": + # Use mmdetection api for gripper + with autocast(enabled=False): + results = inference_detector(model, img_path, text_prompt=text) + + input_boxes = results.pred_instances[0].bboxes.cpu().numpy() + confidences = results.pred_instances[0].scores.cpu().numpy().tolist() + class_names = results.pred_instances[0].label_names + else: + image = Image.open(img_path) + + inputs = processor(images=image, text=text, return_tensors="pt").to( + device + ) + with torch.no_grad(): + outputs = grounding_model(**inputs) + + results = processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=0.4, + text_threshold=0.3, + target_sizes=[image.size[::-1]], + ) + input_boxes = results[0]["boxes"].cpu().numpy() + confidences = results[0]["scores"].cpu().numpy().tolist() + class_names = results[0]["labels"] - input_boxes = results[0]["boxes"].cpu().numpy() - confidences = results[0]["scores"].cpu().numpy().tolist() - class_names = results[0]["labels"] return input_boxes, confidences, class_names input_boxes, confidences, class_names = get_dino_boxes( From 1fd0029995f30375037fa66e47436381f7c60456 Mon Sep 17 00:00:00 2001 From: Evelyn Fu Date: Wed, 9 Apr 2025 18:05:54 -0400 Subject: [PATCH 02/14] skip initial frames with no gripper --- .../segment_moving_object_data.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/scalable_real2sim/segmentation/segment_moving_object_data.py b/scalable_real2sim/segmentation/segment_moving_object_data.py index 13866f0..da8b536 100644 --- a/scalable_real2sim/segmentation/segment_moving_object_data.py +++ b/scalable_real2sim/segmentation/segment_moving_object_data.py @@ -28,6 +28,7 @@ PROMPT_TYPE_FOR_VIDEO = "point" # Choose from ["point", "box", "mask"] OFFLOAD_VIDEO_TO_CPU = True # Prevents OOM for large videos but is slower. OFFLOAD_STATE_TO_CPU = True +DINO_CONFIDENCE_THRESHOLD = 0.6 def convert_png_to_jpg(input_folder, output_folder): @@ -469,9 +470,14 @@ def get_dino_boxes(text, frame_idx): return input_boxes, confidences, class_names - input_boxes, confidences, class_names = get_dino_boxes( - txt_prompt, txt_prompt_index - ) + while True: + input_boxes, confidences, class_names = get_dino_boxes( + txt_prompt, txt_prompt_index + ) + if confidences[0] > DINO_CONFIDENCE_THRESHOLD: + break + else: + txt_prompt_index += 1 assert ( len(input_boxes) > 0 @@ -678,6 +684,16 @@ def get_dino_boxes(text, frame_idx): except Exception as e: logging.error(f"Error deleting {f}: {e}") + # save black masks for gripper not found frames + for frame_idx in range(txt_prompt_index): + image_name = frame_names[frame_idx] + img_path = os.path.join(jpg_dir, image_name) + image = cv2.imread(img_path) + mask = np.zeros_like(image).astype(np.uint8) + + mask_name = os.path.splitext(image_name)[0] + ".png" + cv2.imwrite(os.path.join(output_dir, mask_name), mask) + for frame_idx, segments in video_segments.items(): if gui_frames is None and neg_txt_prompt is not None: pos_segments = {k: v for k, v in segments.items() if k < neg_id_start_orig} From f7bb54e820525020f71243c2bf1e7c11a996f991 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Wed, 9 Apr 2025 18:39:51 -0400 Subject: [PATCH 03/14] Update README.md --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index 59cdb88..c4e5be0 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,24 @@ automatic annotations using DINO work well in many cases but can struggle with t gripper masks. All downstream object tracking and reconstruction results are sensitive to the segmentation quality and thus spending a bit of effort here might be worthwhile. +##### Gripper masking with fine tuned models +We provide fine tuned networks for SAM2 and GroundingDINO for the segmentation and annotation +of the gripper used in our provided dataset which can be downloaded from [insert link here]. + +Please put the checkpoint files in `./checkpoints` and the `config` folder in the root directory. +We used mmdetection's implementation to fine tune Grounding DINO. Please see the +[mmdetection Official Github](https://github.com/open-mmlab/mmdetection/tree/main) +for installation instructions. + +When `txt_prompt` is set to `gripper`, the segmentation script will use the gripper fine tuned +models for annotation and segmentation. + +To fine tune your own object detection model for your gripper, see the instructions +from the [mmdetection Grounding DINO README](https://github.com/open-mmlab/mmdetection/blob/main/configs/grounding_dino/README.md). + +To fine tune your own segmentation model for your gripper, see the instructions from the +[SAM2 Training README](https://github.com/facebookresearch/sam2/blob/main/training/README.md) + ### Submodules #### robot_payload_id From 16958eb40a7d26361c93f5de6dd880d4ec19a2a8 Mon Sep 17 00:00:00 2001 From: Evelyn Fu Date: Thu, 10 Apr 2025 01:29:23 -0400 Subject: [PATCH 04/14] formatting --- .gitignore | 1 - configs/coco_detection.py | 102 +++++++ configs/default_runtime.py | 28 ++ ...ding_dino_swin-t_finetune_16xb2_1x_coco.py | 250 ++++++++++++++++++ configs/schedule_1x.py | 29 ++ .../segment_moving_object_data.py | 2 +- 6 files changed, 410 insertions(+), 2 deletions(-) create mode 100644 configs/coco_detection.py create mode 100644 configs/default_runtime.py create mode 100644 configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py create mode 100644 configs/schedule_1x.py diff --git a/.gitignore b/.gitignore index e57d4e6..33ae31e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ .venv_nerfstudio data checkpoints -configs __pycache__ outputs .vscode diff --git a/configs/coco_detection.py b/configs/coco_detection.py new file mode 100644 index 0000000..622c7ff --- /dev/null +++ b/configs/coco_detection.py @@ -0,0 +1,102 @@ +# This configuration file is taken from https://github.com/open-mmlab/mmdetection/tree/main/configs + +# dataset settings +dataset_type = "CocoDataset" +data_root = "data/coco/" + +# Example to use different file client +# Method 1: simply set the data root and let the file I/O module +# automatically infer from prefix (not support LMDB and Memcache yet) + +# data_root = 's3://openmmlab/datasets/detection/coco/' + +# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6 +# backend_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +backend_args = None + +train_pipeline = [ + dict(type="LoadImageFromFile", backend_args=backend_args), + dict(type="LoadAnnotations", with_bbox=True), + dict(type="Resize", scale=(1333, 800), keep_ratio=True), + dict(type="RandomFlip", prob=0.5), + dict(type="PackDetInputs"), +] +test_pipeline = [ + dict(type="LoadImageFromFile", backend_args=backend_args), + dict(type="Resize", scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type="LoadAnnotations", with_bbox=True), + dict( + type="PackDetInputs", + meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor"), + ), +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type="DefaultSampler", shuffle=True), + batch_sampler=dict(type="AspectRatioBatchSampler"), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file="annotations/instances_train2017.json", + data_prefix=dict(img="train2017/"), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args, + ), +) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type="DefaultSampler", shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file="annotations/instances_val2017.json", + data_prefix=dict(img="val2017/"), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args, + ), +) +test_dataloader = val_dataloader + +val_evaluator = dict( + type="CocoMetric", + ann_file=data_root + "annotations/instances_val2017.json", + metric="bbox", + format_only=False, + backend_args=backend_args, +) +test_evaluator = val_evaluator + +# inference on test dataset and +# format the output results for submission. +# test_dataloader = dict( +# batch_size=1, +# num_workers=2, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type='DefaultSampler', shuffle=False), +# dataset=dict( +# type=dataset_type, +# data_root=data_root, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# data_prefix=dict(img='test2017/'), +# test_mode=True, +# pipeline=test_pipeline)) +# test_evaluator = dict( +# type='CocoMetric', +# metric='bbox', +# format_only=True, +# ann_file=data_root + 'annotations/image_info_test-dev2017.json', +# outfile_prefix='./work_dirs/coco_detection/test') diff --git a/configs/default_runtime.py b/configs/default_runtime.py new file mode 100644 index 0000000..088a3a7 --- /dev/null +++ b/configs/default_runtime.py @@ -0,0 +1,28 @@ +# This configuration file is taken from https://github.com/open-mmlab/mmdetection/tree/main/configs + +default_scope = "mmdet" + +default_hooks = dict( + timer=dict(type="IterTimerHook"), + logger=dict(type="LoggerHook", interval=50), + param_scheduler=dict(type="ParamSchedulerHook"), + checkpoint=dict(type="CheckpointHook", interval=1), + sampler_seed=dict(type="DistSamplerSeedHook"), + visualization=dict(type="DetVisualizationHook"), +) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method="fork", opencv_num_threads=0), + dist_cfg=dict(backend="nccl"), +) + +vis_backends = [dict(type="LocalVisBackend")] +visualizer = dict( + type="DetLocalVisualizer", vis_backends=vis_backends, name="visualizer" +) +log_processor = dict(type="LogProcessor", window_size=50, by_epoch=True) + +log_level = "INFO" +load_from = None +resume = False diff --git a/configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py b/configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py new file mode 100644 index 0000000..19f022a --- /dev/null +++ b/configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py @@ -0,0 +1,250 @@ +# This configuration file is taken from https://github.com/open-mmlab/mmdetection/tree/main/configs + +_base_ = ["coco_detection.py", "schedule_1x.py", "default_runtime.py"] +load_from = "https://download.openmmlab.com/mmdetection/v3.0/grounding_dino/groundingdino_swint_ogc_mmdet-822d7e9d.pth" # noqa +lang_model_name = "bert-base-uncased" + +model = dict( + type="GroundingDINO", + num_queries=900, + with_box_refine=True, + as_two_stage=True, + data_preprocessor=dict( + type="DetDataPreprocessor", + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=False, + ), + language_model=dict( + type="BertModel", + name=lang_model_name, + pad_to_max=False, + use_sub_sentence_represent=True, + special_tokens_list=["[CLS]", "[SEP]", ".", "?"], + add_pooling_layer=False, + ), + backbone=dict( + type="SwinTransformer", + embed_dims=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=True, + convert_weights=False, + ), + neck=dict( + type="ChannelMapper", + in_channels=[192, 384, 768], + kernel_size=1, + out_channels=256, + act_cfg=None, + bias=True, + norm_cfg=dict(type="GN", num_groups=32), + num_outs=4, + ), + encoder=dict( + num_layers=6, + num_cp=6, + # visual layer config + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=256, num_levels=4, dropout=0.0), + ffn_cfg=dict(embed_dims=256, feedforward_channels=2048, ffn_drop=0.0), + ), + # text layer config + text_layer_cfg=dict( + self_attn_cfg=dict(num_heads=4, embed_dims=256, dropout=0.0), + ffn_cfg=dict(embed_dims=256, feedforward_channels=1024, ffn_drop=0.0), + ), + # fusion layer config + fusion_layer_cfg=dict( + v_dim=256, l_dim=256, embed_dim=1024, num_heads=4, init_values=1e-4 + ), + ), + decoder=dict( + num_layers=6, + return_intermediate=True, + layer_cfg=dict( + # query self attention layer + self_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to text + cross_attn_text_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + # cross attention layer query to image + cross_attn_cfg=dict(embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg=dict(embed_dims=256, feedforward_channels=2048, ffn_drop=0.0), + ), + post_norm_cfg=None, + ), + positional_encoding=dict(num_feats=128, normalize=True, offset=0.0, temperature=20), + bbox_head=dict( + type="GroundingDINOHead", + num_classes=80, + sync_cls_avg_factor=True, + contrastive_cfg=dict(max_text_len=256, log_scale=0.0, bias=False), + loss_cls=dict( + type="FocalLoss", use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=1.0 + ), # 2.0 in DeformDETR + loss_bbox=dict(type="L1Loss", loss_weight=5.0), + loss_iou=dict(type="GIoULoss", loss_weight=2.0), + ), + dn_cfg=dict( # TODO: Move to model.train_cfg ? + label_noise_scale=0.5, + box_noise_scale=1.0, # 0.4 for DN-DETR + group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100), + ), # TODO: half num_dn_queries + # training and testing settings + train_cfg=dict( + assigner=dict( + type="HungarianAssigner", + match_costs=[ + dict(type="BinaryFocalLossCost", weight=2.0), + dict(type="BBoxL1Cost", weight=5.0, box_format="xywh"), + dict(type="IoUCost", iou_mode="giou", weight=2.0), + ], + ) + ), + test_cfg=dict(max_per_img=300), +) + +# dataset settings +train_pipeline = [ + dict(type="LoadImageFromFile", backend_args=_base_.backend_args), + dict(type="LoadAnnotations", with_bbox=True), + dict(type="RandomFlip", prob=0.5), + dict( + type="RandomChoice", + transforms=[ + [ + dict( + type="RandomChoiceResize", + scales=[ + (480, 1333), + (512, 1333), + (544, 1333), + (576, 1333), + (608, 1333), + (640, 1333), + (672, 1333), + (704, 1333), + (736, 1333), + (768, 1333), + (800, 1333), + ], + keep_ratio=True, + ) + ], + [ + dict( + type="RandomChoiceResize", + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True, + ), + dict( + type="RandomCrop", + crop_type="absolute_range", + crop_size=(384, 600), + allow_negative_crop=True, + ), + dict( + type="RandomChoiceResize", + scales=[ + (480, 1333), + (512, 1333), + (544, 1333), + (576, 1333), + (608, 1333), + (640, 1333), + (672, 1333), + (704, 1333), + (736, 1333), + (768, 1333), + (800, 1333), + ], + keep_ratio=True, + ), + ], + ], + ), + dict( + type="PackDetInputs", + meta_keys=( + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + "flip", + "flip_direction", + "text", + "custom_entities", + ), + ), +] + +test_pipeline = [ + dict(type="LoadImageFromFile", backend_args=_base_.backend_args), + dict(type="FixScaleResize", scale=(800, 1333), keep_ratio=True), + dict(type="LoadAnnotations", with_bbox=True), + dict( + type="PackDetInputs", + meta_keys=( + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + "text", + "custom_entities", + ), + ), +] + +train_dataloader = dict( + dataset=dict( + filter_cfg=dict(filter_empty_gt=False), + pipeline=train_pipeline, + return_classes=True, + ) +) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline, return_classes=True)) +test_dataloader = val_dataloader + +optim_wrapper = dict( + _delete_=True, + type="OptimWrapper", + optimizer=dict(type="AdamW", lr=0.0001, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + "absolute_pos_embed": dict(decay_mult=0.0), + "backbone": dict(lr_mult=0.1), + } + ), +) +# learning policy +max_epochs = 12 +param_scheduler = [ + dict( + type="MultiStepLR", + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[11], + gamma=0.1, + ) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (16 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=32) diff --git a/configs/schedule_1x.py b/configs/schedule_1x.py new file mode 100644 index 0000000..12b2057 --- /dev/null +++ b/configs/schedule_1x.py @@ -0,0 +1,29 @@ +# training schedule for 1x +train_cfg = dict(type="EpochBasedTrainLoop", max_epochs=12, val_interval=1) +val_cfg = dict(type="ValLoop") +test_cfg = dict(type="TestLoop") + +# learning rate +param_scheduler = [ + dict(type="LinearLR", start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type="MultiStepLR", + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1, + ), +] + +# optimizer +optim_wrapper = dict( + type="OptimWrapper", + optimizer=dict(type="SGD", lr=0.02, momentum=0.9, weight_decay=0.0001), +) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/scalable_real2sim/segmentation/segment_moving_object_data.py b/scalable_real2sim/segmentation/segment_moving_object_data.py index da8b536..b8a82fa 100644 --- a/scalable_real2sim/segmentation/segment_moving_object_data.py +++ b/scalable_real2sim/segmentation/segment_moving_object_data.py @@ -151,7 +151,7 @@ def segment_moving_obj_data( # Build Grounding DINO from Hugging Face (used only if not using GUI) if gui_frames is None: if txt_prompt == "gripper": - config_file = "./configs/grounding_dino_swin-t_finetune_gripper.py" + config_file = "./configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py" checkpoint_file = "./checkpoints/best_coco_bbox_mAP_epoch_8.pth" device = "cuda:0" if torch.cuda.is_available() else "cpu" model = init_detector(config_file, checkpoint_file, device=device) From 6564dbba568a07aafa46b134a61bc78093ad7f72 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 01:31:30 -0400 Subject: [PATCH 05/14] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c4e5be0..e4457b4 100644 --- a/README.md +++ b/README.md @@ -112,12 +112,12 @@ to the segmentation quality and thus spending a bit of effort here might be wort We provide fine tuned networks for SAM2 and GroundingDINO for the segmentation and annotation of the gripper used in our provided dataset which can be downloaded from [insert link here]. -Please put the checkpoint files in `./checkpoints` and the `config` folder in the root directory. +Please put the downloaded checkpoint files in the `./checkpoints` directory. We used mmdetection's implementation to fine tune Grounding DINO. Please see the [mmdetection Official Github](https://github.com/open-mmlab/mmdetection/tree/main) for installation instructions. -When `txt_prompt` is set to `gripper`, the segmentation script will use the gripper fine tuned +When `--txt_prompt` is set to `gripper`, the segmentation script will use the gripper fine tuned models for annotation and segmentation. To fine tune your own object detection model for your gripper, see the instructions From 3a398353055fb56a5562980f8e33028ca5f52390 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 21:20:04 -0400 Subject: [PATCH 06/14] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e4457b4..60a5d86 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ to the segmentation quality and thus spending a bit of effort here might be wort ##### Gripper masking with fine tuned models We provide fine tuned networks for SAM2 and GroundingDINO for the segmentation and annotation -of the gripper used in our provided dataset which can be downloaded from [insert link here]. +of the gripper used in our provided dataset which can be downloaded from [here](https://mitprod-my.sharepoint.com/personal/nepfaff_mit_edu/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fnepfaff%5Fmit%5Fedu%2FDocuments%2Fscalable%5Freal2sim%5Fmodel%5Fweights&ga=1). Please put the downloaded checkpoint files in the `./checkpoints` directory. We used mmdetection's implementation to fine tune Grounding DINO. Please see the @@ -120,11 +120,11 @@ for installation instructions. When `--txt_prompt` is set to `gripper`, the segmentation script will use the gripper fine tuned models for annotation and segmentation. -To fine tune your own object detection model for your gripper, see the instructions -from the [mmdetection Grounding DINO README](https://github.com/open-mmlab/mmdetection/blob/main/configs/grounding_dino/README.md). +To fine tune your own object detection model for your gripper, see [these instructions](https://github.com/open-mmlab/mmdetection/blob/main/configs/grounding_dino/README.md) +from the mmdetection Official Github. -To fine tune your own segmentation model for your gripper, see the instructions from the -[SAM2 Training README](https://github.com/facebookresearch/sam2/blob/main/training/README.md) +To fine tune your own segmentation model for your gripper, see [these instructions](https://github.com/facebookresearch/sam2/blob/main/training/README.md) for training from the +SAM2 Official Github. ### Submodules From 9e84dfac54a093e4e9e29ebc9bed39fbce8efd03 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 21:22:55 -0400 Subject: [PATCH 07/14] Update schedule_1x.py with credit --- configs/schedule_1x.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configs/schedule_1x.py b/configs/schedule_1x.py index 12b2057..d0da729 100644 --- a/configs/schedule_1x.py +++ b/configs/schedule_1x.py @@ -1,3 +1,5 @@ +# This configuration file is taken from https://github.com/open-mmlab/mmdetection/tree/main/configs + # training schedule for 1x train_cfg = dict(type="EpochBasedTrainLoop", max_epochs=12, val_interval=1) val_cfg = dict(type="ValLoop") From 9118391fa99ce50ef70a0a2d20015101a4151ecf Mon Sep 17 00:00:00 2001 From: Evelyn Fu Date: Thu, 10 Apr 2025 21:53:30 -0400 Subject: [PATCH 08/14] update segmentation args --- .../segment_moving_object_data.py | 42 +++++++++++++++---- scripts/segment_moving_obj_data.py | 16 +++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/scalable_real2sim/segmentation/segment_moving_object_data.py b/scalable_real2sim/segmentation/segment_moving_object_data.py index b8a82fa..660d506 100644 --- a/scalable_real2sim/segmentation/segment_moving_object_data.py +++ b/scalable_real2sim/segmentation/segment_moving_object_data.py @@ -104,6 +104,8 @@ def segment_moving_obj_data( num_neg_frames: int = 10, debug_dir: str | None = None, gui_frames: list[str] | None = None, + gripper_sam2_path: str | None = None, + gripper_grounding_dino_path: str | None = None, ): # Ensure mutual exclusivity between GUI and text prompts if gui_frames is not None: @@ -116,9 +118,19 @@ def segment_moving_obj_data( ) if txt_prompt == "gripper": - sam2_checkpoint = ( + default_gripper_sam2_path = ( "./checkpoints/checkpoint_gripper_finetune_sam2_200epoch_4_1.pt" ) + if gripper_sam2_path is None: + sam2_checkpoint = default_gripper_sam2_path + else: + sam2_checkpoint = gripper_sam2_path + + if not os.path.exists(sam2_checkpoint): + logging.info( + "Custom gripper segmentation model not found, using default SAM2 model." + ) + sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" else: sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" model_cfg = "sam2_hiera_l.yaml" @@ -149,15 +161,31 @@ def segment_moving_obj_data( image_predictor = SAM2ImagePredictor(sam2_image_model) # Build Grounding DINO from Hugging Face (used only if not using GUI) + use_mmdetection = False if gui_frames is None: if txt_prompt == "gripper": - config_file = "./configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py" - checkpoint_file = "./checkpoints/best_coco_bbox_mAP_epoch_8.pth" - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model = init_detector(config_file, checkpoint_file, device=device) - else: + default_gripper_grounding_dino_path = ( + "./checkpoints/best_coco_bbox_mAP_epoch_8.pth" + ) + if gripper_grounding_dino_path is None: + checkpoint_file = default_gripper_grounding_dino_path + else: + checkpoint_file = gripper_grounding_dino_path + + if os.path.exists(checkpoint_file): + use_mmdetection = True + config_file = ( + "./configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py" + ) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = init_detector(config_file, checkpoint_file, device=device) + else: + logging.info( + "Custom gripper grounding dino model not found, using default model." + ) + + if not use_mmdetection: model_id = "IDEA-Research/grounding-dino-tiny" - model_id = "./checkpoints/best_coco_bbox_mAP_epoch_8.pth" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained( diff --git a/scripts/segment_moving_obj_data.py b/scripts/segment_moving_obj_data.py index b1804ce..c677ec1 100644 --- a/scripts/segment_moving_obj_data.py +++ b/scripts/segment_moving_obj_data.py @@ -96,6 +96,18 @@ def downsample_images(rgb_dir: str, num_images: int) -> None: help="Number of images to subsample to.", default=None, ) + parser.add_argument( + "--gripper_sam2_path", + type=str, + help="Path to custom SAM2 model for gripper segmentation.", + default=None, + ) + parser.add_argument( + "--gripper_grounding_dino_path", + type=str, + help="Path to custom GroundingDINO model for gripper object detection.", + default=None, + ) args = parser.parse_args() rgb_dir = args.rgb_dir output_dir = args.output_dir @@ -106,6 +118,8 @@ def downsample_images(rgb_dir: str, num_images: int) -> None: debug_dir = args.debug_dir gui_frames = args.gui_frames num_images = args.num_images + gripper_sam2_path = args.gripper_sam2_path + gripper_grounding_dino_path = args.gripper_grounding_dino_path if num_images is not None: downsample_images(rgb_dir, num_images) @@ -119,4 +133,6 @@ def downsample_images(rgb_dir: str, num_images: int) -> None: num_neg_frames=num_neg_frames, debug_dir=debug_dir, gui_frames=gui_frames, + gripper_sam2_path=gripper_sam2_path, + gripper_grounding_dino_path=gripper_grounding_dino_path, ) From 622ae7309adfdabf5c202f6701e470bb80215866 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 21:57:04 -0400 Subject: [PATCH 09/14] Add files via upload --- assets/mask_sam2_custom.png | Bin 0 -> 1985 bytes assets/mask_sam2_default.png | Bin 0 -> 2112 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 assets/mask_sam2_custom.png create mode 100644 assets/mask_sam2_default.png diff --git a/assets/mask_sam2_custom.png b/assets/mask_sam2_custom.png new file mode 100644 index 0000000000000000000000000000000000000000..7b7bc1c33af19d2cf189e01d50c7b7b550da1754 GIT binary patch literal 1985 zcmeH||4-9L7{`mwA*>(DoRfajLb+rU5ReQw+gj(K8BYtuUMbbv)e1Wu>yP9)xZGnQrW?@(XNx% z^6T_dPBH4l@3J0{QCVbmkP_){I{})(t?V(7QD@8gRH?eW@9{JT)!%e3CyQIz&&jUER5nbD^glV} zB)P-PwtYA!4M@Px1pa*Sm8+MOUkdDEVRg$nMTch<=+0* zpnPvqj!N{m6FG37>pa&D!p8A(Zx3D@pJe=j=tBKeVg!V1@^onlS`lO^pNZdso$MIh z)F*FuE6@sflZt6P8(fz-9SGa(y7U2j4*kJgur>|JJCYKVgWYTrn@YBh9SZj!UNn{@SvzucPDWq+Uzbv zWAXcxFElJ$R?pjAJ~oCNLchkNP31Bjc$oP<6dm^Fu-kAp`zr7muN>s&rVD`2Or~1? z0w;wAu#I)DLyAzi`C>oOicMUiPUVyMUb?F8%6@XEe)w2cqDoan#<-2}MnEPj(WDkf zMmPx=2D-A!L-;0spBV{M?w`7zSPU(JG0BjT=%|$a8~ByF(#G|?(WPa-1Ag#uO{0x( zAaE3j)C6dLJAp&MKL-}xNnjrkY5s#9z&{6dyqkFbg<4luPt<8(?^%cH z+IJ=^P$^nUDU_aaax4eEBwkPX%%djKI113R$T=dbQ}}o00(~Ccrwl4j7U?e;;GuYx ua3&PZO)Y0))~1i1Ev4tJ{D1PPJu)n8y~Mq7YxdCG-%Quq&8JP<5Bvj>r&~z? literal 0 HcmV?d00001 diff --git a/assets/mask_sam2_default.png b/assets/mask_sam2_default.png new file mode 100644 index 0000000000000000000000000000000000000000..3fcc3b6b6cb658bd6856d27c217d1f261ff9ec38 GIT binary patch literal 2112 zcmeH}{ZA8T0LP2E!5F1ZBlZO+)GKunsR(UXx$)L_fnt2gM#fsJEWXT$J6#adF$a28 zk<#AjkjP6mxA23RZI(E~#)|HIxo1;|^an&s_2Ok_nZw&Hjy9fk4~xl`Wq-ha@Rv`% zFM00nd7isGcXXY@Do&TDi$o%^-DX-Z63sa%5~bWtgMn!NH*K$rL}{sZQ^m&pDZgHP z>9eii{R`!HlU~ORO|0DUwbS-f@W2s&LjY+<>JXr`XW7595oRKvcpm-7N`Q#?Y0hMh zL?@I?779KTN5?rU8^K5J<_wYpvSg3v8kdf(X(pIP+L{qmdP%P|+&($bsl0S;e97%V)Ll~uNWwCN!-i*A|l8kX#Sn<$!s45xZO0m^Lqg*~D;L0JdaMh3$ z_ZH+ut{y8k+~FD^W88MgDQ*{Jl-mP&o7)F@i`)Mf_hCB2B-A>?&H0AY*}IPP)QoCP znd+SVA!qA_+BMnP2jL&E1V4vsGFR(VF|7ALW z(_*%ox(8+%0gL`c3M_%iEZ0$IfC`cXD}_!f6jvA>;I{vV_N;DP#C{sIow%R_0=$(mr%hczs~NWItiE7 zL)6!d>I7-x11c3v!BFC#kd|5s)Y}x{uHkFkYgh@~-FMh8@G<`{fWb#p8oj4ATAQgC zRL$+oTsDQ>L-mt`vY4Z<+@Ti=n+KRF_7!#?wbb9Fh*kAfz)+t!#PqVc?0#y5l*nSX zzRFR(&<77w3Bz3gLge~v_%$fWMmq%C=itK*jwWm-=k+Rk>06|5T>(dP#S z0)p2aKz*zWf0svzMqE?=s_lDj5xo#lVpkhVqLbVv+6cCaT{!v)8sQGpd%^D+u48C7 z3b;q~8BmnrBGDf7B4-8>FokQXl@s#t0(V@pz^@Ryjtdb^LOfLs%EH;8KhU1xBG_86 z5)az>#q0uV0!ggEjxs9HVkaKs0LD*mw*!A0k0@lPqTDU&OkY8U)T;7Bli$ST`d^8kJo#r-Ro)a~^q! z49c7~-ie)NHX)7b37NB!Pr-Vb345t=x=4_NtMp;KU!6dPZg3lr{=6o6%ZY}MqdVg% z#1bsXd~5rIp2A1+>QsY@d|Va>no literal 0 HcmV?d00001 From c6199f5c3fddc293d80a1ec094e9809e057b3f84 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 21:59:10 -0400 Subject: [PATCH 10/14] Update README.md --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 60a5d86..400acd8 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,12 @@ from the mmdetection Official Github. To fine tune your own segmentation model for your gripper, see [these instructions](https://github.com/facebookresearch/sam2/blob/main/training/README.md) for training from the SAM2 Official Github. +An example of segmentation failure on the gripper with default models: +![gripper segment default](assets/mask_sam2_default.png) + +Gripper segmentation on the same image with custom models: +![gripper segment custom](assets/mask_sam2_custom.png) + ### Submodules #### robot_payload_id From a97de4f572c4970e96df5a8fdf630fe754ad2942 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 22:02:40 -0400 Subject: [PATCH 11/14] Add files via upload --- assets/mask_sam2_custom.png | Bin 1985 -> 755 bytes assets/mask_sam2_default.png | Bin 2112 -> 799 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/assets/mask_sam2_custom.png b/assets/mask_sam2_custom.png index 7b7bc1c33af19d2cf189e01d50c7b7b550da1754..85d722e489bbe49b5d39024860844d56eedae50a 100644 GIT binary patch literal 755 zcmV=xP> z+CV;v$#%QXLv1u4?Y{j~aDJy<_PzPf%+7mn9sq(M2!bF8f*}7p7Ufe|uPrY8B?oN2 zy85k+n;m_ZGUhwr+*sSNtEt*$0f5E(kzW%niCOYJE2Jz7#yVizcets>&z2c2E|R2> zjgv>W(t0JO4Ayh=Lk~qru2*53p&8V=M}CYrHS|DD zA<{U&oMdH22z(^`ikDRA(`CTE=2tS$^}URVYSbwm5I?V>o_;&PjptOV(`Mi!S!KJ( zj)$MU*KJc=oralmMK4CLw)>112LN#Nm&&KY=Yv*ruRmg$23(9z`OXS z{IfVcI>PZO(}D*8fQlC@!ROwb4%o~U2LOO8%j#e6p3uAO=zal=3iiy@$n&x3CA~wV z><2uZ68CA2=&k-QprR=)#=X6bxgJ+Lo3aMV!+LA#w`|z9!!0>4s%^-qGhk)2gt7m~ zzq|a^&7s;#*Rp|yV}n{Ep`p3M1j3PbQ$7G|u;X4q zkDBXm(V>*6I5^QBIB+d0790#8I$b8o1ML15`!n-|5Q*=qGD)1-dA4`5_)t#VY&0wT za404f%cmcPG%aIg^px}@$ukb^X6c)nQnBGXI^5Ff9VqGglee?l^%ZQzlZCIDvEC;- z=qn)b`uEyPM7sB@xd>{bu4h6dge=VrbhmhPv#MkjTdIy9^?Vr|ifXtBX~k11cTYtS l1VIo4K@bE%5ClP#@)HjUz}2nOC3gS-002ovPDHLkV1gi9Wpw}m literal 1985 zcmeH||4-9L7{`mwA*>(DoRfajLb+rU5ReQw+gj(K8BYtuUMbbv)e1Wu>yP9)xZGnQrW?@(XNx% z^6T_dPBH4l@3J0{QCVbmkP_){I{})(t?V(7QD@8gRH?eW@9{JT)!%e3CyQIz&&jUER5nbD^glV} zB)P-PwtYA!4M@Px1pa*Sm8+MOUkdDEVRg$nMTch<=+0* zpnPvqj!N{m6FG37>pa&D!p8A(Zx3D@pJe=j=tBKeVg!V1@^onlS`lO^pNZdso$MIh z)F*FuE6@sflZt6P8(fz-9SGa(y7U2j4*kJgur>|JJCYKVgWYTrn@YBh9SZj!UNn{@SvzucPDWq+Uzbv zWAXcxFElJ$R?pjAJ~oCNLchkNP31Bjc$oP<6dm^Fu-kAp`zr7muN>s&rVD`2Or~1? z0w;wAu#I)DLyAzi`C>oOicMUiPUVyMUb?F8%6@XEe)w2cqDoan#<-2}MnEPj(WDkf zMmPx=2D-A!L-;0spBV{M?w`7zSPU(JG0BjT=%|$a8~ByF(#G|?(WPa-1Ag#uO{0x( zAaE3j)C6dLJAp&MKL-}xNnjrkY5s#9z&{6dyqkFbg<4luPt<8(?^%cH z+IJ=^P$^nUDU_aaax4eEBwkPX%%djKI113R$T=dbQ}}o00(~Ccrwl4j7U?e;;GuYx ua3&PZO)Y0))~1i1Ev4tJ{D1PPJu)n8y~Mq7YxdCG-%Quq&8JP<5Bvj>r&~z? diff --git a/assets/mask_sam2_default.png b/assets/mask_sam2_default.png index 3fcc3b6b6cb658bd6856d27c217d1f261ff9ec38..508e0496a83eb5b582d33068eea0f834b599d1b3 100644 GIT binary patch literal 799 zcmV+)1K|9LP)=Gs=8E~mk! zv-OXWe?}590_&xW=q(GP9(o9bUV4z8dKkR~1(i}S*+V@A2EpjTNg0MvA=8*O6lR&a zHFQgt`*TkZQqA3a<(}(lFa17`_x$;M&$;~0Jr@8$5ClOG1VNDh9S!{uy;(OeNH>*n zKr-YX2GCmT_U&;t=%3t5irNg0W|yZNvQrD+1OR~76O04!HmAFJ|IV!(0AzwMdGoXY zTQUNPE^FdRsPJ~S?BELf2hJ=r44`f4^g1dH61_=AfK(cItFEXa?3<(7-3{Tc#0V`x zo}C{$;?#cfK;hbybgI6-jXv}k!vJ0F^#1fDy8wn`Do%^}PwWEZf6)6Oq_75*1jvo6 z2oPLl7XUufovr`?krygL%X09$WTCBQWovNe1WP}d--;Ev$w#MqGBW<;Ii`L9fAQKQ z_s-XGaAZZ=pY72y4siVE-`R!5k-MkrIRJq6lyrcnCdPhnME~sk#EnCy-^arCQAz(R zJ8M6b53X(P?7-!9lV8nA_}_V$`vLrfWTzf?H5t}FwhnIEuT$2opXWx0XYz%QG)9Zf znHKeYf9V5N02QY^0O-y3wpN$TVicWYw0<=KEbgNo=i1e*7~xGfrJphYS5;H*_3bvw zj7s#~o~K=o;s6?>ycO48HN~Nj^3d7>_ah=FonQ%TkN9eii{R`!HlU~ORO|0DUwbS-f@W2s&LjY+<>JXr`XW7595oRKvcpm-7N`Q#?Y0hMh zL?@I?779KTN5?rU8^K5J<_wYpvSg3v8kdf(X(pIP+L{qmdP%P|+&($bsl0S;e97%V)Ll~uNWwCN!-i*A|l8kX#Sn<$!s45xZO0m^Lqg*~D;L0JdaMh3$ z_ZH+ut{y8k+~FD^W88MgDQ*{Jl-mP&o7)F@i`)Mf_hCB2B-A>?&H0AY*}IPP)QoCP znd+SVA!qA_+BMnP2jL&E1V4vsGFR(VF|7ALW z(_*%ox(8+%0gL`c3M_%iEZ0$IfC`cXD}_!f6jvA>;I{vV_N;DP#C{sIow%R_0=$(mr%hczs~NWItiE7 zL)6!d>I7-x11c3v!BFC#kd|5s)Y}x{uHkFkYgh@~-FMh8@G<`{fWb#p8oj4ATAQgC zRL$+oTsDQ>L-mt`vY4Z<+@Ti=n+KRF_7!#?wbb9Fh*kAfz)+t!#PqVc?0#y5l*nSX zzRFR(&<77w3Bz3gLge~v_%$fWMmq%C=itK*jwWm-=k+Rk>06|5T>(dP#S z0)p2aKz*zWf0svzMqE?=s_lDj5xo#lVpkhVqLbVv+6cCaT{!v)8sQGpd%^D+u48C7 z3b;q~8BmnrBGDf7B4-8>FokQXl@s#t0(V@pz^@Ryjtdb^LOfLs%EH;8KhU1xBG_86 z5)az>#q0uV0!ggEjxs9HVkaKs0LD*mw*!A0k0@lPqTDU&OkY8U)T;7Bli$ST`d^8kJo#r-Ro)a~^q! z49c7~-ie)NHX)7b37NB!Pr-Vb345t=x=4_NtMp;KU!6dPZg3lr{=6o6%ZY}MqdVg% z#1bsXd~5rIp2A1+>QsY@d|Va>no From 7d4d31b3ffa94033c4739a6f52c719f3feabc2f8 Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 22:07:14 -0400 Subject: [PATCH 12/14] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 400acd8..62892b8 100644 --- a/README.md +++ b/README.md @@ -126,11 +126,11 @@ from the mmdetection Official Github. To fine tune your own segmentation model for your gripper, see [these instructions](https://github.com/facebookresearch/sam2/blob/main/training/README.md) for training from the SAM2 Official Github. -An example of segmentation failure on the gripper with default models: -![gripper segment default](assets/mask_sam2_default.png) +An example of segmentation failure on the gripper with default models: \ + -Gripper segmentation on the same image with custom models: -![gripper segment custom](assets/mask_sam2_custom.png) +Gripper segmentation on the same image with custom models: \ + ### Submodules From 8abf814505e25fbbfd960c15f90ae39aeb19409e Mon Sep 17 00:00:00 2001 From: evelyn-fu Date: Thu, 10 Apr 2025 22:09:13 -0400 Subject: [PATCH 13/14] Add files via upload --- assets/mask_sam2_custom.png | Bin 755 -> 1985 bytes assets/mask_sam2_default.png | Bin 799 -> 2112 bytes 2 files changed, 0 insertions(+), 0 deletions(-) diff --git a/assets/mask_sam2_custom.png b/assets/mask_sam2_custom.png index 85d722e489bbe49b5d39024860844d56eedae50a..7b7bc1c33af19d2cf189e01d50c7b7b550da1754 100644 GIT binary patch literal 1985 zcmeH||4-9L7{`mwA*>(DoRfajLb+rU5ReQw+gj(K8BYtuUMbbv)e1Wu>yP9)xZGnQrW?@(XNx% z^6T_dPBH4l@3J0{QCVbmkP_){I{})(t?V(7QD@8gRH?eW@9{JT)!%e3CyQIz&&jUER5nbD^glV} zB)P-PwtYA!4M@Px1pa*Sm8+MOUkdDEVRg$nMTch<=+0* zpnPvqj!N{m6FG37>pa&D!p8A(Zx3D@pJe=j=tBKeVg!V1@^onlS`lO^pNZdso$MIh z)F*FuE6@sflZt6P8(fz-9SGa(y7U2j4*kJgur>|JJCYKVgWYTrn@YBh9SZj!UNn{@SvzucPDWq+Uzbv zWAXcxFElJ$R?pjAJ~oCNLchkNP31Bjc$oP<6dm^Fu-kAp`zr7muN>s&rVD`2Or~1? z0w;wAu#I)DLyAzi`C>oOicMUiPUVyMUb?F8%6@XEe)w2cqDoan#<-2}MnEPj(WDkf zMmPx=2D-A!L-;0spBV{M?w`7zSPU(JG0BjT=%|$a8~ByF(#G|?(WPa-1Ag#uO{0x( zAaE3j)C6dLJAp&MKL-}xNnjrkY5s#9z&{6dyqkFbg<4luPt<8(?^%cH z+IJ=^P$^nUDU_aaax4eEBwkPX%%djKI113R$T=dbQ}}o00(~Ccrwl4j7U?e;;GuYx ua3&PZO)Y0))~1i1Ev4tJ{D1PPJu)n8y~Mq7YxdCG-%Quq&8JP<5Bvj>r&~z? literal 755 zcmV=xP> z+CV;v$#%QXLv1u4?Y{j~aDJy<_PzPf%+7mn9sq(M2!bF8f*}7p7Ufe|uPrY8B?oN2 zy85k+n;m_ZGUhwr+*sSNtEt*$0f5E(kzW%niCOYJE2Jz7#yVizcets>&z2c2E|R2> zjgv>W(t0JO4Ayh=Lk~qru2*53p&8V=M}CYrHS|DD zA<{U&oMdH22z(^`ikDRA(`CTE=2tS$^}URVYSbwm5I?V>o_;&PjptOV(`Mi!S!KJ( zj)$MU*KJc=oralmMK4CLw)>112LN#Nm&&KY=Yv*ruRmg$23(9z`OXS z{IfVcI>PZO(}D*8fQlC@!ROwb4%o~U2LOO8%j#e6p3uAO=zal=3iiy@$n&x3CA~wV z><2uZ68CA2=&k-QprR=)#=X6bxgJ+Lo3aMV!+LA#w`|z9!!0>4s%^-qGhk)2gt7m~ zzq|a^&7s;#*Rp|yV}n{Ep`p3M1j3PbQ$7G|u;X4q zkDBXm(V>*6I5^QBIB+d0790#8I$b8o1ML15`!n-|5Q*=qGD)1-dA4`5_)t#VY&0wT za404f%cmcPG%aIg^px}@$ukb^X6c)nQnBGXI^5Ff9VqGglee?l^%ZQzlZCIDvEC;- z=qn)b`uEyPM7sB@xd>{bu4h6dge=VrbhmhPv#MkjTdIy9^?Vr|ifXtBX~k11cTYtS l1VIo4K@bE%5ClP#@)HjUz}2nOC3gS-002ovPDHLkV1gi9Wpw}m diff --git a/assets/mask_sam2_default.png b/assets/mask_sam2_default.png index 508e0496a83eb5b582d33068eea0f834b599d1b3..3fcc3b6b6cb658bd6856d27c217d1f261ff9ec38 100644 GIT binary patch literal 2112 zcmeH}{ZA8T0LP2E!5F1ZBlZO+)GKunsR(UXx$)L_fnt2gM#fsJEWXT$J6#adF$a28 zk<#AjkjP6mxA23RZI(E~#)|HIxo1;|^an&s_2Ok_nZw&Hjy9fk4~xl`Wq-ha@Rv`% zFM00nd7isGcXXY@Do&TDi$o%^-DX-Z63sa%5~bWtgMn!NH*K$rL}{sZQ^m&pDZgHP z>9eii{R`!HlU~ORO|0DUwbS-f@W2s&LjY+<>JXr`XW7595oRKvcpm-7N`Q#?Y0hMh zL?@I?779KTN5?rU8^K5J<_wYpvSg3v8kdf(X(pIP+L{qmdP%P|+&($bsl0S;e97%V)Ll~uNWwCN!-i*A|l8kX#Sn<$!s45xZO0m^Lqg*~D;L0JdaMh3$ z_ZH+ut{y8k+~FD^W88MgDQ*{Jl-mP&o7)F@i`)Mf_hCB2B-A>?&H0AY*}IPP)QoCP znd+SVA!qA_+BMnP2jL&E1V4vsGFR(VF|7ALW z(_*%ox(8+%0gL`c3M_%iEZ0$IfC`cXD}_!f6jvA>;I{vV_N;DP#C{sIow%R_0=$(mr%hczs~NWItiE7 zL)6!d>I7-x11c3v!BFC#kd|5s)Y}x{uHkFkYgh@~-FMh8@G<`{fWb#p8oj4ATAQgC zRL$+oTsDQ>L-mt`vY4Z<+@Ti=n+KRF_7!#?wbb9Fh*kAfz)+t!#PqVc?0#y5l*nSX zzRFR(&<77w3Bz3gLge~v_%$fWMmq%C=itK*jwWm-=k+Rk>06|5T>(dP#S z0)p2aKz*zWf0svzMqE?=s_lDj5xo#lVpkhVqLbVv+6cCaT{!v)8sQGpd%^D+u48C7 z3b;q~8BmnrBGDf7B4-8>FokQXl@s#t0(V@pz^@Ryjtdb^LOfLs%EH;8KhU1xBG_86 z5)az>#q0uV0!ggEjxs9HVkaKs0LD*mw*!A0k0@lPqTDU&OkY8U)T;7Bli$ST`d^8kJo#r-Ro)a~^q! z49c7~-ie)NHX)7b37NB!Pr-Vb345t=x=4_NtMp;KU!6dPZg3lr{=6o6%ZY}MqdVg% z#1bsXd~5rIp2A1+>QsY@d|Va>no literal 799 zcmV+)1K|9LP)=Gs=8E~mk! zv-OXWe?}590_&xW=q(GP9(o9bUV4z8dKkR~1(i}S*+V@A2EpjTNg0MvA=8*O6lR&a zHFQgt`*TkZQqA3a<(}(lFa17`_x$;M&$;~0Jr@8$5ClOG1VNDh9S!{uy;(OeNH>*n zKr-YX2GCmT_U&;t=%3t5irNg0W|yZNvQrD+1OR~76O04!HmAFJ|IV!(0AzwMdGoXY zTQUNPE^FdRsPJ~S?BELf2hJ=r44`f4^g1dH61_=AfK(cItFEXa?3<(7-3{Tc#0V`x zo}C{$;?#cfK;hbybgI6-jXv}k!vJ0F^#1fDy8wn`Do%^}PwWEZf6)6Oq_75*1jvo6 z2oPLl7XUufovr`?krygL%X09$WTCBQWovNe1WP}d--;Ev$w#MqGBW<;Ii`L9fAQKQ z_s-XGaAZZ=pY72y4siVE-`R!5k-MkrIRJq6lyrcnCdPhnME~sk#EnCy-^arCQAz(R zJ8M6b53X(P?7-!9lV8nA_}_V$`vLrfWTzf?H5t}FwhnIEuT$2opXWx0XYz%QG)9Zf znHKeYf9V5N02QY^0O-y3wpN$TVicWYw0<=KEbgNo=i1e*7~xGfrJphYS5;H*_3bvw zj7s#~o~K=o;s6?>ycO48HN~Nj^3d7>_ah=FonQ%TkN Date: Tue, 15 Apr 2025 02:48:58 -0400 Subject: [PATCH 14/14] address pr comments --- .gitignore | 4 +- README.md | 19 ++++-- ...sam2_default.png => mask_sam2_failure.png} | Bin checkpoints/.gitignore | 2 + run_asset_generation.py | 26 ++++++-- .../coco_detection.py | 0 .../default_runtime.py | 0 ...ding_dino_swin-t_finetune_16xb2_1x_coco.py | 0 .../schedule_1x.py | 0 .../segment_moving_object_data.py | 59 +++++++++--------- 10 files changed, 70 insertions(+), 40 deletions(-) rename assets/{mask_sam2_default.png => mask_sam2_failure.png} (100%) create mode 100644 checkpoints/.gitignore rename {configs => scalable_real2sim/segmentation/finetuned_grounding_dino_utils}/coco_detection.py (100%) rename {configs => scalable_real2sim/segmentation/finetuned_grounding_dino_utils}/default_runtime.py (100%) rename {configs => scalable_real2sim/segmentation/finetuned_grounding_dino_utils}/grounding_dino_swin-t_finetune_16xb2_1x_coco.py (100%) rename {configs => scalable_real2sim/segmentation/finetuned_grounding_dino_utils}/schedule_1x.py (100%) diff --git a/.gitignore b/.gitignore index 33ae31e..42cf9d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,6 @@ .venv .venv_nerfstudio data -checkpoints __pycache__ outputs -.vscode -~/ \ No newline at end of file +.vscode \ No newline at end of file diff --git a/README.md b/README.md index 62892b8..8450ea3 100644 --- a/README.md +++ b/README.md @@ -105,17 +105,24 @@ Our segmentation pipeline for obtaining object and gripper masks. You might want do human-in-the-loop segmentation by annotating specific frames with positive/ negative labels for more robust results. We provide a simple GUI for this purpose. The default automatic annotations using DINO work well in many cases but can struggle with the -gripper masks. All downstream object tracking and reconstruction results are sensitive -to the segmentation quality and thus spending a bit of effort here might be worthwhile. +gripper masks. This is possibly because our particular gripper seems to be out of +distribution for SAM2, and thus it looses track of it for long videos. This can be +solved with re-prompting it after failure. All downstream object tracking and +reconstruction results are sensitive to the segmentation quality and thus spending a bit +of effort here might be worthwhile. ##### Gripper masking with fine tuned models +Using fine tuned SAM2 and GroundingDINO networks for a gripper that is out of distribution +can help to remove the extra step of reprompting after failure. + We provide fine tuned networks for SAM2 and GroundingDINO for the segmentation and annotation of the gripper used in our provided dataset which can be downloaded from [here](https://mitprod-my.sharepoint.com/personal/nepfaff_mit_edu/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fnepfaff%5Fmit%5Fedu%2FDocuments%2Fscalable%5Freal2sim%5Fmodel%5Fweights&ga=1). Please put the downloaded checkpoint files in the `./checkpoints` directory. We used mmdetection's implementation to fine tune Grounding DINO. Please see the [mmdetection Official Github](https://github.com/open-mmlab/mmdetection/tree/main) -for installation instructions. +for installation instructions. Make sure to be in the virtual environment set up with poetry, +not the Nerfstudio virtual environment. When `--txt_prompt` is set to `gripper`, the segmentation script will use the gripper fine tuned models for annotation and segmentation. @@ -127,7 +134,7 @@ To fine tune your own segmentation model for your gripper, see [these instructio SAM2 Official Github. An example of segmentation failure on the gripper with default models: \ - + Gripper segmentation on the same image with custom models: \ @@ -201,7 +208,9 @@ Note that this needs to be done once per environment for the robot data from ste ### 4. Run asset generation -The asset generation can be run with `scalable_real2sim/run_asset_generation.py`. +The asset generation can be run with `scalable_real2sim/run_asset_generation.py`. The `--use-finetuned-gripper-segmentation` flag can be specified to use fine tuned SAM2 and GroundingDINO +models for gripper segmentation. See the section on `segment_moving_obj_data.py` for installation +instructions. ## Figures diff --git a/assets/mask_sam2_default.png b/assets/mask_sam2_failure.png similarity index 100% rename from assets/mask_sam2_default.png rename to assets/mask_sam2_failure.png diff --git a/checkpoints/.gitignore b/checkpoints/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/checkpoints/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/run_asset_generation.py b/run_asset_generation.py index fe6db18..4b6d001 100644 --- a/run_asset_generation.py +++ b/run_asset_generation.py @@ -73,7 +73,9 @@ def downsample_images(data_dir: str, num_images: int) -> None: ) -def run_segmentation(data_dir: str, output_dir: str) -> None: +def run_segmentation( + data_dir: str, output_dir: str, use_finetuned_gripper_networks: bool = False +) -> None: start = time.perf_counter() # Detect the object of interest. Need to add a dot for the DINO model. @@ -89,8 +91,12 @@ def run_segmentation(data_dir: str, output_dir: str) -> None: logging.info(f"Detected object of interest: {object_of_interest}") gripper_txt = ( - "Blue plastic robotic gripper with two symmetrical, curved arms " - "attached to the end of a metallic robotic arm." + ("gripper") + if use_finetuned_gripper_networks + else ( + "Blue plastic robotic gripper with two symmetrical, curved arms " + "attached to the end of a metallic robotic arm." + ) ) # Generate the object masks. @@ -383,6 +389,7 @@ def main( skip_segmentation: bool = False, bundle_sdf_interpolate_missing_vertices: bool = False, use_depth: bool = False, + use_finetuned_gripper_segmentation: bool = False, ): logging.info("Starting asset generation...") @@ -422,7 +429,11 @@ def main( # Generate object and gripper masks. if not skip_segmentation: logging.info("Running segmentation...") - run_segmentation(data_dir=object_dir, output_dir=object_dir) + run_segmentation( + data_dir=object_dir, + output_dir=object_dir, + use_finetuned_gripper_networks=use_finetuned_gripper_segmentation, + ) else: logging.info("Skipping segmentation...") if not os.path.exists(os.path.join(object_dir, "masks")): @@ -607,6 +618,12 @@ def main( help="If specified, use depth images for geometric reconstruction when " "supported by the reconstruction method.", ) + parser.add_argument( + "--use-finetuned-gripper-segmentation", + action="store_true", + help="If specified, use fine tuned SAM2 and GroundingDINO models for gripper" + "segmentation.", + ) args = parser.parse_args() if not os.path.exists(args.data_dir): @@ -622,4 +639,5 @@ def main( skip_segmentation=args.skip_segmentation, bundle_sdf_interpolate_missing_vertices=args.bundle_sdf_interpolate_missing_vertices, use_depth=args.use_depth, + use_finetuned_gripper_segmentation=args.use_finetuned_gripper_segmentation, ) diff --git a/configs/coco_detection.py b/scalable_real2sim/segmentation/finetuned_grounding_dino_utils/coco_detection.py similarity index 100% rename from configs/coco_detection.py rename to scalable_real2sim/segmentation/finetuned_grounding_dino_utils/coco_detection.py diff --git a/configs/default_runtime.py b/scalable_real2sim/segmentation/finetuned_grounding_dino_utils/default_runtime.py similarity index 100% rename from configs/default_runtime.py rename to scalable_real2sim/segmentation/finetuned_grounding_dino_utils/default_runtime.py diff --git a/configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py b/scalable_real2sim/segmentation/finetuned_grounding_dino_utils/grounding_dino_swin-t_finetune_16xb2_1x_coco.py similarity index 100% rename from configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py rename to scalable_real2sim/segmentation/finetuned_grounding_dino_utils/grounding_dino_swin-t_finetune_16xb2_1x_coco.py diff --git a/configs/schedule_1x.py b/scalable_real2sim/segmentation/finetuned_grounding_dino_utils/schedule_1x.py similarity index 100% rename from configs/schedule_1x.py rename to scalable_real2sim/segmentation/finetuned_grounding_dino_utils/schedule_1x.py diff --git a/scalable_real2sim/segmentation/segment_moving_object_data.py b/scalable_real2sim/segmentation/segment_moving_object_data.py index 660d506..8d267ed 100644 --- a/scalable_real2sim/segmentation/segment_moving_object_data.py +++ b/scalable_real2sim/segmentation/segment_moving_object_data.py @@ -14,7 +14,14 @@ import numpy as np import torch -from mmdet.apis import inference_detector, init_detector +try: + from mmdet.apis import inference_detector, init_detector + + MMDET_AVAILABLE = True +except ImportError: + logging.warning("... not installed. Finetuned segmentation model not available.") + MMDET_AVAILABLE = False + from PIL import Image from sam2.build_sam import build_sam2, build_sam2_video_predictor from sam2.sam2_image_predictor import SAM2ImagePredictor @@ -118,18 +125,18 @@ def segment_moving_obj_data( ) if txt_prompt == "gripper": - default_gripper_sam2_path = ( - "./checkpoints/checkpoint_gripper_finetune_sam2_200epoch_4_1.pt" - ) - if gripper_sam2_path is None: - sam2_checkpoint = default_gripper_sam2_path - else: - sam2_checkpoint = gripper_sam2_path - - if not os.path.exists(sam2_checkpoint): - logging.info( - "Custom gripper segmentation model not found, using default SAM2 model." + if MMDET_AVAILABLE: + default_gripper_sam2_path = ( + "./checkpoints/checkpoint_gripper_finetune_sam2_200epoch_4_1.pt" ) + if gripper_sam2_path is None: + sam2_checkpoint = default_gripper_sam2_path + else: + sam2_checkpoint = gripper_sam2_path + + if not os.path.exists(sam2_checkpoint): + raise RuntimeError("Custom gripper segmentation model not found.") + else: sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" else: sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" @@ -161,7 +168,6 @@ def segment_moving_obj_data( image_predictor = SAM2ImagePredictor(sam2_image_model) # Build Grounding DINO from Hugging Face (used only if not using GUI) - use_mmdetection = False if gui_frames is None: if txt_prompt == "gripper": default_gripper_grounding_dino_path = ( @@ -173,18 +179,12 @@ def segment_moving_obj_data( checkpoint_file = gripper_grounding_dino_path if os.path.exists(checkpoint_file): - use_mmdetection = True - config_file = ( - "./configs/grounding_dino_swin-t_finetune_16xb2_1x_coco.py" - ) + config_file = "./scalable_real2sim/segmentation/finetuned_grounding_dino_utils/grounding_dino_swin-t_finetune_16xb2_1x_coco.py" device = "cuda:0" if torch.cuda.is_available() else "cpu" model = init_detector(config_file, checkpoint_file, device=device) else: - logging.info( - "Custom gripper grounding dino model not found, using default model." - ) - - if not use_mmdetection: + raise RuntimeError("Custom gripper grounding dino model not found") + else: model_id = "IDEA-Research/grounding-dino-tiny" device = "cuda" if torch.cuda.is_available() else "cpu" processor = AutoProcessor.from_pretrained(model_id) @@ -470,7 +470,7 @@ def get_dino_boxes(text, frame_idx): img_path = os.path.join(jpg_dir, frame_names[frame_idx]) if txt_prompt == "gripper": # Use mmdetection api for gripper - with autocast(enabled=False): + with autocast(enabled=False): # needed to avoid error results = inference_detector(model, img_path, text_prompt=text) input_boxes = results.pred_instances[0].bboxes.cpu().numpy() @@ -498,17 +498,20 @@ def get_dino_boxes(text, frame_idx): return input_boxes, confidences, class_names - while True: + # Find the first frame in which the gripper appears with confidence > threshold + result_found = False + for txt_prompt_index in range(txt_prompt_index, frame_count): input_boxes, confidences, class_names = get_dino_boxes( txt_prompt, txt_prompt_index ) + if len(input_boxes) == 0: + continue if confidences[0] > DINO_CONFIDENCE_THRESHOLD: + result_found = True break - else: - txt_prompt_index += 1 assert ( - len(input_boxes) > 0 + result_found ), "No results found for the text prompt. Make sure that the prompt ends with a dot '.'!" # Prompt SAM image predictor to get the mask for the object @@ -712,7 +715,7 @@ def get_dino_boxes(text, frame_idx): except Exception as e: logging.error(f"Error deleting {f}: {e}") - # save black masks for gripper not found frames + # Writing black masks until the gripper appears in a frame for frame_idx in range(txt_prompt_index): image_name = frame_names[frame_idx] img_path = os.path.join(jpg_dir, image_name)