diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 5bd9b8684d42..984e0c50c32a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -1614,7 +1614,7 @@ def load_model_hook(models, input_dir): ) if args.cond_image_column is not None: logger.info("I2I fine-tuning enabled.") - batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler,