This issue is both a bug report (silent failure in the official conversion script) and a feature request (LoRA-aware conversion). I have a working implementation available and would like to know whether a PR would be welcome before preparing it.
Environment
- OS: Ubuntu 24.04
- Python: 3.10 (PyTorch worker) and 3.11 (main runtime)
- openpi commit: [c23745b]
- The bug is silent (no exception, no warning), so there is no traceback.
Summary
The official conversion script examples/convert_jax_model_to_pytorch.py does not handle LoRA-finetuned JAX checkpoints. All LoRA adapter weights (lora_a, lora_b) are silently dropped during the subsequent load_state_dict(..., strict=False) call. The resulting PyTorch model loads without errors but produces outputs that diverge from the JAX original, because the finetuning delta is lost.
This bug likely explains the symptoms reported in several existing open issues for users who finetuned with the official LoRA variants:
In each of these cases, users hit the symptom (model behaves differently after conversion) but the root cause—silent LoRA adapter drop—was not identified.
Reproduction
The bug appears when:
-
A pi0 or pi05 model is finetuned in JAX with the official LoRA variants:
paligemma_variant="gemma_2b_lora"
action_expert_variant="gemma_300m_lora"
-
The resulting checkpoint is converted with the official script:
python examples/convert_jax_model_to_pytorch.py \
--checkpoint_dir <path_to_lora_checkpoint> \
--output_path <output_path> \
--config_name <lora_config_name>
-
The script completes successfully with no warning.
-
Inspect the load result:
load_result = model.load_state_dict(state_dict, strict=False)
print("unexpected:", len(load_result.unexpected_keys))
print("missing:", load_result.missing_keys)
Observed on our reproduction:
unexpected_keys: 20 entries (all lora_a / lora_b keys)
missing_keys: 2 entries (only lm_head.weight, which is tied/unused)
- Run inference and compare against the JAX original on the same observation. Outputs diverge.
Numerical impact
Same-input comparison (prompt: "pick up the blue cube", saved observation), JAX CPU vs PyTorch eager:
| Stage |
max_abs |
mean_abs |
| Official converter (LoRA silently dropped) |
0.0309 |
0.0025 |
| LoRA-aware merge applied |
0.0252 |
0.0021 |
| LoRA merge + attn_vec_einsum fix |
~0.017 |
~0.0011 |
| LoRA merge + attn_vec_einsum fix + float32 storage |
0.0017 |
0.0001 |
Root cause
The base examples/convert_jax_model_to_pytorch.py slice flow has no step that merges LoRA adapter weights into the base weights before producing the PyTorch state_dict. The lora_a / lora_b keys then remain in the state_dict and are silently discarded by the strict=False load, with no warning to the user.
Implementation notes
Implementing this correctly requires handling two non-obvious quirks in openpi's runtime LoRA path. A naive standard-LoRA merge formula will produce incorrect results.
-
attn_vec_einsum LoRA: the runtime second einsum in openpi.models.gemma.Attention sums over the head dimension (N), so the equivalent merged weight uses sum_N(lora_b) rather than a per-head outer product.
-
MLP FeedForward._dot(): adds the LoRA delta to the base output without applying the alpha/rank scaling factor that standard LoRA implementations use. Merged MLP weights must mirror this behavior (i.e., merge without the scaling).
Additional finding
Storing the merged weights in bfloat16 causes additional numerical drift (post-merge max_abs ~0.017 in bf16 vs ~0.0017 in float32 on the same observation; see the table above). The merged checkpoint should be stored in float32 even when downstream inference runs in bfloat16. This is consistent with the precision-related observations in #810.
Proposed direction
I have a working implementation that:
- Merges LoRA adapters into base weights before applying the existing
slice_paligemma_state_dict / slice_gemma_state_dict flow.
- Handles the two runtime quirks above.
- Defaults to float32 precision for the saved merged checkpoint.
- Validates the load (no
unexpected_keys; missing_keys limited to the known-tied lm_head).
Repository (WIP, includes domain-specific scaffolding to be cleaned before PR): https://github.com/Ret1ehS/OpenPi-Auboi5/tree/main/tools
Two options for the PR shape:
- (a) Add LoRA support to
examples/convert_jax_model_to_pytorch.py as an additive flag (auto-detected from the config, or a --merge-lora option). Minimally invasive to the existing script.
- (b) Add a separate
examples/convert_jax_lora_to_pytorch.py script. Cleaner separation but slightly more code duplication.
I'm happy to prepare a PR in either form, including a numerical correctness test against JAX inference. Would such a PR be welcome, and if so, which form do maintainers prefer?
This issue is both a bug report (silent failure in the official conversion script) and a feature request (LoRA-aware conversion). I have a working implementation available and would like to know whether a PR would be welcome before preparing it.
Environment
Summary
The official conversion script
examples/convert_jax_model_to_pytorch.pydoes not handle LoRA-finetuned JAX checkpoints. All LoRA adapter weights (lora_a,lora_b) are silently dropped during the subsequentload_state_dict(..., strict=False)call. The resulting PyTorch model loads without errors but produces outputs that diverge from the JAX original, because the finetuning delta is lost.This bug likely explains the symptoms reported in several existing open issues for users who finetuned with the official LoRA variants:
In each of these cases, users hit the symptom (model behaves differently after conversion) but the root cause—silent LoRA adapter drop—was not identified.
Reproduction
The bug appears when:
A pi0 or pi05 model is finetuned in JAX with the official LoRA variants:
paligemma_variant="gemma_2b_lora"action_expert_variant="gemma_300m_lora"The resulting checkpoint is converted with the official script:
python examples/convert_jax_model_to_pytorch.py \ --checkpoint_dir <path_to_lora_checkpoint> \ --output_path <output_path> \ --config_name <lora_config_name>The script completes successfully with no warning.
Inspect the load result:
Observed on our reproduction:
unexpected_keys: 20 entries (alllora_a/lora_bkeys)missing_keys: 2 entries (onlylm_head.weight, which is tied/unused)Numerical impact
Same-input comparison (prompt: "pick up the blue cube", saved observation), JAX CPU vs PyTorch eager:
Root cause
The base
examples/convert_jax_model_to_pytorch.pyslice flow has no step that merges LoRA adapter weights into the base weights before producing the PyTorch state_dict. Thelora_a/lora_bkeys then remain in the state_dict and are silently discarded by thestrict=Falseload, with no warning to the user.Implementation notes
Implementing this correctly requires handling two non-obvious quirks in openpi's runtime LoRA path. A naive standard-LoRA merge formula will produce incorrect results.
attn_vec_einsumLoRA: the runtime second einsum inopenpi.models.gemma.Attentionsums over the head dimension (N), so the equivalent merged weight usessum_N(lora_b)rather than a per-head outer product.MLP
FeedForward._dot(): adds the LoRA delta to the base output without applying thealpha/rankscaling factor that standard LoRA implementations use. Merged MLP weights must mirror this behavior (i.e., merge without the scaling).Additional finding
Storing the merged weights in bfloat16 causes additional numerical drift (post-merge
max_abs ~0.017in bf16 vs~0.0017in float32 on the same observation; see the table above). The merged checkpoint should be stored in float32 even when downstream inference runs in bfloat16. This is consistent with the precision-related observations in #810.Proposed direction
I have a working implementation that:
slice_paligemma_state_dict/slice_gemma_state_dictflow.unexpected_keys;missing_keyslimited to the known-tiedlm_head).Repository (WIP, includes domain-specific scaffolding to be cleaned before PR): https://github.com/Ret1ehS/OpenPi-Auboi5/tree/main/tools
Two options for the PR shape:
examples/convert_jax_model_to_pytorch.pyas an additive flag (auto-detected from the config, or a--merge-loraoption). Minimally invasive to the existing script.examples/convert_jax_lora_to_pytorch.pyscript. Cleaner separation but slightly more code duplication.I'm happy to prepare a PR in either form, including a numerical correctness test against JAX inference. Would such a PR be welcome, and if so, which form do maintainers prefer?