Skip to content

Can't torch.compile transformer models that load GGUF via from_single_file #10795

@AstraliteHeart

Description

@AstraliteHeart

Describe the bug

transformer model loaded via GGUF can't be torch.compile(d) and raises torch._dynamo.exc.Unsupported: call_method SetVariable() __setitem__ (UserDefinedObjectVariable(GGUFParameter), ConstantVariable(NoneType: None)) {}

'normal' model loaded from HF for the same pipeline can be torch.compile(d) just fine.

Reproduction

If I load the pipeline from HF model, i.e.

import torch
from diffusers import AuraFlowPipeline

torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
).to("cuda")

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline("A cute pony", width=512, height=512, num_inference_steps=5)

I can torch.compile it (and observer better performance).

If I try to load the transformer part from GGUF

import torch
from diffusers import (
    AuraFlowPipeline,
    GGUFQuantizationConfig,
    AuraFlowTransformer2DModel,
)

torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

transformer = AuraFlowTransformer2DModel.from_single_file(
    "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
    "fal/AuraFlow-v0.3",
    torch_dtype=torch.bfloat16,
    transformer=transformer,
).to("cuda")

pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline("A cute pony", width=512, height=512, num_inference_steps=5)

it raises an exception (see log).

I am still learning on how torch.compile/dynamo function so its unclear to me if this is just some basic confusion of GGUFParameter wrapping torch.nn.Parameter or if diffusers need to do anything special (or if this is something torch must do better?). I've only tested on AuraFlow but this should be the same for any code using GGUF loading. Happy to continue debugging/raise issue with the torch devs but would appreciate if someone more knowledgable have a look at this.

Logs

Traceback (most recent call last):
  File "test_diff_gguf.py", line 27, in <module>
    pipeline("A cute pony", width=512, height=512, num_inference_steps=5)
  File "/env/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/env/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 555, in __call__
    noise_pred = self.transformer(
  File "/env/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/env/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/env/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
  File "/env/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/env/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/env/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
  File "/env/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
  File "/env/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/env/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/env/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
  File "/env/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/env/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/env/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
  File "/env/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/env/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
  File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
  File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 1748, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
  File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
  File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 1748, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function
    return super().call_function(tx, args, kwargs)
  File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function
    return super().call_function(tx, args, kwargs)
  File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
    tracer.run()
  File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
  File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
    return inner_fn(self, inst)
  File "/env/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/env/torch/_dynamo/variables/misc.py", line 1022, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
  File "/env/torch/_dynamo/variables/dicts.py", line 566, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/env/torch/_dynamo/variables/dicts.py", line 396, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/env/torch/_dynamo/variables/base.py", line 414, in call_method
    unimplemented(f"call_method {self} {name} {args} {kwargs}")
  File "/env/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_method SetVariable() __setitem__ (UserDefinedObjectVariable(GGUFParameter), ConstantVariable(NoneType: None)) {}

from user code:
   File "/env/diffusers/models/transformers/auraflow_transformer_2d.py", line 458, in forward
    temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
  File "/env/torch/nn/modules/module.py", line 2630, in parameters
    for _name, param in self.named_parameters(recurse=recurse):
  File "/env/torch/nn/modules/module.py", line 2657, in named_parameters
    gen = self._named_members(
  File "/env/torch/nn/modules/module.py", line 2604, in _named_members
    memo.add(v)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.10.16
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.2
  • Accelerate version: 1.3.0
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.2
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@DN6 @hlky @stevhliu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions