Skip to content
@jax-ml

jax-ml

Pushing back the limits on numerical computing.

Pinned Loading

  1. jax jax Public

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

    Python 33.7k 3.2k

  2. jax-llm-examples jax-llm-examples Public

    Minimal yet performant LLM examples in pure JAX

    Python 185 23

  3. jax-triton jax-triton Public

    jax-triton contains integrations between JAX and OpenAI Triton

    Python 428 51

  4. scaling-book scaling-book Public

    Home for "How To Scale Your Model", a short blog-style textbook about scaling LLMs on TPUs

    HTML 654 94

  5. ml_dtypes ml_dtypes Public

    A stand-alone implementation of several NumPy dtype extensions used in machine learning.

    C++ 301 48

Repositories

Showing 10 of 14 repositories
  • jax Public

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

    jax-ml/jax’s past year of commit activity
    Python 33,699 Apache-2.0 3,197 1,611 (5 issues need help) 570 Updated Oct 15, 2025
  • jax-ml/jax-tpu-embedding’s past year of commit activity
    Python 23 Apache-2.0 3 1 13 Updated Oct 14, 2025
  • jax-triton Public

    jax-triton contains integrations between JAX and OpenAI Triton

    jax-ml/jax-triton’s past year of commit activity
    Python 428 Apache-2.0 51 9 25 Updated Oct 14, 2025
  • bonsai Public

    Minimal, lightweight JAX implementations of popular models.

    jax-ml/bonsai’s past year of commit activity
    Jupyter Notebook 112 Apache-2.0 16 11 8 Updated Oct 13, 2025
  • jax-ai-stack Public
    jax-ml/jax-ai-stack’s past year of commit activity
    Python 222 Apache-2.0 35 7 16 Updated Oct 10, 2025
  • ml_dtypes Public

    A stand-alone implementation of several NumPy dtype extensions used in machine learning.

    jax-ml/ml_dtypes’s past year of commit activity
    C++ 301 Apache-2.0 48 28 9 Updated Oct 7, 2025
  • oryx Public

    Oryx is a library for probabilistic programming and deep learning built on top of Jax.

    jax-ml/oryx’s past year of commit activity
    Python 279 Apache-2.0 11 17 (1 issue needs help) 2 Updated Oct 7, 2025
  • scaling-book Public

    Home for "How To Scale Your Model", a short blog-style textbook about scaling LLMs on TPUs

    jax-ml/scaling-book’s past year of commit activity
    HTML 654 MIT 94 2 2 Updated Oct 6, 2025
  • jax-llm-examples Public

    Minimal yet performant LLM examples in pure JAX

    jax-ml/jax-llm-examples’s past year of commit activity
    Python 185 Apache-2.0 23 7 3 Updated Sep 23, 2025
  • coix Public

    Inference Combinators in JAX

    jax-ml/coix’s past year of commit activity
    Jupyter Notebook 51 Apache-2.0 3 9 0 Updated May 16, 2025