SGL-JAX is a high-performance, JAX-based inference engine for Large Language Models (LLMs), specifically optimized for Google TPUs. It is engineered from the ground up to deliver exceptional throughput and low latency for the most demanding LLM serving workloads.
The engine incorporates state-of-the-art techniques to maximize hardware utilization and serving efficiency, making it ideal for deploying large-scale models in production on TPUs.
- High-Throughput Continuous Batching: Implements a sophisticated scheduler that dynamically batches incoming requests, maximizing TPU utilization and overall throughput.
- Optimized KV Cache with Radix Tree: Utilizes a Radix Tree for KV cache management (conceptually similar to PagedAttention), enabling memory-efficient prefix sharing between requests and significantly reducing computation for prompts with common prefixes.
- FlashAttention Integration: Leverages a high-performance FlashAttention kernel for faster and more memory-efficient attention calculations, crucial for long sequences.
- Tensor Parallelism: Natively supports tensor parallelism to distribute large models across multiple TPU devices, enabling inference for models that exceed the memory of a single accelerator.
- OpenAI-Compatible API: Provides a drop-in replacement for the OpenAI API, allowing for seamless integration with a wide range of existing clients, SDKs, and tools (e.g., LangChain, LlamaIndex).
- Native Qwen Support: Includes first-class, optimized support for the Qwen model family, including recent Mixture-of-Experts (MoE) variants.
SGL-JAX operates on a distributed architecture designed for scalability and performance:
- HTTP Server: The entry point for all requests, compatible with the OpenAI API standard.
- Scheduler: The core of the engine. It receives requests, manages prompts, and schedules token generation in batches. It intelligently groups requests to form optimal batches for the model executor.
- TP Worker (Tensor Parallel Worker): A set of distributed workers that host the model weights, distributed via tensor parallelism. They execute the forward pass for the model.
- Model Runner: Manages the actual JAX-based model execution, including the forward pass, attention computation, and KV cache operations.
- Radix Cache: A global, memory-efficient KV cache that is shared across all requests, enabling prefix reuse and reducing the memory footprint.
For more features and usage details, please read the documents in the docs
directory.
SGL-JAX is designed for easy extension to new model architectures. It currently provides first-class, optimized support for:
- Qwen
- Qwen 3
- Qwen 3 MoE
For detailed performance evaluation and to run the benchmarks yourself, please see the scripts located in the benchmark/
and python/sgl_jax/
directories (e.g., bench_serving.py
).
The project includes a comprehensive test suite to ensure correctness and stability. To run the full suite of tests:
cd test/srt
python run_suite.py
Contributions are welcome! If you would like to contribute, please feel free to open an issue to discuss your ideas or submit a pull request.