diff --git a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py index 5d25356c5..ee8a00c71 100644 --- a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py +++ b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py @@ -6,8 +6,6 @@ from diffusers import StableDiffusion3Pipeline from onediffx import compile_pipe, quantize_pipe -torch._logging.set_logs(fusion=True) - def parse_args(): parser = argparse.ArgumentParser( @@ -184,18 +182,22 @@ def main(): "negative_prompt": args.negative_prompt, } - sd3.warmup(gen_args) - - for prompt in prompt_list: - gen_args["prompt"] = prompt - print(f"Processing prompt of length {len(prompt)} characters.") - image, inference_time = sd3.generate(gen_args) - assert inference_time < 20, "Prompt inference took too long" - print( - f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds." - ) - cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3) - print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") + with torch.profiler.profile() as prof: + with torch.profiler.record_function("warmup compile"): + sd3.warmup(gen_args) + + with torch.profiler.record_function("sd3 compiled"): + for prompt in prompt_list: + gen_args["prompt"] = prompt + print(f"Processing prompt of length {len(prompt)} characters.") + image, inference_time = sd3.generate(gen_args) + assert inference_time < 20, "Prompt inference took too long" + print( + f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds." + ) + cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3) + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") + prof.export_chrome_trace("sd3_compile_cache.json") if args.run_multiple_resolutions: gen_args["prompt"] = args.prompt diff --git a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py index 192820c69..665e87dd3 100644 --- a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py +++ b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py @@ -38,6 +38,9 @@ def _recursive_setattr(obj, attr, value): "vae.decoder", "vae.encoder", ] +_PARTS = [ + "transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline +] def _filter_parts(ignores=()):