@@ -231,86 +231,87 @@ def forward(
231
231
),
232
232
torch .tensor ([[1.0 ]], device = self .device ).repeat (num_images_per_prompt , 1 ),
233
233
)
234
+
234
235
for _ in range (num_images_per_prompt ):
235
- with torch .no_grad ():
236
- prompts .append (
237
- prepare_prompt_ar (prompt , self .base_ratios , device = self .device , show = False )[0 ].strip ()
238
- )
236
+ prompts .append (
237
+ prepare_prompt_ar (prompt , self .base_ratios , device = self .device , show = False )[0 ].strip ()
238
+ )
239
239
240
- # prepare text feature
241
- if not self .config .text_encoder .chi_prompt :
242
- max_length_all = self .config .text_encoder .model_max_length
243
- prompts_all = prompts
244
- else :
245
- chi_prompt = "\n " .join (self .config .text_encoder .chi_prompt )
246
- prompts_all = [chi_prompt + prompt for prompt in prompts ]
247
- num_chi_prompt_tokens = len (self .tokenizer .encode (chi_prompt ))
248
- max_length_all = (
249
- num_chi_prompt_tokens + self .config .text_encoder .model_max_length - 2
250
- ) # magic number 2: [bos], [_]
251
-
252
- caption_token = self .tokenizer (
253
- prompts_all ,
254
- max_length = max_length_all ,
255
- padding = "max_length" ,
256
- truncation = True ,
257
- return_tensors = "pt" ,
258
- ).to (device = self .device )
259
- select_index = [0 ] + list (range (- self .config .text_encoder .model_max_length + 1 , 0 ))
260
- caption_embs = self .text_encoder (caption_token .input_ids , caption_token .attention_mask )[0 ][:, None ][
261
- :, :, select_index
262
- ].to (self .weight_dtype )
263
- emb_masks = caption_token .attention_mask [:, select_index ]
264
- null_y = self .null_caption_embs .repeat (len (prompts ), 1 , 1 )[:, None ].to (self .weight_dtype )
265
-
266
- n = len (prompts )
267
- if latents is None :
268
- z = torch .randn (
269
- n ,
270
- self .config .vae .vae_latent_dim ,
271
- self .latent_size_h ,
272
- self .latent_size_w ,
273
- generator = generator ,
274
- device = self .device ,
275
- dtype = self .weight_dtype ,
276
- )
277
- else :
278
- z = latents .to (self .weight_dtype ).to (self .device )
279
- model_kwargs = dict (data_info = {"img_hw" : hw , "aspect_ratio" : ar }, mask = emb_masks )
280
- if self .vis_sampler == "flow_euler" :
281
- flow_solver = FlowEuler (
282
- self .model ,
283
- condition = caption_embs ,
284
- uncondition = null_y ,
285
- cfg_scale = guidance_scale ,
286
- model_kwargs = model_kwargs ,
287
- )
288
- sample = flow_solver .sample (
289
- z ,
290
- steps = num_inference_steps ,
291
- )
292
- elif self .vis_sampler == "flow_dpm-solver" :
293
- scheduler = DPMS (
294
- self .model ,
295
- condition = caption_embs ,
296
- uncondition = null_y ,
297
- guidance_type = self .guidance_type ,
298
- cfg_scale = guidance_scale ,
299
- pag_scale = pag_guidance_scale ,
300
- pag_applied_layers = self .config .model .pag_applied_layers ,
301
- model_type = "flow" ,
302
- model_kwargs = model_kwargs ,
303
- schedule = "FLOW" ,
304
- )
305
- scheduler .register_progress_bar (self .progress_fn )
306
- sample = scheduler .sample (
307
- z ,
308
- steps = num_inference_steps ,
309
- order = 2 ,
310
- skip_type = "time_uniform_flow" ,
311
- method = "multistep" ,
312
- flow_shift = self .flow_shift ,
313
- )
240
+ with torch .no_grad ():
241
+ # prepare text feature
242
+ if not self .config .text_encoder .chi_prompt :
243
+ max_length_all = self .config .text_encoder .model_max_length
244
+ prompts_all = prompts
245
+ else :
246
+ chi_prompt = "\n " .join (self .config .text_encoder .chi_prompt )
247
+ prompts_all = [chi_prompt + prompt for prompt in prompts ]
248
+ num_chi_prompt_tokens = len (self .tokenizer .encode (chi_prompt ))
249
+ max_length_all = (
250
+ num_chi_prompt_tokens + self .config .text_encoder .model_max_length - 2
251
+ ) # magic number 2: [bos], [_]
252
+
253
+ caption_token = self .tokenizer (
254
+ prompts_all ,
255
+ max_length = max_length_all ,
256
+ padding = "max_length" ,
257
+ truncation = True ,
258
+ return_tensors = "pt" ,
259
+ ).to (device = self .device )
260
+ select_index = [0 ] + list (range (- self .config .text_encoder .model_max_length + 1 , 0 ))
261
+ caption_embs = self .text_encoder (caption_token .input_ids , caption_token .attention_mask )[0 ][:, None ][
262
+ :, :, select_index
263
+ ].to (self .weight_dtype )
264
+ emb_masks = caption_token .attention_mask [:, select_index ]
265
+ null_y = self .null_caption_embs .repeat (len (prompts ), 1 , 1 )[:, None ].to (self .weight_dtype )
266
+
267
+ n = len (prompts )
268
+ if latents is None :
269
+ z = torch .randn (
270
+ n ,
271
+ self .config .vae .vae_latent_dim ,
272
+ self .latent_size_h ,
273
+ self .latent_size_w ,
274
+ generator = generator ,
275
+ device = self .device ,
276
+ dtype = self .weight_dtype ,
277
+ )
278
+ else :
279
+ z = latents .to (self .weight_dtype ).to (self .device )
280
+ model_kwargs = dict (data_info = {"img_hw" : hw , "aspect_ratio" : ar }, mask = emb_masks )
281
+ if self .vis_sampler == "flow_euler" :
282
+ flow_solver = FlowEuler (
283
+ self .model ,
284
+ condition = caption_embs ,
285
+ uncondition = null_y ,
286
+ cfg_scale = guidance_scale ,
287
+ model_kwargs = model_kwargs ,
288
+ )
289
+ sample = flow_solver .sample (
290
+ z ,
291
+ steps = num_inference_steps ,
292
+ )
293
+ elif self .vis_sampler == "flow_dpm-solver" :
294
+ scheduler = DPMS (
295
+ self .model ,
296
+ condition = caption_embs ,
297
+ uncondition = null_y ,
298
+ guidance_type = self .guidance_type ,
299
+ cfg_scale = guidance_scale ,
300
+ pag_scale = pag_guidance_scale ,
301
+ pag_applied_layers = self .config .model .pag_applied_layers ,
302
+ model_type = "flow" ,
303
+ model_kwargs = model_kwargs ,
304
+ schedule = "FLOW" ,
305
+ )
306
+ scheduler .register_progress_bar (self .progress_fn )
307
+ sample = scheduler .sample (
308
+ z ,
309
+ steps = num_inference_steps ,
310
+ order = 2 ,
311
+ skip_type = "time_uniform_flow" ,
312
+ method = "multistep" ,
313
+ flow_shift = self .flow_shift ,
314
+ )
314
315
315
316
sample = sample .to (self .weight_dtype )
316
317
with torch .no_grad ():
0 commit comments