48
48
from diffusers import (
49
49
AutoencoderKL ,
50
50
FlowMatchEulerDiscreteScheduler ,
51
- Lumina2Text2ImgPipeline ,
51
+ Lumina2Pipeline ,
52
52
Lumina2Transformer2DModel ,
53
53
)
54
54
from diffusers .optimization import get_scheduler
72
72
import wandb
73
73
74
74
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
75
- check_min_version ("0.33 .0.dev0" )
75
+ check_min_version ("0.34 .0.dev0" )
76
76
77
77
logger = get_logger (__name__ )
78
78
@@ -898,7 +898,7 @@ def main(args):
898
898
cur_class_images = len (list (class_images_dir .iterdir ()))
899
899
900
900
if cur_class_images < args .num_class_images :
901
- pipeline = Lumina2Text2ImgPipeline .from_pretrained (
901
+ pipeline = Lumina2Pipeline .from_pretrained (
902
902
args .pretrained_model_name_or_path ,
903
903
torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
904
904
revision = args .revision ,
@@ -990,7 +990,7 @@ def main(args):
990
990
text_encoder .to (dtype = torch .bfloat16 )
991
991
992
992
# Initialize a text encoding pipeline and keep it to CPU for now.
993
- text_encoding_pipeline = Lumina2Text2ImgPipeline .from_pretrained (
993
+ text_encoding_pipeline = Lumina2Pipeline .from_pretrained (
994
994
args .pretrained_model_name_or_path ,
995
995
vae = None ,
996
996
transformer = None ,
@@ -1034,7 +1034,7 @@ def save_model_hook(models, weights, output_dir):
1034
1034
# make sure to pop weight so that corresponding model is not saved again
1035
1035
weights .pop ()
1036
1036
1037
- Lumina2Text2ImgPipeline .save_lora_weights (
1037
+ Lumina2Pipeline .save_lora_weights (
1038
1038
output_dir ,
1039
1039
transformer_lora_layers = transformer_lora_layers_to_save ,
1040
1040
)
@@ -1050,7 +1050,7 @@ def load_model_hook(models, input_dir):
1050
1050
else :
1051
1051
raise ValueError (f"unexpected save model: { model .__class__ } " )
1052
1052
1053
- lora_state_dict = Lumina2Text2ImgPipeline .lora_state_dict (input_dir )
1053
+ lora_state_dict = Lumina2Pipeline .lora_state_dict (input_dir )
1054
1054
1055
1055
transformer_state_dict = {
1056
1056
f"{ k .replace ('transformer.' , '' )} " : v for k , v in lora_state_dict .items () if k .startswith ("transformer." )
@@ -1473,7 +1473,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1473
1473
if accelerator .is_main_process :
1474
1474
if args .validation_prompt is not None and epoch % args .validation_epochs == 0 :
1475
1475
# create pipeline
1476
- pipeline = Lumina2Text2ImgPipeline .from_pretrained (
1476
+ pipeline = Lumina2Pipeline .from_pretrained (
1477
1477
args .pretrained_model_name_or_path ,
1478
1478
transformer = accelerator .unwrap_model (transformer ),
1479
1479
revision = args .revision ,
@@ -1503,14 +1503,14 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1503
1503
transformer = transformer .to (weight_dtype )
1504
1504
transformer_lora_layers = get_peft_model_state_dict (transformer )
1505
1505
1506
- Lumina2Text2ImgPipeline .save_lora_weights (
1506
+ Lumina2Pipeline .save_lora_weights (
1507
1507
save_directory = args .output_dir ,
1508
1508
transformer_lora_layers = transformer_lora_layers ,
1509
1509
)
1510
1510
1511
1511
# Final inference
1512
1512
# Load previous pipeline
1513
- pipeline = Lumina2Text2ImgPipeline .from_pretrained (
1513
+ pipeline = Lumina2Pipeline .from_pretrained (
1514
1514
args .pretrained_model_name_or_path ,
1515
1515
revision = args .revision ,
1516
1516
variant = args .variant ,
0 commit comments