diff --git a/Pretrain.py b/Pretrain.py index 8dd794d..a3423b2 100644 --- a/Pretrain.py +++ b/Pretrain.py @@ -49,9 +49,9 @@ def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, if args.distributed: data_loader.sampler.set_epoch(epoch) - +import mixgen as mg for i, (image, text) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): - + image, text = mg.mixgen(image, text, num=16) optimizer.zero_grad() image = image.to(device,non_blocking=True) @@ -200,4 +200,4 @@ def main(args, config): yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) - main(args, config) \ No newline at end of file + main(args, config)