Skip to content

[WIP] [SmolLM3] Add Backbone and CausalLM #2327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f2dedc4
add first few utils
DavidLandup0 Jul 14, 2025
1d90715
add eager attention forward
DavidLandup0 Jul 14, 2025
e5a8f33
Add SmolLM3Attention
DavidLandup0 Jul 14, 2025
54191ca
Add SmolLM3MLP
DavidLandup0 Jul 14, 2025
1369733
Add SmolLM3DecoderLayer
DavidLandup0 Jul 14, 2025
2448d80
remove unnecessary comments
DavidLandup0 Jul 14, 2025
598fd74
Add SmolLM3RotaryEmbedding
DavidLandup0 Jul 14, 2025
b9e458d
add most of smollm3backbone
DavidLandup0 Jul 14, 2025
6a53a7d
Fix calls within causal model
DavidLandup0 Jul 16, 2025
81eff73
Move causal mask computation to forward call
DavidLandup0 Jul 16, 2025
b0080f2
Add convert_smollm3.py and update preset loader
DavidLandup0 Jul 16, 2025
d5767c1
Fix causal mask call
DavidLandup0 Jul 16, 2025
186eaf8
Fix conversion weight names
DavidLandup0 Jul 16, 2025
6ab2e5c
remove unnecessary arg
DavidLandup0 Jul 16, 2025
6819fd1
Build all layers
DavidLandup0 Jul 16, 2025
e126938
Remove k and q norms
DavidLandup0 Jul 16, 2025
26511b2
add causal attn mask, a few fixes
DavidLandup0 Jul 16, 2025
d81e831
add softmax op
DavidLandup0 Jul 26, 2025
e07e848
fix build cache shape?
DavidLandup0 Jul 26, 2025
e25fcdd
fix shape positioning in cache update
DavidLandup0 Jul 26, 2025
5a49ed6
Remove position ids as input
DavidLandup0 Jul 26, 2025
89391d9
use sampler's max length
DavidLandup0 Jul 26, 2025
7a9d99c
format
DavidLandup0 Jul 26, 2025
e3067a5
add logs
DavidLandup0 Jul 26, 2025
7622315
switch order or value heads and max length
DavidLandup0 Jul 26, 2025
982a546
oh god please
DavidLandup0 Jul 26, 2025
7319f48
oh god please
DavidLandup0 Jul 26, 2025
3c3d7fb
oh god please
DavidLandup0 Jul 26, 2025
8046d4b
oh god please
DavidLandup0 Jul 26, 2025
2d4a3b5
oh god please
DavidLandup0 Jul 26, 2025
53efb59
god has answered my prayers
DavidLandup0 Jul 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,30 @@
from keras_hub.src.models.siglip.siglip_vision_encoder import (
SigLIPVisionEncoder as SigLIPVisionEncoder,
)
from keras_hub.src.models.smollm3.smollm3_backbone import (
SmolLM3Backbone as SmolLM3Backbone,
)
from keras_hub.src.models.smollm3.smollm3_backbone import (
SmolLM3Backbone as SmolLMBackbone,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
SmolLM3CausalLM as SmolLM3CausalLM,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
SmolLM3CausalLM as SmolLMCausalLM,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
SmolLM3CausalLMPreprocessor as SmolLM3CausalLMPreprocessor,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
SmolLM3CausalLMPreprocessor as SmolLMCausalLMPreprocessor,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLM3Tokenizer,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLMTokenizer,
)
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
StableDiffusion3Backbone as StableDiffusion3Backbone,
)
Expand Down
6 changes: 6 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@
from keras_hub.src.models.siglip.siglip_tokenizer import (
SigLIPTokenizer as SigLIPTokenizer,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLM3Tokenizer,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLMTokenizer,
)
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
from keras_hub.src.models.whisper.whisper_tokenizer import (
WhisperTokenizer as WhisperTokenizer,
Expand Down
186 changes: 186 additions & 0 deletions keras_hub/src/models/smollm3/smollm3_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding


@keras_hub_export(
[
"keras_hub.models.SmolLM3Backbone",
"keras_hub.models.SmolLMBackbone",
]
)
class SmolLM3Backbone(Backbone):
"""
The SmolLM Transformer core architecture with hyperparameters.

This network implements a Transformer-based decoder network,
SmolLM3, as described in the SmolLM3 model architecture.
It includes the embedding lookups and transformer layers.

The default constructor gives a fully customizable, randomly initialized
SmolLM3 model with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the `from_preset`
constructor.

Args:


Examples:

```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}

# Pretrained SmolLM decoder.
model = keras_hub.models.SmolLM3Backbone.from_preset("...")
model(input_data)

# Randomly initialized SmolLM3 decoder with custom config.
model = keras_hub.models.SmolLM3Backbone(
...
)
model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
hidden_dim,
intermediate_dim,
num_layers,
num_attention_heads,
num_key_value_heads,
attention_bias,
attention_dropout,
rope_layer_enabled_list,
layer_types,
mlp_bias,
layer_norm_epsilon,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually there's some of these terms (like the epsilon's and rope theta) that have a consistent value across all the presets we care about, and we give them defaults here. Not super important, just for people that wanted an easier time making a custom small version of the arch or something like that.

max_position_embeddings,
rope_theta,
partial_rotary_factor,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
name="token_embedding",
)
self.transformer_layers = []

for i in range(num_layers):
layer = SmolLM3DecoderLayer(
hidden_size=hidden_dim,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rope_layer_enabled_list=rope_layer_enabled_list,
layer_types=layer_types,
layer_idx=i,
intermediate_size=intermediate_dim,
mlp_bias=mlp_bias,
layer_norm_epsilon=layer_norm_epsilon,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)

self.norm = keras.layers.RMSNormalization(
epsilon=layer_norm_epsilon,
name="sequence_output_layernorm",
)

self.rotary_embedding = SmolLM3RotaryEmbedding(
hidden_size=hidden_dim,
num_attention_heads=num_attention_heads,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
)

# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

# Infer position IDs from the shape of token IDs.
seq_len = ops.shape(token_id_input)[1]
position_ids = ops.arange(0, seq_len, dtype="int32")
# Add a batch dimension to broadcast.
position_ids = ops.expand_dims(position_ids, axis=0)

hidden_states = self.token_embedding(token_id_input)
position_embeddings = self.rotary_embedding(hidden_states, position_ids)

for decoder_layer in self.transformer_layers[:num_layers]:
hidden_states = decoder_layer(
hidden_states,
position_embeddings=position_embeddings,
decoder_padding_mask=padding_mask_input,
**kwargs,
)

sequence_output = self.norm(hidden_states)
super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rope_layer_enabled_list = rope_layer_enabled_list
self.layer_types = layer_types
self.mlp_bias = mlp_bias
self.layer_norm_epsilon = layer_norm_epsilon
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"num_key_value_heads": self.num_key_value_heads,
"attention_bias": self.attention_bias,
"attention_dropout": self.attention_dropout,
"rope_layer_enabled_list": self.rope_layer_enabled_list,
"layer_types": self.layer_types,
"mlp_bias": self.mlp_bias,
"layer_norm_epsilon": self.layer_norm_epsilon,
"max_position_embeddings": self.max_position_embeddings,
"rope_theta": self.rope_theta,
"partial_rotary_factor": self.partial_rotary_factor,
}
)
return config
Loading
Loading