diff --git a/paddle3d/apis/trainer.py b/paddle3d/apis/trainer.py index 84358932..bfbb25a6 100644 --- a/paddle3d/apis/trainer.py +++ b/paddle3d/apis/trainer.py @@ -113,76 +113,79 @@ def __init__( **dataloader_fn) if isinstance(dataloader_fn, dict) else dataloader_fn - self.train_dataloader = _dataloader_build_fn(train_dataset, self.model) + self.train_dataloader = _dataloader_build_fn( + train_dataset, self.model) if train_dataset else None self.eval_dataloader = _dataloader_build_fn( val_dataset, self.model) if val_dataset else None self.val_dataset = val_dataset - self.resume = resume - vdl_file_name = None - self.iters_per_epoch = len(self.train_dataloader) + if train_dataset: + self.resume = resume + vdl_file_name = None + self.iters_per_epoch = len(self.train_dataloader) - if iters is None: - self.epochs = epochs - self.iters = epochs * self.iters_per_epoch - self.train_by_epoch = True - else: - self.iters = iters - self.epochs = (iters - 1) // self.iters_per_epoch + 1 - self.train_by_epoch = False - - self.cur_iter = 0 - self.cur_epoch = 0 + if iters is None: + self.epochs = epochs + self.iters = epochs * self.iters_per_epoch + self.train_by_epoch = True + else: + self.iters = iters + self.epochs = (iters - 1) // self.iters_per_epoch + 1 + self.train_by_epoch = False - if self.optimizer.__class__.__name__ == 'OneCycleAdam': - self.optimizer.before_run(max_iters=self.iters) + self.cur_iter = 0 + self.cur_epoch = 0 - self.checkpoint = default_checkpoint_build_fn( - **checkpoint) if isinstance(checkpoint, dict) else checkpoint + if self.optimizer.__class__.__name__ == 'OneCycleAdam': + self.optimizer.before_run(max_iters=self.iters) - if isinstance(scheduler, dict): - scheduler.setdefault('train_by_epoch', self.train_by_epoch) - scheduler.setdefault('iters_per_epoch', self.iters_per_epoch) - self.scheduler = default_scheduler_build_fn(**scheduler) - else: - self.scheduler = scheduler - - if self.checkpoint is None: - return - - if not self.checkpoint.empty: - if not resume: - raise RuntimeError( - 'The checkpoint {} is not emtpy! Set `resume=True` to continue training or use another dir as checkpoint' - .format(self.checkpoint.rootdir)) - - if self.checkpoint.meta.get( - 'train_by_epoch') != self.train_by_epoch: - raise RuntimeError( - 'Unable to resume training since the train_by_epoch is inconsistent with that saved in the checkpoint' - ) - - params_dict, opt_dict = self.checkpoint.get() - self.model.set_dict(params_dict) - self.optimizer.set_state_dict(opt_dict) - self.cur_iter = self.checkpoint.meta.get('iters') - self.cur_epoch = self.checkpoint.meta.get('epochs') - self.scheduler.step(self.cur_iter) - - logger.info( - 'Resume model from checkpoint {}, current iter set to {}'. - format(self.checkpoint.rootdir, self.cur_iter)) - vdl_file_name = self.checkpoint.meta['vdl_file_name'] - elif resume: - logger.warning( - "Attempt to restore parameters from an empty checkpoint") + self.checkpoint = default_checkpoint_build_fn( + **checkpoint) if isinstance(checkpoint, dict) else checkpoint - if env.local_rank == 0: - self.log_writer = LogWriter( - logdir=self.checkpoint.rootdir, file_name=vdl_file_name) - self.checkpoint.record('vdl_file_name', - os.path.basename(self.log_writer.file_name)) - self.checkpoint.record('train_by_epoch', self.train_by_epoch) + if isinstance(scheduler, dict): + scheduler.setdefault('train_by_epoch', self.train_by_epoch) + scheduler.setdefault('iters_per_epoch', self.iters_per_epoch) + self.scheduler = default_scheduler_build_fn(**scheduler) + else: + self.scheduler = scheduler + + if self.checkpoint is None: + return + + if not self.checkpoint.empty: + if not resume: + raise RuntimeError( + 'The checkpoint {} is not emtpy! Set `resume=True` to continue training or use another dir as checkpoint' + .format(self.checkpoint.rootdir)) + + if self.checkpoint.meta.get( + 'train_by_epoch') != self.train_by_epoch: + raise RuntimeError( + 'Unable to resume training since the train_by_epoch is inconsistent with that saved in the checkpoint' + ) + + params_dict, opt_dict = self.checkpoint.get() + self.model.set_dict(params_dict) + self.optimizer.set_state_dict(opt_dict) + self.cur_iter = self.checkpoint.meta.get('iters') + self.cur_epoch = self.checkpoint.meta.get('epochs') + self.scheduler.step(self.cur_iter) + + logger.info( + 'Resume model from checkpoint {}, current iter set to {}'. + format(self.checkpoint.rootdir, self.cur_iter)) + vdl_file_name = self.checkpoint.meta['vdl_file_name'] + elif resume: + logger.warning( + "Attempt to restore parameters from an empty checkpoint") + + if env.local_rank == 0: + self.log_writer = LogWriter( + logdir=self.checkpoint.rootdir, file_name=vdl_file_name) + self.checkpoint.record( + 'vdl_file_name', + os.path.basename(self.log_writer.file_name)) + self.checkpoint.record('train_by_epoch', self.train_by_epoch) def train(self): """ diff --git a/tools/evaluate.py b/tools/evaluate.py index 692dfa1b..e5a124f3 100644 --- a/tools/evaluate.py +++ b/tools/evaluate.py @@ -80,7 +80,8 @@ def main(args): 'dataloader_fn': { 'batch_size': batch_size, 'num_workers': args.num_workers - } + }, + 'train_dataset': None }) if args.model is not None: