Skip to content

Commit

Permalink
🐛 [Fix] target filtering
Browse files Browse the repository at this point in the history
Changed dummy target filtering to be based on class IDs being -1 and moved it inside the to_metrics_format function.
  • Loading branch information
Adamusen authored Jan 3, 2025
1 parent 1b44260 commit 75bcb34
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
2 changes: 1 addition & 1 deletion yolo/tools/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def validation_step(self, batch, batch_idx):
H, W = images.shape[2:]
predicts = self.post_process(self.ema(images), image_size=[W, H])
self.metric.update([to_metrics_format(predict) for predict in predicts],
[to_metrics_format(target[target.sum(1) > 0]) for target in targets])
[to_metrics_format(target) for target in targets])
return predicts

def on_validation_epoch_end(self):
Expand Down
1 change: 1 addition & 0 deletions yolo/utils/bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def calculate_map(predictions, ground_truths) -> Dict[str, Tensor]:


def to_metrics_format(prediction: Tensor) -> Dict[str, Union[float, Tensor]]:
prediction = prediction[prediction[:, 0] != -1]
bbox = {"boxes": prediction[:, 1:5], "labels": prediction[:, 0].int()}
if prediction.size(1) == 6:
bbox["scores"] = prediction[:, 5]
Expand Down

0 comments on commit 75bcb34

Please sign in to comment.