A reusable collection of high-performance neural network layers and models for JAX, aiming to match and exceed the capabilities available in the PyTorch ecosystem.
JAX Layers was created to provide the JAX ecosystem with a comprehensive library of well-documented, thoroughly tested, and numerically accurate implementations of neural network layers and models. The project aims to:
- Provide both functional APIs and Flax NNX wrappers for maximum flexibility
- Ensure seamless integration with the broader JAX ecosystem, especially Flax
- Facilitate easy upstreaming of implementations to core libraries
- Maintain rigorous testing and documentation standards
- Match or exceed the performance of equivalent PyTorch implementations
Initially started within the ML GDE group, the project began with a high-performance MultiHeadAttention implementation supporting various attention backends, with plans to expand to more layers and models.
- MultiHeadAttention: A Flax NNX-compatible implementation with support for different attention backends.
- Supports JAX's native Flash Attention implementation through cuDNN
- Seamlessly integrates with Flax NNX's module system
- Provides a simple interface for switching between attention implementations
# Install from source
git clone https://github.com/ml-gde/jax-layers.git
cd jax-layers
pip install -e .
import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax_layers.attention import MultiHeadAttention
# Create a MultiHeadAttention module with Flash Attention support
attention = MultiHeadAttention(
num_heads=8,
in_features=512,
implementation="cudnn", # Use cuDNN's Flash Attention if available
rngs=nnx.Rngs(0),
)
# Create input data
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (2, 128, 512)) # (batch, seq_length, hidden_dim)
# Create a causal attention mask
mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)
# Apply the model
output = attention(x, mask=mask)
import jax
import jax.numpy as jnp
from jax_layers.functional import dot_product_attention
# Create random query, key, value tensors
key = jax.random.PRNGKey(0)
query = jax.random.normal(key, (2, 128, 8, 64)) # (batch, seq_len, heads, head_dim)
key_tensor = jax.random.normal(key, (2, 128, 8, 64))
value = jax.random.normal(key, (2, 128, 8, 64))
# Create a causal attention mask
mask = jnp.tril(jnp.ones((2, 1, 128, 128))) # (batch, 1, q_len, kv_len)
# Apply dot product attention with Flash Attention implementation
output = dot_product_attention(
query=query,
key=key_tensor,
value=value,
mask=mask,
implementation="cudnn", # Use cuDNN's Flash Attention implementation
)
- Please fork the repository to your account first.
- Follow the instructions below.
# Clone the repository
git clone https://github.com/yourusername/jax-layers.git
cd jax-layers
# Install development dependencies
pip install -e ".[dev]"
The project maintains a comprehensive test suite to ensure correctness and numerical accuracy:
# Run all tests
pytest
# Run tests with coverage
pytest tests/ --cov=jax_layers
# Run specific test file
pytest tests/test_multi_head_attention.py
We maintain high code quality standards through automated checks:
# Run linting
ruff check .
# Run type checking
mypy jax_layers
# Run tests
pytest
Documentation is automatically generated from docstrings:
# Build documentation
cd docs
make html
Since JAX doesn't support CUDA on Windows natively, we provide a development container configuration:
- Install Docker Desktop with WSL 2 backend
- Install NVIDIA Container Toolkit
- Install Visual Studio Code with the Remote - Containers extension
- Open the project in VS Code
- Click the green icon in the bottom-left corner and select "Reopen in Container"
The container provides:
- Python 3.10
- CUDA 12.4 with cuDNN 9
- JAX with CUDA support
- All dependencies from your pyproject.toml
See .devcontainer/README.md for more details.
Contributions are more than welcome! Whether it's:
- Adding new layer implementations
- Improving documentation
- Adding tests
- Reporting bugs
- Suggesting improvements
Please feel free to open issues and pull requests.
This project is licensed under the MIT License - see the LICENSE file for details.
- Thanks to the JAX and Flax teams for their excellent libraries.
- Special thanks to the ML GDE group for initiating this project.