Skip to content

Port ESMC and ESMFold2 to Transformers#46419

Open
Rocketknight1 wants to merge 70 commits into
mainfrom
port-esmc-esmfold2
Open

Port ESMC and ESMFold2 to Transformers#46419
Rocketknight1 wants to merge 70 commits into
mainfrom
port-esmc-esmfold2

Conversation

@Rocketknight1

@Rocketknight1 Rocketknight1 commented Jun 4, 2026

Copy link
Copy Markdown
Member

CI

This is just about ready at this point! I've deferred the experimental ESMFold2 models and the SAEs to a follow-up PR so we can launch this as soon as possible. It's got a lot of very verbose Claude comments and docstrings too, but those are forgivable and we can always clean them up later 😅

The main thing for reviewers to know is that there are two model classes here, ESMC (a masked protein LM) and ESMFold2 (a protein folding model that uses ESMC as a backbone model). ESMC is mostly just a 'normal' masked LM, but I wasn't able to find exact modular matches for a lot of methods. Still, it should be easy enough to review. ESMFold2 is much more challenging and is not like any other model in the codebase, even the original ESMFold. I have some tests confirming that we get the same output as BioHub's original code, so I think reviewers should probably just trust that for correctness and focus on the API, although you can try going through all of the diffusion code if you want!

@Rocketknight1 Rocketknight1 marked this pull request as ready for review June 4, 2026 14:16
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1 Rocketknight1 force-pushed the port-esmc-esmfold2 branch 2 times, most recently from e1ed451 to 1af5750 Compare June 12, 2026 15:47
@github-actions

Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46419&sha=464744

@Rocketknight1 Rocketknight1 force-pushed the port-esmc-esmfold2 branch 2 times, most recently from 0a716e4 to 976468d Compare June 22, 2026 17:15

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deep dive, I think esmc is alright. Unsure about esmfold2 it seems very much all over the place where I cannot really read the flow too easily and a lot of stuff is going at once

