diff --git a/yolo/config/config.py b/yolo/config/config.py index 100fc52..274b3f8 100644 --- a/yolo/config/config.py +++ b/yolo/config/config.py @@ -127,6 +127,7 @@ class TrainConfig: task: str epoch: int data: DataConfig + eval_interval: int optimizer: OptimizerConfig loss: LossConfig scheduler: SchedulerConfig diff --git a/yolo/config/task/train.yaml b/yolo/config/task/train.yaml index d3eab6c..19c14d4 100644 --- a/yolo/config/task/train.yaml +++ b/yolo/config/task/train.yaml @@ -4,6 +4,7 @@ defaults: - validation: ../validation epoch: 500 +eval_interval: 10 data: batch_size: 16 diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 69953e3..991a159 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -46,6 +46,7 @@ def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: Progres self.loss_fn = create_loss_function(cfg, vec2box) self.progress = progress self.num_epochs = cfg.task.epoch + self.eval_interval = cfg.task.eval_interval self.mAPs_dict = defaultdict(list) self.weights_dir = self.progress.save_path / "weights" @@ -126,11 +127,11 @@ def save_checkpoint(self, epoch_idx: int, file_name: Optional[str] = None): torch.save(checkpoint, file_path) def good_epoch(self, mAPs: Dict[str, Tensor]) -> bool: - save_flag = True + save_flag = False for mAP_key, mAP_val in mAPs.items(): + if not self.mAPs_dict[mAP_key] or mAP_val > max(self.mAPs_dict[mAP_key]): + save_flag = True self.mAPs_dict[mAP_key].append(mAP_val) - if mAP_val < max(self.mAPs_dict[mAP_key]): - save_flag = False return save_flag def solve(self, dataloader: DataLoader): @@ -146,9 +147,13 @@ def solve(self, dataloader: DataLoader): epoch_loss = self.train_one_epoch(dataloader) self.progress.finish_one_epoch(epoch_loss, epoch_idx=epoch_idx) - mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx) - if mAPs is not None and self.good_epoch(mAPs): - self.save_checkpoint(epoch_idx=epoch_idx) + if (epoch_idx + 1) % self.eval_interval == 0: + mAPs = self.validator.solve(self.validation_dataloader, epoch_idx=epoch_idx) + if mAPs is not None: + self.save_checkpoint(epoch_idx=epoch_idx+1) + if self.good_epoch(mAPs): + self.save_checkpoint(epoch_idx=epoch_idx+1, file_name="best.pt") + # TODO: save model if result are better than before self.progress.finish_train() @@ -259,9 +264,15 @@ def solve(self, dataloader, epoch_idx=1): if self.progress.local_rank != 0: return json.dump(predict_json, f) - if hasattr(self, "coco_gt"): + + if predict_json and hasattr(self, "coco_gt"): self.progress.start_pycocotools() result = calculate_ap(self.coco_gt, predict_json) self.progress.finish_pycocotools(result, epoch_idx) + else: + if not predict_json: + logger.warning("⚠️ No predictions available for evaluation.") + if not hasattr(self, "coco_gt"): + logger.warning("⚠️ COCO ground truth not found. Please check dataset configuration.") - return avg_mAPs + return avg_mAPs \ No newline at end of file