diff --git a/projects/6DofPoseEstimation/README.md b/projects/6DofPoseEstimation/README.md new file mode 100644 index 0000000000..3fa8e1f89c --- /dev/null +++ b/projects/6DofPoseEstimation/README.md @@ -0,0 +1,29 @@ +Dataset download: +``` +. +├── data +│   └── ape +│ ├── depth +│ ├── mask +│ ├── rgb +│ ├── linemod_preprocessed_train.json +│ └── linemod_preprocessed_test.json +├── models + ├── models_info.yml + └── obj_01.ply +``` + + +| Model | download | +| ----------------------- | ------------------------------------------------------------ | +| RTMDet-tiny-linemod-ape | 链接: https://pan.baidu.com/s/1eJUQyoprhInleDRYS3gBew 提取码: t26i | +| RTMPose-s-linemod-ape | 链接: https://pan.baidu.com/s/14K7E9EhxHE1Kud-my51Rqg 提取码: t4fp | + + + + + + + + + diff --git a/projects/6DofPoseEstimation/configs/datasets/linemod.py b/projects/6DofPoseEstimation/configs/datasets/linemod.py new file mode 100644 index 0000000000..2ab3e99242 --- /dev/null +++ b/projects/6DofPoseEstimation/configs/datasets/linemod.py @@ -0,0 +1,89 @@ +dataset_info = dict( + dataset_name='linemod', + paper_info=dict( + author='', + title='', + container='', + year='', + homepage='', + ), + keypoint_info={ + 0: + dict(name='min_min_min', + id=0, + color=[0,0,0], + type='', + swap=''), + 1: + dict(name='min_min_max', + id=1, + color=[0,0,0], + type='', + swap=''), + 2: + dict(name='min_max_min', + id=2, + color=[0,0,0], + type='', + swap=''), + 3: + dict(name='min_max_max', + id=3, + color=[0,0,0], + type='', + swap=''), + 4: + dict(name='max_min_min', + id=4, + color=[0,0,0], + type='', + swap=''), + 5: + dict(name='max_min_max', + id=5, + color=[0,0,0], + type='', + swap=''), + 6: + dict(name='max_max_min', + id=6, + color=[0,0,0], + type='', + swap=''), + 7: + dict(name='max_max_max', + id=7, + color=[0,0,0], + type='', + swap=''), + }, + skeleton_info={ + 0: + dict(link=('min_min_min', 'max_min_min'), id=0, color=[255,0,0]), + 1: + dict(link=('min_min_max', 'max_min_max'), id=1, color=[255,0,0]), + 2: + dict(link=('min_max_max', 'max_max_max'), id=2, color=[255,0,0]), + 3: + dict(link=('max_max_min', 'min_max_min'), id=3, color=[255,0,0]), + 4: + dict(link=('min_min_min', 'min_max_min'), id=4, color=[0,255,0]), + 5: + dict(link=('min_min_max', 'min_max_max'), id=5, color=[0,255,0]), + 6: + dict(link=('max_max_max', 'max_min_max'), id=6, color=[0,255,0]), + 7: + dict(link=('max_min_min', 'max_max_min'), id=7, color=[0,255,0]), + 8: + dict(link=('min_min_min', 'min_min_max'), id=8, color=[0,0,255]), + 9: + dict(link=('max_max_max', 'max_max_min'), id=9, color=[0,0,255]), + 10: + dict(link=('max_min_max', 'max_min_min'), id=10, color=[0,0,255]), + 11: + dict(link=('min_max_min', 'min_max_max'), id=11, color=[0,0,255]) + }, + joint_weights=[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], + sigmas=[0.025, 0.025, 0.025, 0.025, 0.025, 0.025, + 0.025, 0.025] +) \ No newline at end of file diff --git a/projects/6DofPoseEstimation/configs/rtmdet_tiny_ape.py b/projects/6DofPoseEstimation/configs/rtmdet_tiny_ape.py new file mode 100644 index 0000000000..18da106d5d --- /dev/null +++ b/projects/6DofPoseEstimation/configs/rtmdet_tiny_ape.py @@ -0,0 +1,356 @@ +dataset_type = 'CocoDataset' +data_root = 'data/ape/' +metainfo = dict(classes=('ape', )) +NUM_CLASSES = 1 +load_from = None +MAX_EPOCHS = 200 +TRAIN_BATCH_SIZE = 4 +VAL_BATCH_SIZE = 2 +stage2_num_epochs = 20 +base_lr = 0.004 +VAL_INTERVAL = 5 +default_scope = 'mmdet' +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=1), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict( + type='CheckpointHook', + interval=10, + max_keep_ckpts=2, + save_best='coco/bbox_mAP'), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='DetVisualizationHook')) +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='DetLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer') +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) +log_level = 'INFO' +resume = False +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=200, + val_interval=5, + dynamic_intervals=[(180, 1)]) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0, + end=1000), + dict( + type='CosineAnnealingLR', + eta_min=0.0002, + begin=150, + end=300, + T_max=150, + by_epoch=True, + convert_to_iter_based=True) +] +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.004, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) +auto_scale_lr = dict(enable=False, base_batch_size=16) +backend_args = None +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='CachedMosaic', + img_scale=(640, 640), + pad_val=114.0, + max_cached_images=20, + random_pop=False), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=10, + random_pop=False, + pad_val=(114, 114, 114), + prob=0.5), + dict(type='PackDetInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='Resize', scale=(640, 640), keep_ratio=True), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=None, + dataset=dict( + type='CocoDataset', + data_root='data/ape/', + metainfo=dict(classes=('ape', )), + ann_file='linemod_preprocessed_train.json', + data_prefix=dict(img='rgb/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=[ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='CachedMosaic', + img_scale=(640, 640), + pad_val=114.0, + max_cached_images=20, + random_pop=False), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict( + type='Pad', size=(640, 640), + pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=10, + random_pop=False, + pad_val=(114, 114, 114), + prob=0.5), + dict(type='PackDetInputs') + ], + backend_args=None), + pin_memory=True) +val_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='CocoDataset', + data_root='data/ape/', + metainfo=dict(classes=('ape', )), + ann_file='linemod_preprocessed_test.json', + data_prefix=dict(img='rgb/'), + test_mode=True, + pipeline=[ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='Resize', scale=(640, 640), keep_ratio=True), + dict( + type='Pad', size=(640, 640), + pad_val=dict(img=(114, 114, 114))), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) + ], + backend_args=None)) +test_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type='CocoDataset', + data_root='data/ape/', + metainfo=dict(classes=('ape', )), + ann_file='linemod_preprocessed_test.json', + data_prefix=dict(img='rgb/'), + test_mode=True, + pipeline=[ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='Resize', scale=(640, 640), keep_ratio=True), + dict( + type='Pad', size=(640, 640), + pad_val=dict(img=(114, 114, 114))), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) + ], + backend_args=None)) +val_evaluator = dict( + type='CocoMetric', + ann_file='data/ape/linemod_preprocessed_test.json', + metric=['bbox'], + format_only=False, + backend_args=None, + proposal_nums=(100, 1, 10)) +test_evaluator = dict( + type='CocoMetric', + ann_file='data/ape/linemod_preprocessed_test.json', + metric=['bbox'], + format_only=False, + backend_args=None, + proposal_nums=(100, 1, 10)) +tta_model = dict( + type='DetTTAModel', + tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100)) +tta_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict( + type='TestTimeAug', + transforms=[[{ + 'type': 'Resize', + 'scale': (640, 640), + 'keep_ratio': True + }, { + 'type': 'Resize', + 'scale': (320, 320), + 'keep_ratio': True + }, { + 'type': 'Resize', + 'scale': (960, 960), + 'keep_ratio': True + }], + [{ + 'type': 'RandomFlip', + 'prob': 1.0 + }, { + 'type': 'RandomFlip', + 'prob': 0.0 + }], + [{ + 'type': 'Pad', + 'size': (960, 960), + 'pad_val': { + 'img': (114, 114, 114) + } + }], + [{ + 'type': + 'PackDetInputs', + 'meta_keys': + ('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction') + }]]) +] +model = dict( + type='RTMDet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[103.53, 116.28, 123.675], + std=[57.375, 57.12, 58.395], + bgr_to_rgb=False, + batch_augments=None), + backbone=dict( + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.167, + widen_factor=0.375, + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint= + 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' + )), + neck=dict( + type='CSPNeXtPAFPN', + in_channels=[96, 192, 384], + out_channels=96, + num_csp_blocks=1, + expand_ratio=0.5, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + bbox_head=dict( + type='RTMDetSepBNHead', + num_classes=1, + in_channels=96, + stacked_convs=2, + feat_channels=96, + anchor_generator=dict( + type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]), + bbox_coder=dict(type='DistancePointBBoxCoder'), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + with_objectness=False, + exp_on_reg=False, + share_conv=True, + pred_kernel_size=1, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU', inplace=True)), + train_cfg=dict( + assigner=dict(type='DynamicSoftLabelAssigner', topk=13), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=30000, + min_bbox_size=0, + score_thr=0.001, + nms=dict(type='nms', iou_threshold=0.65), + max_per_img=300)) +train_pipeline_stage2 = [ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=(640, 640), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') +] +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=180, + switch_pipeline=[ + dict(type='LoadImageFromFile', backend_args=None), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=(640, 640), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict( + type='Pad', size=(640, 640), + pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') + ]) +] +launcher = 'none' +work_dir = './work_dirs/rtmdet_tiny_ape' diff --git a/projects/6DofPoseEstimation/configs/rtmpose-s_ape.py b/projects/6DofPoseEstimation/configs/rtmpose-s_ape.py new file mode 100644 index 0000000000..219f29d76d --- /dev/null +++ b/projects/6DofPoseEstimation/configs/rtmpose-s_ape.py @@ -0,0 +1,359 @@ +default_scope = 'mmpose' +default_hooks = dict( + timer=dict(type='IterTimerHook', _scope_='mmpose'), + logger=dict(type='LoggerHook', interval=1, _scope_='mmpose'), + param_scheduler=dict(type='ParamSchedulerHook', _scope_='mmpose'), + checkpoint=dict( + type='CheckpointHook', + interval=10, + _scope_='mmpose', + save_best='PCK', + rule='greater', + max_keep_ckpts=2), + sampler_seed=dict(type='DistSamplerSeedHook', _scope_='mmpose'), + visualization=dict( + type='PoseVisualizationHook', enable=False, _scope_='mmpose')) +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=300, + switch_pipeline=[ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0.0, + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5) + ]), + dict( + type='GenerateTarget', + encoder=dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(12, 12), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False)), + dict(type='PackPoseInputs') + ]) +] +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl')) +vis_backends = [dict(type='LocalVisBackend', _scope_='mmpose')] +visualizer = dict( + type='PoseLocalVisualizer', + vis_backends=[dict(type='LocalVisBackend')], + name='visualizer', + _scope_='mmpose') +log_processor = dict( + type='LogProcessor', + window_size=50, + by_epoch=True, + num_digits=6, + _scope_='mmpose') +log_level = 'INFO' +load_from = None +resume = False +backend_args = dict(backend='local') +train_cfg = dict(by_epoch=True, max_epochs=300, val_interval=10) +val_cfg = dict() +test_cfg = dict() +dataset_type = 'CocoDataset' +data_mode = 'topdown' +data_root = 'data/ape/' +dataset_info = dict(from_file='configs/_base_/datasets/linemod.py') +NUM_KEYPOINTS = 8 +max_epochs = 300 +val_interval = 10 +train_batch_size = 32 +val_batch_size = 8 +stage2_num_epochs = 0 +base_lr = 0.0005 +randomness = dict(seed=21) +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0005, weight_decay=0.05), + paramwise_cfg=dict( + norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True)) +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0, end=20), + dict( + type='CosineAnnealingLR', + eta_min=2.5e-05, + begin=150, + end=300, + T_max=150, + by_epoch=True, + convert_to_iter_based=True) +] +auto_scale_lr = dict(base_batch_size=1024) +codec = dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(12, 12), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.67, + widen_factor=0.75, + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint= + 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-m_8xb256-rsb-a1-600e_in1k-ecb3bbd9.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=768, + out_channels=8, + input_size=(256, 256), + in_featuremap_size=(8, 8), + simcc_split_ratio=2.0, + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0.0, + drop_path=0.0, + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10.0, + label_softmax=True), + decoder=dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(12, 12), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False)), + test_cfg=dict(flip_test=True)) +train_pipeline = [ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict( + type='RandomBBoxTransform', scale_factor=[0.8, 1.2], rotate_factor=30), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='ChannelShuffle', p=0.5), + dict(type='CLAHE', p=0.5), + dict(type='ColorJitter', p=0.5), + dict( + type='CoarseDropout', + max_holes=4, + max_height=0.3, + max_width=0.3, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5) + ]), + dict( + type='GenerateTarget', + encoder=dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(12, 12), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False)), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='PackPoseInputs') +] +train_pipeline_stage2 = [ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict( + type='RandomBBoxTransform', + shift_factor=0.0, + scale_factor=[0.75, 1.25], + rotate_factor=60), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='Blur', p=0.1), + dict(type='MedianBlur', p=0.1), + dict( + type='CoarseDropout', + max_holes=1, + max_height=0.4, + max_width=0.4, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5) + ]), + dict( + type='GenerateTarget', + encoder=dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(12, 12), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False)), + dict(type='PackPoseInputs') +] +train_dataloader = dict( + batch_size=32, + num_workers=1, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CocoDataset', + data_root='data/ape/', + metainfo=dict(from_file='configs/_base_/datasets/linemod.py'), + data_mode='topdown', + ann_file='linemod_preprocessed_train.json', + data_prefix=dict(img='rgb/'), + pipeline=[ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict( + type='RandomBBoxTransform', + scale_factor=[0.8, 1.2], + rotate_factor=30), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + type='Albumentation', + transforms=[ + dict(type='ChannelShuffle', p=0.5), + dict(type='CLAHE', p=0.5), + dict(type='ColorJitter', p=0.5), + dict( + type='CoarseDropout', + max_holes=4, + max_height=0.3, + max_width=0.3, + min_holes=1, + min_height=0.2, + min_width=0.2, + p=0.5) + ]), + dict( + type='GenerateTarget', + encoder=dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(12, 12), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False)), + dict(type='PackPoseInputs') + ])) +val_dataloader = dict( + batch_size=8, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type='CocoDataset', + data_root='data/ape/', + metainfo=dict(from_file='configs/_base_/datasets/linemod.py'), + data_mode='topdown', + ann_file='linemod_preprocessed_test.json', + data_prefix=dict(img='rgb/'), + pipeline=[ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='PackPoseInputs') + ])) +test_dataloader = dict( + batch_size=8, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type='CocoDataset', + data_root='data/ape/', + metainfo=dict(from_file='configs/_base_/datasets/linemod.py'), + data_mode='topdown', + ann_file='linemod_preprocessed_test.json', + data_prefix=dict(img='rgb/'), + pipeline=[ + dict(type='LoadImage', backend_args=dict(backend='local')), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=(256, 256)), + dict(type='PackPoseInputs') + ])) +val_evaluator = [ + dict( + type='CocoMetric', ann_file='data/ape/linemod_preprocessed_test.json'), + dict(type='PCKAccuracy'), + dict(type='AUC'), + dict(type='NME', norm_mode='keypoint_distance', keypoint_indices=[1, 2]) +] +test_evaluator = [ + dict( + type='CocoMetric', ann_file='data/ape/linemod_preprocessed_test.json'), + dict(type='PCKAccuracy'), + dict(type='AUC'), + dict(type='NME', norm_mode='keypoint_distance', keypoint_indices=[1, 2]) +] +launcher = 'none' +work_dir = './work_dirs/rtmpose-s_ape' diff --git a/projects/6DofPoseEstimation/datasets/__init__.py b/projects/6DofPoseEstimation/datasets/__init__.py new file mode 100644 index 0000000000..3a2ad6a3c1 --- /dev/null +++ b/projects/6DofPoseEstimation/datasets/__init__.py @@ -0,0 +1,6 @@ +from .linemod_det import LineMODDetCocoDataset +from .linemod_keypoint import LineMODKeypointCocoDataset +from .transforms import CopyPaste6D + +__all__ = ['LineMODDetCocoDataset', 'LineMODKeypointCocoDataset', + 'CopyPaste6D'] diff --git a/projects/6DofPoseEstimation/datasets/linemod_det.py b/projects/6DofPoseEstimation/datasets/linemod_det.py new file mode 100644 index 0000000000..fe0b474e4b --- /dev/null +++ b/projects/6DofPoseEstimation/datasets/linemod_det.py @@ -0,0 +1,64 @@ +import os.path as osp +from mmdet.datasets import CocoDataset +from mmdet.registry import DATASETS + +from typing import Optional, Sequence, Union, List, Callable + +@DATASETS.register_module() +class LineMODDetCocoDataset(CocoDataset): + """Dataset for LineMOD with COCO keypoint style""" + def __init__(self, + *args, + background_path: str = '', + **kwargs): + + super().__init__(*args, **kwargs) + self.background_path = background_path + + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format.""" + img_info = raw_data_info['raw_img_info'] + ann_info = raw_data_info['raw_ann_info'] + + data_info = {} + + img_path = osp.join(self.data_prefix['img'], img_info['file_name']) + data_info['img_path'] = img_path + + mask_path = img_path.replace('rgb', 'mask') + data_info['mask_path'] = mask_path + + data_info['img_id'] = img_info['img_id'] + data_info['height'] = img_info['height'] + data_info['width'] = img_info['width'] + + instances = [] + for i, ann in enumerate(ann_info): + instance = {} + + if ann.get('ignore', False): + continue + x1, y1, w, h = ann['bbox'] + inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) + inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) + if inter_w * inter_h == 0: + continue + if ann['area'] <= 0 or w < 1 or h < 1: + continue + if ann['category_id'] not in self.cat_ids: + continue + bbox = [x1, y1, x1 + w, y1 + h] + + if ann.get('iscrowd', False): + instance['ignore_flag'] = 1 + else: + instance['ignore_flag'] = 0 + instance['bbox'] = bbox + instance['bbox_label'] = self.cat2label[ann['category_id']] + + if ann.get('segmentation', None): + instance['mask'] = ann['segmentation'] + + instances.append(instance) + data_info['instances'] = instances + return data_info \ No newline at end of file diff --git a/projects/6DofPoseEstimation/datasets/linemod_keypoint.py b/projects/6DofPoseEstimation/datasets/linemod_keypoint.py new file mode 100644 index 0000000000..5d6983ee48 --- /dev/null +++ b/projects/6DofPoseEstimation/datasets/linemod_keypoint.py @@ -0,0 +1,76 @@ +import copy +import numpy as np + +from mmpose.datasets import BaseCocoStyleDataset +from mmpose.registry import DATASETS + +from typing import Optional, Sequence, Union, List, Callable + + +@DATASETS.register_module() +class LineMODKeypointCocoDataset(BaseCocoStyleDataset): + """Dataset for LineMOD with COCO keypoint style""" + def __init__(self, + *args, + background_path: str = '', + **kwargs): + + super().__init__(*args, **kwargs) + self.background_path = background_path + + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: + """Parse raw annotation to target format.""" + img = raw_data_info['raw_img_info'] + ann = raw_data_info['raw_ann_info'] + + img_path = img['img_path'] + mask_path = img_path.replace('rgb', 'mask') + + # filter invalid instance + if 'bbox' not in ann or 'keypoints' not in ann: + return None + + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + if 'num_keypoints' in ann: + num_keypoints = ann['num_keypoints'] + else: + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'mask_path': mask_path, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann.get('iscrowd', 0), + 'segmentation': ann.get('segmentation', None), + 'id': ann['id'], + 'category_id': ann['category_id'], + # store the raw annotation of the instance + # it is useful for evaluation without providing ann_file + 'raw_ann_info': copy.deepcopy(ann), + } + + if 'crowdIndex' in img: + data_info['crowd_index'] = img['crowdIndex'] + + return data_info \ No newline at end of file diff --git a/projects/6DofPoseEstimation/datasets/transforms.py b/projects/6DofPoseEstimation/datasets/transforms.py new file mode 100644 index 0000000000..8cd3f4c5fb --- /dev/null +++ b/projects/6DofPoseEstimation/datasets/transforms.py @@ -0,0 +1,53 @@ +import os +import cv2 +import random +import numpy as np +from PIL import Image, ImageMath +from mmpose.registry import TRANSFORMS +from mmcv.transforms import BaseTransform + +def change_background(img, mask, bg): + ow, oh = img.size + bg = bg.resize((ow, oh)).convert('RGB') + + imcs = list(img.split()) + bgcs = list(bg.split()) + maskcs = list(mask.split()) + fics = list(Image.new(img.mode, img.size).split()) + + for c in range(len(imcs)): + negmask = maskcs[c].point(lambda i: 1 - i / 255) + posmask = maskcs[c].point(lambda i: i / 255) + fics[c] = ImageMath.eval("a * c + b * d", a=imcs[c], b=bgcs[c], c=posmask, d=negmask).convert('L') + out = Image.merge(img.mode, tuple(fics)) + + return out + + +@TRANSFORMS.register_module() +class CopyPaste6D(BaseTransform): + """change the background""" + def __init__(self, + background_path = None, + ): + background_path_list = [] + for filename in os.listdir(background_path): + if filename.endswith(".jpg") or filename.endswith(".png"): + background_path_list.append(os.path.join(background_path, filename)) + self.background_path_list = background_path_list + + def transform(self, results:dict) -> dict: + ## data augmentation + img = results['img'] + maskpath = results.get('mask_path') + + if maskpath is not None: + bgpath = random.choice(self.background_path_list) + img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + mask = Image.open(maskpath) + bg = Image.open(bgpath) + img = change_background(img, mask, bg) + img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) + results['img'] = img + + return results diff --git a/projects/6DofPoseEstimation/demo.py b/projects/6DofPoseEstimation/demo.py new file mode 100644 index 0000000000..9488b6f00a --- /dev/null +++ b/projects/6DofPoseEstimation/demo.py @@ -0,0 +1,135 @@ +import os +import numpy as np +from utils import * +import argparse +import torch +import matplotlib.pyplot as plt + +import mmengine +import mmcv + +from mmengine.registry import init_default_scope +from mmpose.apis import inference_topdown +from mmpose.apis import init_model as init_pose_estimator +from mmpose.structures import merge_data_samples +from mmpose.evaluation.functional import nms + +from mmdet.apis import inference_detector, init_detector +from mmengine.visualization import Visualizer + +import time + + +def predict(image_path, start_time): + + Visualizer.get_instance('visualization_hook') + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + detector = init_detector( + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_milk/rtmdet-tiny_milk.py', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_ape/rtmdet-tiny_ape.py', + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_milk/best_coco_bbox_mAP_epoch_197.pth', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_ape/best_coco_bbox_mAP_epoch_198.pth', + device=device) + + pose_estimator = init_pose_estimator( + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_milk/rtmpose-s_milk.py', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_ape/rtmpose-s_ape.py', + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_milk/best_PCK_epoch_20.pth', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_ape/best_PCK_epoch_260.pth', + device=device, + cfg_options={'model': {'test_cfg': {'output_heatmaps': True}}}) + + init_default_scope(detector.cfg.get('default_scope', 'mmdet')) + + start_time = time.time() + + detect_result = inference_detector(detector, image_path) + CONF_THRES = 0.004 + + pred_instance = detect_result.pred_instances.cpu().numpy() + bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) + bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > CONF_THRES)] + bboxes = bboxes[nms(bboxes, 0.3)][:, :4].astype('int') + + pose_results = inference_topdown(pose_estimator, image_path, bboxes) + data_samples = merge_data_samples(pose_results) + keypoints = data_samples.pred_instances.keypoints.astype('int') + + return keypoints, start_time + + +def parse_args(): + parser = argparse.ArgumentParser(description='demo') + parser.add_argument('--image-path', help='image path') + parser.add_argument('--id', type=str, help='object id, for example: 01') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + image_path = args.image_path + obj_id = args.id + + # 根据图像文件路径,划分物体id 根目录 与 图像名称 + file_name = os.path.basename(image_path) + root_path = '/'.join(image_path.split('/')[:-4]) + '/' + + # get model_info_dict, obj_id + model_info_path = root_path + 'models/' + model_info_dict = mmengine.load(model_info_path + 'models_info.yml') + object_path = os.path.join(root_path, f'data/{args.id}/') + info_dict = mmengine.load(object_path + 'info.yml') + gt_dict = mmengine.load(object_path + 'gt.yml') + + # 根据图像名,获取对应图像的内参 + intrinsic = np.array(info_dict[int(file_name.split(".")[0])]['cam_K']).reshape(3,3) + + # get corner3D (8*3) + corners3D = get_3D_corners(model_info_dict, obj_id) + + # get gt and prediction + start_time = 0 + keypoint_pr, start_time = predict(image_path, start_time) + corners2D_pr = keypoint_pr.reshape(-1,2) + + # Compute [R|t] by pnp ===== pred + R_pr, t_pr = pnp(corners3D, + corners2D_pr, + np.array(intrinsic, dtype='float32')) + Rt_pr = np.concatenate((R_pr, t_pr), axis=1) + proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, intrinsic)) + + # Compute [R|t] by pnp ===== gt + R_gt = np.array(gt_dict[int(file_name.split(".")[0])][0]['cam_R_m2c']).reshape(3,3) + t_gt = np.array(gt_dict[int(file_name.split(".")[0])][0]['cam_t_m2c']).reshape(3,1) + Rt_gt = np.concatenate((R_gt, t_gt), axis=1) + proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, intrinsic)) + + image = mmcv.imread(image_path) + height = image.shape[0] + width = image.shape[1] + + plt.xlim((0, width)) + plt.ylim((0, height)) + + save = mmcv.imresize(image, (width, height))[:,:,::-1] + plt.imshow(save) + + filename = os.path.basename(image_path) + mmcv.imwrite(save, './photo/'+filename) + + # Projections + edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], + [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]] + for edge in edges_corners: + plt.plot(proj_corners_pr[edge, 0], proj_corners_pr[edge, 1], color='g', linewidth=2.0) + plt.plot(proj_corners_gt[edge, 0], proj_corners_gt[edge, 1], color='g', linewidth=2.0) + plt.gca().invert_yaxis() + plt.show() + plt.pause(0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/projects/6DofPoseEstimation/inference.py b/projects/6DofPoseEstimation/inference.py new file mode 100644 index 0000000000..56020875ef --- /dev/null +++ b/projects/6DofPoseEstimation/inference.py @@ -0,0 +1,123 @@ +import os +import numpy as np +from utils import * +import argparse +import torch +import matplotlib.pyplot as plt + +import mmengine +import mmcv + +from mmengine.registry import init_default_scope +from mmpose.apis import inference_topdown +from mmpose.apis import init_model as init_pose_estimator +from mmpose.structures import merge_data_samples +from mmpose.evaluation.functional import nms + +from mmdet.apis import inference_detector, init_detector +from mmengine.visualization import Visualizer + + +def predict(image_path): + + Visualizer.get_instance('visualization_hook') + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + detector = init_detector( + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_milk/rtmdet-tiny_milk.py', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_ape/rtmdet-tiny_ape.py', + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_milk/best_coco_bbox_mAP_epoch_197.pth', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmdet-tiny_ape/best_coco_bbox_mAP_epoch_198.pth', + device=device) + + pose_estimator = init_pose_estimator( + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_milk/rtmpose-s_milk.py', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_ape/rtmpose-s_ape.py', + # '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_milk/best_PCK_epoch_20.pth', + '/home/liuyoufu/code/mmpose-openmmlab/mmpose/work_dirs/rtmpose-s_ape/best_PCK_epoch_260.pth', + device=device, + cfg_options={'model': {'test_cfg': {'output_heatmaps': True}}}) + + init_default_scope(detector.cfg.get('default_scope', 'mmdet')) + + detect_result = inference_detector(detector, image_path) + CONF_THRES = 0.5 + + pred_instance = detect_result.pred_instances.cpu().numpy() + bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1) + bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > CONF_THRES)] + bboxes = bboxes[nms(bboxes, 0.3)][:, :4].astype('int') + + pose_results = inference_topdown(pose_estimator, image_path, bboxes) + data_samples = merge_data_samples(pose_results) + keypoints = data_samples.pred_instances.keypoints.astype('int') + + return keypoints + + +def parse_args(): + parser = argparse.ArgumentParser(description='demo') + parser.add_argument('--image-path', help='image path') + parser.add_argument('--root-path', type=str, help='root path') + parser.add_argument('--id', type=str, help='object id, for example: 01') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + image_path = args.image_path + obj_id = args.id + + # 根据图像文件路径,划分物体id 根目录 与 图像名称 + root_path = args.root_path + + # get model_info_dict, obj_id + model_info_path = root_path + 'models/' + model_info_dict = mmengine.load(model_info_path + 'models_info.yml') + object_path = os.path.join(root_path, f'data/{args.id}/') + info_dict = mmengine.load(object_path + 'info.yml') + + # 根据图像名,获取对应图像的内参 + intrinsic = np.array(info_dict[0]['cam_K']).reshape(3,3) + + # get corner3D (8*3) + corners3D = get_3D_corners(model_info_dict, obj_id) + + # get prediction + keypoint_pr = predict(image_path) + corners2D_pr = keypoint_pr.reshape(-1,2) + + # Compute [R|t] by pnp ===== pred + R_pr, t_pr = pnp(corners3D, + corners2D_pr, + np.array(intrinsic, dtype='float32')) + Rt_pr = np.concatenate((R_pr, t_pr), axis=1) + proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, intrinsic)) + + + image = mmcv.imread(image_path) + height = image.shape[0] + width = image.shape[1] + + plt.xlim((0, width)) + plt.ylim((0, height)) + + save = mmcv.imresize(image, (width, height))[:,:,::-1] + plt.imshow(save) + + filename = os.path.basename(image_path) + mmcv.imwrite(save, './photo/'+filename) + + # Projections + edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], + [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]] + for edge in edges_corners: + plt.plot(proj_corners_pr[edge, 0], proj_corners_pr[edge, 1], color='g', linewidth=2.0) + plt.gca().invert_yaxis() + plt.show() + plt.pause(0) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/projects/6DofPoseEstimation/tools/linemod_to_coco.py b/projects/6DofPoseEstimation/tools/linemod_to_coco.py new file mode 100644 index 0000000000..9006b83e6d --- /dev/null +++ b/projects/6DofPoseEstimation/tools/linemod_to_coco.py @@ -0,0 +1,200 @@ +import os +import mmengine +import argparse +import numpy as np +import mmcv + +def parse_examples(data_file): + if not os.path.isfile(data_file): + print(f'Error: file {data_file} does not exist!') + return None + + with open(data_file) as fid: + data_examples = [example.strip() for example in fid if example != ''] + + return data_examples + +def images_info(object_path, data_examples): + all_images_path = os.path.join(object_path, 'rgb') + all_filenames = [ + filename for filename in os.listdir(all_images_path) + if '.png' in filename and filename.replace('.png', '') in data_examples + ] + image_paths = [ + os.path.join(all_images_path, filename) for filename in all_filenames + ] + images = [] + for id, image_path in enumerate(image_paths): + img = mmcv.imread(image_path) + height = img.shape[0] + width = img.shape[1] + images.append(dict(file_name=all_filenames[id], + height=height, + width=width, + id=id)) + return images + +def project_points_3D_to_2D(points_3D, rotation_vector, translation_vector, + camera_matrix): + points_3D = points_3D.reshape(3,1) + rotation_vector = rotation_vector.reshape(3,3) + translation_vector = translation_vector.reshape(3,1) + pixel = camera_matrix.dot( + rotation_vector.dot(points_3D)+translation_vector) + pixel /= pixel[-1] + points_2D = pixel[:2] + + return points_2D + +def insert_np_cam_calibration(filtered_infos): + for info in filtered_infos: + info['cam_K_np'] = np.reshape(np.array(info['cam_K']), newshape=(3, 3)) + + return filtered_infos + +def get_bbox_from_mask(mask, mask_value=None): + if mask_value is None: + seg = np.where(mask != 0) + else: + seg = np.where(mask == mask_value) + # check if mask is empty + if seg[0].size <= 0 or seg[1].size <= 0: + return np.zeros((4, ), dtype=np.float32), False + min_x = np.min(seg[1]) + min_y = np.min(seg[0]) + max_x = np.max(seg[1]) + max_y = np.max(seg[0]) + + return np.array([min_x, min_y, max_x-min_x, max_y-min_y], dtype=np.float32) + +def annotations_info(object_path, data_examples, gt_dict, info_dict, + model_info_dict, obj_id): + all_images_path = os.path.join(object_path, 'rgb') + all_filenames = [ + filename for filename in os.listdir(all_images_path) + if '.png' in filename and filename.replace('.png', '') in data_examples + ] + image_paths = [ + os.path.join(all_images_path, filename) for filename in all_filenames + ] + mask_paths = [ + image_path.replace('rgb', 'mask') for image_path in image_paths + ] + + example_ids = [int(filename.split('.')[0]) for filename in all_filenames] + filtered_gt_lists = [gt_dict[key] for key in example_ids] + filtered_gts = [] + for gt_list in filtered_gt_lists: + all_annos = [anno for anno in gt_list if anno['obj_id'] == int(obj_id)] + if len(all_annos) <= 0: + print('\nError: No annotation found!') + filtered_gts.append(None) + elif len(all_annos) > 1: + print('\nWarning: found more than one annotation.\ + using only the first annotation') + filtered_gts.append(all_annos[0]) + else: + filtered_gts.append(all_annos[0]) + + filtered_infos = [info_dict[key] for key in example_ids] + info_list = insert_np_cam_calibration(filtered_infos) + + id = 0 + annotations = [] + # 获取bbox与keypoints + for gt, info, mask_path in zip(filtered_gts, info_list, mask_paths): + mask = mmcv.imread(mask_path) + annotation = {} + annotation['category_id'] = 1 + annotation['segmentation'] = [] + annotation['iscrowd'] = 0 + annotation['image_id'] = id + annotation['id'] = id # 因为图片中只有一个物体,所以image_id=id + bbox = get_bbox_from_mask(mask) + annotation['bbox'] = bbox + annotation['area'] = bbox[2] * bbox[3] + annotation['num_keypoints'] = 8 + + # keypoints中 不存在的关键点为[0,0] 关键点的第三位是0 没有标注点 1 遮挡点 2正常点 + min_x = model_info_dict[int(obj_id)]['min_x'] + min_y = model_info_dict[int(obj_id)]['min_y'] + min_z = model_info_dict[int(obj_id)]['min_z'] + max_x = min_x + model_info_dict[int(obj_id)]['size_x'] + max_y = min_y + model_info_dict[int(obj_id)]['size_y'] + max_z = min_z + model_info_dict[int(obj_id)]['size_z'] + corners = np.array([[max_x, max_y, min_z], + [max_x, max_y, max_z], + [max_x, min_y, min_z], + [max_x, min_y, max_z], + [min_x, max_y, min_z], + [min_x, max_y, max_z], + [min_x, min_y, min_z], + [min_x, min_y, max_z]]) + corners = [ + project_points_3D_to_2D(corner, np.array(gt['cam_R_m2c']), + np.array(gt['cam_t_m2c']), + info['cam_K_np']) + for corner in corners] + corners = np.array(corners).reshape(8,2) + tmp = np.array([2]*8).reshape(8,1) + corners = np.hstack((corners, tmp)) + corners = corners.reshape(-1) + annotation['keypoints'] = corners + + id += 1 + annotations.append(annotation) + return annotations + +def parse_args(): + parser = argparse.ArgumentParser(description='Create_linemod_json') + parser.add_argument('--root', help='root path') + parser.add_argument('--id', type=str, help='object id, for example: 01') + parser.add_argument('--mode', type=str, help='mode, for example: train') + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + object_path = os.path.join(args.root, f'data/{args.id}/') + data_examples = parse_examples(object_path + args.mode + '.txt') + gt_dict = mmengine.load(object_path + 'gt.yml') + info_dict = mmengine.load(object_path + 'info.yml') + obj_id = args.id + model_info_path = args.root + 'models/' + model_info_dict = mmengine.load(model_info_path + 'models_info.yml') + + # images + images = images_info(object_path, data_examples) + + # annotations + annotations = annotations_info(object_path, data_examples, + gt_dict, info_dict, model_info_dict, + obj_id) + + # categories + object = [{ + 'supercatgory': 'ape', + 'id': 1, + 'name': 'ape', + 'keypoints': [ + 'min_min_min', 'min_min_max', + 'min_max_min', 'min_max_max', + 'max_min_min', 'max_min_max', + 'max_max_min', 'max_max_max'], + 'skeleton': [[0, 4], [1, 5], [3, 7], [6, 2], + [0, 2], [1, 3], [7, 5], [4, 6], + [0, 1], [7, 6], [5, 4], [2, 3]], + }] + + # remove invalid data + linemod_coco = { + 'categories': object, + 'images': images, + 'annotations': annotations + } + out_file = args.root + 'json/linemod_preprocessed_'+ args.mode + '.json' + mmengine.dump(linemod_coco, out_file) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/projects/6DofPoseEstimation/utils/__init__.py b/projects/6DofPoseEstimation/utils/__init__.py new file mode 100644 index 0000000000..ecae66b5f4 --- /dev/null +++ b/projects/6DofPoseEstimation/utils/__init__.py @@ -0,0 +1,4 @@ +from .vis import * + +__all__ = ['get_3D_corners', 'pnp', 'project_points_3D_to_2D', + 'compute_projection'] \ No newline at end of file diff --git a/projects/6DofPoseEstimation/utils/vis.py b/projects/6DofPoseEstimation/utils/vis.py new file mode 100644 index 0000000000..ed596c1f66 --- /dev/null +++ b/projects/6DofPoseEstimation/utils/vis.py @@ -0,0 +1,64 @@ +import cv2 +import numpy as np + +def get_3D_corners(model_info_dict, obj_id): + min_x = model_info_dict[int(obj_id)]['min_x'] + min_y = model_info_dict[int(obj_id)]['min_y'] + min_z = model_info_dict[int(obj_id)]['min_z'] + max_x = min_x + model_info_dict[int(obj_id)]['size_x'] + max_y = min_y + model_info_dict[int(obj_id)]['size_y'] + max_z = min_z + model_info_dict[int(obj_id)]['size_z'] + corners = np.array([[max_x, max_y, min_z], + [max_x, max_y, max_z], + [max_x, min_y, min_z], + [max_x, min_y, max_z], + [min_x, max_y, min_z], + [min_x, max_y, max_z], + [min_x, min_y, min_z], + [min_x, min_y, max_z]]) + return corners + + +def pnp(points_3D, points_2D, cameraMatrix): + try: + distCoeffs = pnp.distCoeffs + except: + distCoeffs = np.zeros((8, 1), dtype='float32') + + assert points_2D.shape[0] == points_2D.shape[0], 'points 3D and points 2D must have same number of vertices' + + points_2D = points_2D.astype(np.float32)[0:8] + points_3D = (points_3D).astype(np.float32) + _, R_exp, t = cv2.solvePnP(points_3D, + points_2D.reshape((-1,1,2)), + cameraMatrix, + distCoeffs) + + R, _ = cv2.Rodrigues(R_exp) + return R, t + + +def project_points_3D_to_2D(points_3D, rotation_vector, translation_vector, + camera_matrix): + points_3D = points_3D.reshape(3,1) + rotation_vector = rotation_vector.reshape(3,3) + translation_vector = translation_vector.reshape(3,1) + pixel = camera_matrix.dot( + rotation_vector.dot(points_3D)+translation_vector) + pixel /= pixel[-1] + points_2D = pixel[:2] + + return points_2D + + +def compute_projection(points_3D, transformation, internal_calibration): + points_3D = points_3D.T + tmp = np.array([1.]*8).reshape(1, 8) + points_3D = np.concatenate((points_3D, tmp)) + + projections_2d = np.zeros((2, points_3D.shape[1]), dtype='float32') + camera_projection = (internal_calibration.dot(transformation)).dot(points_3D) + projections_2d[0, :] = camera_projection[0, :]/camera_projection[2, :] + projections_2d[1, :] = camera_projection[1, :]/camera_projection[2, :] + return projections_2d +