Comment thread docs/source/en/model_doc/esmc.md Outdated
Comment on lines +30 to +32
[`biohub/ESMC-300M`](https://huggingface.co/biohub/ESMC-300M),
[`biohub/ESMC-600M`](https://huggingface.co/biohub/ESMC-600M) and
[`biohub/ESMC-6B`](https://huggingface.co/biohub/ESMC-6B).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[`biohub/ESMC-300M`](https://huggingface.co/biohub/ESMC-300M),
[`biohub/ESMC-600M`](https://huggingface.co/biohub/ESMC-600M) and
[`biohub/ESMC-6B`](https://huggingface.co/biohub/ESMC-6B).
- [`biohub/ESMC-300M`](https://huggingface.co/biohub/ESMC-300M)
- [`biohub/ESMC-600M`](https://huggingface.co/biohub/ESMC-600M)
- [`biohub/ESMC-6B`](https://huggingface.co/biohub/ESMC-6B)

nit

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread docs/source/en/model_doc/esmc.md Outdated

```python
import torch
from transformers import AutoTokenizer, ESMCModel

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't have any auto model for this?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with AutoModel! (ESMC is fairly normal, it's just a masked LM)

"ESMFold2TriangleMultiplication": {
"cuda": {
Mode.INFERENCE: LayerRepository(
repo_id="Rocketknight1/esmfold2-trimul-kernel",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to port to kernels community? Lemme nudge internally if you need help check the kernels channel #kernels

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this was a placeholder location while I was working on the PR! We should definitely move this before merging

Comment on lines +16 to +23
from ...utils import _LazyModule # type: ignore[import]
from ...utils.import_utils import define_import_structure # type: ignore[import]


if TYPE_CHECKING:
from .configuration_esmc import * # noqa: F403
from .modeling_esmc import * # noqa: F403
from .tokenization_esmc import * # noqa: F403

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claued going crazy? We shouldnt need all these linter ignore stuff

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed!


from huggingface_hub.dataclasses import strict

from ...configuration_utils import PreTrainedConfig # type: ignore[import]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok last mention since it might happen more often

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also removed!



_CONFIDENCE_EPS = 1e-6
_NONPOLYMER_ID = 4

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config

denom = pair_m.sum(dim=(-1, -2)) + _CONFIDENCE_EPS
pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum(dim=(-1, -2)) / denom

return {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

way too big forward

# precision drives the diffusion conditioning; keep them fp32 even under dtype=bf16.
_keep_in_fp32_modules_strict = ["fourier"]
_supports_sdpa = True
_supports_flash_attn = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid focusing on this for now; it really doesnt seem to compatible

init.zeros_(module.base_z_combine)


NUM_RES_TYPES = 33

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config

token_bonds_encoding=token_bonds_encoding.detach(),
)

return ESMFold2Output(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same way too much going on at once in one forward, it's hard

@vasqu vasqu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part2 for esmc, I think that portion looks solid now 🫡 only minor details

Comment thread docs/source/en/model_doc/esmc.md Outdated

tokenizer = AutoTokenizer.from_pretrained("biohub/ESMC-300M")
# ESMC is registered with the auto classes (AutoModel, AutoModelForMaskedLM,
# AutoModelForSequenceClassification, AutoModelForTokenClassification).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo if you want to you can use hf options and show each variant instead, wdyt?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"""

model_type = "esmc"
default_theta = 10000.0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is alreadythe default could be removed I think

Suggested change
default_theta = 10000.0

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines +64 to +68
attribute_map = {
"d_model": "hidden_size",
"n_heads": "num_attention_heads",
"n_layers": "num_hidden_layers",
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea dw, it would have been nice to avoid but it's also fine this way imo

attention_dropout: float | None = 0.0
qk_layernorm: bool | None = True
scale_residue: bool | None = True
tie_word_embeddings: bool | None = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attributes are fairly standard, would it make sense to move this to modular and inherit from e.g. llama config? You can always override / add (attr = ...) or delete unnecessary attributes (attr = AttributeError())

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, inherited from LlamaConfig in modular and overrode to restore the ESMC default values

self._tokenizer = Tokenizer(BPE(token_to_id, merges=[], unk_token=unk_token))
self._tokenizer.add_special_tokens([cls_token, pad_token, mask_token, eos_token, chain_break_token])

# Automatically wrap every encoded sequence with <cls> … <eos>.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a bit weird, iirc wasnt bert using cls / sep, this is kind of a mix here. but if that's already their convention it's nothing to change, just curious

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is the convention in the model so I left it alone!

for layer in self.layers:
hidden_states = layer(
hidden_states,
attn_bias,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would still call it attention mask tbh

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

especially since sdpa will use a bool mask 😬

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

# ---------------------------------------------------------------------------


class ESMCMaskedLMHead(nn.Module):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's restructure a bit: It is a CLIP mlp with an additional layer norm in between, can we make it follow the naming a bit closer to that then?



@auto_docstring
class ESMCForMaskedLM(ESMCPreTrainedModel):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be unrelated no? If you do modular here it only cares about the namings but not the underlying implementation (unless you didn't override then it will try to create its version based on the original parent's)

with torch.no_grad():
logits = model(**inputs).logits
self.assertEqual(logits.shape, (1, inputs["input_ids"].shape[1], model.config.vocab_size))
self.assertTrue(torch.isfinite(logits).all())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check nomic bert for integration tests --> would be nice on logits

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Integration tests in!



@require_tokenizers
class ESMCTokenizationTest(unittest.TestCase):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TokenizerTesterMixin?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Rocketknight1 and others added 21 commits June 30, 2026 17:43
…dapted)

Moves the purely-additive model code from the Biohub fork
(github.com/Biohub/transformers @ f9a5a37, based on v4.57.6) onto a
branch off current main (v5.10.0.dev0). This is the verbatim fork code as
a starting point; v5 convention adaptation (attention interface, modular,
__all__, nested-config round-trip) is follow-up work per the port plan.

Contents:
- src/transformers/models/esmc/    (6 files: config, sae config, modeling,
  sae modeling, tokenizer) — imports and exports cleanly under v5.
- src/transformers/models/esmfold2/ (24 files incl. deferred kernels/,
  distributed/, experimental) — config + modeling import; ESMFold2Model is
  defined but not yet exported (no __all__ — known adaptation item).

Auto-registration adapted to the v5 layout (NOT a verbatim copy of the
fork's 4 hook diffs):
- models/__init__.py: from .esmc/.esmfold2 import *
- auto/auto_mappings.py: CONFIG_MAPPING_NAMES + SPECIAL_MODEL_TYPE_TO_MODULE_NAME
  (these moved out of configuration_auto.py in v5; MODEL_NAMES_MAPPING was
  dropped in v5 so the fork's hunk for it has no equivalent).
- auto/modeling_auto.py: base + masked-LM + seq-cls + token-cls maps.
- auto/tokenization_auto.py: flat ("esmc", "ESMCTokenizer") form (v5 dropped
  the (slow, fast) tuple format).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…old2

Transformers loads fused kernels from the Hub via the `kernels` library
(@use_kernel_forward_from_hub / register_kernel_mapping), not by vendoring
Triton in a model dir. Remove the fork's bespoke acceleration stack and
leave a single pure-PyTorch path; fused ops can return later via a Hub
kernels repo as an opt-in follow-up.

Removed:
- src/transformers/models/esmfold2/kernels/ (8 Triton files, ~3.3k LOC).
- The shared `set_kernel_backend` selector across the module tree, which
  drove BOTH backends: "fused" (vendored Triton) and "cuequivariance"
  (external lib). Both gone — the cueq import block and BACKEND_*/
  _VALID_BACKENDS/_fused_active/_cueq_active helpers with them.
- Per-module fused/cueq branches and now-dead helpers (_can_use_*,
  _fused_trimul_with_residual, split_kernel_weights, _kernel_flow_direction,
  Transition._swiglu_pre_w3/_addmm_residual, DropoutResidual fused impl).
- The vestigial no-op set_kernel_backend hooks in distributed/utils.py and
  modeling_esmfold2_experimental.py.

Kept intact: the independent `set_chunk_size` memory knob, and the optional
flash-attn / transformer_engine guards.

Verified: both modeling modules import; CPU smoke test runs AttentionPairBias,
TriangleMultiplicativeUpdate (both orientations), Transition (chunked ==
unchunked), DropoutResidual, and FoldingTrunk end-to-end on the pure-PyTorch
path. No new ruff errors introduced (9 pre-existing fork lint items remain for
the later `make style` pass).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The distributed/ package is a NVIDIA/MIT-licensed 2D context-parallel
implementation of the folding trunk (DTensor + DeviceMesh + NCCL) for
multi-GPU 6B inference. It is dropped from the port because:

- It is not imported by any model code (core, config, __init__, or the
  experimental file) — fully inert in the package.
- It is broken on import: all 7 files import from
  `projects.huggingface.transformers.models.esmfold2...` (the fork's
  internal monorepo path), so `import transformers.models.esmfold2.distributed`
  raises ModuleNotFoundError. It never worked in the standalone layout.
- It is NVIDIA/MIT-licensed, unlike the Apache/Biohub model code.
- Transformers expresses parallelism declaratively via `base_model_tp_plan`
  / `tp_plan="auto"`, not a vendored per-model DTensor/NCCL stack.

Nothing unique is lost: the math it shards already exists as the
pure-PyTorch reference in modeling_esmfold2_common.py. If multi-GPU
inference is needed later, author a tp_plan on ESMFold2Model fresh.

Verified: nothing references distributed/; `import transformers` and the
esmfold2 modeling module still import cleanly.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Replace ESMC's bespoke attention dispatch with the standard v5 interface,
mirroring models/esm (a bidirectional encoder). Behaviour is preserved
bit-exactly at all real-token positions.

Before: a hand-rolled `_scaled_dot_product_attention` choosing xformers ->
flash-attn-2 -> SDPA, a `_FlashMultiHeadAttention` subclass with manual
unpad_input/pad_input in ESMCModel.forward, a `_TritonRotaryEmbedding`, and
a `seq_id`-threaded chain mask.

After:
- Module-level `eager_attention_forward`; `MultiHeadAttention` dispatches via
  `ALL_ATTENTION_FUNCTIONS.get_interface(config._attn_implementation, ...)`
  with q/k/v shaped (B, H, S, Dh) and `scaling=head_dim**-0.5` (RoPE is
  rotation-invariant to scaling, so it stays in the module). output_attentions
  forces the eager interface so probabilities remain observable.
- ESMCModel.forward builds the 4D mask once: `create_bidirectional_mask` for
  the padding case (works for eager/sdpa/flash), and a block-diagonal additive
  bias for multi-chain `sequence_id` (eager/sdpa; flash multi-chain still
  raises). Removes the unpad/pad varlen path.
- Drops the xformers / flash-attn / triton-rotary import machinery and their
  warnings. transformer_engine LN/MLP fusion is unchanged. The
  `_supports_sdpa/_supports_flash_attn/_supports_attention_backend` flags
  (already declared) are now actually honoured.

flash-attn-2 is still supported — now via the standard attn_implementation
backend rather than bespoke dispatch.

Verified: loading identical weights into the refactored model reproduces the
pre-refactor outputs to 0.000e+00 at every non-padding position (plain,
padding-mask, and multi-chain cases); the only differences are at padding
positions, which are masked out downstream. eager and sdpa agree bit-exactly
on valid-token logits; MaskedLM/SequenceClassification/TokenClassification
and output_attentions all work.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
No other Transformers model depends on transformer_engine; drop it from
ESMC. TE provided fused fp32-reduction LayerNorm+Linear / LayerNorm+MLP
kernels, but the model already shipped pure-PyTorch fallbacks whose
parameter names (`layer_norm_weight`, `fc1_weight`, `fc2_weight`, ...) match
the published-checkpoint layout. Make those the only path.

- Drop the `transformer_engine` import + `_te_available` guard + the
  "TE not installed" warning.
- `_swiglu_ln_ffn` / `_make_attn_layernorm_qkv` / `_make_attn_out_proj` now
  unconditionally return the pure-PyTorch modules (`_PyTorchLayerNormMLP`,
  `_PyTorchLayerNormLinear`, `nn.Linear`).
- Remove dead `_SwiGLU` class (the MLP fallback inlines silu(x1)*x2).

If exact TE numerics (fp32-reduction LayerNorm) are ever required, that
belongs in a Hub kernel via the `kernels` library, not a hard dependency.

Verified: strict state_dict load from the pre-change baseline succeeds
(parameter names unchanged -> published checkpoints still load), and
last_hidden_state is bit-identical (0.0) at all valid positions for the
plain, padding-mask, and multi-chain cases. Locally TE was never installed,
so this is the exact path that already ran.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…mb convention

Replace the flash-attn-style cache-based rotary with the standard
Transformers convention used by esm/llama, as the last code-shaping step
before authoring modular_esmc.py.

Before: a stateful `RotaryEmbedding` (per-attention-module) caching cos/sin,
`forward(q, k)` returning rotated tensors, a custom `_apply` override to keep
`inv_freq` fp32 across device casts, plus `_rotate_half` / `_apply_rotary_emb_torch`.

After:
- `rotate_half` + `apply_rotary_pos_emb` (identical to esm/llama).
- `ESMCRotaryEmbedding(config)` -> `(cos, sin)`, computed once in
  `ESMCModel.forward` and threaded down (position_embeddings) through the
  stack/block to attention, mirroring esm. `inv_freq` is fp32 and
  non-persistent (matches the old behaviour: no rotary tensors in the
  checkpoint), and cos/sin are built in fp32 then cast.
- Add `config.rope_theta` (default 10000.0, the previously-hardcoded base).
- `_init_weights` recomputes `inv_freq` for `ESMCRotaryEmbedding` (meta-init safe).

Verified: strict state_dict load from the saved baseline succeeds (rotary
buffers are non-persistent, so keys are unchanged -> published checkpoints
still load), and last_hidden_state is bit-identical (0.0) at all valid
positions for plain, padding-mask, and multi-chain. The fp32 matmul-based
freqs equal the old `outer(t, inv_freq)`; same RoPE math, idiomatic shape.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMC now follows the modular convention. modular_esmc.py is the source of
truth; modeling_esmc.py is generated by utils/modular_model_converter.py and
carries the auto-generated header.

Reuse from esm (the natural parent — also a bidirectional protein encoder):
`eager_attention_forward`, `rotate_half`, and `apply_rotary_pos_emb` are now
imported from ..esm.modeling_esm and inlined into the generated file with
`# Copied from` headers (so they stay in sync). `rotate_half` is pulled in
transitively as a dependency of `apply_rotary_pos_emb`, matching the qwen3
pattern.

Everything else stays ESMC-specific and is defined in the modular file: the
SAE-integrated ESMCModel + ForMaskedLM/SequenceClassification/
TokenClassification, the fused-LN MultiHeadAttention, SwiGLU FFN,
TransformerStack, ESMCRotaryEmbedding, and the SAE-carrying output
dataclasses. As expected for this architecture the dedup is modest; the win
is convention compliance + auto-sync of the shared functions.

The modular file was ruff-fixed/formatted (Optional[X] -> X | None, import
order) before regeneration, so both files are now ruff-clean.

Verified: `check_modular_conversion.py` passes (files in sync); `transformers`
imports; and loading identical weights reproduces the pre-conversion
last_hidden_state bit-for-bit (0.0) at all valid positions for plain,
padding-mask, and multi-chain inputs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…rk on ESMC-6B)

Two bugs surfaced only when loading the real biohub/ESMC-6B checkpoint and
running an unpadded sequence (prior parity tests all used padded inputs):

1. _init_weights re-initialized loaded nn.Linear weights. It used
   `module.weight.data.normal_()`, which writes through `.data` and bypasses
   the `_is_hf_initialized` flag that transformers sets on loaded params. So
   `from_pretrained` clobbered out_proj / lm_head with random init (silently —
   no missing-keys warning), while custom/LayerNorm/Embedding modules survived
   only because _init_weights had no branch for them. Switch to the
   flag-respecting `transformers.initialization` helpers (init.normal_,
   init.zeros_, init.copy_), matching esm/base. (This bug is latent in the
   v4 fork too; it only bites under the v5 init-after-load flow.)

2. Attention defaulted to causal on unpadded inputs. ESMC is a bidirectional
   encoder, but MultiHeadAttention never set `is_causal`, so the sdpa/flash
   interface fell back to `getattr(module, "is_causal", True)` and applied
   causal masking whenever `attention_mask is None`. Set `self.is_causal =
   False`. (Introduced by the attention-interface refactor; missed because
   every earlier parity test passed a padded 4D mask.)

Source of truth is modular_esmc.py; modeling_esmc.py regenerated.

Verified: loading biohub/ESMC-6B (and ESMC-300M) into the refactored model
and into the original fork code yields BIT-IDENTICAL logits (max|Δ| = 0.0)
and identical argmax predictions on an 80-residue sequence. State-dict loads
with no remapping (240 TE `_extra_state` keys ignored as before).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The ESMCTokenizer behaves correctly under v5 (verified against the published
biohub/ESMC-6B tokenizer.json), but its docstring doctest example was wrong:
it listed 21 ids for a 20-residue sequence with two residues mis-ordered/
dropped. Correct it to the actual output (22 ids, <cls> ... <eos>).

Add tests/models/esmc/test_tokenization_esmc.py (fast-only tokenizer, so it
mirrors models/esm's plain-TestCase style rather than the slow vocab-file
setup): documented example, character-level tokenize, <cls>/<eos> wrapping,
special-token ids (incl. bos==cls alias), batch padding + attention_mask,
chain-break token, mask token, unknown-residue -> <unk>, decode round-trip,
save/load round-trip, and a @slow integration test asserting AutoTokenizer
resolves to ESMCTokenizer and the hub tokenizer matches the code-built one.

Verified: 10 passed, 1 slow-skipped (the slow test passes with RUN_SLOW=1).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMCTokenizer is fast-only (built from a tokenizer object); it never reads or
writes a slow `vocab.txt`. Declare only `tokenizer_file: tokenizer.json`.
save_pretrained already emits just tokenizer.json + tokenizer_config.json;
tokenizer tests still pass (10 passed, 1 slow-skipped).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add tests/models/esmc/test_modeling_esmc.py (ModelTesterMixin +
PipelineTesterMixin), mirroring models/esm: ESMCModelTester +
create_and_check for ESMCModel / ForMaskedLM / ForSequenceClassification /
ForTokenClassification, plus a @slow masked-LM integration test on
biohub/ESMC-300M. ESMCModel has no pooler, so the base-model check asserts
only last_hidden_state.

Fix surfaced by test_can_init_all_missing_weights: _init_weights only handled
nn.Linear, so the fused-LN modules, nn.LayerNorm and nn.Embedding were left
uninitialized on a from-scratch / meta-device init. Handle the custom
_PyTorchLayerNormLinear / _PyTorchLayerNormMLP explicitly (their `weight` is a
Linear weight, not a norm — the base initializer matches norms by class-name
substring and would wrongly set it to ones), the rotary buffer explicitly,
and delegate nn.Linear / nn.Embedding / nn.LayerNorm to super()._init_weights.

Skip test_retain_grad_hidden_states_attentions: ESMC returns `hidden_states`
as one stacked tensor (consumed by the SAE feature), not the live per-layer
tensors, so grad cannot flow back to the returned copy — intentional.

Verified: full model test file = 92 passed, 102 skipped, 0 failed; the @slow
integration test passes; and from_pretrained still loads biohub/ESMC weights
bit-exactly (the flag-respecting init helpers don't clobber loaded params).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ion heads

Add docs/source/en/model_doc/esmc.md (short: overview, links to the
biohub/ESMC-{300M,600M,6B} checkpoints and the evolutionaryscale/esm repo, a
usage snippet, and [[autodoc]] for ESMCConfig, ESMCSAEConfig, ESMCTokenizer,
ESMCModel, ESMCFor{MaskedLM,SequenceClassification,TokenClassification},
ESMCSAEModel) and register it in _toctree.yml (alphabetically after ESM).

ESMCForSequenceClassification/ForTokenClassification have no Examples block in
their forward docstrings, so auto_docstring could not find a checkpoint and
errored on import. Add the standard checkpoint sentence with a Hub link
([Biohub/ESMC-600M-2024-12](...)) to the ESMCConfig docstring; auto_docstring
falls back to the config's checkpoint, resolving it for all model classes at
once (no per-class decorator needed).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The sparse-autoencoder is architecturally distinct from ESMC (a linear
encoder + top-k sparsity + linear decoder operating on frozen hidden states,
not a transformer) and has a non-standard per-layer from_pretrained/
save_pretrained path. Keep it out of the core ESMC PR; it will land as its
own follow-up (recoverable from this commit's parent).

Removed:
- modeling_esmc_sae.py and configuration_esmc_sae.py.
- esmc_sae auto-registration (CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES, and
  the esmc_sae -> esmc SPECIAL_MODEL_TYPE_TO_MODULE_NAME mapping).
- The SAE integration woven through ESMCModel and the heads: add_sae_models,
  _get_sae_outputs / _get_sae_layer_num_requested / _validate_sae_inputs /
  _SAE_KEY_RE / _sae_models, the compute_sae/normalize_sae params, and the
  sae_outputs field on all four output dataclasses. The forward simplifies
  accordingly (layers_to_collect now only serves output_hidden_states; the
  SAE-only bool_mask is gone).
- ESMCSAEConfig/ESMCSAEModel autodoc from the model doc.

Source of truth is modular_esmc.py; modeling_esmc.py regenerated.

Verified: ESMCSAE* no longer importable and esmc_sae deregistered; ruff clean;
model + tokenizer tests pass (102 passed, 103 skipped); and fork-vs-new logits
on biohub/ESMC-300M remain bit-identical (max|Δ| = 0.0) — the removal is
output-neutral for the core model.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
`make fix-repo`/check_docstrings normalizations on the ESMC public objects:
- ESMCTokenizer: default values use single backticks (`"<unk>"`) per the
  canonical "*optional*, defaults to ..." format.
- ESMCModel / ESMCForMaskedLM / ESMCForSequenceClassification: reorder the
  forward docstring argument entries to match the signature order
  (output_attentions before labels) and add the missing blank line before the
  Examples block.

Docstring-only (no logic change); modular and generated modeling stay in sync.
ESMC now passes check_repo, check_copies, check_inits, check_dummies,
check_modular_conversion, and check_docstrings (remaining repo-wide failures
are pre-existing Qwen output-doc errors and Phase-B esmfold2, neither ESMC).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Add `__all__ = ["ESMFold2Model"]` to modeling_esmfold2.py so the model is
  exported (`from transformers import ESMFold2Model` now works; previously the
  class was defined but the module had no `__all__`).
- Replace the hard cross-model import `from ..esmc.modeling_esmc import
  ESMCModel` in `load_esmc` (both modeling_esmfold2.py and the experimental
  file) with `AutoModel.from_pretrained(...)`. ESMC is a shared, frozen 6B
  backbone loaded separately from its own repo (`config.esmc_id`); resolving it
  through the Auto registry (model_type "esmc" -> ESMCModel) keeps esmc and
  esmfold2 as separate model directories without a runtime cross-dir import.

Verified: ESMFold2Model + ESMFold2Config export; core + experimental modules
import; no `..esmc` imports remain in the esmfold2 dir; AutoModel.from_config
resolves ESMCConfig -> ESMCModel. (Also removed the empty kernels/ and
distributed/ dirs left behind by their earlier git rm.)

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…-only)

Scope the core ESMFold2 PR to the release model. The experimental
(legacy/dev binder-design) variant is deferred like the ESMC SAE.

- Delete modeling_esmfold2_experimental.py and remove it from the package
  __init__ and from the ESMFold2Model.from_pretrained `config.type ==
  "experimental"` routing (from_pretrained now always builds ESMFold2Model).
- ESMFold2Config: keep the `type` field for checkpoint compat but accept only
  "release"; fix the docstring example to use ESMFold2Model. Update the two
  stale `ESMFold2ExperimentalModel` references in comments/docstrings.
- Remove the `esmfold2_v2 -> esmfold2` SPECIAL_MODEL_TYPE_TO_MODULE_NAME entry
  (no v2 variant ships).

Verified: ESMFold2Model + ESMFold2Config import; ESMFold2ExperimentalModel
gone; no experimental/v2 references remain.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMFold2's only TransformerEngine dependency was the optional fp8
quantization of the (now TE-free) ESMC backbone. Drop the import guard,
the dead _convert_te_modules_to_fp8_inplace walker, the "fp8" precision
option, and the fp8 padding/autocast plumbing. _lm_precision_context is
now a plain bf16 autocast; ESMC loads at bf16 (default) or fp32.

No other Transformers model depends on transformer_engine.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
SWA3DRoPEAttention's plain softmax(QKᵀ)V core now dispatches through
ALL_ATTENTION_FUNCTIONS / a local eager_attention_forward, keyed on
config._attn_implementation, with the sliding window expressed as an
additive attention mask. The custom flash-attention path (native
bidirectional window_size + varlen for packed inputs) is kept as an
opt-in backend, now gated on _attn_implementation == "flash_attention_2"
instead of auto-selecting whenever flash-attn is importable — so the
default is sdpa (matching the fork's SDPA fallback bit-for-bit) and
flash is opt-in, per v5 conventions.

ESMFold2Model declares _supports_sdpa / _supports_flash_attn /
_supports_attention_backend and, after construction, attaches its shared
config to every SWA3DRoPEAttention (the atom encoders/decoders build them
from explicit dims), so dispatch stays live under set_attn_implementation.
is_causal=False guards against the interface defaulting to causal when no
mask is passed. Pair-bias (AttentionPairBias) and triangular math are left
untouched.

Validated on CPU vs the pre-refactor forward (random weights): sdpa
max|Δ|=0.0 (bit-exact), eager max|Δ|=1.3e-3 (bf16 softmax precision).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The config was one of ~4 holdouts still using the old `def __init__(**kwargs)`
+ hand-rolled `to_dict` style; 460/463 configs use `@strict` dataclass
`PreTrainedConfig`. Rewrite it to match, mirroring `models/esm` (ESMFold v1):

- Every sub-config (MSAEncoder/Parcae/LMEncoder/AtomAttention/FoldingTrunk/
  InputsEmbedder/DiffusionModule/DiffusionStructureHead/ConfidenceHead) is now a
  `@strict PreTrainedConfig` with typed fields + defaults, instead of a plain
  `@dataclass`.
- Nesting is declared via `sub_configs = {...}` on each parent + a `__post_init__`
  that turns `dict -> SubConfig(**dict)`. This deletes the brittle hand-rolled
  `ESMFold2Config.to_dict()` (base class now serializes via `sub_configs`) and
  gives recursive `_attn_implementation` propagation to sub-configs for free.
- Top config switches to `@auto_docstring(checkpoint="biohub/ESMFold2")`, which
  resolves the outstanding `check_docstrings` failure for ESMFold2Config.
- `__all__` is reduced to `["ESMFold2Config"]` (sub-configs are implementation
  detail, matching ESM which exports only `EsmConfig`).

check_config_attributes: top config keeps a precise 4-item allow-list (type +
three training/experimental recipe knobs not read by the core inference path);
the 9 sub-configs are allow-listed wholesale because their fields are threaded
into submodules as explicit dims (e.g. `d_atom=cfg.inputs.atom_encoder.d_atom`),
which the checker's `config.<attr>` heuristic cannot trace.

Verified: default + nested-dict construction, save/load round-trip (to_dict
identical, sub-config types preserved), recursive attn-impl propagation, type
validation, unknown-kwarg tolerance; the tiny ESMFold2Model still builds and the
SWA attention equivalence is unchanged. ruff + check_docstrings +
check_config_attributes all clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMFold2 is an all-atom structure predictor whose forward takes ~18 structural
feature tensors and returns a plain dict (not a ModelOutput), so it doesn't fit
ModelTesterMixin. Following the sanctioned pattern for such models, register the
test file in check_repo's TEST_FILES_WITH_NO_COMMON_TESTS and provide focused
coverage:

- ESMFold2ConfigTest: ConfigTester common tests (incl. composite sub-config
  save/load), `type` validation, nested round-trip, attn-impl propagation.
- ESMFold2ModelTest (CPU): full pure-PyTorch forward via infer_protein with no
  ESMC backbone (LM conditioning skipped) under both sdpa and eager; SWA config
  dispatch; weight-level save/load fidelity + a usable reloaded model.
- ESMFold2IntegrationTest: @slow real-weight fold on biohub/ESMFold2 (GPU-gated).

To make ConfigTester's composite test pass, each sub-config now declares a
unique `model_type` (e.g. "esmfold2_inputs_embedder") — the CLIP pattern — so
that `SubConfig.from_pretrained(<composite dir>)` extracts the matching nested
dict (configuration_utils keys this off model_type). ESM dodges this test
because its sub-config is None by default; ESMFold2's are always present.

check_repo: `modeling_esmfold2_common` (shared building blocks, no public model)
is added to get_model_modules' _ignore_modules. The remaining check_repo item
for esmfold2 is the model-doc page, which is a separate pending task.

The tiny test config encodes two real sizing constraints discovered via the
forward smoke: 3D RoPE needs 3*n_spatial + n_uid <= head_dim//2, and
inputs.d_inputs == 67 + d_token//2 == diffusion_module.c_s_inputs.

7 non-slow tests pass; ruff + check_config_attributes + check_docstrings clean.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Short overview page (mirrors esmc.md): describes the all-atom structure
predictor and its separately-loaded ESMC backbone, a usage snippet
(infer_protein_as_pdb), and autodoc for ESMFold2Config + ESMFold2Model. Registered
in _toctree.yml after ESMC. Resolves the final check_repo "objects documented"
item for esmfold2.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@Rocketknight1

Rocketknight1 commented Jul 1, 2026

Copy link
Copy Markdown
Member Author

@vasqu writing this in one big comment rather than breaking it up to every individual location, because there was a lot of repeated stuff! I just did a big pass for ESMFold2, now that ESMC is mostly okay.

I dropped the flashattention path for now, inlined a bunch of methods, broke RoPE out into a single class, tried to merge as many SwiGLUs as I could (not all were possible, some have little embedded norms). A lot of stuff got renamed and moved to configs instead of just being constants, and I tried to pull out private methods from the big long forward methods.

There's a few tricky bits still:

  1. DiffusionModule still has massive signatures with 20 args. Want me to bundle that into one big object we just pass around and access attributes on?
  2. Renamed a lot of things and broke up a lot of nn.Sequential() and single-letter vars. This required lots of weight renaming and attribute_map. We'll need to regenerate the checkpoint anyway (the existing checkpoints don't contain the ESMC backbone weights), so we could rename everything in the checkpoints and then delete the rename lines before merging. Marked those with TODO.
  3. Some bits I left untouched. In particular, some single-letter names were kept because they're kind of standard, like splitting a quaternion into r, i, j, k, and some q and v keys. I can rename those if you want! Other stuff like plddt comes from AlphaFold and people doing protein folding know what it means even if it's quite weird for most ML engineers!
  4. Some other bits I left include an F.rms_norm() call with no affine scaling and so no weights. You suggested nanochat there but I'm not sure why we need a module at all! It works fine as one functional line.

Overall I tried to make it more readable, but it's a very complex model, so there's only so much I can do! 😅

@vasqu

vasqu commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

@Rocketknight1 Thanks a lot! Yes the model is super complex so I don't expect perfect stuff. I think the esmc model should be fairly good now. I will try to take a look at esmfold2 in the coming days.

Re

  1. Yea that might be more reasonable, having too long signatures is really hard to read imo - similar to how we pass configs around and only access the portions we need. Just a suggestion and I'm not 100% sure it's a good one, would not put it high prio
  2. Gotcha, sounds good. Yea we have a few models that have a long list of renames as well dw.
  3. Yea, hard to catch the abbreviations sometimes. Sometimes it's me quickly glancing, sometimes my missing background for biology etc. It's not too bad if there really is meaning; I just would avoid too many where possible
  4. This is mainly for inheritance with modular + potentially applying kernels on them. Won't be possible if we rely on an F. portion of torch

@Rocketknight1

Copy link
Copy Markdown
Member Author

On 4 I'm unsure - we're not doing modular for ESMFold2 because the code is so weird compared to other models. I'd leave that one for a future refactor, if there are kernels that turn out to really justify it, or if we get a lot of similar folding models in the codebase.

1 is done, though, and I managed to drop some dead args in the process!

@github-actions

github-actions Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, esmc, esmfold2

@github-actions

github-actions Bot commented Jul 3, 2026

Copy link
Copy Markdown
Contributor

CI recap

Dashboard: View test results in Grafana
Latest run: 28664416136:1
Result: success | Jobs: 13 | Tests: 170,342 | Failures: 0 | Duration: 24h 40m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants