Skip to content

[LoRA training] add aspect ratio bucketing #11438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2da7c20
initial commit
linoytsaban Apr 25, 2025
6517a70
initial commit
linoytsaban Apr 25, 2025
873fe89
fix best bucket
linoytsaban Apr 27, 2025
fa4765c
fix best bucket
linoytsaban Apr 27, 2025
9ad4b61
fix best bucket
linoytsaban Apr 28, 2025
5211ffa
Merge branch 'huggingface:main' into aspect_ratio_bucketing
linoytsaban Apr 28, 2025
4130560
make it configurable
linoytsaban Apr 28, 2025
bd7a8b8
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Apr 28, 2025
50782b7
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Apr 30, 2025
314cbdb
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Apr 30, 2025
1571961
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 2, 2025
b817ca1
move `find_nearest_bucket`, `parse_buckets_string` to training_utils.py
linoytsaban May 2, 2025
b0d77ee
move `find_nearest_bucket`, `parse_buckets_string` to training_utils.py
linoytsaban May 2, 2025
f5636c6
fix import
linoytsaban May 2, 2025
0df0ea1
fix flip
linoytsaban May 2, 2025
4646c60
Apply style fixes
github-actions[bot] May 2, 2025
ad28907
Merge branch 'huggingface:main' into aspect_ratio_bucketing
linoytsaban May 5, 2025
d442e5a
cleanup
linoytsaban May 5, 2025
b6d180b
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 5, 2025
1f235b5
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 5, 2025
bcd6a60
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 8, 2025
fd8bccf
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 22, 2025
eb59238
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 27, 2025
84d5b20
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban May 29, 2025
4ec0872
Merge branch 'main' into aspect_ratio_bucketing
linoytsaban Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 100 additions & 29 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from PIL import Image
from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
from torchvision.transforms.functional import crop
from tqdm.auto import tqdm
Expand All @@ -57,14 +58,17 @@
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
parse_buckets_string,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module


Expand All @@ -76,6 +80,9 @@

logger = get_logger(__name__)

if is_torch_npu_available():
torch.npu.config.allow_internal_format = False


def save_model_card(
repo_id: str,
Expand Down Expand Up @@ -398,6 +405,16 @@ def parse_args(input_args=None):
" resolution"
),
)
parser.add_argument(
"--aspect_ratio_buckets",
type=str,
default=None,
help=(
"Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
"e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
"Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
),
)
parser.add_argument(
"--center_crop",
default=False,
Expand Down Expand Up @@ -715,17 +732,18 @@ def __init__(
class_prompt,
class_data_root=None,
class_num=None,
size=1024,
repeats=1,
center_crop=False,
buckets=None,
):
self.size = size
self.center_crop = center_crop

self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.class_prompt = class_prompt

self.buckets = buckets

# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
Expand Down Expand Up @@ -790,32 +808,38 @@ def __init__(
self.instance_images.extend(itertools.repeat(img, repeats))

self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
for image in self.instance_images:
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")

width, height = image.size

# Find the closest bucket
bucket_idx = find_nearest_bucket(height, width, self.buckets)
target_height, target_width = self.buckets[bucket_idx]
self.size = (target_height, target_width)

# based on the bucket assignment, define the transformations
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
image = train_resize(image)
if args.random_flip and random.random() < 0.5:
# flip
image = train_flip(image)
if args.center_crop:
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
y1, x1, h, w = train_crop.get_params(image, self.size)
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
image = train_flip(image)
image = train_transforms(image)
self.pixel_values.append(image)
self.pixel_values.append((image, bucket_idx))

self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
Expand All @@ -834,8 +858,8 @@ def __init__(

self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
Expand All @@ -846,8 +870,9 @@ def __len__(self):

def __getitem__(self, index):
example = {}
instance_image = self.pixel_values[index % self.num_instance_images]
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx

if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
Expand Down Expand Up @@ -888,6 +913,47 @@ def collate_fn(examples, with_prior_preservation=False):
return batch


class BucketBatchSampler(BatchSampler):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))

self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
self.bucket_indices[bucket_idx].append(idx)

self.sampler_len = 0
self.batches = []

# Pre-generate batches for each bucket
for indices_in_bucket in self.bucket_indices:
# Shuffle indices within the bucket
random.shuffle(indices_in_bucket)
# Create batches
for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue # Skip partial batch if drop_last is True
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches

def __iter__(self):
# Shuffle the order of the batches each epoch
random.shuffle(self.batches)
for batch in self.batches:
yield batch

def __len__(self):
return self.sampler_len


class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."

Expand Down Expand Up @@ -1142,8 +1208,7 @@ def main(args):
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1438,22 +1503,27 @@ def load_model_hook(models, input_dir):
safeguard_warmup=args.prodigy_safeguard_warmup,
)

if args.aspect_ratio_buckets is not None:
buckets = parse_buckets_string(args.aspect_ratio_buckets)
else:
buckets = [(args.resolution, args.resolution)]
logger.info(f"Using parsed aspect ratio buckets: {buckets}")

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
)

batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why is the drop_last not True here? in this other PR this is set to True so it doesn't error when the batch size is > 1 and the last batch does not have enough images.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank for pointing it out! replied in #11921

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
batch_sampler=batch_sampler,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
Expand Down Expand Up @@ -1892,7 +1962,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
free_memory()

images = None
del pipeline
free_memory()

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down Expand Up @@ -1944,6 +2014,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
is_final_validation=True,
torch_dtype=weight_dtype,
)
del pipeline
free_memory()

if args.push_to_hub:
save_model_card(
Expand All @@ -1963,7 +2035,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
)

images = None
del pipeline

accelerator.end_training()

Expand Down
Loading
Loading