-
Notifications
You must be signed in to change notification settings - Fork 572
[JAX] Add tutorial for integrating TE/JAX quantization into an existing framework #2423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[JAX] Add tutorial for integrating TE/JAX quantization into an existing framework #2423
Conversation
Signed-off-by: Jeremy Berchtold <[email protected]>
Greptile OverviewGreptile SummaryThis PR adds a comprehensive tutorial notebook for integrating TransformerEngine/JAX quantization into existing model frameworks, along with supporting API improvements and bug fixes to the quickstart tutorial. Key Changes:
The implementation is clean and well-documented. The tutorial provides clear examples of using different quantization recipes (DelayedScaling, NVFP4BlockScaling, etc.) and explains important considerations like checkpoint policies for TE GEMMs. Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Recipe as Quantization Recipe
participant MakeDotGeneral as make_dot_general_cls
participant TEWrapper as TEWrapper Module
participant FlaxDense as nn.Dense
participant TEDense as te.dense.dense
participant Quantizer as QuantizerSet
User->>Recipe: Create recipe (e.g., DelayedScaling())
User->>MakeDotGeneral: make_dot_general_cls(recipe)
MakeDotGeneral->>TEWrapper: wrap_function_in_te_state_module(te_dot_general)
TEWrapper-->>MakeDotGeneral: Return TEWrapper class
MakeDotGeneral-->>User: Return te_dot_general_cls
User->>FlaxDense: nn.Dense(..., dot_general=te_dot_general_cls())
Note over FlaxDense,TEDense: Forward Pass
FlaxDense->>TEWrapper: Call dot_general(x, kernel, dims)
TEWrapper->>TEWrapper: generate_quantizer_set()
TEWrapper->>Quantizer: Create quantizer with recipe state
Quantizer-->>TEWrapper: Return quantizer_set
TEWrapper->>TEDense: dense(x, kernel, contracting_dims, quantizer_set)
TEDense->>TEDense: Quantize inputs using quantizer_set
TEDense->>TEDense: Execute quantized GEMM
TEDense-->>TEWrapper: Return output
TEWrapper-->>FlaxDense: Return output
FlaxDense-->>User: Forward result
Note over TEWrapper,Quantizer: Recipe State Management
TEWrapper->>TEWrapper: Store amax_history, scales in Flax variables
User->>User: Pass full var_collect (params + state) to next step
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
| "utils.speedometer(\n", | ||
| " model_apply_fn=flax_transformer.apply,\n", | ||
| " variables=params,\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Unresolved TODO comment about fixing sr_rng setup for NVFP4. Should this be resolved before publishing the tutorial?
| "cell_type": "code", | ||
| "execution_count": 3, | ||
| "id": "8b44649d", | ||
| "metadata": {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Typo in comment: signma should be sigma
| "metadata": {}, | |
| # grad_target = derivative of L (loss fn) over y (output) = sigma(L)/sigma(y) |
Signed-off-by: Jeremy Berchtold <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, no comments
…oat16 cast Signed-off-by: Jeremy Berchtold <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| This method does a couple things: | ||
| 1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Docstring incorrectly mentions MaxText-specific behavior. This is generic TE functionality, not MaxText-specific.
| 1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes. | |
| 1. Wraps the given function in a Flax linen module context to support TransformerEngine quantization operations. |
|
|
||
|
|
||
| def make_dot_general_cls(quantization_recipe): | ||
| """Placeholder for dot_general implementation in subclasses.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Misleading docstring - says "Placeholder" but this is a complete implementation, not a placeholder.
| """Placeholder for dot_general implementation in subclasses.""" | |
| """Creates a dot_general_cls that wraps JAX dense operations with TransformerEngine quantization.""" |
Signed-off-by: Jeremy Berchtold <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
docs/examples/quickstart_jax_utils.py, line 79 (link)syntax: Typo:
signmashould besigma
5 files reviewed, 1 comment
|
/te-ci jax |
…e-quant-integration-tutorial Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Jeremy Berchtold <[email protected]>
Signed-off-by: Jeremy Berchtold <[email protected]>
for more information, see https://pre-commit.ci
docs/examples/quickstart_jax.ipynb
Outdated
| " self_attn_bias_type='no_bias',\n", | ||
| " hidden_dropout=0.0\n", | ||
| " hidden_dropout=0.0,\n", | ||
| " dtype=jnp.bfloat16,\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FIXME: This dtype=jnp.bfloat16 may make the Dense layers model params bfloat16 instead of just attention. Need to fix attention defaulting to fp32 unfused backend even if no bias is available first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to also change this in the TEFusedTransformerLayer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TEFusedTransformerLayer is okay as is because I directly set dtype=bf16 on TE's DotProductAttention, which does not contain any Dense layers. The issue with the full TransformerLayer is setting the dtype on it will likely set the dtype of both its inner DotProductAttention and any Dense layers (QKV proj, Out Proj, MLP ffn1 and ffn2)
I realized the dtype on DotProductAttention is unnecessary because it doesn't allocate any model params itself. It just performs attention based on the input Q, K, and V (and optional bias) which are already in bf16. So I'll probably try to merge that PR first and I can remove any usages of dtype=bf16 here altogether
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
docs/examples/quickstart_jax_utils.py, line 79 (link)syntax: Typo in comment:
signmashould besigma
7 files reviewed, 1 comment
Signed-off-by: Jeremy Berchtold <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, no comments
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
Signed-off-by: Jeremy Berchtold <[email protected]>
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
Signed-off-by: Jeremy Berchtold <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, no comments
d1f3dcd to
e107f66
Compare
Signed-off-by: Jeremy Berchtold <[email protected]>
e107f66 to
7d1f2fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, no comments
Signed-off-by: Jeremy Berchtold <[email protected]>
acdc365 to
3550b96
Compare
|
/te-ci jax L1 |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, no comments
Description
Adds a new notebook that is a tutorial for using TE/JAX quantization in an existing model framework.
Type of change
Changes
Checklist: