Port ESMC and ESMFold2 to Transformers#46419
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
e1ed451 to
1af5750
Compare
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=46419&sha=464744 |
0a716e4 to
976468d
Compare
vasqu
left a comment
There was a problem hiding this comment.
Deep dive, I think esmc is alright. Unsure about esmfold2 it seems very much all over the place where I cannot really read the flow too easily and a lot of stuff is going at once
| [`biohub/ESMC-300M`](https://huggingface.co/biohub/ESMC-300M), | ||
| [`biohub/ESMC-600M`](https://huggingface.co/biohub/ESMC-600M) and | ||
| [`biohub/ESMC-6B`](https://huggingface.co/biohub/ESMC-6B). |
There was a problem hiding this comment.
| [`biohub/ESMC-300M`](https://huggingface.co/biohub/ESMC-300M), | |
| [`biohub/ESMC-600M`](https://huggingface.co/biohub/ESMC-600M) and | |
| [`biohub/ESMC-6B`](https://huggingface.co/biohub/ESMC-6B). | |
| - [`biohub/ESMC-300M`](https://huggingface.co/biohub/ESMC-300M) | |
| - [`biohub/ESMC-600M`](https://huggingface.co/biohub/ESMC-600M) | |
| - [`biohub/ESMC-6B`](https://huggingface.co/biohub/ESMC-6B) |
nit
|
|
||
| ```python | ||
| import torch | ||
| from transformers import AutoTokenizer, ESMCModel |
There was a problem hiding this comment.
I guess we don't have any auto model for this?
There was a problem hiding this comment.
Replaced with AutoModel! (ESMC is fairly normal, it's just a masked LM)
| "ESMFold2TriangleMultiplication": { | ||
| "cuda": { | ||
| Mode.INFERENCE: LayerRepository( | ||
| repo_id="Rocketknight1/esmfold2-trimul-kernel", |
There was a problem hiding this comment.
Would it be possible to port to kernels community? Lemme nudge internally if you need help check the kernels channel #kernels
There was a problem hiding this comment.
Yeah, this was a placeholder location while I was working on the PR! We should definitely move this before merging
| from ...utils import _LazyModule # type: ignore[import] | ||
| from ...utils.import_utils import define_import_structure # type: ignore[import] | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from .configuration_esmc import * # noqa: F403 | ||
| from .modeling_esmc import * # noqa: F403 | ||
| from .tokenization_esmc import * # noqa: F403 |
There was a problem hiding this comment.
claued going crazy? We shouldnt need all these linter ignore stuff
|
|
||
| from huggingface_hub.dataclasses import strict | ||
|
|
||
| from ...configuration_utils import PreTrainedConfig # type: ignore[import] |
There was a problem hiding this comment.
ok last mention since it might happen more often
|
|
||
|
|
||
| _CONFIDENCE_EPS = 1e-6 | ||
| _NONPOLYMER_ID = 4 |
| denom = pair_m.sum(dim=(-1, -2)) + _CONFIDENCE_EPS | ||
| pair_chains_iptm[:, c1, c2] = (tm_expected * pair_m).sum(dim=(-1, -2)) / denom | ||
|
|
||
| return { |
| # precision drives the diffusion conditioning; keep them fp32 even under dtype=bf16. | ||
| _keep_in_fp32_modules_strict = ["fourier"] | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = True |
There was a problem hiding this comment.
I would avoid focusing on this for now; it really doesnt seem to compatible
| init.zeros_(module.base_z_combine) | ||
|
|
||
|
|
||
| NUM_RES_TYPES = 33 |
| token_bonds_encoding=token_bonds_encoding.detach(), | ||
| ) | ||
|
|
||
| return ESMFold2Output( |
There was a problem hiding this comment.
same way too much going on at once in one forward, it's hard
vasqu
left a comment
There was a problem hiding this comment.
Part2 for esmc, I think that portion looks solid now 🫡 only minor details
|
|
||
| tokenizer = AutoTokenizer.from_pretrained("biohub/ESMC-300M") | ||
| # ESMC is registered with the auto classes (AutoModel, AutoModelForMaskedLM, | ||
| # AutoModelForSequenceClassification, AutoModelForTokenClassification). |
There was a problem hiding this comment.
Imo if you want to you can use hf options and show each variant instead, wdyt?
| """ | ||
|
|
||
| model_type = "esmc" | ||
| default_theta = 10000.0 |
There was a problem hiding this comment.
I think this is alreadythe default could be removed I think
| default_theta = 10000.0 |
| attribute_map = { | ||
| "d_model": "hidden_size", | ||
| "n_heads": "num_attention_heads", | ||
| "n_layers": "num_hidden_layers", | ||
| } |
There was a problem hiding this comment.
Yea dw, it would have been nice to avoid but it's also fine this way imo
| attention_dropout: float | None = 0.0 | ||
| qk_layernorm: bool | None = True | ||
| scale_residue: bool | None = True | ||
| tie_word_embeddings: bool | None = False |
There was a problem hiding this comment.
the attributes are fairly standard, would it make sense to move this to modular and inherit from e.g. llama config? You can always override / add (attr = ...) or delete unnecessary attributes (attr = AttributeError())
There was a problem hiding this comment.
Yeah, inherited from LlamaConfig in modular and overrode to restore the ESMC default values
| self._tokenizer = Tokenizer(BPE(token_to_id, merges=[], unk_token=unk_token)) | ||
| self._tokenizer.add_special_tokens([cls_token, pad_token, mask_token, eos_token, chain_break_token]) | ||
|
|
||
| # Automatically wrap every encoded sequence with <cls> … <eos>. |
There was a problem hiding this comment.
a bit weird, iirc wasnt bert using cls / sep, this is kind of a mix here. but if that's already their convention it's nothing to change, just curious
There was a problem hiding this comment.
Yeah, this is the convention in the model so I left it alone!
| for layer in self.layers: | ||
| hidden_states = layer( | ||
| hidden_states, | ||
| attn_bias, |
There was a problem hiding this comment.
would still call it attention mask tbh
There was a problem hiding this comment.
especially since sdpa will use a bool mask 😬
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class ESMCMaskedLMHead(nn.Module): |
There was a problem hiding this comment.
Then let's restructure a bit: It is a CLIP mlp with an additional layer norm in between, can we make it follow the naming a bit closer to that then?
|
|
||
|
|
||
| @auto_docstring | ||
| class ESMCForMaskedLM(ESMCPreTrainedModel): |
There was a problem hiding this comment.
should be unrelated no? If you do modular here it only cares about the namings but not the underlying implementation (unless you didn't override then it will try to create its version based on the original parent's)
| with torch.no_grad(): | ||
| logits = model(**inputs).logits | ||
| self.assertEqual(logits.shape, (1, inputs["input_ids"].shape[1], model.config.vocab_size)) | ||
| self.assertTrue(torch.isfinite(logits).all()) |
There was a problem hiding this comment.
can you check nomic bert for integration tests --> would be nice on logits
There was a problem hiding this comment.
Integration tests in!
|
|
||
|
|
||
| @require_tokenizers | ||
| class ESMCTokenizationTest(unittest.TestCase): |
…dapted) Moves the purely-additive model code from the Biohub fork (github.com/Biohub/transformers @ f9a5a37, based on v4.57.6) onto a branch off current main (v5.10.0.dev0). This is the verbatim fork code as a starting point; v5 convention adaptation (attention interface, modular, __all__, nested-config round-trip) is follow-up work per the port plan. Contents: - src/transformers/models/esmc/ (6 files: config, sae config, modeling, sae modeling, tokenizer) — imports and exports cleanly under v5. - src/transformers/models/esmfold2/ (24 files incl. deferred kernels/, distributed/, experimental) — config + modeling import; ESMFold2Model is defined but not yet exported (no __all__ — known adaptation item). Auto-registration adapted to the v5 layout (NOT a verbatim copy of the fork's 4 hook diffs): - models/__init__.py: from .esmc/.esmfold2 import * - auto/auto_mappings.py: CONFIG_MAPPING_NAMES + SPECIAL_MODEL_TYPE_TO_MODULE_NAME (these moved out of configuration_auto.py in v5; MODEL_NAMES_MAPPING was dropped in v5 so the fork's hunk for it has no equivalent). - auto/modeling_auto.py: base + masked-LM + seq-cls + token-cls maps. - auto/tokenization_auto.py: flat ("esmc", "ESMCTokenizer") form (v5 dropped the (slow, fast) tuple format). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…old2 Transformers loads fused kernels from the Hub via the `kernels` library (@use_kernel_forward_from_hub / register_kernel_mapping), not by vendoring Triton in a model dir. Remove the fork's bespoke acceleration stack and leave a single pure-PyTorch path; fused ops can return later via a Hub kernels repo as an opt-in follow-up. Removed: - src/transformers/models/esmfold2/kernels/ (8 Triton files, ~3.3k LOC). - The shared `set_kernel_backend` selector across the module tree, which drove BOTH backends: "fused" (vendored Triton) and "cuequivariance" (external lib). Both gone — the cueq import block and BACKEND_*/ _VALID_BACKENDS/_fused_active/_cueq_active helpers with them. - Per-module fused/cueq branches and now-dead helpers (_can_use_*, _fused_trimul_with_residual, split_kernel_weights, _kernel_flow_direction, Transition._swiglu_pre_w3/_addmm_residual, DropoutResidual fused impl). - The vestigial no-op set_kernel_backend hooks in distributed/utils.py and modeling_esmfold2_experimental.py. Kept intact: the independent `set_chunk_size` memory knob, and the optional flash-attn / transformer_engine guards. Verified: both modeling modules import; CPU smoke test runs AttentionPairBias, TriangleMultiplicativeUpdate (both orientations), Transition (chunked == unchunked), DropoutResidual, and FoldingTrunk end-to-end on the pure-PyTorch path. No new ruff errors introduced (9 pre-existing fork lint items remain for the later `make style` pass). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The distributed/ package is a NVIDIA/MIT-licensed 2D context-parallel implementation of the folding trunk (DTensor + DeviceMesh + NCCL) for multi-GPU 6B inference. It is dropped from the port because: - It is not imported by any model code (core, config, __init__, or the experimental file) — fully inert in the package. - It is broken on import: all 7 files import from `projects.huggingface.transformers.models.esmfold2...` (the fork's internal monorepo path), so `import transformers.models.esmfold2.distributed` raises ModuleNotFoundError. It never worked in the standalone layout. - It is NVIDIA/MIT-licensed, unlike the Apache/Biohub model code. - Transformers expresses parallelism declaratively via `base_model_tp_plan` / `tp_plan="auto"`, not a vendored per-model DTensor/NCCL stack. Nothing unique is lost: the math it shards already exists as the pure-PyTorch reference in modeling_esmfold2_common.py. If multi-GPU inference is needed later, author a tp_plan on ESMFold2Model fresh. Verified: nothing references distributed/; `import transformers` and the esmfold2 modeling module still import cleanly. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Replace ESMC's bespoke attention dispatch with the standard v5 interface, mirroring models/esm (a bidirectional encoder). Behaviour is preserved bit-exactly at all real-token positions. Before: a hand-rolled `_scaled_dot_product_attention` choosing xformers -> flash-attn-2 -> SDPA, a `_FlashMultiHeadAttention` subclass with manual unpad_input/pad_input in ESMCModel.forward, a `_TritonRotaryEmbedding`, and a `seq_id`-threaded chain mask. After: - Module-level `eager_attention_forward`; `MultiHeadAttention` dispatches via `ALL_ATTENTION_FUNCTIONS.get_interface(config._attn_implementation, ...)` with q/k/v shaped (B, H, S, Dh) and `scaling=head_dim**-0.5` (RoPE is rotation-invariant to scaling, so it stays in the module). output_attentions forces the eager interface so probabilities remain observable. - ESMCModel.forward builds the 4D mask once: `create_bidirectional_mask` for the padding case (works for eager/sdpa/flash), and a block-diagonal additive bias for multi-chain `sequence_id` (eager/sdpa; flash multi-chain still raises). Removes the unpad/pad varlen path. - Drops the xformers / flash-attn / triton-rotary import machinery and their warnings. transformer_engine LN/MLP fusion is unchanged. The `_supports_sdpa/_supports_flash_attn/_supports_attention_backend` flags (already declared) are now actually honoured. flash-attn-2 is still supported — now via the standard attn_implementation backend rather than bespoke dispatch. Verified: loading identical weights into the refactored model reproduces the pre-refactor outputs to 0.000e+00 at every non-padding position (plain, padding-mask, and multi-chain cases); the only differences are at padding positions, which are masked out downstream. eager and sdpa agree bit-exactly on valid-token logits; MaskedLM/SequenceClassification/TokenClassification and output_attentions all work. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
No other Transformers model depends on transformer_engine; drop it from ESMC. TE provided fused fp32-reduction LayerNorm+Linear / LayerNorm+MLP kernels, but the model already shipped pure-PyTorch fallbacks whose parameter names (`layer_norm_weight`, `fc1_weight`, `fc2_weight`, ...) match the published-checkpoint layout. Make those the only path. - Drop the `transformer_engine` import + `_te_available` guard + the "TE not installed" warning. - `_swiglu_ln_ffn` / `_make_attn_layernorm_qkv` / `_make_attn_out_proj` now unconditionally return the pure-PyTorch modules (`_PyTorchLayerNormMLP`, `_PyTorchLayerNormLinear`, `nn.Linear`). - Remove dead `_SwiGLU` class (the MLP fallback inlines silu(x1)*x2). If exact TE numerics (fp32-reduction LayerNorm) are ever required, that belongs in a Hub kernel via the `kernels` library, not a hard dependency. Verified: strict state_dict load from the pre-change baseline succeeds (parameter names unchanged -> published checkpoints still load), and last_hidden_state is bit-identical (0.0) at all valid positions for the plain, padding-mask, and multi-chain cases. Locally TE was never installed, so this is the exact path that already ran. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…mb convention Replace the flash-attn-style cache-based rotary with the standard Transformers convention used by esm/llama, as the last code-shaping step before authoring modular_esmc.py. Before: a stateful `RotaryEmbedding` (per-attention-module) caching cos/sin, `forward(q, k)` returning rotated tensors, a custom `_apply` override to keep `inv_freq` fp32 across device casts, plus `_rotate_half` / `_apply_rotary_emb_torch`. After: - `rotate_half` + `apply_rotary_pos_emb` (identical to esm/llama). - `ESMCRotaryEmbedding(config)` -> `(cos, sin)`, computed once in `ESMCModel.forward` and threaded down (position_embeddings) through the stack/block to attention, mirroring esm. `inv_freq` is fp32 and non-persistent (matches the old behaviour: no rotary tensors in the checkpoint), and cos/sin are built in fp32 then cast. - Add `config.rope_theta` (default 10000.0, the previously-hardcoded base). - `_init_weights` recomputes `inv_freq` for `ESMCRotaryEmbedding` (meta-init safe). Verified: strict state_dict load from the saved baseline succeeds (rotary buffers are non-persistent, so keys are unchanged -> published checkpoints still load), and last_hidden_state is bit-identical (0.0) at all valid positions for plain, padding-mask, and multi-chain. The fp32 matmul-based freqs equal the old `outer(t, inv_freq)`; same RoPE math, idiomatic shape. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMC now follows the modular convention. modular_esmc.py is the source of truth; modeling_esmc.py is generated by utils/modular_model_converter.py and carries the auto-generated header. Reuse from esm (the natural parent — also a bidirectional protein encoder): `eager_attention_forward`, `rotate_half`, and `apply_rotary_pos_emb` are now imported from ..esm.modeling_esm and inlined into the generated file with `# Copied from` headers (so they stay in sync). `rotate_half` is pulled in transitively as a dependency of `apply_rotary_pos_emb`, matching the qwen3 pattern. Everything else stays ESMC-specific and is defined in the modular file: the SAE-integrated ESMCModel + ForMaskedLM/SequenceClassification/ TokenClassification, the fused-LN MultiHeadAttention, SwiGLU FFN, TransformerStack, ESMCRotaryEmbedding, and the SAE-carrying output dataclasses. As expected for this architecture the dedup is modest; the win is convention compliance + auto-sync of the shared functions. The modular file was ruff-fixed/formatted (Optional[X] -> X | None, import order) before regeneration, so both files are now ruff-clean. Verified: `check_modular_conversion.py` passes (files in sync); `transformers` imports; and loading identical weights reproduces the pre-conversion last_hidden_state bit-for-bit (0.0) at all valid positions for plain, padding-mask, and multi-chain inputs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…rk on ESMC-6B) Two bugs surfaced only when loading the real biohub/ESMC-6B checkpoint and running an unpadded sequence (prior parity tests all used padded inputs): 1. _init_weights re-initialized loaded nn.Linear weights. It used `module.weight.data.normal_()`, which writes through `.data` and bypasses the `_is_hf_initialized` flag that transformers sets on loaded params. So `from_pretrained` clobbered out_proj / lm_head with random init (silently — no missing-keys warning), while custom/LayerNorm/Embedding modules survived only because _init_weights had no branch for them. Switch to the flag-respecting `transformers.initialization` helpers (init.normal_, init.zeros_, init.copy_), matching esm/base. (This bug is latent in the v4 fork too; it only bites under the v5 init-after-load flow.) 2. Attention defaulted to causal on unpadded inputs. ESMC is a bidirectional encoder, but MultiHeadAttention never set `is_causal`, so the sdpa/flash interface fell back to `getattr(module, "is_causal", True)` and applied causal masking whenever `attention_mask is None`. Set `self.is_causal = False`. (Introduced by the attention-interface refactor; missed because every earlier parity test passed a padded 4D mask.) Source of truth is modular_esmc.py; modeling_esmc.py regenerated. Verified: loading biohub/ESMC-6B (and ESMC-300M) into the refactored model and into the original fork code yields BIT-IDENTICAL logits (max|Δ| = 0.0) and identical argmax predictions on an 80-residue sequence. State-dict loads with no remapping (240 TE `_extra_state` keys ignored as before). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The ESMCTokenizer behaves correctly under v5 (verified against the published biohub/ESMC-6B tokenizer.json), but its docstring doctest example was wrong: it listed 21 ids for a 20-residue sequence with two residues mis-ordered/ dropped. Correct it to the actual output (22 ids, <cls> ... <eos>). Add tests/models/esmc/test_tokenization_esmc.py (fast-only tokenizer, so it mirrors models/esm's plain-TestCase style rather than the slow vocab-file setup): documented example, character-level tokenize, <cls>/<eos> wrapping, special-token ids (incl. bos==cls alias), batch padding + attention_mask, chain-break token, mask token, unknown-residue -> <unk>, decode round-trip, save/load round-trip, and a @slow integration test asserting AutoTokenizer resolves to ESMCTokenizer and the hub tokenizer matches the code-built one. Verified: 10 passed, 1 slow-skipped (the slow test passes with RUN_SLOW=1). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMCTokenizer is fast-only (built from a tokenizer object); it never reads or writes a slow `vocab.txt`. Declare only `tokenizer_file: tokenizer.json`. save_pretrained already emits just tokenizer.json + tokenizer_config.json; tokenizer tests still pass (10 passed, 1 slow-skipped). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Add tests/models/esmc/test_modeling_esmc.py (ModelTesterMixin + PipelineTesterMixin), mirroring models/esm: ESMCModelTester + create_and_check for ESMCModel / ForMaskedLM / ForSequenceClassification / ForTokenClassification, plus a @slow masked-LM integration test on biohub/ESMC-300M. ESMCModel has no pooler, so the base-model check asserts only last_hidden_state. Fix surfaced by test_can_init_all_missing_weights: _init_weights only handled nn.Linear, so the fused-LN modules, nn.LayerNorm and nn.Embedding were left uninitialized on a from-scratch / meta-device init. Handle the custom _PyTorchLayerNormLinear / _PyTorchLayerNormMLP explicitly (their `weight` is a Linear weight, not a norm — the base initializer matches norms by class-name substring and would wrongly set it to ones), the rotary buffer explicitly, and delegate nn.Linear / nn.Embedding / nn.LayerNorm to super()._init_weights. Skip test_retain_grad_hidden_states_attentions: ESMC returns `hidden_states` as one stacked tensor (consumed by the SAE feature), not the live per-layer tensors, so grad cannot flow back to the returned copy — intentional. Verified: full model test file = 92 passed, 102 skipped, 0 failed; the @slow integration test passes; and from_pretrained still loads biohub/ESMC weights bit-exactly (the flag-respecting init helpers don't clobber loaded params). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ion heads
Add docs/source/en/model_doc/esmc.md (short: overview, links to the
biohub/ESMC-{300M,600M,6B} checkpoints and the evolutionaryscale/esm repo, a
usage snippet, and [[autodoc]] for ESMCConfig, ESMCSAEConfig, ESMCTokenizer,
ESMCModel, ESMCFor{MaskedLM,SequenceClassification,TokenClassification},
ESMCSAEModel) and register it in _toctree.yml (alphabetically after ESM).
ESMCForSequenceClassification/ForTokenClassification have no Examples block in
their forward docstrings, so auto_docstring could not find a checkpoint and
errored on import. Add the standard checkpoint sentence with a Hub link
([Biohub/ESMC-600M-2024-12](...)) to the ESMCConfig docstring; auto_docstring
falls back to the config's checkpoint, resolving it for all model classes at
once (no per-class decorator needed).
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The sparse-autoencoder is architecturally distinct from ESMC (a linear encoder + top-k sparsity + linear decoder operating on frozen hidden states, not a transformer) and has a non-standard per-layer from_pretrained/ save_pretrained path. Keep it out of the core ESMC PR; it will land as its own follow-up (recoverable from this commit's parent). Removed: - modeling_esmc_sae.py and configuration_esmc_sae.py. - esmc_sae auto-registration (CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES, and the esmc_sae -> esmc SPECIAL_MODEL_TYPE_TO_MODULE_NAME mapping). - The SAE integration woven through ESMCModel and the heads: add_sae_models, _get_sae_outputs / _get_sae_layer_num_requested / _validate_sae_inputs / _SAE_KEY_RE / _sae_models, the compute_sae/normalize_sae params, and the sae_outputs field on all four output dataclasses. The forward simplifies accordingly (layers_to_collect now only serves output_hidden_states; the SAE-only bool_mask is gone). - ESMCSAEConfig/ESMCSAEModel autodoc from the model doc. Source of truth is modular_esmc.py; modeling_esmc.py regenerated. Verified: ESMCSAE* no longer importable and esmc_sae deregistered; ruff clean; model + tokenizer tests pass (102 passed, 103 skipped); and fork-vs-new logits on biohub/ESMC-300M remain bit-identical (max|Δ| = 0.0) — the removal is output-neutral for the core model. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
`make fix-repo`/check_docstrings normalizations on the ESMC public objects: - ESMCTokenizer: default values use single backticks (`"<unk>"`) per the canonical "*optional*, defaults to ..." format. - ESMCModel / ESMCForMaskedLM / ESMCForSequenceClassification: reorder the forward docstring argument entries to match the signature order (output_attentions before labels) and add the missing blank line before the Examples block. Docstring-only (no logic change); modular and generated modeling stay in sync. ESMC now passes check_repo, check_copies, check_inits, check_dummies, check_modular_conversion, and check_docstrings (remaining repo-wide failures are pre-existing Qwen output-doc errors and Phase-B esmfold2, neither ESMC). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- Add `__all__ = ["ESMFold2Model"]` to modeling_esmfold2.py so the model is exported (`from transformers import ESMFold2Model` now works; previously the class was defined but the module had no `__all__`). - Replace the hard cross-model import `from ..esmc.modeling_esmc import ESMCModel` in `load_esmc` (both modeling_esmfold2.py and the experimental file) with `AutoModel.from_pretrained(...)`. ESMC is a shared, frozen 6B backbone loaded separately from its own repo (`config.esmc_id`); resolving it through the Auto registry (model_type "esmc" -> ESMCModel) keeps esmc and esmfold2 as separate model directories without a runtime cross-dir import. Verified: ESMFold2Model + ESMFold2Config export; core + experimental modules import; no `..esmc` imports remain in the esmfold2 dir; AutoModel.from_config resolves ESMCConfig -> ESMCModel. (Also removed the empty kernels/ and distributed/ dirs left behind by their earlier git rm.) Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…-only) Scope the core ESMFold2 PR to the release model. The experimental (legacy/dev binder-design) variant is deferred like the ESMC SAE. - Delete modeling_esmfold2_experimental.py and remove it from the package __init__ and from the ESMFold2Model.from_pretrained `config.type == "experimental"` routing (from_pretrained now always builds ESMFold2Model). - ESMFold2Config: keep the `type` field for checkpoint compat but accept only "release"; fix the docstring example to use ESMFold2Model. Update the two stale `ESMFold2ExperimentalModel` references in comments/docstrings. - Remove the `esmfold2_v2 -> esmfold2` SPECIAL_MODEL_TYPE_TO_MODULE_NAME entry (no v2 variant ships). Verified: ESMFold2Model + ESMFold2Config import; ESMFold2ExperimentalModel gone; no experimental/v2 references remain. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMFold2's only TransformerEngine dependency was the optional fp8 quantization of the (now TE-free) ESMC backbone. Drop the import guard, the dead _convert_te_modules_to_fp8_inplace walker, the "fp8" precision option, and the fp8 padding/autocast plumbing. _lm_precision_context is now a plain bf16 autocast; ESMC loads at bf16 (default) or fp32. No other Transformers model depends on transformer_engine. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
SWA3DRoPEAttention's plain softmax(QKᵀ)V core now dispatches through ALL_ATTENTION_FUNCTIONS / a local eager_attention_forward, keyed on config._attn_implementation, with the sliding window expressed as an additive attention mask. The custom flash-attention path (native bidirectional window_size + varlen for packed inputs) is kept as an opt-in backend, now gated on _attn_implementation == "flash_attention_2" instead of auto-selecting whenever flash-attn is importable — so the default is sdpa (matching the fork's SDPA fallback bit-for-bit) and flash is opt-in, per v5 conventions. ESMFold2Model declares _supports_sdpa / _supports_flash_attn / _supports_attention_backend and, after construction, attaches its shared config to every SWA3DRoPEAttention (the atom encoders/decoders build them from explicit dims), so dispatch stays live under set_attn_implementation. is_causal=False guards against the interface defaulting to causal when no mask is passed. Pair-bias (AttentionPairBias) and triangular math are left untouched. Validated on CPU vs the pre-refactor forward (random weights): sdpa max|Δ|=0.0 (bit-exact), eager max|Δ|=1.3e-3 (bf16 softmax precision). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The config was one of ~4 holdouts still using the old `def __init__(**kwargs)`
+ hand-rolled `to_dict` style; 460/463 configs use `@strict` dataclass
`PreTrainedConfig`. Rewrite it to match, mirroring `models/esm` (ESMFold v1):
- Every sub-config (MSAEncoder/Parcae/LMEncoder/AtomAttention/FoldingTrunk/
InputsEmbedder/DiffusionModule/DiffusionStructureHead/ConfidenceHead) is now a
`@strict PreTrainedConfig` with typed fields + defaults, instead of a plain
`@dataclass`.
- Nesting is declared via `sub_configs = {...}` on each parent + a `__post_init__`
that turns `dict -> SubConfig(**dict)`. This deletes the brittle hand-rolled
`ESMFold2Config.to_dict()` (base class now serializes via `sub_configs`) and
gives recursive `_attn_implementation` propagation to sub-configs for free.
- Top config switches to `@auto_docstring(checkpoint="biohub/ESMFold2")`, which
resolves the outstanding `check_docstrings` failure for ESMFold2Config.
- `__all__` is reduced to `["ESMFold2Config"]` (sub-configs are implementation
detail, matching ESM which exports only `EsmConfig`).
check_config_attributes: top config keeps a precise 4-item allow-list (type +
three training/experimental recipe knobs not read by the core inference path);
the 9 sub-configs are allow-listed wholesale because their fields are threaded
into submodules as explicit dims (e.g. `d_atom=cfg.inputs.atom_encoder.d_atom`),
which the checker's `config.<attr>` heuristic cannot trace.
Verified: default + nested-dict construction, save/load round-trip (to_dict
identical, sub-config types preserved), recursive attn-impl propagation, type
validation, unknown-kwarg tolerance; the tiny ESMFold2Model still builds and the
SWA attention equivalence is unchanged. ruff + check_docstrings +
check_config_attributes all clean.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
ESMFold2 is an all-atom structure predictor whose forward takes ~18 structural feature tensors and returns a plain dict (not a ModelOutput), so it doesn't fit ModelTesterMixin. Following the sanctioned pattern for such models, register the test file in check_repo's TEST_FILES_WITH_NO_COMMON_TESTS and provide focused coverage: - ESMFold2ConfigTest: ConfigTester common tests (incl. composite sub-config save/load), `type` validation, nested round-trip, attn-impl propagation. - ESMFold2ModelTest (CPU): full pure-PyTorch forward via infer_protein with no ESMC backbone (LM conditioning skipped) under both sdpa and eager; SWA config dispatch; weight-level save/load fidelity + a usable reloaded model. - ESMFold2IntegrationTest: @slow real-weight fold on biohub/ESMFold2 (GPU-gated). To make ConfigTester's composite test pass, each sub-config now declares a unique `model_type` (e.g. "esmfold2_inputs_embedder") — the CLIP pattern — so that `SubConfig.from_pretrained(<composite dir>)` extracts the matching nested dict (configuration_utils keys this off model_type). ESM dodges this test because its sub-config is None by default; ESMFold2's are always present. check_repo: `modeling_esmfold2_common` (shared building blocks, no public model) is added to get_model_modules' _ignore_modules. The remaining check_repo item for esmfold2 is the model-doc page, which is a separate pending task. The tiny test config encodes two real sizing constraints discovered via the forward smoke: 3D RoPE needs 3*n_spatial + n_uid <= head_dim//2, and inputs.d_inputs == 67 + d_token//2 == diffusion_module.c_s_inputs. 7 non-slow tests pass; ruff + check_config_attributes + check_docstrings clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Short overview page (mirrors esmc.md): describes the all-atom structure predictor and its separately-loaded ESMC backbone, a usage snippet (infer_protein_as_pdb), and autodoc for ESMFold2Config + ESMFold2Model. Registered in _toctree.yml after ESMC. Resolves the final check_repo "objects documented" item for esmfold2. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
119142e to
425ba1b
Compare
|
@vasqu writing this in one big comment rather than breaking it up to every individual location, because there was a lot of repeated stuff! I just did a big pass for ESMFold2, now that ESMC is mostly okay. I dropped the flashattention path for now, inlined a bunch of methods, broke RoPE out into a single class, tried to merge as many SwiGLUs as I could (not all were possible, some have little embedded norms). A lot of stuff got renamed and moved to configs instead of just being constants, and I tried to pull out private methods from the big long forward methods. There's a few tricky bits still:
Overall I tried to make it more readable, but it's a very complex model, so there's only so much I can do! 😅 |
|
@Rocketknight1 Thanks a lot! Yes the model is super complex so I don't expect perfect stuff. I think the esmc model should be fairly good now. I will try to take a look at esmfold2 in the coming days. Re
|
|
On 4 I'm unsure - we're not doing modular for ESMFold2 because the code is so weird compared to other models. I'd leave that one for a future refactor, if there are kernels that turn out to really justify it, or if we get a lot of similar folding models in the codebase. 1 is done, though, and I managed to drop some dead args in the process! |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: auto, esmc, esmfold2 |
CI recapDashboard: View test results in Grafana |
This is just about ready at this point! I've deferred the experimental ESMFold2 models and the SAEs to a follow-up PR so we can launch this as soon as possible. It's got a lot of very verbose Claude comments and docstrings too, but those are forgivable and we can always clean them up later 😅
The main thing for reviewers to know is that there are two model classes here, ESMC (a masked protein LM) and ESMFold2 (a protein folding model that uses ESMC as a backbone model). ESMC is mostly just a 'normal' masked LM, but I wasn't able to find exact modular matches for a lot of methods. Still, it should be easy enough to review. ESMFold2 is much more challenging and is not like any other model in the codebase, even the original ESMFold. I have some tests confirming that we get the same output as BioHub's original code, so I think reviewers should probably just trust that for correctness and focus on the API, although you can try going through all of the diffusion code if you want!