-
Couldn't load subscription status.
- Fork 18
initial vit implementation #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
cb97128
initial vit implementation
chapman20j 793688c
Merge branch 'jax-ml:main' into vit
chapman20j 64d822f
Simplify, remove comments, update param loading to avoid cache miss
chapman20j 18704f7
remove comment
chapman20j 4d63f20
comments
chapman20j fb8e3d6
Update test name
chapman20j c488c6e
Merge branch 'main' into vit
jenriver File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| ## ViT in JAX | ||
|
|
||
| This directory contains a pure JAX implementation of the [ViT](https://huggingface.co/google/vit-base-patch16-224) model, using the Flax NNX API. | ||
|
|
||
| ## Tested on | ||
|
|
||
| | Model Name | Config | CPU | GPU A100 (1x) | GPU H100 (1x) | GPU A100 (8x) | GPU H100 (8x) | TPU v2 (8x) | TPU v5e (1x) | | ||
| | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | | ||
| | [ViT](https://huggingface.co/google/vit-base-patch16-224) | ✅ Supported | ✅ Runs | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | | ||
|
|
||
| ### Running this model | ||
|
|
||
| Run ResNet model inference in action: | ||
|
|
||
| ```sh | ||
| python3 -m bonsai.models.vit.tests.run_model | ||
| ``` | ||
|
|
||
| ## How to contribute to this model | ||
|
|
||
| We welcome contributions! You can contribute to this model via the following: | ||
| * Add a model config variant from the above `🟡 Not started` to `class ModelConfig` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR. | ||
| * Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `❔ Needs check`. Mark as `✅ Runs` or `⛔️ Not supported`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| from functools import partial | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from flax import nnx | ||
|
|
||
|
|
||
| class Embeddings(nnx.Module): | ||
| def __init__( | ||
| self, | ||
| image_size: tuple[int, int], | ||
| patch_size: tuple[int, int], | ||
| num_channels: int, | ||
| hidden_dim: int, | ||
| dropout_prob: float, | ||
| *, | ||
| rngs: nnx.Rngs, | ||
| ): | ||
| num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1]) | ||
|
|
||
| self.projection = nnx.Conv(num_channels, hidden_dim, kernel_size=patch_size, strides=patch_size, rngs=rngs) | ||
| self.cls_token = nnx.Variable(jax.random.normal(rngs.params(), (1, 1, hidden_dim))) | ||
| self.pos_embeddings = nnx.Variable(jax.random.normal(rngs.params(), (1, num_patches + 1, hidden_dim))) | ||
| self.dropout = nnx.Dropout(dropout_prob, rngs=rngs) | ||
|
|
||
| def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray: | ||
| embeddings = self.projection(pixel_values) | ||
| b, h, w, c = embeddings.shape | ||
| embeddings = embeddings.reshape(b, h * w, c) | ||
| cls_tokens = jnp.tile(self.cls_token.value, (b, 1, 1)) | ||
| embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) | ||
| embeddings = embeddings + self.pos_embeddings.value | ||
| embeddings = self.dropout(embeddings, deterministic=deterministic) | ||
| return embeddings | ||
|
|
||
|
|
||
| ## ENCODINGS | ||
| class TransformerEncoder(nnx.Module): | ||
| def __init__(self, num_heads: int, attn_dim: int, mlp_dim: int, dropout_prob: float, eps: float, *, rngs: nnx.Rngs): | ||
| self.attention = nnx.MultiHeadAttention(num_heads=num_heads, in_features=attn_dim, decode=False, rngs=rngs) | ||
| self.linear1 = nnx.Linear(attn_dim, mlp_dim, rngs=rngs) | ||
| self.linear2 = nnx.Linear(mlp_dim, attn_dim, rngs=rngs) | ||
| self.dropout = nnx.Dropout(dropout_prob, rngs=rngs) | ||
| self.layernorm_before = nnx.LayerNorm(attn_dim, epsilon=eps, rngs=rngs) | ||
| self.layernorm_after = nnx.LayerNorm(attn_dim, epsilon=eps, rngs=rngs) | ||
|
|
||
| def __call__(self, hidden_states, head_mask=None, deterministic: bool = False): | ||
| hidden_states_norm = self.layernorm_before(hidden_states) | ||
| attention_output = self.attention(hidden_states_norm, head_mask, deterministic=deterministic) | ||
| hidden_states = attention_output + hidden_states | ||
| layer_output = self.layernorm_after(hidden_states) | ||
| layer_output = jax.nn.gelu(self.linear1(layer_output)) | ||
| layer_output = self.linear2(layer_output) | ||
| layer_output = self.dropout(layer_output, deterministic=deterministic) | ||
| layer_output += hidden_states | ||
| return layer_output | ||
|
|
||
|
|
||
| class ViTClassificationModel(nnx.Module): | ||
| def __init__( | ||
| self, | ||
| image_size: tuple[int, int], | ||
| patch_size: tuple[int, int], | ||
| num_channels: int, | ||
| hidden_dim: int, | ||
| dropout_prob: float, | ||
| num_heads: int, | ||
| mlp_dim: int, | ||
| eps: float, | ||
| num_layers: int, | ||
| num_labels: int, | ||
| *, | ||
| rngs: nnx.Rngs, | ||
| ): | ||
| self.pos_embeddings = Embeddings(image_size, patch_size, num_channels, hidden_dim, dropout_prob, rngs=rngs) | ||
| self.layers = nnx.Sequential( | ||
| *[ | ||
| TransformerEncoder(num_heads, hidden_dim, mlp_dim, dropout_prob, eps, rngs=rngs) | ||
| for _ in range(num_layers) | ||
| ] | ||
| ) | ||
| self.ln = nnx.LayerNorm(hidden_dim, epsilon=eps, rngs=rngs) | ||
| self.classifier = nnx.Linear(hidden_dim, num_labels, rngs=rngs) | ||
|
|
||
| def __call__(self, x): | ||
| x = self.pos_embeddings(x) | ||
| x = self.layers(x) | ||
| x = self.ln(x) | ||
| x = self.classifier(x[:, 0, :]) | ||
| return x | ||
|
|
||
|
|
||
| def ViT(*, rngs: nnx.Rngs): | ||
| return ViTClassificationModel((224, 224), (16, 16), 3, 768, 0.0, 12, 3072, 1e-12, 12, 1000, rngs=rngs) | ||
|
|
||
|
|
||
| @partial(jax.jit, static_argnames=["model"]) | ||
| def forward(model, x): | ||
| return model(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| # Copyright 2025 The JAX Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import re | ||
|
|
||
| import jax | ||
| import safetensors.flax as safetensors | ||
| from flax import nnx | ||
|
|
||
| from bonsai.models.vit import modeling as model_lib | ||
|
|
||
|
|
||
| def _get_key_and_transform_mapping(): | ||
| # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). | ||
| embed_dim, num_heads, head_dim = 768, 12, 64 | ||
| return { | ||
| # classifier | ||
| r"^classifier.bias$": (r"classifier.bias", None), | ||
| r"^classifier.weight$": (r"classifier.kernel", ((1, 0), None)), | ||
| # embeddings | ||
| r"^vit.embeddings.cls_token$": (r"pos_embeddings.cls_token", None), | ||
| r"^vit.embeddings.patch_embeddings.projection.bias$": (r"pos_embeddings.projection.bias", None), | ||
| r"^vit.embeddings.patch_embeddings.projection.weight$": ( | ||
| r"pos_embeddings.projection.kernel", | ||
| ((2, 3, 1, 0), None), | ||
| ), | ||
| r"^vit.embeddings.position_embeddings$": (r"pos_embeddings.pos_embeddings", None), | ||
| # layers | ||
| r"^vit.encoder.layer.([0-9]+).attention.attention.key.bias$": ( | ||
| r"layers.layers.\1.attention.key.bias", | ||
| (None, (num_heads, head_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).attention.attention.key.weight$": ( | ||
| r"layers.layers.\1.attention.key.kernel", | ||
| ((1, 0), (embed_dim, num_heads, head_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).attention.attention.query.bias$": ( | ||
| r"layers.layers.\1.attention.query.bias", | ||
| (None, (num_heads, head_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).attention.attention.query.weight$": ( | ||
| r"layers.layers.\1.attention.query.kernel", | ||
| ((1, 0), (embed_dim, num_heads, head_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).attention.attention.value.bias$": ( | ||
| r"layers.layers.\1.attention.value.bias", | ||
| (None, (num_heads, head_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).attention.attention.value.weight$": ( | ||
| r"layers.layers.\1.attention.value.kernel", | ||
| ((1, 0), (embed_dim, num_heads, head_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (r"layers.layers.\1.attention.out.bias", None), | ||
| r"^vit.encoder.layer.([0-9]+).attention.output.dense.weight$": ( | ||
| r"layers.layers.\1.attention.out.kernel", | ||
| ((1, 0), (num_heads, head_dim, embed_dim)), | ||
| ), | ||
| r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.layers.\1.linear1.bias", None), | ||
| r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (r"layers.layers.\1.linear1.kernel", ((1, 0), None)), | ||
| r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (r"layers.layers.\1.layernorm_after.bias", None), | ||
| r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (r"layers.layers.\1.layernorm_after.scale", None), | ||
| r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (r"layers.layers.\1.layernorm_before.bias", None), | ||
| r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (r"layers.layers.\1.layernorm_before.scale", None), | ||
| r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.layers.\1.linear2.bias", None), | ||
| r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.layers.\1.linear2.kernel", ((1, 0), None)), | ||
| # layernorm | ||
| r"^vit.layernorm.bias$": (r"ln.bias", None), | ||
| r"^vit.layernorm.weight$": (r"ln.scale", None), | ||
| } | ||
|
|
||
|
|
||
| def _st_key_to_jax_key(mapping, source_key): | ||
| """Map a safetensors key to exactly one JAX key & transform, else warn/error.""" | ||
| subs = [ | ||
| (re.sub(pat, repl, source_key), transform) | ||
| for pat, (repl, transform) in mapping.items() | ||
| if re.match(pat, source_key) | ||
| ] | ||
| if not subs: | ||
| logging.warning(f"No mapping found for key: {source_key!r}") | ||
| return None, None | ||
| if len(subs) > 1: | ||
| keys = [s for s, _ in subs] | ||
| raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}") | ||
| return subs[0] | ||
|
|
||
|
|
||
| def _assign_weights(keys, tensor, state_dict, st_key, transform): | ||
| """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" | ||
| key, *rest = keys | ||
| if not rest: | ||
| if transform is not None: | ||
| permute, reshape = transform | ||
| if permute: | ||
| tensor = tensor.transpose(permute) | ||
| if reshape: | ||
| tensor = tensor.reshape(reshape) | ||
| if tensor.shape != state_dict[key].shape: | ||
| raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") | ||
| state_dict[key] = tensor | ||
| else: | ||
| _assign_weights(rest, tensor, state_dict[key], st_key, transform) | ||
|
|
||
|
|
||
| def _stoi(s): | ||
| try: | ||
| return int(s) | ||
| except ValueError: | ||
| return s | ||
|
|
||
|
|
||
| def create_vit_from_pretrained( | ||
| file_path: str, | ||
| num_classes: int = 1000, | ||
| *, | ||
| mesh: jax.sharding.Mesh | None = None, | ||
| ): | ||
| """ | ||
| Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. | ||
|
|
||
| Returns: | ||
| A flax.nnx.Model instance with loaded parameters. | ||
| """ | ||
| state_dict = safetensors.load_file(file_path) | ||
|
|
||
| vit = nnx.eval_shape(lambda: model_lib.ViT(rngs=nnx.Rngs(0))) | ||
| graph_def, abs_state = nnx.split(vit) | ||
| jax_state = abs_state.to_pure_dict() | ||
|
|
||
| mapping = _get_key_and_transform_mapping() | ||
| for st_key, tensor in state_dict.items(): | ||
| jax_key, transform = _st_key_to_jax_key(mapping, st_key) | ||
| if jax_key is None: | ||
| continue | ||
| keys = [_stoi(k) for k in jax_key.split(".")] | ||
| _assign_weights(keys, tensor, jax_state, st_key, transform) | ||
|
|
||
| # if mesh is not None: | ||
| # sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict() | ||
| # jax_state = jax.device_put(jax_state, sharding) | ||
| # else: | ||
| # jax_state = jax.device_put(jax_state, jax.devices()[0]) | ||
|
|
||
| return nnx.merge(graph_def, jax_state) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright 2025 The JAX Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
| import time | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| from bonsai.models.vit import modeling as model_lib | ||
| from bonsai.models.vit import params | ||
|
|
||
|
|
||
| def run_model(MODEL_CP_PATH=None): | ||
| # 1. Download safetensors file | ||
| model_name = "google/vit-base-patch16-224" | ||
| if MODEL_CP_PATH is None: | ||
| MODEL_CP_PATH = "/tmp/models-bonsai/" + model_name.split("/")[1] | ||
|
|
||
| if not os.path.isdir(MODEL_CP_PATH): | ||
| snapshot_download(model_name, local_dir=MODEL_CP_PATH) | ||
|
|
||
| safetensors_path = os.path.join(MODEL_CP_PATH, "model.safetensors") | ||
|
|
||
| # 2. Load pretrained model | ||
| model = params.create_vit_from_pretrained(safetensors_path) | ||
|
|
||
| # 3. Prepare dummy input | ||
| batch_size, channels, image_size = 8, 3, 224 | ||
| dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32) | ||
|
|
||
| # 4. Warmup + profiling | ||
chapman20j marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Warmup (triggers compilation) | ||
| _ = model_lib.forward(model, dummy_input) | ||
| jax.block_until_ready(_) | ||
|
|
||
| # Profile a few steps | ||
| jax.profiler.start_trace("/tmp/profile-vit") | ||
| for _ in range(5): | ||
| logits = model_lib.forward(model, dummy_input) | ||
| jax.block_until_ready(logits) | ||
| jax.profiler.stop_trace() | ||
|
|
||
| # 5. Timed execution | ||
| t0 = time.perf_counter() | ||
| for _ in range(10): | ||
| logits = model_lib.forward(model, dummy_input) | ||
| jax.block_until_ready(logits) | ||
| print(f"10 runs took {time.perf_counter() - t0:.4f} s") | ||
chapman20j marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # 6. Show top-1 predicted class | ||
| pred = jnp.argmax(logits, axis=-1) | ||
| print("Predicted classes:", pred) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_model() | ||
|
|
||
| __all__ = ["run_model"] | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.