Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def train(
description="Regular expression to match specific layers to optimize. Optimizing fewer layers results in shorter training times, but can also result in a weaker LoRA. For example, To target layers 7, 12, 16, 20 which seems to create good likeness with faster training (as discovered by lux in the Ostris discord, inspired by The Last Ben), use `transformer.single_transformer_blocks.(7|12|16|20).proj_out`.",
default=None,
),
gradient_checkpointing: bool = Input(
description="Turn on gradient checkpointing; saves memory at the cost of training speed. Automatically enabled for batch sizes > 1.",
default=False,
),
hf_repo_id: str = Input(
description="Hugging Face repository ID, if you'd like to upload the trained LoRA to Hugging Face. For example, lucataco/flux-dev-lora. If the given repo does not exist, a new public repo will be created.",
default=None,
Expand Down Expand Up @@ -223,11 +227,32 @@ def train(
f"The regex '{layers_to_optimize_regex}' didn't match any layers. These layers can be optimized:\n"
+ "\n".join(available_layers_to_optimize)
)
quantize = False
resolutions = [int(res) for res in resolution.split(",")]

sample_prompts = []
if wandb_sample_prompts:
sample_prompts = [p.strip() for p in wandb_sample_prompts.split("\n")]

if not gradient_checkpointing:
if (
torch.cuda.get_device_properties(0).total_memory
< 1024 * 1024 * 1024 * 100 # memory < 100 GB?
):
print(
"Turning gradient checkpointing on and quantizing base model, GPU has less than 100 GB of memory"
)
gradient_checkpointing = True
quantize = True
elif batch_size > 1:
print("Turning gradient checkpointing on automatically for batch size > 1")
gradient_checkpointing = True
elif max(resolutions) > 1024:
print(
"Turning gradient checkpointing on; training resolution greater than 1024x1024"
)
gradient_checkpointing = True

train_config = OrderedDict(
{
"job": "custom_job",
Expand Down Expand Up @@ -260,9 +285,7 @@ def train(
# TODO: Do we need to cache to disk? It's faster not to.
"cache_latents_to_disk": cache_latents_to_disk,
"cache_latents": True,
"resolution": [
int(res) for res in resolution.split(",")
],
"resolution": resolutions,
}
],
"train": {
Expand All @@ -272,7 +295,7 @@ def train(
"train_unet": True,
"train_text_encoder": False,
"content_or_style": "balanced",
"gradient_checkpointing": True,
"gradient_checkpointing": gradient_checkpointing,
"noise_scheduler": "flowmatch",
"optimizer": optimizer,
"lr": learning_rate,
Expand All @@ -282,7 +305,7 @@ def train(
"model": {
"name_or_path": str(WEIGHTS_PATH),
"is_flux": True,
"quantize": True,
"quantize": quantize,
},
"sample": {
"sampler": "flowmatch",
Expand Down
Loading