From 2da7c204bd77e483546e691a30bf11f710d70e27 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 25 Apr 2025 16:45:33 +0300 Subject: [PATCH 01/12] initial commit --- .../train_dreambooth_lora_hidream.py | 107 ++++++++++++++---- 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 26a920906b3e..1d8b18bb32bc 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -24,6 +24,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from torch.utils.data.sampler import Sampler, BatchSampler import numpy as np import torch @@ -67,7 +68,6 @@ from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -735,14 +735,19 @@ def __init__( size=1024, repeats=1, center_crop=False, + buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)], + bucket_aspects=[1.0, 9 / 16, 16 / 9, 3 / 4, 4 / 3, 3 / 2, 2 / 3], ): - self.size = size + # self.size = size self.center_crop = center_crop self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt + self.buckets = np.array(buckets) + self.bucket_aspects = np.array(bucket_aspects) + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -807,32 +812,46 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + # self.aspect_assignments = [] for image in self.instance_images: image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") - image = train_resize(image) + if args.random_flip and random.random() < 0.5: # flip image = train_flip(image) + + width, height = image.size + print("width, height", width, height) + aspect_ratio = width / float(height) + # Find the closest bucket + bucket_idx = np.argmin(np.abs(self.bucket_aspects - aspect_ratio)) + target_height, target_width = self.buckets[bucket_idx] + size = (target_height, target_width) + print("WTF", size, type(size)) + # based on the bucket assignment, define the transformations + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + image = train_resize(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) + print("cropped", image.size) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params(image, size) image = crop(image, y1, x1, h, w) + print("cropped", image.size) + image = train_transforms(image) - self.pixel_values.append(image) + self.pixel_values.append((image, bucket_idx)) + # self.assignments.append((image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -863,8 +882,9 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.pixel_values[index % self.num_instance_images] + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -905,6 +925,49 @@ def collate_fn(examples, with_prior_preservation=False): return batch +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i:i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -1297,11 +1360,15 @@ def load_model_hook(models, input_dir): repeats=args.repeats, center_crop=args.center_crop, ) - - train_dataloader = torch.utils.data.DataLoader( + batch_sampler = BucketBatchSampler( train_dataset, batch_size=args.train_batch_size, - shuffle=True, + drop_last=False) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + # batch_size=args.train_batch_size, + # shuffle=True, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) From 6517a705f6b33faa79fe98432772d0ea5694e1a2 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 25 Apr 2025 19:51:17 +0300 Subject: [PATCH 02/12] initial commit --- .../train_dreambooth_lora_hidream.py | 71 +++++++++++-------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 1d8b18bb32bc..94d5ff504b8b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -81,12 +81,12 @@ def save_model_card( - repo_id: str, - images=None, - base_model: str = None, - instance_prompt=None, - validation_prompt=None, - repo_folder=None, + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, ): widget_dict = [] if images is not None: @@ -188,13 +188,13 @@ def load_text_encoders(class_one, class_two, class_three): def log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - torch_dtype, - is_final_validation=False, + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, ): args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 logger.info( @@ -244,7 +244,7 @@ def log_validation( def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision @@ -331,8 +331,8 @@ def parse_args(input_args=None): type=str, default="image", help="The column of the dataset containing the target image. By " - "default, the standard Image Dataset maps out 'file_name' " - "to 'image'.", + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", ) parser.add_argument( "--caption_column", @@ -588,7 +588,7 @@ def parse_args(input_args=None): type=float, default=None, help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " - "uses the value of square root of beta2. Ignored if optimizer is adamW", + "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") @@ -619,7 +619,7 @@ def parse_args(input_args=None): type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " - "Ignored if optimizer is adamW", + "Ignored if optimizer is adamW", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -726,17 +726,26 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, - buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)], - bucket_aspects=[1.0, 9 / 16, 16 / 9, 3 / 4, 4 / 3, 3 / 2, 2 / 3], + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + # buckets=[(1024,1024),(768,1360),(1360, 768),(880, 1168),(1168, 880), (1248, 832), (832, 1248)], + buckets=[(1024, 1024)], + # "1024 × 1024 (Square)", + # "768 × 1360 (Portrait)", + # "1360 × 768 (Landscape)", + # "880 × 1168 (Portrait)", + # "1168 × 880 (Landscape)", + # "1248 × 832 (Landscape)", + # "832 × 1248 (Portrait)" + # bucket_aspects = [1.0, 9/16, 16/9, 3/4, 4/3, 3/2, 2/3], + bucket_aspects=[1.0], ): # self.size = size self.center_crop = center_crop @@ -1062,7 +1071,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -1276,7 +1285,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. From 873fe895797aa517af49e0f2a73a7249103f1b31 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 27 Apr 2025 14:33:31 +0300 Subject: [PATCH 03/12] fix best bucket --- .../train_dreambooth_lora_hidream.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 94d5ff504b8b..0802992b5bb9 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -725,6 +725,18 @@ class DreamBoothDataset(Dataset): It pre-processes the images. """ + def find_nearest_bucket(self, h, w, bucket_options): + min_metric = float('inf') + best_bucket = None + best_bucket_idx = None + for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: + min_metric = metric + best_bucket = (bucket_h, bucket_w) + best_bucket_idx = bucket_idx + return best_bucket_idx + def __init__( self, instance_data_root, @@ -745,7 +757,6 @@ def __init__( # "1248 × 832 (Landscape)", # "832 × 1248 (Portrait)" # bucket_aspects = [1.0, 9/16, 16/9, 3/4, 4/3, 3/2, 2/3], - bucket_aspects=[1.0], ): # self.size = size self.center_crop = center_crop @@ -754,8 +765,7 @@ def __init__( self.custom_instance_prompts = None self.class_prompt = class_prompt - self.buckets = np.array(buckets) - self.bucket_aspects = np.array(bucket_aspects) + self.buckets = buckets # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset @@ -835,7 +845,7 @@ def __init__( print("width, height", width, height) aspect_ratio = width / float(height) # Find the closest bucket - bucket_idx = np.argmin(np.abs(self.bucket_aspects - aspect_ratio)) + bucket_idx = find_nearest_bucket(height, width, self.buckets) target_height, target_width = self.buckets[bucket_idx] size = (target_height, target_width) print("WTF", size, type(size)) From fa4765cb42986276246835269cd6f3aebb0acffd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Sun, 27 Apr 2025 16:15:33 +0300 Subject: [PATCH 04/12] fix best bucket --- examples/dreambooth/train_dreambooth_lora_hidream.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 0802992b5bb9..75bec1b8f3ef 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -725,7 +725,8 @@ class DreamBoothDataset(Dataset): It pre-processes the images. """ - def find_nearest_bucket(self, h, w, bucket_options): + @staticmethod + def find_nearest_bucket(h, w, bucket_options): min_metric = float('inf') best_bucket = None best_bucket_idx = None From 9ad4b611ebe3df6bf03573ef1e479faffe84d5c7 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 28 Apr 2025 12:29:32 +0300 Subject: [PATCH 05/12] fix best bucket --- examples/dreambooth/train_dreambooth_lora_hidream.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 75bec1b8f3ef..70d2cafe4493 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -748,16 +748,8 @@ def __init__( size=1024, repeats=1, center_crop=False, - # buckets=[(1024,1024),(768,1360),(1360, 768),(880, 1168),(1168, 880), (1248, 832), (832, 1248)], - buckets=[(1024, 1024)], - # "1024 × 1024 (Square)", - # "768 × 1360 (Portrait)", - # "1360 × 768 (Landscape)", - # "880 × 1168 (Portrait)", - # "1168 × 880 (Landscape)", - # "1248 × 832 (Landscape)", - # "832 × 1248 (Portrait)" - # bucket_aspects = [1.0, 9/16, 16/9, 3/4, 4/3, 3/2, 2/3], + buckets=[(1024,1024),(768,1360),(1360, 768),(880, 1168),(1168, 880), (1248, 832), (832, 1248)], + # buckets=[(1024, 1024)], ): # self.size = size self.center_crop = center_crop From 4130560f0f57e2ee0c93de66b6592a2d41693f10 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 28 Apr 2025 15:10:08 +0300 Subject: [PATCH 06/12] make it configurable add to flux lora too --- .../dreambooth/train_dreambooth_lora_flux.py | 173 +++++++++++++++--- .../train_dreambooth_lora_hidream.py | 60 +++++- 2 files changed, 199 insertions(+), 34 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 193c5affe600..d457d82a5c81 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -24,6 +24,7 @@ import warnings from contextlib import nullcontext from pathlib import Path +from torch.utils.data.sampler import Sampler, BatchSampler import numpy as np import torch @@ -65,6 +66,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -76,6 +78,38 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + torch.npu.config.allow_internal_format = False + + +def parse_buckets_string(buckets_str): + """ Parses a string defining buckets into a list of (height, width) tuples. """ + if not buckets_str: + raise ValueError("Bucket string cannot be empty.") + + bucket_pairs = buckets_str.strip().split(';') + parsed_buckets = [] + for pair_str in bucket_pairs: + match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) + if not match: + raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") + try: + height = int(match.group(1)) + width = int(match.group(2)) + if height <= 0 or width <= 0: + raise ValueError("Bucket dimensions must be positive integers.") + if height % 8 != 0 or width % 8 != 0: + logger.warning(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") + parsed_buckets.append((height, width)) + except ValueError as e: + raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e + + if not parsed_buckets: + raise ValueError("No valid buckets found in the provided string.") + + logger.info(f"Using parsed aspect ratio buckets: {parsed_buckets}") + return parsed_buckets + def save_model_card( repo_id: str, @@ -390,6 +424,16 @@ def parse_args(input_args=None): " resolution" ), ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -700,6 +744,19 @@ class DreamBoothDataset(Dataset): It pre-processes the images. """ + @staticmethod + def find_nearest_bucket(h, w, bucket_options): + min_metric = float('inf') + best_bucket = None + best_bucket_idx = None + for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: + min_metric = metric + best_bucket = (bucket_h, bucket_w) + best_bucket_idx = bucket_idx + return best_bucket_idx + def __init__( self, instance_data_root, @@ -710,14 +767,18 @@ def __init__( size=1024, repeats=1, center_crop=False, + buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)], + # buckets=[(1024, 1024)], ): - self.size = size + # self.size = (size, size) self.center_crop = center_crop self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt + self.buckets = buckets + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: @@ -782,32 +843,40 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) for image in self.instance_images: image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") - image = train_resize(image) + if args.random_flip and random.random() < 0.5: # flip image = train_flip(image) + + width, height = image.size + print("width, height", width, height) + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + image = train_resize(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params(image, self.size) image = crop(image, y1, x1, h, w) image = train_transforms(image) - self.pixel_values.append(image) + self.pixel_values.append((image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -826,8 +895,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -838,8 +907,9 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.pixel_values[index % self.num_instance_images] + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -880,6 +950,49 @@ def collate_fn(examples, with_prior_preservation=False): return batch +class BucketBatchSampler(BatchSampler): + def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i:i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + def __iter__(self): + # Shuffle the order of the batches each epoch + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -1134,8 +1247,7 @@ def main(args): image.save(image_filename) del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() + free_memory() # Handle the repository creation if accelerator.is_main_process: @@ -1425,6 +1537,11 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) + if args.aspect_ratio_buckets: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1433,14 +1550,19 @@ def load_model_hook(models, input_dir): class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, size=args.resolution, + buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, ) - - train_dataloader = torch.utils.data.DataLoader( + batch_sampler = BucketBatchSampler( train_dataset, batch_size=args.train_batch_size, - shuffle=True, + drop_last=False) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + # batch_size=args.train_batch_size, + # shuffle=True, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) @@ -1879,7 +2001,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): free_memory() images = None - del pipeline + free_memory() # Save the lora layers accelerator.wait_for_everyone() @@ -1927,6 +2049,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): is_final_validation=True, torch_dtype=weight_dtype, ) + del pipeline + free_memory() if args.push_to_hub: save_model_card( @@ -1946,7 +2070,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) images = None - del pipeline accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 70d2cafe4493..4af1c55c8d8c 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -79,6 +79,33 @@ if is_torch_npu_available(): torch.npu.config.allow_internal_format = False +def parse_buckets_string(buckets_str): + """ Parses a string defining buckets into a list of (height, width) tuples. """ + if not buckets_str: + raise ValueError("Bucket string cannot be empty.") + + bucket_pairs = buckets_str.strip().split(';') + parsed_buckets = [] + for pair_str in bucket_pairs: + match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) + if not match: + raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") + try: + height = int(match.group(1)) + width = int(match.group(2)) + if height <= 0 or width <= 0: + raise ValueError("Bucket dimensions must be positive integers.") + if height % 8 != 0 or width % 8 != 0: + logger.warning(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") + parsed_buckets.append((height, width)) + except ValueError as e: + raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e + + if not parsed_buckets: + raise ValueError("No valid buckets found in the provided string.") + + logger.info(f"Using parsed aspect ratio buckets: {parsed_buckets}") + return parsed_buckets def save_model_card( repo_id: str, @@ -443,6 +470,16 @@ def parse_args(input_args=None): " resolution" ), ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -751,7 +788,7 @@ def __init__( buckets=[(1024,1024),(768,1360),(1360, 768),(880, 1168),(1168, 880), (1248, 832), (832, 1248)], # buckets=[(1024, 1024)], ): - # self.size = size + # self.size = (size, size) self.center_crop = center_crop self.instance_prompt = instance_prompt @@ -836,15 +873,14 @@ def __init__( width, height = image.size print("width, height", width, height) - aspect_ratio = width / float(height) # Find the closest bucket bucket_idx = find_nearest_bucket(height, width, self.buckets) target_height, target_width = self.buckets[bucket_idx] - size = (target_height, target_width) - print("WTF", size, type(size)) + self.size = (target_height, target_width) + # based on the bucket assignment, define the transformations - train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -857,7 +893,7 @@ def __init__( image = train_crop(image) print("cropped", image.size) else: - y1, x1, h, w = train_crop.get_params(image, size) + y1, x1, h, w = train_crop.get_params(image, self.size) image = crop(image, y1, x1, h, w) print("cropped", image.size) @@ -882,8 +918,8 @@ def __init__( self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -1361,6 +1397,11 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) + if args.aspect_ratio_buckets: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1369,6 +1410,7 @@ def load_model_hook(models, input_dir): class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, size=args.resolution, + buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, ) From b817ca1a0b99a504eedbb09096186707129dfdfd Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 2 May 2025 13:18:23 +0300 Subject: [PATCH 07/12] move `find_nearest_bucket`, `parse_buckets_string` to training_utils.py --- .../dreambooth/train_dreambooth_lora_flux.py | 46 ++----------------- .../train_dreambooth_lora_hidream.py | 45 ++---------------- src/diffusers/training_utils.py | 37 +++++++++++++++ 3 files changed, 43 insertions(+), 85 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index c3a177b1c4b2..f29ec8347b82 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -59,6 +59,8 @@ compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory, + parse_buckets_string, + find_nearest_bucket ) from diffusers.utils import ( check_min_version, @@ -82,35 +84,6 @@ torch.npu.config.allow_internal_format = False -def parse_buckets_string(buckets_str): - """ Parses a string defining buckets into a list of (height, width) tuples. """ - if not buckets_str: - raise ValueError("Bucket string cannot be empty.") - - bucket_pairs = buckets_str.strip().split(';') - parsed_buckets = [] - for pair_str in bucket_pairs: - match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) - if not match: - raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") - try: - height = int(match.group(1)) - width = int(match.group(2)) - if height <= 0 or width <= 0: - raise ValueError("Bucket dimensions must be positive integers.") - if height % 8 != 0 or width % 8 != 0: - logger.warning(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") - parsed_buckets.append((height, width)) - except ValueError as e: - raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e - - if not parsed_buckets: - raise ValueError("No valid buckets found in the provided string.") - - logger.info(f"Using parsed aspect ratio buckets: {parsed_buckets}") - return parsed_buckets - - def save_model_card( repo_id: str, images=None, @@ -744,19 +717,6 @@ class DreamBoothDataset(Dataset): It pre-processes the images. """ - @staticmethod - def find_nearest_bucket(h, w, bucket_options): - min_metric = float('inf') - best_bucket = None - best_bucket_idx = None - for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): - metric = abs(h * bucket_w - w * bucket_h) - if metric <= min_metric: - min_metric = metric - best_bucket = (bucket_h, bucket_w) - best_bucket_idx = bucket_idx - return best_bucket_idx - def __init__( self, instance_data_root, @@ -1541,7 +1501,7 @@ def load_model_hook(models, input_dir): buckets = parse_buckets_string(args.aspect_ratio_buckets) else: buckets = [(args.resolution, args.resolution)] - + logger.info(f"Using parsed aspect ratio buckets: {buckets}") # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index aba589410f53..ced2efc34f66 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -58,6 +58,8 @@ compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory, + parse_buckets_string, + find_nearest_bucket ) from diffusers.utils import ( check_min_version, @@ -79,34 +81,6 @@ if is_torch_npu_available(): torch.npu.config.allow_internal_format = False -def parse_buckets_string(buckets_str): - """ Parses a string defining buckets into a list of (height, width) tuples. """ - if not buckets_str: - raise ValueError("Bucket string cannot be empty.") - - bucket_pairs = buckets_str.strip().split(';') - parsed_buckets = [] - for pair_str in bucket_pairs: - match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) - if not match: - raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") - try: - height = int(match.group(1)) - width = int(match.group(2)) - if height <= 0 or width <= 0: - raise ValueError("Bucket dimensions must be positive integers.") - if height % 8 != 0 or width % 8 != 0: - logger.warning(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") - parsed_buckets.append((height, width)) - except ValueError as e: - raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e - - if not parsed_buckets: - raise ValueError("No valid buckets found in the provided string.") - - logger.info(f"Using parsed aspect ratio buckets: {parsed_buckets}") - return parsed_buckets - def save_model_card( repo_id: str, images=None, @@ -761,19 +735,6 @@ class DreamBoothDataset(Dataset): It pre-processes the images. """ - @staticmethod - def find_nearest_bucket(h, w, bucket_options): - min_metric = float('inf') - best_bucket = None - best_bucket_idx = None - for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): - metric = abs(h * bucket_w - w * bucket_h) - if metric <= min_metric: - min_metric = metric - best_bucket = (bucket_h, bucket_w) - best_bucket_idx = bucket_idx - return best_bucket_idx - def __init__( self, instance_data_root, @@ -1400,7 +1361,7 @@ def load_model_hook(models, input_dir): buckets = parse_buckets_string(args.aspect_ratio_buckets) else: buckets = [(args.resolution, args.resolution)] - + logger.info(f"Using parsed aspect ratio buckets: {buckets}") # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index b98c4e33f862..c06844842108 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -307,6 +307,43 @@ def free_memory(): elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() +def parse_buckets_string(buckets_str): + """ Parses a string defining buckets into a list of (height, width) tuples. """ + if not buckets_str: + raise ValueError("Bucket string cannot be empty.") + + bucket_pairs = buckets_str.strip().split(';') + parsed_buckets = [] + for pair_str in bucket_pairs: + match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) + if not match: + raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.") + try: + height = int(match.group(1)) + width = int(match.group(2)) + if height <= 0 or width <= 0: + raise ValueError("Bucket dimensions must be positive integers.") + if height % 8 != 0 or width % 8 != 0: + warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.") + parsed_buckets.append((height, width)) + except ValueError as e: + raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e + + if not parsed_buckets: + raise ValueError("No valid buckets found in the provided string.") + + return parsed_buckets + +def find_nearest_bucket(h, w, bucket_options): + """ Finds the closes bucket to the given height and width. """ + min_metric = float('inf') + best_bucket_idx = None + for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): + metric = abs(h * bucket_w - w * bucket_h) + if metric <= min_metric: + min_metric = metric + best_bucket_idx = bucket_idx + return best_bucket_idx # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: From b0d77eec66786e42ad1f23d5c69fadf02044d54d Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 2 May 2025 13:21:11 +0300 Subject: [PATCH 08/12] move `find_nearest_bucket`, `parse_buckets_string` to training_utils.py --- src/diffusers/training_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index c06844842108..081c8d78b50c 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -3,6 +3,7 @@ import gc import math import random +import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np From f5636c6883ac750b09418d36c475234f3cd25436 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 2 May 2025 16:30:03 +0300 Subject: [PATCH 09/12] fix import --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 081c8d78b50c..36471729bf9a 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,7 +5,7 @@ import random import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - +import re import numpy as np import torch From 0df0ea1e10cc6da5bf3decfbd7f573f50df0ab36 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Fri, 2 May 2025 16:35:33 +0300 Subject: [PATCH 10/12] fix flip --- examples/dreambooth/train_dreambooth_lora_flux.py | 8 +++----- examples/dreambooth/train_dreambooth_lora_hidream.py | 10 ++-------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index f29ec8347b82..50b6523a89a5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -808,12 +808,8 @@ def __init__( if not image.mode == "RGB": image = image.convert("RGB") - if args.random_flip and random.random() < 0.5: - # flip - image = train_flip(image) - width, height = image.size - print("width, height", width, height) + # Find the closest bucket bucket_idx = find_nearest_bucket(height, width, self.buckets) target_height, target_width = self.buckets[bucket_idx] @@ -835,6 +831,8 @@ def __init__( else: y1, x1, h, w = train_crop.get_params(image, self.size) image = crop(image, y1, x1, h, w) + if args.random_flip and random.random() < 0.5: + image = train_flip(image) image = train_transforms(image) self.pixel_values.append((image, bucket_idx)) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index ced2efc34f66..82293d72b5ac 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -827,12 +827,7 @@ def __init__( if not image.mode == "RGB": image = image.convert("RGB") - if args.random_flip and random.random() < 0.5: - # flip - image = train_flip(image) - width, height = image.size - print("width, height", width, height) # Find the closest bucket bucket_idx = find_nearest_bucket(height, width, self.buckets) target_height, target_width = self.buckets[bucket_idx] @@ -851,12 +846,11 @@ def __init__( image = train_resize(image) if args.center_crop: image = train_crop(image) - print("cropped", image.size) else: y1, x1, h, w = train_crop.get_params(image, self.size) image = crop(image, y1, x1, h, w) - print("cropped", image.size) - + if args.random_flip and random.random() < 0.5: + image = train_flip(image) image = train_transforms(image) self.pixel_values.append((image, bucket_idx)) # self.assignments.append((image, bucket_idx)) From 4646c6049224824be205c370f208d16de741d2dc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 2 May 2025 13:39:02 +0000 Subject: [PATCH 11/12] Apply style fixes --- .../dreambooth/train_dreambooth_lora_flux.py | 17 ++-- .../train_dreambooth_lora_hidream.py | 81 +++++++++---------- src/diffusers/training_utils.py | 14 ++-- 3 files changed, 54 insertions(+), 58 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 50b6523a89a5..5a8451be3cc7 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -24,7 +24,6 @@ import warnings from contextlib import nullcontext from pathlib import Path -from torch.utils.data.sampler import Sampler, BatchSampler import numpy as np import torch @@ -40,6 +39,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -58,9 +58,9 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + find_nearest_bucket, free_memory, parse_buckets_string, - find_nearest_bucket ) from diffusers.utils import ( check_min_version, @@ -911,11 +911,9 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): if not isinstance(batch_size, int) or batch_size <= 0: - raise ValueError("batch_size should be a positive integer value, " - "but got batch_size={}".format(batch_size)) + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): - raise ValueError("drop_last should be a boolean value, but got " - "drop_last={}".format(drop_last)) + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) self.dataset = dataset self.batch_size = batch_size @@ -935,7 +933,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool random.shuffle(indices_in_bucket) # Create batches for i in range(0, len(indices_in_bucket), self.batch_size): - batch = indices_in_bucket[i:i + self.batch_size] + batch = indices_in_bucket[i : i + self.batch_size] if len(batch) < self.batch_size and self.drop_last: continue # Skip partial batch if drop_last is True self.batches.append(batch) @@ -1512,10 +1510,7 @@ def load_model_hook(models, input_dir): repeats=args.repeats, center_crop=args.center_crop, ) - batch_sampler = BucketBatchSampler( - train_dataset, - batch_size=args.train_batch_size, - drop_last=False) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 82293d72b5ac..58ab8c706e0b 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -24,7 +24,6 @@ import warnings from contextlib import nullcontext from pathlib import Path -from torch.utils.data.sampler import Sampler, BatchSampler import numpy as np import torch @@ -40,6 +39,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler from torchvision import transforms from torchvision.transforms.functional import crop from tqdm.auto import tqdm @@ -57,9 +57,9 @@ cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, + find_nearest_bucket, free_memory, parse_buckets_string, - find_nearest_bucket ) from diffusers.utils import ( check_min_version, @@ -70,6 +70,7 @@ from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module + if is_wandb_available(): import wandb @@ -81,13 +82,14 @@ if is_torch_npu_available(): torch.npu.config.allow_internal_format = False + def save_model_card( - repo_id: str, - images=None, - base_model: str = None, - instance_prompt=None, - validation_prompt=None, - repo_folder=None, + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, ): widget_dict = [] if images is not None: @@ -189,13 +191,13 @@ def load_text_encoders(class_one, class_two, class_three): def log_validation( - pipeline, - args, - accelerator, - pipeline_args, - epoch, - torch_dtype, - is_final_validation=False, + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, ): args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 logger.info( @@ -244,7 +246,7 @@ def log_validation( def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision @@ -331,8 +333,8 @@ def parse_args(input_args=None): type=str, default="image", help="The column of the dataset containing the target image. By " - "default, the standard Image Dataset maps out 'file_name' " - "to 'image'.", + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", ) parser.add_argument( "--caption_column", @@ -598,7 +600,7 @@ def parse_args(input_args=None): type=float, default=None, help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " - "uses the value of square root of beta2. Ignored if optimizer is adamW", + "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") @@ -629,7 +631,7 @@ def parse_args(input_args=None): type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " - "Ignored if optimizer is adamW", + "Ignored if optimizer is adamW", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -736,17 +738,17 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, - buckets=[(1024,1024),(768,1360),(1360, 768),(880, 1168),(1168, 880), (1248, 832), (832, 1248)], - # buckets=[(1024, 1024)], + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)], + # buckets=[(1024, 1024)], ): # self.size = (size, size) self.center_crop = center_crop @@ -930,11 +932,9 @@ def collate_fn(examples, with_prior_preservation=False): class BucketBatchSampler(BatchSampler): def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False): if not isinstance(batch_size, int) or batch_size <= 0: - raise ValueError("batch_size should be a positive integer value, " - "but got batch_size={}".format(batch_size)) + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) if not isinstance(drop_last, bool): - raise ValueError("drop_last should be a boolean value, but got " - "drop_last={}".format(drop_last)) + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) self.dataset = dataset self.batch_size = batch_size @@ -954,7 +954,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool random.shuffle(indices_in_bucket) # Create batches for i in range(0, len(indices_in_bucket), self.batch_size): - batch = indices_in_bucket[i:i + self.batch_size] + batch = indices_in_bucket[i : i + self.batch_size] if len(batch) < self.batch_size and self.drop_last: continue # Skip partial batch if drop_last is True self.batches.append(batch) @@ -1064,7 +1064,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -1278,7 +1278,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. @@ -1368,10 +1368,7 @@ def load_model_hook(models, input_dir): repeats=args.repeats, center_crop=args.center_crop, ) - batch_sampler = BucketBatchSampler( - train_dataset, - batch_size=args.train_batch_size, - drop_last=False) + batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 36471729bf9a..7aa834f97bb2 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -3,9 +3,10 @@ import gc import math import random +import re import warnings from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -import re + import numpy as np import torch @@ -308,12 +309,13 @@ def free_memory(): elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.empty_cache() + def parse_buckets_string(buckets_str): - """ Parses a string defining buckets into a list of (height, width) tuples. """ + """Parses a string defining buckets into a list of (height, width) tuples.""" if not buckets_str: raise ValueError("Bucket string cannot be empty.") - bucket_pairs = buckets_str.strip().split(';') + bucket_pairs = buckets_str.strip().split(";") parsed_buckets = [] for pair_str in bucket_pairs: match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str) @@ -335,9 +337,10 @@ def parse_buckets_string(buckets_str): return parsed_buckets + def find_nearest_bucket(h, w, bucket_options): - """ Finds the closes bucket to the given height and width. """ - min_metric = float('inf') + """Finds the closes bucket to the given height and width.""" + min_metric = float("inf") best_bucket_idx = None for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options): metric = abs(h * bucket_w - w * bucket_h) @@ -346,6 +349,7 @@ def find_nearest_bucket(h, w, bucket_options): best_bucket_idx = bucket_idx return best_bucket_idx + # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ From d442e5a01503abffdefaeb500c9dec1da3701f02 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Mon, 5 May 2025 11:42:52 +0300 Subject: [PATCH 12/12] cleanup --- examples/dreambooth/train_dreambooth_lora_flux.py | 13 ++++--------- .../dreambooth/train_dreambooth_lora_hidream.py | 14 ++++---------- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 5a8451be3cc7..9af3e3fa6b7a 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -404,7 +404,7 @@ def parse_args(input_args=None): help=( "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" - "Images will be resized and cropped to fit the nearest bucket." + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." ), ) parser.add_argument( @@ -724,13 +724,10 @@ def __init__( class_prompt, class_data_root=None, class_num=None, - size=1024, repeats=1, center_crop=False, - buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)], - # buckets=[(1024, 1024)], + buckets=None, ): - # self.size = (size, size) self.center_crop = center_crop self.instance_prompt = instance_prompt @@ -1493,11 +1490,12 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) - if args.aspect_ratio_buckets: + if args.aspect_ratio_buckets is not None: buckets = parse_buckets_string(args.aspect_ratio_buckets) else: buckets = [(args.resolution, args.resolution)] logger.info(f"Using parsed aspect ratio buckets: {buckets}") + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1505,7 +1503,6 @@ def load_model_hook(models, input_dir): class_prompt=args.class_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, - size=args.resolution, buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, @@ -1514,8 +1511,6 @@ def load_model_hook(models, input_dir): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, - # batch_size=args.train_batch_size, - # shuffle=True, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, ) diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index 58ab8c706e0b..ed0725581d08 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -452,7 +452,7 @@ def parse_args(input_args=None): help=( "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" - "Images will be resized and cropped to fit the nearest bucket." + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." ), ) parser.add_argument( @@ -744,13 +744,10 @@ def __init__( class_prompt, class_data_root=None, class_num=None, - size=1024, repeats=1, center_crop=False, - buckets=[(1024, 1024), (768, 1360), (1360, 768), (880, 1168), (1168, 880), (1248, 832), (832, 1248)], - # buckets=[(1024, 1024)], + buckets=None, ): - # self.size = (size, size) self.center_crop = center_crop self.instance_prompt = instance_prompt @@ -823,7 +820,6 @@ def __init__( self.instance_images.extend(itertools.repeat(img, repeats)) self.pixel_values = [] - # self.aspect_assignments = [] for image in self.instance_images: image = exif_transpose(image) if not image.mode == "RGB": @@ -1351,11 +1347,12 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) - if args.aspect_ratio_buckets: + if args.aspect_ratio_buckets is not None: buckets = parse_buckets_string(args.aspect_ratio_buckets) else: buckets = [(args.resolution, args.resolution)] logger.info(f"Using parsed aspect ratio buckets: {buckets}") + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1363,7 +1360,6 @@ def load_model_hook(models, input_dir): class_prompt=args.class_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, - size=args.resolution, buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, @@ -1372,8 +1368,6 @@ def load_model_hook(models, input_dir): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_sampler=batch_sampler, - # batch_size=args.train_batch_size, - # shuffle=True, collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), num_workers=args.dataloader_num_workers, )