-
Notifications
You must be signed in to change notification settings - Fork 16
First pass at LLaDA-8b implementation #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks Marlin, let's wait for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! Added some comments.
return logits, state | ||
|
||
|
||
out_tokens, state_final = modeling.generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there 3 differentmodeling.generate
runs of different gen/block lengths?
Also, is this a warmup run? If so, one step would suffice.
Having step
in an explicit for loop like Qwen3 would make this much simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RE: Also, is this a warmup run? If so, one step would suffice.
To be more specific, a jit-compiled function with non static_argnums set as steps
should suffice with a warmup run of steps=1
, as long as the function is not compiled again. This is another reason it'd be better to separate steps
outside of this generate function in an explicit for loop.
|
||
|
||
class BlockType(str, Enum): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove unnecessary whitespace. Also let's stick with either enum.auto() or str to keep consistency throughout the entire file.
bonsai/models/llada_8b/modeling.py
Outdated
|
||
if ln_type == LayerNormType.GEMMA_RMS: | ||
return GemmaRMSNorm( | ||
dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can fit in oneline with 120char limit, please maximize 120char limit. Ditto all else where applicable, including before ValueError
.
return GemmaRMSNorm(dim, epsilon=eps, use_bias=use_bias, rngs=rngs)
.
Please also remove following unnecessary whitespace all where applicable.
# Tokenization helper | ||
def tokenize(tokenizer, inputs: list[str]): | ||
pad_id = tokenizer.pad_token_id | ||
if pad_id is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about pad_id = pad_id or tokenizer.eos_token_id # fall back to eos
# Model + params | ||
cfg = modeling.ModelConfig.llada_8b_instruct() | ||
|
||
# Load weights into a constructed model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove self explanatory comments, ditto all else where applicable
graphdef, state = nnx.split(model) | ||
|
||
|
||
def model_step(tokens, state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just have a single forward which is jitted and has all the key logics in generate? Similar to what we had in Qwen3.
bonsai/models/llada_8b/modeling.py
Outdated
], | ||
) | ||
def generate( | ||
model_step, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pleaes add typehints all where applicable. However we want to factor out the step here as in other comments.
bonsai/models/llada_8b/modeling.py
Outdated
return (x, state, rng), None | ||
|
||
(x, final_state, _), _ = lax.scan(do_block, (x, init_state, rng), xs=jnp.arange(num_blocks, dtype=jnp.int32)) | ||
return x, final_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor the generate steps out into an explicit for loop? Something like what we have in Qwen3.
Current implemenation makes it difficult to benchmark performance per step and separate static argnums from jax.
jax.profiler.stop_trace() | ||
|
||
# Profile a single run, print FLOPs, memory | ||
jax.profiler.profile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a valid method? I don't think it exists here
https://github.com/jax-ml/jax/blob/52383198d74ffdc3b5c3bef86cfd64f02ab9dbb5/jax/profiler.py#L17
|
||
|
||
# Demo queries | ||
queries = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we test with multiple queries of varying size to ensure batch is correct?
ex:
queries = ["Why is the sky blue instead of any other color like purple?", "Who am I?"]
num_traced_runs=1, | ||
num_profiled_runs=1, | ||
) | ||
print("Generated token IDs:", out_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we decode the output as this is an example model run to test whether the output makes sense?
Example:
bonsai/models/llada_8b/modeling.py
Outdated
|
||
class LLaDAOutput(NamedTuple): | ||
logits: jax.Array | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can simplify comments in this class.
Renaming
Can you also add a quality logits checker (ex: test_outputs.py) to verify that the resulting logits are correct? |
Hello, please make sure to have a Ref: |
d6abf89
to
175959a
Compare
HF_REPO = "GSAI-ML/LLaDA-8B-Instruct" | ||
|
||
|
||
def _tokenize_chat_batch(tokenizer, prompts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please keep consistency with tokenize
in run_model.py
as discussed offline.
Otherwise this PR looks ready.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this implementation!
Looks mostly good, minor stylistic changes can be addressed in follow-up CL's.
Currently working on the run_model.py