Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions bonsai/models/vit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
## ViT in JAX

This directory contains a pure JAX implementation of the [ViT](https://huggingface.co/google/vit-base-patch16-224) model, using the Flax NNX API.

## Tested on

| Model Name | Config | CPU | GPU A100 (1x) | GPU H100 (1x) | GPU A100 (8x) | GPU H100 (8x) | TPU v2 (8x) | TPU v5e (1x) |
| :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- | :--- |
| [ViT](https://huggingface.co/google/vit-base-patch16-224) | ✅ Supported | ✅ Runs | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check | ❔ Needs check |

### Running this model

Run ResNet model inference in action:

```sh
python3 -m bonsai.models.vit.tests.run_model
```

## How to contribute to this model

We welcome contributions! You can contribute to this model via the following:
* Add a model config variant from the above `🟡 Not started` to `class ModelConfig` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR.
* Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `❔ Needs check`. Mark as `✅ Runs` or `⛔️ Not supported`.
99 changes: 99 additions & 0 deletions bonsai/models/vit/modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from functools import partial

import jax
import jax.numpy as jnp
from flax import nnx


class Embeddings(nnx.Module):
def __init__(
self,
image_size: tuple[int, int],
patch_size: tuple[int, int],
num_channels: int,
hidden_dim: int,
dropout_prob: float,
*,
rngs: nnx.Rngs,
):
num_patches = (image_size[0] // patch_size[0]) * (image_size[1] // patch_size[1])

self.projection = nnx.Conv(num_channels, hidden_dim, kernel_size=patch_size, strides=patch_size, rngs=rngs)
self.cls_token = nnx.Variable(jax.random.normal(rngs.params(), (1, 1, hidden_dim)))
self.pos_embeddings = nnx.Variable(jax.random.normal(rngs.params(), (1, num_patches + 1, hidden_dim)))
self.dropout = nnx.Dropout(dropout_prob, rngs=rngs)

def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = False) -> jnp.ndarray:
embeddings = self.projection(pixel_values)
b, h, w, c = embeddings.shape
embeddings = embeddings.reshape(b, h * w, c)
cls_tokens = jnp.tile(self.cls_token.value, (b, 1, 1))
embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1)
embeddings = embeddings + self.pos_embeddings.value
embeddings = self.dropout(embeddings, deterministic=deterministic)
return embeddings


## ENCODINGS
class TransformerEncoder(nnx.Module):
def __init__(self, num_heads: int, attn_dim: int, mlp_dim: int, dropout_prob: float, eps: float, *, rngs: nnx.Rngs):
self.attention = nnx.MultiHeadAttention(num_heads=num_heads, in_features=attn_dim, decode=False, rngs=rngs)
self.linear1 = nnx.Linear(attn_dim, mlp_dim, rngs=rngs)
self.linear2 = nnx.Linear(mlp_dim, attn_dim, rngs=rngs)
self.dropout = nnx.Dropout(dropout_prob, rngs=rngs)
self.layernorm_before = nnx.LayerNorm(attn_dim, epsilon=eps, rngs=rngs)
self.layernorm_after = nnx.LayerNorm(attn_dim, epsilon=eps, rngs=rngs)

def __call__(self, hidden_states, head_mask=None, deterministic: bool = False):
hidden_states_norm = self.layernorm_before(hidden_states)
attention_output = self.attention(hidden_states_norm, head_mask, deterministic=deterministic)
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = jax.nn.gelu(self.linear1(layer_output))
layer_output = self.linear2(layer_output)
layer_output = self.dropout(layer_output, deterministic=deterministic)
layer_output += hidden_states
return layer_output


class ViTClassificationModel(nnx.Module):
def __init__(
self,
image_size: tuple[int, int],
patch_size: tuple[int, int],
num_channels: int,
hidden_dim: int,
dropout_prob: float,
num_heads: int,
mlp_dim: int,
eps: float,
num_layers: int,
num_labels: int,
*,
rngs: nnx.Rngs,
):
self.pos_embeddings = Embeddings(image_size, patch_size, num_channels, hidden_dim, dropout_prob, rngs=rngs)
self.layers = nnx.Sequential(
*[
TransformerEncoder(num_heads, hidden_dim, mlp_dim, dropout_prob, eps, rngs=rngs)
for _ in range(num_layers)
]
)
self.ln = nnx.LayerNorm(hidden_dim, epsilon=eps, rngs=rngs)
self.classifier = nnx.Linear(hidden_dim, num_labels, rngs=rngs)

def __call__(self, x):
x = self.pos_embeddings(x)
x = self.layers(x)
x = self.ln(x)
x = self.classifier(x[:, 0, :])
return x


def ViT(*, rngs: nnx.Rngs):
return ViTClassificationModel((224, 224), (16, 16), 3, 768, 0.0, 12, 3072, 1e-12, 12, 1000, rngs=rngs)


@partial(jax.jit, static_argnames=["model"])
def forward(model, x):
return model(x)
156 changes: 156 additions & 0 deletions bonsai/models/vit/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re

import jax
import safetensors.flax as safetensors
from flax import nnx

from bonsai.models.vit import modeling as model_lib


def _get_key_and_transform_mapping():
# Mapping st_keys -> (nnx_keys, (permute_rule, reshape_rule)).
embed_dim, num_heads, head_dim = 768, 12, 64
return {
# classifier
r"^classifier.bias$": (r"classifier.bias", None),
r"^classifier.weight$": (r"classifier.kernel", ((1, 0), None)),
# embeddings
r"^vit.embeddings.cls_token$": (r"pos_embeddings.cls_token", None),
r"^vit.embeddings.patch_embeddings.projection.bias$": (r"pos_embeddings.projection.bias", None),
r"^vit.embeddings.patch_embeddings.projection.weight$": (
r"pos_embeddings.projection.kernel",
((2, 3, 1, 0), None),
),
r"^vit.embeddings.position_embeddings$": (r"pos_embeddings.pos_embeddings", None),
# layers
r"^vit.encoder.layer.([0-9]+).attention.attention.key.bias$": (
r"layers.layers.\1.attention.key.bias",
(None, (num_heads, head_dim)),
),
r"^vit.encoder.layer.([0-9]+).attention.attention.key.weight$": (
r"layers.layers.\1.attention.key.kernel",
((1, 0), (embed_dim, num_heads, head_dim)),
),
r"^vit.encoder.layer.([0-9]+).attention.attention.query.bias$": (
r"layers.layers.\1.attention.query.bias",
(None, (num_heads, head_dim)),
),
r"^vit.encoder.layer.([0-9]+).attention.attention.query.weight$": (
r"layers.layers.\1.attention.query.kernel",
((1, 0), (embed_dim, num_heads, head_dim)),
),
r"^vit.encoder.layer.([0-9]+).attention.attention.value.bias$": (
r"layers.layers.\1.attention.value.bias",
(None, (num_heads, head_dim)),
),
r"^vit.encoder.layer.([0-9]+).attention.attention.value.weight$": (
r"layers.layers.\1.attention.value.kernel",
((1, 0), (embed_dim, num_heads, head_dim)),
),
r"^vit.encoder.layer.([0-9]+).attention.output.dense.bias$": (r"layers.layers.\1.attention.out.bias", None),
r"^vit.encoder.layer.([0-9]+).attention.output.dense.weight$": (
r"layers.layers.\1.attention.out.kernel",
((1, 0), (num_heads, head_dim, embed_dim)),
),
r"^vit.encoder.layer.([0-9]+).intermediate.dense.bias$": (r"layers.layers.\1.linear1.bias", None),
r"^vit.encoder.layer.([0-9]+).intermediate.dense.weight$": (r"layers.layers.\1.linear1.kernel", ((1, 0), None)),
r"^vit.encoder.layer.([0-9]+).layernorm_after.bias$": (r"layers.layers.\1.layernorm_after.bias", None),
r"^vit.encoder.layer.([0-9]+).layernorm_after.weight$": (r"layers.layers.\1.layernorm_after.scale", None),
r"^vit.encoder.layer.([0-9]+).layernorm_before.bias$": (r"layers.layers.\1.layernorm_before.bias", None),
r"^vit.encoder.layer.([0-9]+).layernorm_before.weight$": (r"layers.layers.\1.layernorm_before.scale", None),
r"^vit.encoder.layer.([0-9]+).output.dense.bias$": (r"layers.layers.\1.linear2.bias", None),
r"^vit.encoder.layer.([0-9]+).output.dense.weight$": (r"layers.layers.\1.linear2.kernel", ((1, 0), None)),
# layernorm
r"^vit.layernorm.bias$": (r"ln.bias", None),
r"^vit.layernorm.weight$": (r"ln.scale", None),
}


def _st_key_to_jax_key(mapping, source_key):
"""Map a safetensors key to exactly one JAX key & transform, else warn/error."""
subs = [
(re.sub(pat, repl, source_key), transform)
for pat, (repl, transform) in mapping.items()
if re.match(pat, source_key)
]
if not subs:
logging.warning(f"No mapping found for key: {source_key!r}")
return None, None
if len(subs) > 1:
keys = [s for s, _ in subs]
raise ValueError(f"Multiple mappings found for {source_key!r}: {keys}")
return subs[0]


def _assign_weights(keys, tensor, state_dict, st_key, transform):
"""Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor."""
key, *rest = keys
if not rest:
if transform is not None:
permute, reshape = transform
if permute:
tensor = tensor.transpose(permute)
if reshape:
tensor = tensor.reshape(reshape)
if tensor.shape != state_dict[key].shape:
raise ValueError(f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}")
state_dict[key] = tensor
else:
_assign_weights(rest, tensor, state_dict[key], st_key, transform)


def _stoi(s):
try:
return int(s)
except ValueError:
return s


def create_vit_from_pretrained(
file_path: str,
num_classes: int = 1000,
*,
mesh: jax.sharding.Mesh | None = None,
):
"""
Load safetensor weights from a file, then convert & merge into a flax.nnx ViT model.

Returns:
A flax.nnx.Model instance with loaded parameters.
"""
state_dict = safetensors.load_file(file_path)

vit = nnx.eval_shape(lambda: model_lib.ViT(rngs=nnx.Rngs(0)))
graph_def, abs_state = nnx.split(vit)
jax_state = abs_state.to_pure_dict()

mapping = _get_key_and_transform_mapping()
for st_key, tensor in state_dict.items():
jax_key, transform = _st_key_to_jax_key(mapping, st_key)
if jax_key is None:
continue
keys = [_stoi(k) for k in jax_key.split(".")]
_assign_weights(keys, tensor, jax_state, st_key, transform)

# if mesh is not None:
# sharding = nnx.get_named_sharding(abs_state, mesh).to_pure_dict()
# jax_state = jax.device_put(jax_state, sharding)
# else:
# jax_state = jax.device_put(jax_state, jax.devices()[0])

return nnx.merge(graph_def, jax_state)
71 changes: 71 additions & 0 deletions bonsai/models/vit/tests/run_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

import jax
import jax.numpy as jnp
from huggingface_hub import snapshot_download

from bonsai.models.vit import modeling as model_lib
from bonsai.models.vit import params


def run_model(MODEL_CP_PATH=None):
# 1. Download safetensors file
model_name = "google/vit-base-patch16-224"
if MODEL_CP_PATH is None:
MODEL_CP_PATH = "/tmp/models-bonsai/" + model_name.split("/")[1]

if not os.path.isdir(MODEL_CP_PATH):
snapshot_download(model_name, local_dir=MODEL_CP_PATH)

safetensors_path = os.path.join(MODEL_CP_PATH, "model.safetensors")

# 2. Load pretrained model
model = params.create_vit_from_pretrained(safetensors_path)

# 3. Prepare dummy input
batch_size, channels, image_size = 8, 3, 224
dummy_input = jnp.ones((batch_size, image_size, image_size, channels), dtype=jnp.float32)

# 4. Warmup + profiling
# Warmup (triggers compilation)
_ = model_lib.forward(model, dummy_input)
jax.block_until_ready(_)

# Profile a few steps
jax.profiler.start_trace("/tmp/profile-vit")
for _ in range(5):
logits = model_lib.forward(model, dummy_input)
jax.block_until_ready(logits)
jax.profiler.stop_trace()

# 5. Timed execution
t0 = time.perf_counter()
for _ in range(10):
logits = model_lib.forward(model, dummy_input)
jax.block_until_ready(logits)
print(f"10 runs took {time.perf_counter() - t0:.4f} s")

# 6. Show top-1 predicted class
pred = jnp.argmax(logits, axis=-1)
print("Predicted classes:", pred)


if __name__ == "__main__":
run_model()

__all__ = ["run_model"]
Loading