diff --git a/bonsai/models/vit/README.md b/bonsai/models/vit/README.md new file mode 100644 index 0000000..8193f0f --- /dev/null +++ b/bonsai/models/vit/README.md @@ -0,0 +1,23 @@ +## ViT in JAX + +This directory contains a pure JAX implementation of the [ViT](https://huggingface.co/google/vit-base-patch16-224) model, using the Flax NNX API. + +## Tested on + +| Model Name | Config | CPU | GPU A100 (1x) | GPU H100 (1x) | GPU A100 (8x) | GPU H100 (8x) | TPU v2 (8x) | TPU v5e (1x) | +| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | +| [ViT](https://huggingface.co/google/vit-base-patch16-224) | ✅ Supported | ✅ Runs | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | + +### Running this model + +Run ResNet model inference in action: + +```sh +python3 -m bonsai.models.vit.tests.run_model +``` + +## How to contribute to this model + +We welcome contributions! You can contribute to this model via the following: +* Add a model config variant from the above `🟡 Not started` to `class ModelConfig` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR. +* Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `❔ Needs check`. Mark as `✅ Runs` or `⛔️ Not supported`. diff --git a/bonsai/models/vit/modeling.py b/bonsai/models/vit/modeling.py new file mode 100644 index 0000000..49509c4 --- /dev/null +++ b/bonsai/models/vit/modeling.py @@ -0,0 +1,97 @@ +import jax +import jax.numpy as jnp +from flax import nnx + + +class Embeddings(nnx.Module): + def __init__( + self, + image_size: tuple[int, int], + patch_size: tuple[int, int], + num_channels: int, + hidden_dim: int, + dropout_prob: float, + *, + rngs: nnx.Rngs, + ): + num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1]) + + self.projection = nnx.Conv(num_channels, hidden_dim, kernel_size=patch_size, strides=patch_size, rngs=rngs) + self.cls_token = nnx.Variable(jax.random.normal(rngs.params(), (1, 1, hidden_dim))) + self.pos_embeddings = nnx.Variable(jax.random.normal(rngs.params(), (1, num_patches + 1, hidden_dim))) + self.dropout = nnx.Dropout(dropout_prob, rngs=rngs) + + def __call__(self, pixel_values: jnp.ndarray) -> jnp.ndarray: + embeddings = self.projection(pixel_values) + b, h, w, c = embeddings.shape + embeddings = embeddings.reshape(b, h * w, c) + cls_tokens = jnp.tile(self.cls_token.value, (b, 1, 1)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + embeddings = embeddings + self.pos_embeddings.value + embeddings = self.dropout(embeddings) + return embeddings + + +class TransformerEncoder(nnx.Module): + def __init__(self, num_heads: int, attn_dim: int, mlp_dim: int, dropout_prob: float, eps: float, *, rngs: nnx.Rngs): + self.attention = nnx.MultiHeadAttention(num_heads=num_heads, in_features=attn_dim, decode=False, rngs=rngs) + self.linear1 = nnx.Linear(attn_dim, mlp_dim, rngs=rngs) + self.linear2 = nnx.Linear(mlp_dim, attn_dim, rngs=rngs) + self.dropout = nnx.Dropout(dropout_prob, rngs=rngs) + self.layernorm_before = nnx.LayerNorm(attn_dim, epsilon=eps, rngs=rngs) + self.layernorm_after = nnx.LayerNorm(attn_dim, epsilon=eps, rngs=rngs) + + def __call__(self, hidden_states, head_mask=None): + hidden_states_norm = self.layernorm_before(hidden_states) + attention_output = self.attention(hidden_states_norm, head_mask) + hidden_states = attention_output + hidden_states + layer_output = self.layernorm_after(hidden_states) + layer_output = jax.nn.gelu(self.linear1(layer_output)) + layer_output = self.linear2(layer_output) + layer_output = self.dropout(layer_output) + layer_output += hidden_states + return layer_output + + +class ViTClassificationModel(nnx.Module): + def __init__( + self, + image_size: tuple[int, int], + patch_size: tuple[int, int], + num_channels: int, + hidden_dim: int, + dropout_prob: float, + num_heads: int, + mlp_dim: int, + eps: float, + num_layers: int, + num_labels: int, + *, + rngs: nnx.Rngs, + ): + self.pos_embeddings = Embeddings(image_size, patch_size, num_channels, hidden_dim, dropout_prob, rngs=rngs) + self.layers = nnx.Sequential( + *[ + TransformerEncoder(num_heads, hidden_dim, mlp_dim, dropout_prob, eps, rngs=rngs) + for _ in range(num_layers) + ] + ) + self.ln = nnx.LayerNorm(hidden_dim, epsilon=eps, rngs=rngs) + self.classifier = nnx.Linear(hidden_dim, num_labels, rngs=rngs) + + def __call__(self, x): + x = self.pos_embeddings(x) + x = self.layers(x) + x = self.ln(x) + x = self.classifier(x[:, 0, :]) + return x + + +def ViT(num_classes: int, *, rngs: nnx.Rngs): + return ViTClassificationModel((224, 224), (16, 16), 3, 768, 0.0, 12, 3072, 1e-12, 12, num_classes, rngs=rngs) + + +@jax.jit +def forward(graphdef: nnx.GraphDef[nnx.Module], state: nnx.State, x: jax.Array) -> jax.Array: + model = nnx.merge(graphdef, state) + return model(x) diff --git a/bonsai/models/vit/params.py b/bonsai/models/vit/params.py new file mode 100644 index 0000000..ffd4f80 --- /dev/null +++ b/bonsai/models/vit/params.py @@ -0,0 +1,153 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re + +import jax +import safetensors.flax as safetensors +from etils import epath +from flax import nnx + +from bonsai.models.vit import modeling as model_lib + + +def _get_key_and_transform_mapping(): + # Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)). + embed_dim, num_heads, head_dim = 768, 12, 64 + return { + r"^classifier.bias$": (r"classifier.bias", None), + r"^classifier.weight$": (r"classifier.kernel", ((1, 0), None)), + r"^vit.embeddings.cls_token$": (r"pos_embeddings.cls_token", None), + r"^vit.embeddings.patch_embeddings.projection.bias$": (r"pos_embeddings.projection.bias", None), + r"^vit.embeddings.patch_embeddings.projection.weight$": ( + r"pos_embeddings.projection.kernel", + ((2, 3, 1, 0), None), + ), + r"^vit.embeddings.position_embeddings$": (r"pos_embeddings.pos_embeddings", None), + r"^vit.encoder.layer.([0-9]+).attention.attention.key.bias$": ( + r"layers.layers.\1.attention.key.bias", + (None, (num_heads, head_dim)), + ), + r"^vit.encoder.layer.([0-9]+).attention.attention.key.weight$": ( + r"layers.layers.\1.attention.key.kernel", + ((1, 0), (embed_dim, num_heads, head_dim)), + ), + r"^vit.encoder.layer.([0-9]+).attention.attention.query.bias$": ( + r"layers.layers.\1.attention.query.bias", + (None, (num_heads, head_dim)), + ), + r"^vit.encoder.layer.([0-9]+).attention.attention.query.weight$": ( + r"layers.layers.\1.attention.query.kernel", + ((1, 0), (embed_dim, num_heads, head_dim)), + ), + r"^vit.encoder.layer.([0-9]+).attention.attention.value.bias$": ( + r"layers.layers.\1.attention.value.bias", + (None, (num_heads, head_dim)), + ), + r"^vit.encoder.layer.([0-9]+).attention.attention.value.weight$": ( + r"layers.layers.\1.attention.value.kernel", + ((1, 0), (embed_dim, num_heads, head_dim)), + ), + r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (r"layers.layers.\1.attention.out.bias", None), + r"^vit.encoder.layer.([0-9]+).attention.output.dense.weight$": ( + r"layers.layers.\1.attention.out.kernel", + ((1, 0), (num_heads, head_dim, embed_dim)), + ), + r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.layers.\1.linear1.bias", None), + r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (r"layers.layers.\1.linear1.kernel", ((1, 0), None)), + r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (r"layers.layers.\1.layernorm_after.bias", None), + r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (r"layers.layers.\1.layernorm_after.scale", None), + r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (r"layers.layers.\1.layernorm_before.bias", None), + r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (r"layers.layers.\1.layernorm_before.scale", None), + r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.layers.\1.linear2.bias", None), + r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.layers.\1.linear2.kernel", ((1, 0), None)), + r"^vit.layernorm.bias$": (r"ln.bias", None), + r"^vit.layernorm.weight$": (r"ln.scale", None), + } + + +def _st_key_to_jax_key(mapping, source_key): + """Map a safetensors key to exactly one JAX key & transform, else warn/error.""" + subs = [ + (re.sub(pat, repl, source_key), transform) + for pat, (repl, transform) in mapping.items() + if re.match(pat, source_key) + ] + if not subs: + logging.warning(f"No mapping found for key: {source_key!r}") + return None, None + if len(subs) > 1: + keys = [s for s, _ in subs] + raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}") + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape = transform + if permute: + tensor = tensor.transpose(permute) + if reshape: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + state_dict[key] = tensor + else: + _assign_weights(rest, tensor, state_dict[key], st_key, transform) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +def create_vit_from_pretrained( + file_dir: str, + num_classes: int = 1000, + *, + mesh: jax.sharding.Mesh | None = None, +): + """ + Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model. + + Returns: + A flax.nnx.Model instance with loaded parameters. + """ + files = list(epath.Path(file_dir).expanduser().glob("*.safetensors")) + if not files: + raise ValueError(f"No safetensors found in {file_dir}") + + tensor_dict = {} + for f in files: + tensor_dict |= safetensors.load_file(f) + + vit = model_lib.ViT(num_classes=num_classes, rngs=nnx.Rngs(0)) + graph_def, abs_state = nnx.split(vit) + jax_state = abs_state.to_pure_dict() + + mapping = _get_key_and_transform_mapping() + for st_key, tensor in tensor_dict.items(): + jax_key, transform = _st_key_to_jax_key(mapping, st_key) + if jax_key is None: + continue + keys = [_stoi(k) for k in jax_key.split(".")] + _assign_weights(keys, tensor, jax_state, st_key, transform) + + return nnx.merge(graph_def, jax_state) diff --git a/bonsai/models/vit/tests/run_model.py b/bonsai/models/vit/tests/run_model.py new file mode 100644 index 0000000..210df3d --- /dev/null +++ b/bonsai/models/vit/tests/run_model.py @@ -0,0 +1,65 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import jax +import jax.numpy as jnp +from flax import nnx +from huggingface_hub import snapshot_download + +from bonsai.models.vit import modeling as model_lib +from bonsai.models.vit import params + + +def run_model(): + # Download safetensors file + model_name = "google/vit-base-patch16-224" + model_ckpt_path = snapshot_download(model_name) + + # Load pretrained model + model = params.create_vit_from_pretrained(model_ckpt_path) + graphdef, state = nnx.split(model) + flat_state = jax.tree.leaves(state) + + # Prepare dummy input + batch_size, channels, image_size = 8, 3, 224 + dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32) + + # Warmup (triggers compilation) + _ = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready() + + # Profile a few steps + jax.profiler.start_trace("/tmp/profile-vit") + for _ in range(5): + logits = model_lib.forward(graphdef, flat_state, dummy_input) + jax.block_until_ready(logits) + jax.profiler.stop_trace() + + # Timed execution + t0 = time.perf_counter() + for _ in range(10): + logits = model_lib.forward(graphdef, flat_state, dummy_input).block_until_ready() + print(f"Step time: {(time.perf_counter() - t0) / 10:.4f} s") + + # Show top-1 predicted class + pred = jnp.argmax(logits, axis=-1) + print("Predicted classes:", pred) + + +if __name__ == "__main__": + run_model() + +__all__ = ["run_model"] diff --git a/bonsai/models/vit/tests/test_outputs.py b/bonsai/models/vit/tests/test_outputs.py new file mode 100644 index 0000000..2058a37 --- /dev/null +++ b/bonsai/models/vit/tests/test_outputs.py @@ -0,0 +1,63 @@ +import jax +import jax.numpy as jnp +import torch +from absl.testing import absltest +from huggingface_hub import snapshot_download +from transformers import ViTForImageClassification + +from bonsai.models.vit import params + + +class TestModuleForwardPasses(absltest.TestCase): + def setUp(self): + super().setUp() + model_name = "google/vit-base-patch16-224" + model_ckpt_path = snapshot_download(model_name) + self.bonsai_model = params.create_vit_from_pretrained(model_ckpt_path) + self.baseline_model = ViTForImageClassification.from_pretrained(model_name) + self.bonsai_model.eval() + self.baseline_model.eval() + + self.batch_size = 32 + self.image_shape = (self.batch_size, 224, 224, 3) + + def test_embeddings(self): + torch_emb = self.baseline_model.vit.embeddings + nnx_emb = self.bonsai_model.pos_embeddings + + jx = jax.random.normal(jax.random.key(0), self.image_shape, dtype=jnp.float32) + tx = torch.tensor(jx).permute(0, 3, 1, 2) + + with torch.no_grad(): + ty = torch_emb(tx) + jy = nnx_emb(jx) + + torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-5) + + def test_first_layer(self): + torch_layer = self.baseline_model.vit.encoder.layer[0] + nnx_layer = self.bonsai_model.layers.layers[0] + + hidden_shape = (self.batch_size, 197, 768) + jx = jax.random.normal(jax.random.key(0), hidden_shape, dtype=jnp.float32) + tx = torch.tensor(jx) + + with torch.no_grad(): + ty = torch_layer(tx) + jy = nnx_layer(jx) + + torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=1e-2) + + def test_full(self): + jx = jax.random.normal(jax.random.key(0), self.image_shape, dtype=jnp.float32) + tx = torch.tensor(jx).permute(0, 3, 1, 2) + + with torch.no_grad(): + ty = self.baseline_model(tx).logits + jy = self.bonsai_model(jx) + + torch.testing.assert_close(torch.tensor(jy), ty, rtol=1e-5, atol=5e-2) + + +if __name__ == "__main__": + absltest.main()