Skip to content

Commit 9da8550

Browse files
committed
Fix batch
1 parent c6de237 commit 9da8550

File tree

1 file changed

+79
-78
lines changed

1 file changed

+79
-78
lines changed

app/sana_pipeline.py

+79-78
Original file line numberDiff line numberDiff line change
@@ -231,86 +231,87 @@ def forward(
231231
),
232232
torch.tensor([[1.0]], device=self.device).repeat(num_images_per_prompt, 1),
233233
)
234+
234235
for _ in range(num_images_per_prompt):
235-
with torch.no_grad():
236-
prompts.append(
237-
prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()
238-
)
236+
prompts.append(
237+
prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip()
238+
)
239239

240-
# prepare text feature
241-
if not self.config.text_encoder.chi_prompt:
242-
max_length_all = self.config.text_encoder.model_max_length
243-
prompts_all = prompts
244-
else:
245-
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
246-
prompts_all = [chi_prompt + prompt for prompt in prompts]
247-
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
248-
max_length_all = (
249-
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
250-
) # magic number 2: [bos], [_]
251-
252-
caption_token = self.tokenizer(
253-
prompts_all,
254-
max_length=max_length_all,
255-
padding="max_length",
256-
truncation=True,
257-
return_tensors="pt",
258-
).to(device=self.device)
259-
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
260-
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
261-
:, :, select_index
262-
].to(self.weight_dtype)
263-
emb_masks = caption_token.attention_mask[:, select_index]
264-
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
265-
266-
n = len(prompts)
267-
if latents is None:
268-
z = torch.randn(
269-
n,
270-
self.config.vae.vae_latent_dim,
271-
self.latent_size_h,
272-
self.latent_size_w,
273-
generator=generator,
274-
device=self.device,
275-
dtype=self.weight_dtype,
276-
)
277-
else:
278-
z = latents.to(self.weight_dtype).to(self.device)
279-
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
280-
if self.vis_sampler == "flow_euler":
281-
flow_solver = FlowEuler(
282-
self.model,
283-
condition=caption_embs,
284-
uncondition=null_y,
285-
cfg_scale=guidance_scale,
286-
model_kwargs=model_kwargs,
287-
)
288-
sample = flow_solver.sample(
289-
z,
290-
steps=num_inference_steps,
291-
)
292-
elif self.vis_sampler == "flow_dpm-solver":
293-
scheduler = DPMS(
294-
self.model,
295-
condition=caption_embs,
296-
uncondition=null_y,
297-
guidance_type=self.guidance_type,
298-
cfg_scale=guidance_scale,
299-
pag_scale=pag_guidance_scale,
300-
pag_applied_layers=self.config.model.pag_applied_layers,
301-
model_type="flow",
302-
model_kwargs=model_kwargs,
303-
schedule="FLOW",
304-
)
305-
scheduler.register_progress_bar(self.progress_fn)
306-
sample = scheduler.sample(
307-
z,
308-
steps=num_inference_steps,
309-
order=2,
310-
skip_type="time_uniform_flow",
311-
method="multistep",
312-
flow_shift=self.flow_shift,
313-
)
240+
with torch.no_grad():
241+
# prepare text feature
242+
if not self.config.text_encoder.chi_prompt:
243+
max_length_all = self.config.text_encoder.model_max_length
244+
prompts_all = prompts
245+
else:
246+
chi_prompt = "\n".join(self.config.text_encoder.chi_prompt)
247+
prompts_all = [chi_prompt + prompt for prompt in prompts]
248+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
249+
max_length_all = (
250+
num_chi_prompt_tokens + self.config.text_encoder.model_max_length - 2
251+
) # magic number 2: [bos], [_]
252+
253+
caption_token = self.tokenizer(
254+
prompts_all,
255+
max_length=max_length_all,
256+
padding="max_length",
257+
truncation=True,
258+
return_tensors="pt",
259+
).to(device=self.device)
260+
select_index = [0] + list(range(-self.config.text_encoder.model_max_length + 1, 0))
261+
caption_embs = self.text_encoder(caption_token.input_ids, caption_token.attention_mask)[0][:, None][
262+
:, :, select_index
263+
].to(self.weight_dtype)
264+
emb_masks = caption_token.attention_mask[:, select_index]
265+
null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None].to(self.weight_dtype)
266+
267+
n = len(prompts)
268+
if latents is None:
269+
z = torch.randn(
270+
n,
271+
self.config.vae.vae_latent_dim,
272+
self.latent_size_h,
273+
self.latent_size_w,
274+
generator=generator,
275+
device=self.device,
276+
dtype=self.weight_dtype,
277+
)
278+
else:
279+
z = latents.to(self.weight_dtype).to(self.device)
280+
model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks)
281+
if self.vis_sampler == "flow_euler":
282+
flow_solver = FlowEuler(
283+
self.model,
284+
condition=caption_embs,
285+
uncondition=null_y,
286+
cfg_scale=guidance_scale,
287+
model_kwargs=model_kwargs,
288+
)
289+
sample = flow_solver.sample(
290+
z,
291+
steps=num_inference_steps,
292+
)
293+
elif self.vis_sampler == "flow_dpm-solver":
294+
scheduler = DPMS(
295+
self.model,
296+
condition=caption_embs,
297+
uncondition=null_y,
298+
guidance_type=self.guidance_type,
299+
cfg_scale=guidance_scale,
300+
pag_scale=pag_guidance_scale,
301+
pag_applied_layers=self.config.model.pag_applied_layers,
302+
model_type="flow",
303+
model_kwargs=model_kwargs,
304+
schedule="FLOW",
305+
)
306+
scheduler.register_progress_bar(self.progress_fn)
307+
sample = scheduler.sample(
308+
z,
309+
steps=num_inference_steps,
310+
order=2,
311+
skip_type="time_uniform_flow",
312+
method="multistep",
313+
flow_shift=self.flow_shift,
314+
)
314315

315316
sample = sample.to(self.weight_dtype)
316317
with torch.no_grad():

0 commit comments

Comments
 (0)