Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Nov 26, 2025

Description

Adds a new notebook that is a tutorial for using TE/JAX quantization in an existing model framework.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new tutorial notebook for quantization integration into an existing model framework
  • Small fixes to quickstart_jax.ipynb to
    • Add missing Out Proj to baseline Flax transformer layer
    • Use TE fused attention instead of unfused backend

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 26, 2025

Greptile Overview

Greptile Summary

This PR adds a comprehensive tutorial notebook for integrating TransformerEngine/JAX quantization into existing model frameworks, along with supporting API improvements and bug fixes to the quickstart tutorial.

Key Changes:

  • New te_jax_integration.ipynb tutorial demonstrating how to add TE quantization to existing Flax models without restructuring code or changing parameter management
  • New public APIs make_dot_general_cls() and wrap_function_in_te_state_module() that enable drop-in quantization support via nn.Dense(..., dot_general=te_dot_general_cls())
  • Fixed missing output projection in the baseline Flax transformer in quickstart_jax.ipynb
  • Refactored utils to support multiple RNG keys (dropout, sr_rng) via generic rngs dict parameter
  • Improved DotProductAttention dtype handling: removed hardcoded dtype parameter and added runtime dtype validation via _assert_dtypes(), making it infer dtypes from inputs
  • Added is_mesh_available() helper to prevent errors when accessing mesh resources without an active mesh

The implementation is clean and well-documented. The tutorial provides clear examples of using different quantization recipes (DelayedScaling, NVFP4BlockScaling, etc.) and explains important considerations like checkpoint policies for TE GEMMs.

Confidence Score: 5/5

  • This PR is safe to merge with no issues found
  • All changes are well-implemented documentation and API improvements. The new tutorial is comprehensive and accurate, code changes are clean with proper error handling, and fixes to the quickstart tutorial improve correctness (adding missing output projection). Previous review concerns have been addressed.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
docs/examples/te_jax_integration.ipynb 5/5 New tutorial notebook demonstrating TE/JAX quantization integration into existing frameworks. Content is clear, well-structured with working examples.
transformer_engine/jax/flax/module.py 5/5 Adds wrap_function_in_te_state_module and make_dot_general_cls functions to enable quantization for existing models. Clean implementation.
transformer_engine/jax/flax/transformer.py 5/5 Removes hardcoded dtype parameter, adds dtype validation via _assert_dtypes, and uses runtime dtypes from inputs instead. Improves flexibility.
docs/examples/quickstart_jax.ipynb 5/5 Fixes missing output projection in baseline transformer, switches to TE fused attention, and updates to use generic rngs dict instead of single dropout_key.
docs/examples/quickstart_jax_utils.py 5/5 Refactors speedometer to support multiple RNG keys via rngs dict parameter, adding _split_step_rngs helper for proper RNG management.

Sequence Diagram

sequenceDiagram
    participant User
    participant Recipe as Quantization Recipe
    participant MakeDotGeneral as make_dot_general_cls
    participant TEWrapper as TEWrapper Module
    participant FlaxDense as nn.Dense
    participant TEDense as te.dense.dense
    participant Quantizer as QuantizerSet

    User->>Recipe: Create recipe (e.g., DelayedScaling())
    User->>MakeDotGeneral: make_dot_general_cls(recipe)
    MakeDotGeneral->>TEWrapper: wrap_function_in_te_state_module(te_dot_general)
    TEWrapper-->>MakeDotGeneral: Return TEWrapper class
    MakeDotGeneral-->>User: Return te_dot_general_cls
    
    User->>FlaxDense: nn.Dense(..., dot_general=te_dot_general_cls())
    
    Note over FlaxDense,TEDense: Forward Pass
    FlaxDense->>TEWrapper: Call dot_general(x, kernel, dims)
    TEWrapper->>TEWrapper: generate_quantizer_set()
    TEWrapper->>Quantizer: Create quantizer with recipe state
    Quantizer-->>TEWrapper: Return quantizer_set
    TEWrapper->>TEDense: dense(x, kernel, contracting_dims, quantizer_set)
    TEDense->>TEDense: Quantize inputs using quantizer_set
    TEDense->>TEDense: Execute quantized GEMM
    TEDense-->>TEWrapper: Return output
    TEWrapper-->>FlaxDense: Return output
    FlaxDense-->>User: Forward result
    
    Note over TEWrapper,Quantizer: Recipe State Management
    TEWrapper->>TEWrapper: Store amax_history, scales in Flax variables
    User->>User: Pass full var_collect (params + state) to next step
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +234 to +236
"utils.speedometer(\n",
" model_apply_fn=flax_transformer.apply,\n",
" variables=params,\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Unresolved TODO comment about fixing sr_rng setup for NVFP4. Should this be resolved before publishing the tutorial?

