Skip to content

Conversation

marlinfiggins
Copy link
Collaborator

Currently working on the run_model.py

@jenriver
Copy link
Member

jenriver commented Aug 8, 2025

Thanks Marlin, let's wait for the run_model.py and config test grid until we merge the PR.

Copy link
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Added some comments.

return logits, state


out_tokens, state_final = modeling.generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are there 3 differentmodeling.generate runs of different gen/block lengths?

Also, is this a warmup run? If so, one step would suffice.

Having step in an explicit for loop like Qwen3 would make this much simpler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE: Also, is this a warmup run? If so, one step would suffice.

To be more specific, a jit-compiled function with non static_argnums set as steps should suffice with a warmup run of steps=1, as long as the function is not compiled again. This is another reason it'd be better to separate steps outside of this generate function in an explicit for loop.



class BlockType(str, Enum):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove unnecessary whitespace. Also let's stick with either enum.auto() or str to keep consistency throughout the entire file.


if ln_type == LayerNormType.GEMMA_RMS:
return GemmaRMSNorm(
dim,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can fit in oneline with 120char limit, please maximize 120char limit. Ditto all else where applicable, including before ValueError.

return GemmaRMSNorm(dim, epsilon=eps, use_bias=use_bias, rngs=rngs).

Please also remove following unnecessary whitespace all where applicable.

# Tokenization helper
def tokenize(tokenizer, inputs: list[str]):
pad_id = tokenizer.pad_token_id
if pad_id is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about pad_id = pad_id or tokenizer.eos_token_id # fall back to eos

# Model + params
cfg = modeling.ModelConfig.llada_8b_instruct()

# Load weights into a constructed model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove self explanatory comments, ditto all else where applicable

graphdef, state = nnx.split(model)


def model_step(tokens, state):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just have a single forward which is jitted and has all the key logics in generate? Similar to what we had in Qwen3.

],
)
def generate(
model_step,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pleaes add typehints all where applicable. However we want to factor out the step here as in other comments.

return (x, state, rng), None

(x, final_state, _), _ = lax.scan(do_block, (x, init_state, rng), xs=jnp.arange(num_blocks, dtype=jnp.int32))
return x, final_state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor the generate steps out into an explicit for loop? Something like what we have in Qwen3.

Current implemenation makes it difficult to benchmark performance per step and separate static argnums from jax.

jax.profiler.stop_trace()

# Profile a single run, print FLOPs, memory
jax.profiler.profile(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



# Demo queries
queries = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we test with multiple queries of varying size to ensure batch is correct?

ex:
queries = ["Why is the sky blue instead of any other color like purple?", "Who am I?"]

num_traced_runs=1,
num_profiled_runs=1,
)
print("Generated token IDs:", out_tokens)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we decode the output as this is an example model run to test whether the output makes sense?

Example:


class LLaDAOutput(NamedTuple):
logits: jax.Array
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can simplify comments in this class.

@jenriver
Copy link
Member

jenriver commented Oct 6, 2025

Can you also add a quality logits checker (ex: test_outputs.py) to verify that the resulting logits are correct?

@jenriver
Copy link
Member

Hello, please make sure to have a test_outputs.py to make sure the implementations above are correct.

Ref:

HF_REPO = "GSAI-ML/LLaDA-8B-Instruct"


def _tokenize_chat_batch(tokenizer, prompts):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep consistency with tokenize in run_model.py as discussed offline.

Otherwise this PR looks ready.

Copy link
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this implementation!

Looks mostly good, minor stylistic changes can be addressed in follow-up CL's.

@jenriver jenriver merged commit 983f756 into main Oct 20, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants