-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
Describe the bug
transformer
model loaded via GGUF can't be torch.compile(d) and raises torch._dynamo.exc.Unsupported: call_method SetVariable() __setitem__ (UserDefinedObjectVariable(GGUFParameter), ConstantVariable(NoneType: None)) {}
'normal' model loaded from HF for the same pipeline can be torch.compile(d) just fine.
Reproduction
If I load the pipeline from HF model, i.e.
import torch
from diffusers import AuraFlowPipeline
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline("A cute pony", width=512, height=512, num_inference_steps=5)
I can torch.compile it (and observer better performance).
If I try to load the transformer part from GGUF
import torch
from diffusers import (
AuraFlowPipeline,
GGUFQuantizationConfig,
AuraFlowTransformer2DModel,
)
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
transformer = AuraFlowTransformer2DModel.from_single_file(
"https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf",
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipeline = AuraFlowPipeline.from_pretrained(
"fal/AuraFlow-v0.3",
torch_dtype=torch.bfloat16,
transformer=transformer,
).to("cuda")
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
pipeline("A cute pony", width=512, height=512, num_inference_steps=5)
it raises an exception (see log).
I am still learning on how torch.compile/dynamo function so its unclear to me if this is just some basic confusion of GGUFParameter wrapping torch.nn.Parameter or if diffusers need to do anything special (or if this is something torch must do better?). I've only tested on AuraFlow but this should be the same for any code using GGUF loading. Happy to continue debugging/raise issue with the torch devs but would appreciate if someone more knowledgable have a look at this.
Logs
Traceback (most recent call last):
File "test_diff_gguf.py", line 27, in <module>
pipeline("A cute pony", width=512, height=512, num_inference_steps=5)
File "/env/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/env/diffusers/pipelines/aura_flow/pipeline_aura_flow.py", line 555, in __call__
noise_pred = self.transformer(
File "/env/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/env/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/env/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
File "/env/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/env/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/env/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
File "/env/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
File "/env/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/env/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
File "/env/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
File "/env/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
File "/env/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/env/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
File "/env/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/env/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function
return super().call_function(tx, args, kwargs)
File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 1748, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function
return super().call_function(tx, args, kwargs)
File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 1748, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/env/torch/_dynamo/variables/functions.py", line 378, in call_function
return super().call_function(tx, args, kwargs)
File "/env/torch/_dynamo/variables/functions.py", line 317, in call_function
return super().call_function(tx, args, kwargs)
File "/env/torch/_dynamo/variables/functions.py", line 118, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/env/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
tracer.run()
File "/env/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
File "/env/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
return inner_fn(self, inst)
File "/env/torch/_dynamo/symbolic_convert.py", line 1658, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/env/torch/_dynamo/symbolic_convert.py", line 897, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
File "/env/torch/_dynamo/variables/misc.py", line 1022, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
File "/env/torch/_dynamo/variables/dicts.py", line 566, in call_method
return super().call_method(tx, name, args, kwargs)
File "/env/torch/_dynamo/variables/dicts.py", line 396, in call_method
return super().call_method(tx, name, args, kwargs)
File "/env/torch/_dynamo/variables/base.py", line 414, in call_method
unimplemented(f"call_method {self} {name} {args} {kwargs}")
File "/env/torch/_dynamo/exc.py", line 317, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: call_method SetVariable() __setitem__ (UserDefinedObjectVariable(GGUFParameter), ConstantVariable(NoneType: None)) {}
from user code:
File "/env/diffusers/models/transformers/auraflow_transformer_2d.py", line 458, in forward
temb = self.time_step_embed(timestep).to(dtype=next(self.parameters()).dtype)
File "/env/torch/nn/modules/module.py", line 2630, in parameters
for _name, param in self.named_parameters(recurse=recurse):
File "/env/torch/nn/modules/module.py", line 2657, in named_parameters
gen = self._named_members(
File "/env/torch/nn/modules/module.py", line 2604, in _named_members
memo.add(v)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.16
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.2
- Accelerate version: 1.3.0
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no