diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 9c529cbb92ca..e1f582aa19eb 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -38,6 +38,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -57,7 +58,9 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + find_nearest_bucket, free_memory, + parse_buckets_string, ) from diffusers.utils import ( check_min_version, @@ -65,6 +68,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -76,6 +80,9 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + def save_model_card( repo_id: str, @@ -398,6 +405,16 @@ def parse_args(input_args=None): " resolution" ), ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -715,17 +732,18 @@ def __init__( class_prompt, class_data_root=None, class_num=None, - size=1024, repeats=1, center_crop=False, + buckets=None, ): - self.size = size self.center_crop = center_crop self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt + self.buckets = buckets + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -790,32 +808,38 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) for image in self.instance_images: image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) image = train_resize(image) - if args.random_flip and random.random() < 0.5: - # flip - image = train_flip(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params(image, self.size) image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + image = train_flip(image) image = train_transforms(image) - self.pixel_values.append(image) + self.pixel_values.append((image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -834,8 +858,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -846,8 +870,9 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.pixel_values[index % self.num_instance_images] + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -888,6 +913,47 @@ def collate_fn(examples, with_prior_preservation=False): return batch +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -1142,8 +1208,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1438,6 +1503,12 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1445,15 +1516,14 @@ def load_model_hook(models, input_dir): class_prompt=args.class_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, - size=args.resolution, + buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, ) - + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, - batch_size=args.train_batch_size, - shuffle=True, + batch_sampler=batch_sampler, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) @@ -1892,7 +1962,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): free_memory() images = None - del pipeline + free_memory() # Save the lora layers accelerator.wait_for_everyone() @@ -1944,6 +2014,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): is_final_validation=True, torch_dtype=weight_dtype, ) + del pipeline + free_memory() if args.push_to_hub: save_model_card( @@ -1963,7 +2035,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) images = None - del pipeline accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index a1337e8dbaa4..32139f5b0362 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -39,6 +39,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -57,7 +58,9 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + find_nearest_bucket, free_memory, + parse_buckets_string, ) from diffusers.utils import ( check_min_version, @@ -452,6 +455,16 @@ def parse_args(input_args=None): " resolution" ), ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -741,17 +754,18 @@ def __init__( class_prompt, class_data_root=None, class_num=None, - size=1024, repeats=1, center_crop=False, + buckets=None, ): - self.size = size self.center_crop = center_crop self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt + self.buckets = buckets + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -816,32 +830,38 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) for image in self.instance_images: image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") + + width, height = image.size + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) image = train_resize(image) - if args.random_flip and random.random() < 0.5: - # flip - image = train_flip(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params(image, self.size) image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + image = train_flip(image) image = train_transforms(image) - self.pixel_values.append(image) + self.pixel_values.append((image, bucket_idx)) + # self.assignments.append((image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -860,8 +880,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -872,8 +892,9 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.pixel_values[index % self.num_instance_images] + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -914,6 +935,47 @@ def collate_fn(examples, with_prior_preservation=False): return batch +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -1321,6 +1383,12 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1328,15 +1396,14 @@ def load_model_hook(models, input_dir): class_prompt=args.class_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, - size=args.resolution, + buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, ) - + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, - batch_size=args.train_batch_size, - shuffle=True, + batch_sampler=batch_sampler, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index bc30411d8726..755ff818830c 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -3,6 +3,8 @@ import gc import math import random +import re +import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -316,6 +318,46 @@ def free_memory(): torch.xpu.empty_cache() +def parse_buckets_string(buckets_str): + """Parses a string defining buckets into a list of (height, width) tuples.""" + if not buckets_str: + raise ValueError("Bucket string cannot be empty.") + + bucket_pairs = buckets_str.strip().split(";") + parsed_buckets = [] + for pair_str in bucket_pairs: + match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) + if not match: + raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") + try: + height = int(match.group(1)) + width = int(match.group(2)) + if height <= 0 or width <= 0: + raise ValueError("Bucket dimensions must be positive integers.") + if height % 8 != 0 or width % 8 != 0: + warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") + parsed_buckets.append((height, width)) + except ValueError as e: + raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e + + if not parsed_buckets: + raise ValueError("No valid buckets found in the provided string.") + + return parsed_buckets + + +def find_nearest_bucket(h, w, bucket_options): + """Finds the closes bucket to the given height and width.""" + min_metric = float("inf") + best_bucket_idx = None + for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: + min_metric = metric + best_bucket_idx = bucket_idx + return best_bucket_idx + + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """