11# Copyright (c) OpenMMLab. All rights reserved.
22from collections import namedtuple
3- from copy import deepcopy
3+ # from copy import deepcopy
44from itertools import product
55from typing import Any , List , Optional , Tuple
66
77import numpy as np
88import torch
9- from mmengine import dump
9+ # from mmengine import dump
1010from munkres import Munkres
1111from torch import Tensor
1212
@@ -77,7 +77,7 @@ def _init_group():
7777 tag_list = [])
7878 return _group
7979
80- group_history = []
80+ # group_history = []
8181
8282 for idx , i in enumerate (keypoint_order ):
8383 # Get all valid candidate of the i-th keypoints
@@ -105,7 +105,7 @@ def _init_group():
105105 group .tag_list .append (tag )
106106
107107 groups .append (group )
108- costs_copy = None
108+ # costs_copy = None
109109 matches = None
110110
111111 else : # Match keypoints to existing groups
@@ -126,7 +126,7 @@ def _init_group():
126126 if num_kpts > num_groups :
127127 padding = np .full ((num_kpts , num_kpts - num_groups ), 1e10 )
128128 costs = np .concatenate ((costs , padding ), axis = 1 )
129- costs_copy = costs .copy ()
129+ # costs_copy = costs.copy()
130130
131131 # Match keypoints and groups by Munkres algorithm
132132 matches = munkres .compute (costs )
@@ -148,18 +148,18 @@ def _init_group():
148148 group .scores [i ] = vals_i [kpt_idx ]
149149 group .tag_list .append (tags_i [kpt_idx ])
150150
151- out = {
152- 'idx' : idx ,
153- 'i' : i ,
154- 'costs' : costs_copy ,
155- 'matches' : matches ,
156- 'kpts' : np .array ([g .kpts for g in groups ]),
157- 'scores' : np .array ([g .scores for g in groups ]),
158- 'tag_list' : [np .array (g .tag_list ) for g in groups ],
159- }
160- group_history .append (deepcopy (out ))
151+ # out = {
152+ # 'idx': idx,
153+ # 'i': i,
154+ # 'costs': costs_copy,
155+ # 'matches': matches,
156+ # 'kpts': np.array([g.kpts for g in groups]),
157+ # 'scores': np.array([g.scores for g in groups]),
158+ # 'tag_list': [np.array(g.tag_list) for g in groups],
159+ # }
160+ # group_history.append(deepcopy(out))
161161
162- dump (group_history , 'group_history.pkl' )
162+ # dump(group_history, 'group_history.pkl')
163163
164164 groups = groups [:max_groups ]
165165 if groups :
@@ -369,10 +369,10 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
369369 L = batch_tags .shape [1 ] // K
370370
371371 # Heatmap NMS
372- dump (batch_heatmaps .cpu ().numpy (), 'heatmaps.pkl' )
372+ # dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl')
373373 batch_heatmaps = batch_heatmap_nms (batch_heatmaps ,
374374 self .decode_nms_kernel )
375- dump (batch_heatmaps .cpu ().numpy (), 'heatmaps_nms.pkl' )
375+ # dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl')
376376
377377 # shape of topk_val, top_indices: (B, K, TopK)
378378 topk_vals , topk_indices = batch_heatmaps .flatten (- 2 , - 1 ).topk (
@@ -534,7 +534,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
534534 blur_kernel_size = self .decode_gaussian_kernel )
535535 else :
536536 keypoints = refine_keypoints (keypoints , heatmaps )
537- # keypoints += 0.75
537+ # The following 0.5-pixel shift is adapted from mmpose 0.x
538+ # where the heatmap center is calculated by a biased
539+ # rounding ``mu=[int(x), int(y)]``. We keep this shift
540+ # operation for now to to compatible with 0.x checkpoints
541+ # In mmpose 1.x, AE heatmap center is calculated by the
542+ # unbiased rounding ``mu=[int(x+0.5), int(y+0.5)], so the
543+ # following shift will be removed in the future.
538544 keypoints += 0.5
539545
540546 batch_keypoints .append (keypoints )
0 commit comments