Skip to content

Add LORA Hotswap Compatibility with torch.compile() #11409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Tir25
Copy link

@Tir25 Tir25 commented Apr 24, 2025

LORA Hotswap Compatibility with torch.compile()

Description

This PR adds support for LORA state preservation when using torch.compile(). Currently, when using LORA with compiled models, the adapter state is lost after compilation, making it impossible to switch LORA weights during runtime. This PR fixes this issue by introducing a CompileSafeLORAMixin class that properly handles LORA state during and after compilation.

Changes

  1. Added CompileSafeLORAMixin class in src/diffusers/loaders/lora_base.py:

    • Preserves LORA state before compilation
    • Restores state after compilation
    • Handles adapter switching in compiled models
  2. Updated FluxPipeline to support LORA compilation:

    • Added enable_lora_hotswap method with compilation support
    • Added _compile_model method for proper LORA handling
  3. Added comprehensive tests in tests/pipelines/flux/test_flux_lora.py

Example Usage

from diffusers import FluxPipeline
import torch

# Create pipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

# Enable LORA hotswap
pipe.enable_lora_hotswap(target_rank=256)

# Load initial LORA weights
pipe.load_lora_weights("path/to/lora1", adapter_name="lora1")
pipe.set_adapters(["lora1"], adapter_weights=[1.0])

# Compile model
pipe.transformer = torch.compile(pipe.transformer)

# Switch LORA weights at runtime (now works!)
pipe.load_lora_weights("path/to/lora2", adapter_name="lora2")
pipe.set_adapters(["lora2"], adapter_weights=[0.5])

Testing

The implementation includes three test cases:

  1. test_enable_lora_hotswap: Verifies proper LORA setup
  2. test_compile_with_lora: Ensures LORA state preservation during compilation
  3. test_lora_switch_after_compile: Validates LORA weight switching after compilation

All tests pass successfully.

Related Issue

Fixes #ISSUE_NUMBER (Please reference the issue number)

Additional Notes

  • This implementation maintains backward compatibility
  • No breaking changes to existing functionality
  • Minimal overhead added to compilation process
  • Follows project's coding style and conventions

Checklist

  • Code follows the project's coding guidelines
  • Added tests that prove the fix is effective
  • Updated documentation to reflect changes
  • All tests pass locally

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant