@@ -231,7 +231,8 @@ def main(conf: Dict,
231
231
as_half : bool = True ,
232
232
image_list : Optional [Union [Path , List [str ]]] = None ,
233
233
feature_path : Optional [Path ] = None ,
234
- overwrite : bool = False ) -> Path :
234
+ overwrite : bool = False ,
235
+ mask_dir : Optional [Path ] = None ) -> Path :
235
236
logger .info ('Extracting local features with configuration:'
236
237
f'\n { pprint .pformat (conf )} ' )
237
238
@@ -256,6 +257,14 @@ def main(conf: Dict,
256
257
name = dataset .names [idx ]
257
258
pred = model ({'image' : data ['image' ].to (device , non_blocking = True )})
258
259
pred = {k : v [0 ].cpu ().numpy () for k , v in pred .items ()}
260
+ if mask_dir is not None :
261
+ mask_name = str (mask_dir / name ) + '.png'
262
+ # print(mask_name)
263
+ mask = cv2 .imread (mask_name )[:, :, 0 ]
264
+ valid_keypoint = mask [pred ['keypoints' ][:, 1 ].astype ('int' ), pred ['keypoints' ][:, 0 ].astype ('int' )]
265
+ pred ['keypoints' ] = pred ['keypoints' ][valid_keypoint > 0 ]
266
+ pred ['descriptors' ] = pred ['descriptors' ][:, valid_keypoint > 0 ]
267
+ pred ['scores' ] = pred ['scores' ][valid_keypoint > 0 ]
259
268
260
269
pred ['image_size' ] = original_size = data ['original_size' ][0 ].numpy ()
261
270
if 'keypoints' in pred :
0 commit comments