"cell_type": "code",
"execution_count": 3,
"id": "8b44649d",
"metadata": {},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: Typo in comment: signma should be sigma

Suggested change
"metadata": {},
# grad_target = derivative of L (loss fn) over y (output) = sigma(L)/sigma(y)

Signed-off-by: Jeremy Berchtold <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

This method does a couple things:
1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Docstring incorrectly mentions MaxText-specific behavior. This is generic TE functionality, not MaxText-specific.

Suggested change
1. Wraps the given function in a context that specifies MaxText's physical mesh axes to TransformerEngine. This ensures our collective operations in TransformerEngine are using the correct axes.
1. Wraps the given function in a Flax linen module context to support TransformerEngine quantization operations.



def make_dot_general_cls(quantization_recipe):
"""Placeholder for dot_general implementation in subclasses."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Misleading docstring - says "Placeholder" but this is a complete implementation, not a placeholder.

Suggested change
"""Placeholder for dot_general implementation in subclasses."""
"""Creates a dot_general_cls that wraps JAX dense operations with TransformerEngine quantization."""

Signed-off-by: Jeremy Berchtold <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. docs/examples/quickstart_jax_utils.py, line 79 (link)

    syntax: Typo: signma should be sigma

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@jberchtold-nvidia jberchtold-nvidia changed the title [Draft][JAX] Tutorial for integration TE/JAX quantization into an existing framework [JAX] Tutorial for integration TE/JAX quantization into an existing framework Nov 26, 2025
@jberchtold-nvidia jberchtold-nvidia changed the title [JAX] Tutorial for integration TE/JAX quantization into an existing framework [JAX] Add tutorial for integrating TE/JAX quantization into an existing framework Nov 26, 2025
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

" self_attn_bias_type='no_bias',\n",
" hidden_dropout=0.0\n",
" hidden_dropout=0.0,\n",
" dtype=jnp.bfloat16,\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FIXME: This dtype=jnp.bfloat16 may make the Dense layers model params bfloat16 instead of just attention. Need to fix attention defaulting to fp32 unfused backend even if no bias is available first

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to also change this in the TEFusedTransformerLayer?

Copy link
Collaborator Author

@jberchtold-nvidia jberchtold-nvidia Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEFusedTransformerLayer is okay as is because I directly set dtype=bf16 on TE's DotProductAttention, which does not contain any Dense layers. The issue with the full TransformerLayer is setting the dtype on it will likely set the dtype of both its inner DotProductAttention and any Dense layers (QKV proj, Out Proj, MLP ffn1 and ffn2)

I realized the dtype on DotProductAttention is unnecessary because it doesn't allocate any model params itself. It just performs attention based on the input Q, K, and V (and optional bias) which are already in bf16. So I'll probably try to merge that PR first and I can remove any usages of dtype=bf16 here altogether

@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. docs/examples/quickstart_jax_utils.py, line 79 (link)

    syntax: Typo in comment: signma should be sigma

7 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/te-quant-integration-tutorial branch from d1f3dcd to e107f66 Compare December 8, 2025 17:47
Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/te-quant-integration-tutorial branch from e107f66 to 7d1f2fb Compare December 8, 2025 17:48
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Jeremy Berchtold <[email protected]>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/te-quant-integration-tutorial branch from acdc365 to 3550b96 Compare December 8, 2025 18:30
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants