@@ -4,9 +4,12 @@ You can install SGLang-Jax using one of the methods below.
44
55This page is mainly applicable to TPU devices running through JAX.
66
7- ## Method 1: With pip or uv
7+ ## Method 1: With uv
88
9- 🚧 ** Under Construction** 🚧
9+ ``` bash
10+ uv venv --python 3.12 && source .venv/bin/activate
11+ uv pip install sglang-jax
12+ ```
1013
1114## Method 2: From source
1215
@@ -16,11 +19,11 @@ git clone https://github.com/sgl-project/sglang-jax
1619cd sglang-jax
1720
1821# Install the python packages
19- pip install --upgrade pip setuptools packaging
20- pip install -e " python[all] "
22+ uv venv --python 3.12 && source .venv/bin/activate
23+ uv pip install -e python/
2124
2225# Run Qwen-7B Model
23- JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server --model-path Qwen/Qwen-7B-Chat --trust-remote-code --dist-init-addr=0.0.0.0:10011 --nnodes=1 --tp-size=4 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.8 --max-prefill-tokens=8192 --download-dir=/tmp --dtype=bfloat16 --skip-server-warmup --host 0.0.0.0 --port 30000
26+ JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache uv run python -u -m sgl_jax.launch_server --model-path Qwen/Qwen-7B-Chat --trust-remote-code --dist-init-addr=0.0.0.0:10011 --nnodes=1 --tp-size=4 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.8 --max-prefill-tokens=8192 --download-dir=/tmp --dtype=bfloat16 --skip-server-warmup --host 0.0.0.0 --port 30000
2427```
2528
2629## Method 3: Using docker
@@ -31,11 +34,7 @@ JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server --m
3134
3235🚧 ** Under Construction** 🚧
3336
34- ## Method 5: Using docker compose
35-
36- 🚧 ** Under Construction** 🚧
37-
38- ## Method 6: Run on Cloud TPU with SkyPilot
37+ ## Method 5: Run on Cloud TPU with SkyPilot
3938
4039<details >
4140<summary >More</summary >
@@ -54,25 +53,19 @@ resources:
5453 accelerator_args :
5554 tpu_vm : True
5655 runtime_version : v2-alpha-tpuv6e
57- file_mounts :
58- ~/.ssh/id_rsa : ~/.ssh/id_rsa
59- setup : |
60- chmod 600 ~/.ssh/id_rsa
61- rm ~/.ssh/config
62- GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" git clone https://github.com/sgl-project/sglang-jax
6356run : |
64- cd sglang-jax
65- pip install -e "python[all]"
66- JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server --model-path Qwen/Qwen-7B-Chat --trust-remote-code --dist-init-addr=0.0.0.0:10011 --nnodes=1 --tp-size=4 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.8 --max-prefill-tokens=8192 --download-dir=/tmp --dtype=bfloat16 --skip-server-warmup --attention-backend=fa --host 0.0.0.0 --port 30000
57+ git clone https://github.com/sgl-project/sglang-jax.git
58+ cd sglang-jax && git fetch origin $REF:$REF && git checkout $REF
59+ uv venv --python 3.12
60+ source .venv/bin/activate
61+ uv pip install -e python/
6762` ` `
6863
6964</details>
7065
7166` ` ` bash
7267sky launch -c sglang-jax sglang.yaml --infra=gcp
7368
74- # Get the HTTP API endpoint
75- sky status --endpoint 30000 sglang-jax
7669```
7770- For debugging and testing purposes, you can use spot instances to reduce costs by adding the ` --use-spot ` flag to your SkyPilot commands:
7871 ``` bash
0 commit comments