-
Notifications
You must be signed in to change notification settings - Fork 292
[WIP] [SmolLM3] Add Backbone and CausalLM #2327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
DavidLandup0
wants to merge
31
commits into
keras-team:master
Choose a base branch
from
DavidLandup0:smollm3
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
f2dedc4
add first few utils
DavidLandup0 1d90715
add eager attention forward
DavidLandup0 e5a8f33
Add SmolLM3Attention
DavidLandup0 54191ca
Add SmolLM3MLP
DavidLandup0 1369733
Add SmolLM3DecoderLayer
DavidLandup0 2448d80
remove unnecessary comments
DavidLandup0 598fd74
Add SmolLM3RotaryEmbedding
DavidLandup0 b9e458d
add most of smollm3backbone
DavidLandup0 6a53a7d
Fix calls within causal model
DavidLandup0 81eff73
Move causal mask computation to forward call
DavidLandup0 b0080f2
Add convert_smollm3.py and update preset loader
DavidLandup0 d5767c1
Fix causal mask call
DavidLandup0 186eaf8
Fix conversion weight names
DavidLandup0 6ab2e5c
remove unnecessary arg
DavidLandup0 6819fd1
Build all layers
DavidLandup0 e126938
Remove k and q norms
DavidLandup0 26511b2
add causal attn mask, a few fixes
DavidLandup0 d81e831
add softmax op
DavidLandup0 e07e848
fix build cache shape?
DavidLandup0 e25fcdd
fix shape positioning in cache update
DavidLandup0 5a49ed6
Remove position ids as input
DavidLandup0 89391d9
use sampler's max length
DavidLandup0 7a9d99c
format
DavidLandup0 e3067a5
add logs
DavidLandup0 7622315
switch order or value heads and max length
DavidLandup0 982a546
oh god please
DavidLandup0 7319f48
oh god please
DavidLandup0 3c3d7fb
oh god please
DavidLandup0 8046d4b
oh god please
DavidLandup0 2d4a3b5
oh god please
DavidLandup0 53efb59
god has answered my prayers
DavidLandup0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
ReversibleEmbedding, | ||
) | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer | ||
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.models.SmolLM3Backbone", | ||
"keras_hub.models.SmolLMBackbone", | ||
] | ||
) | ||
class SmolLM3Backbone(Backbone): | ||
""" | ||
The SmolLM Transformer core architecture with hyperparameters. | ||
|
||
This network implements a Transformer-based decoder network, | ||
SmolLM3, as described in the SmolLM3 model architecture. | ||
It includes the embedding lookups and transformer layers. | ||
|
||
The default constructor gives a fully customizable, randomly initialized | ||
SmolLM3 model with any number of layers, heads, and embedding | ||
dimensions. To load preset architectures and weights, use the `from_preset` | ||
constructor. | ||
|
||
Args: | ||
|
||
|
||
Examples: | ||
|
||
```python | ||
input_data = { | ||
"token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), | ||
} | ||
|
||
# Pretrained SmolLM decoder. | ||
model = keras_hub.models.SmolLM3Backbone.from_preset("...") | ||
model(input_data) | ||
|
||
# Randomly initialized SmolLM3 decoder with custom config. | ||
model = keras_hub.models.SmolLM3Backbone( | ||
... | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocabulary_size, | ||
hidden_dim, | ||
intermediate_dim, | ||
num_layers, | ||
num_attention_heads, | ||
num_key_value_heads, | ||
attention_bias, | ||
attention_dropout, | ||
rope_layer_enabled_list, | ||
layer_types, | ||
mlp_bias, | ||
layer_norm_epsilon, | ||
max_position_embeddings, | ||
rope_theta, | ||
partial_rotary_factor, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.token_embedding = ReversibleEmbedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
name="token_embedding", | ||
) | ||
self.transformer_layers = [] | ||
|
||
for i in range(num_layers): | ||
layer = SmolLM3DecoderLayer( | ||
hidden_size=hidden_dim, | ||
num_attention_heads=num_attention_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
attention_bias=attention_bias, | ||
attention_dropout=attention_dropout, | ||
rope_layer_enabled_list=rope_layer_enabled_list, | ||
layer_types=layer_types, | ||
layer_idx=i, | ||
intermediate_size=intermediate_dim, | ||
mlp_bias=mlp_bias, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
name=f"transformer_layer_{i}", | ||
) | ||
self.transformer_layers.append(layer) | ||
|
||
self.norm = keras.layers.RMSNormalization( | ||
epsilon=layer_norm_epsilon, | ||
name="sequence_output_layernorm", | ||
) | ||
|
||
self.rotary_embedding = SmolLM3RotaryEmbedding( | ||
hidden_size=hidden_dim, | ||
num_attention_heads=num_attention_heads, | ||
max_position_embeddings=max_position_embeddings, | ||
rope_theta=rope_theta, | ||
partial_rotary_factor=partial_rotary_factor, | ||
) | ||
|
||
# === Functional Model === | ||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
|
||
padding_mask_input = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
|
||
# Infer position IDs from the shape of token IDs. | ||
seq_len = ops.shape(token_id_input)[1] | ||
position_ids = ops.arange(0, seq_len, dtype="int32") | ||
# Add a batch dimension to broadcast. | ||
position_ids = ops.expand_dims(position_ids, axis=0) | ||
|
||
hidden_states = self.token_embedding(token_id_input) | ||
position_embeddings = self.rotary_embedding(hidden_states, position_ids) | ||
|
||
for decoder_layer in self.transformer_layers[:num_layers]: | ||
hidden_states = decoder_layer( | ||
hidden_states, | ||
position_embeddings=position_embeddings, | ||
decoder_padding_mask=padding_mask_input, | ||
**kwargs, | ||
) | ||
|
||
sequence_output = self.norm(hidden_states) | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_id_input, | ||
"padding_mask": padding_mask_input, | ||
}, | ||
outputs=sequence_output, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.vocabulary_size = vocabulary_size | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.num_layers = num_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.attention_bias = attention_bias | ||
self.attention_dropout = attention_dropout | ||
self.rope_layer_enabled_list = rope_layer_enabled_list | ||
self.layer_types = layer_types | ||
self.mlp_bias = mlp_bias | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.max_position_embeddings = max_position_embeddings | ||
self.rope_theta = rope_theta | ||
self.partial_rotary_factor = partial_rotary_factor | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"num_layers": self.num_layers, | ||
"num_attention_heads": self.num_attention_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"attention_bias": self.attention_bias, | ||
"attention_dropout": self.attention_dropout, | ||
"rope_layer_enabled_list": self.rope_layer_enabled_list, | ||
"layer_types": self.layer_types, | ||
"mlp_bias": self.mlp_bias, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"max_position_embeddings": self.max_position_embeddings, | ||
"rope_theta": self.rope_theta, | ||
"partial_rotary_factor": self.partial_rotary_factor, | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually there's some of these terms (like the epsilon's and rope theta) that have a consistent value across all the presets we care about, and we give them defaults here. Not super important, just for people that wanted an easier time making a custom small version of the arch or something like that.