Skip to content

A collection of reusable, high-performance, well-documented, thorough-tested layers and models in Jax

License

Notifications You must be signed in to change notification settings

deep-diver/jax-layers

 
 

Repository files navigation

JAX Layers

Logo

A reusable collection of high-performance neural network layers and models for JAX, aiming to match and exceed the capabilities available in the PyTorch ecosystem.

Motivation

JAX Layers was created to provide the JAX ecosystem with a comprehensive library of well-documented, thoroughly tested, and numerically accurate implementations of neural network layers and models. The project aims to:

  • Provide both functional APIs and Flax NNX wrappers for maximum flexibility
  • Ensure seamless integration with the broader JAX ecosystem, especially Flax
  • Facilitate easy upstreaming of implementations to core libraries
  • Maintain rigorous testing and documentation standards
  • Match or exceed the performance of equivalent PyTorch implementations

Initially started within the ML GDE group, the project began with a high-performance MultiHeadAttention implementation supporting various attention backends, with plans to expand to more layers and models.

Features

  • MultiHeadAttention: A Flax NNX-compatible implementation with support for different attention backends.
    • Supports JAX's native Flash Attention implementation through cuDNN
    • Seamlessly integrates with Flax NNX's module system
    • Provides a simple interface for switching between attention implementations

Installation

# Install from source
git clone https://github.com/ml-gde/jax-layers.git
cd jax-layers
pip install -e .

Usage

MultiHeadAttention Module (Flax NNX)

import jax
import jax.numpy as jnp
import flax.nnx as nnx
from jax_layers.attention import MultiHeadAttention

# Create a MultiHeadAttention module with Flash Attention support
attention = MultiHeadAttention(
    num_heads=8,
    in_features=512,
    implementation="cudnn",  # Use cuDNN's Flash Attention if available
    rngs=nnx.Rngs(0),
)

# Create input data
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (2, 128, 512))  # (batch, seq_length, hidden_dim)

# Create a causal attention mask
mask = jnp.tril(jnp.ones((2, 1, 128, 128)))  # (batch, 1, q_len, kv_len)

# Apply the model
output = attention(x, mask=mask)

Functional API

Dot Product Attention with Implementation Selection

import jax
import jax.numpy as jnp
from jax_layers.functional import dot_product_attention

# Create random query, key, value tensors
key = jax.random.PRNGKey(0)
query = jax.random.normal(key, (2, 128, 8, 64))  # (batch, seq_len, heads, head_dim)
key_tensor = jax.random.normal(key, (2, 128, 8, 64))
value = jax.random.normal(key, (2, 128, 8, 64))

# Create a causal attention mask
mask = jnp.tril(jnp.ones((2, 1, 128, 128)))  # (batch, 1, q_len, kv_len)

# Apply dot product attention with Flash Attention implementation
output = dot_product_attention(
    query=query,
    key=key_tensor,
    value=value,
    mask=mask,
    implementation="cudnn",  # Use cuDNN's Flash Attention implementation
)

Development

Setup

  1. Please fork the repository to your account first.
  2. Follow the instructions below.
# Clone the repository
git clone https://github.com/yourusername/jax-layers.git
cd jax-layers

# Install development dependencies
pip install -e ".[dev]"

Testing

The project maintains a comprehensive test suite to ensure correctness and numerical accuracy:

# Run all tests
pytest

# Run tests with coverage
pytest tests/ --cov=jax_layers

# Run specific test file
pytest tests/test_multi_head_attention.py

Code Quality

We maintain high code quality standards through automated checks:

# Run linting
ruff check .

# Run type checking
mypy jax_layers

# Run tests
pytest

Documentation

Documentation is automatically generated from docstrings:

# Build documentation
cd docs
make html

Development Container (for Windows users)

Since JAX doesn't support CUDA on Windows natively, we provide a development container configuration:

  1. Install Docker Desktop with WSL 2 backend
  2. Install NVIDIA Container Toolkit
  3. Install Visual Studio Code with the Remote - Containers extension
  4. Open the project in VS Code
  5. Click the green icon in the bottom-left corner and select "Reopen in Container"

The container provides:

  • Python 3.10
  • CUDA 12.4 with cuDNN 9
  • JAX with CUDA support
  • All dependencies from your pyproject.toml

See .devcontainer/README.md for more details.

Contributing

Contributions are more than welcome! Whether it's:

  • Adding new layer implementations
  • Improving documentation
  • Adding tests
  • Reporting bugs
  • Suggesting improvements

Please feel free to open issues and pull requests.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgements

  • Thanks to the JAX and Flax teams for their excellent libraries.
  • Special thanks to the ML GDE group for initiating this project.

About

A collection of reusable, high-performance, well-documented, thorough-tested layers and models in Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 93.4%
  • Dockerfile 6.6%