Skip to content

Commit

Permalink
perf compile
Browse files Browse the repository at this point in the history
  • Loading branch information
strint committed Jul 15, 2024
1 parent 957d678 commit a8b3502
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
30 changes: 16 additions & 14 deletions onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from diffusers import StableDiffusion3Pipeline
from onediffx import compile_pipe, quantize_pipe

torch._logging.set_logs(fusion=True)


def parse_args():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -184,18 +182,22 @@ def main():
"negative_prompt": args.negative_prompt,
}

sd3.warmup(gen_args)

for prompt in prompt_list:
gen_args["prompt"] = prompt
print(f"Processing prompt of length {len(prompt)} characters.")
image, inference_time = sd3.generate(gen_args)
assert inference_time < 20, "Prompt inference took too long"
print(
f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
)
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")
with torch.profiler.profile() as prof:
with torch.profiler.record_function("warmup compile"):
sd3.warmup(gen_args)

with torch.profiler.record_function("sd3 compiled"):
for prompt in prompt_list:
gen_args["prompt"] = prompt
print(f"Processing prompt of length {len(prompt)} characters.")
image, inference_time = sd3.generate(gen_args)
assert inference_time < 20, "Prompt inference took too long"
print(
f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
)
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")
prof.export_chrome_trace("sd3_compile_cache.json")

if args.run_multiple_resolutions:
gen_args["prompt"] = args.prompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def _recursive_setattr(obj, attr, value):
"vae.decoder",
"vae.encoder",
]
_PARTS = [
"transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline
]


def _filter_parts(ignores=()):
Expand Down

0 comments on commit a8b3502

Please sign in to comment.