Skip to content

Commit 2dddc28

Browse files
authored
[Feature] Use uv for package management and installation (#485)
* [Feature] Use `uv` for package management and installation
1 parent 9d3666d commit 2dddc28

File tree

9 files changed

+1202
-142
lines changed

9 files changed

+1202
-142
lines changed

.pre-commit-config.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,14 @@ repos:
6868
pass_filenames: false
6969
always_run: false
7070
require_serial: true
71+
72+
# Generate requirements.txt with uv
73+
- repo: https://github.com/astral-sh/uv-pre-commit
74+
# uv version.
75+
rev: 0.9.5
76+
hooks:
77+
# Compile requirements
78+
- id: pip-compile
79+
name: Compile requirements.txt using uv
80+
files: ^requirements\.txt|pyproject\.toml$
81+
args: [pyproject.toml, -o, requirements.txt]

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
recursive-include evaluation/latex2sympy *

areal/tests/test_model_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
class InvalidConfig(BaseExperimentConfig):
2424
"""Invalid config for testing - missing actor/model attribute."""
2525

26-
pass
27-
2826

2927
def _create_base_cluster_config(fileroot="/tmp/areal_test"):
3028
"""Helper to create a cluster config with common test values."""

docs/tutorial/installation.md

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,30 @@ conda create -n areal python=3.12
6565
conda activate areal
6666
```
6767

68-
3. Install pip dependencies:
68+
3. Install pip dependencies using uv:
6969

7070
```bash
7171
git clone https://github.com/inclusionAI/AReaL
7272
cd AReaL
73-
bash examples/env/setup-pip-deps.sh
73+
pip install uv
74+
uv pip install -e .[all]
7475
```
7576

76-
**Note**: Installing with `examples/env/setup-pip-deps.sh` will install
77-
`flash-attn==2.8.3` since it does not require compilation with torch version 2.8.0.
78-
However, `flash-attn==2.8.3` is not compatible with Megatron training backend. If you
79-
want to use Megatron training backend, please compile and install `flash-attn==2.8.1` in
80-
your custom environment, or use docker installation instead.
77+
**Note**: Directly install with `uv` and `pip` will install `flash-attn==2.8.3` since it
78+
does not require compilation with torch version 2.8.0. However, `flash-attn==2.8.3` is
79+
not compatible with Megatron training backend. If you want to use Megatron training
80+
backend, please compile and install `flash-attn==2.8.1` in your custom environment, or
81+
use docker installation instead.
82+
83+
4. Validate your AReaL installation:
84+
85+
We provide a script to validate AReaL installation. Simply run:
86+
87+
```bash
88+
python3 examples/env/validate_installation.py
89+
```
90+
91+
After installation validation passed, you are good to go!
8192

8293
## (Optional) Launch Ray Cluster for Distributed Training
8394

examples/env/setup-eval-pip-deps.sh

Lines changed: 0 additions & 9 deletions
This file was deleted.

examples/env/setup-pip-deps.sh

Lines changed: 0 additions & 17 deletions
This file was deleted.

examples/env/validate_installation.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import importlib
1111
import sys
1212
from importlib.metadata import version as get_version
13-
from typing import Optional
1413

1514
from packaging.version import Version
1615

@@ -25,7 +24,7 @@ def test_import(
2524
self,
2625
module_name: str,
2726
required: bool = True,
28-
test_func: Optional[callable] = None,
27+
test_func: callable | None = None,
2928
) -> bool:
3029
"""Test importing a module and optionally run additional tests."""
3130
try:
@@ -78,9 +77,9 @@ def test_vllm_functionality(self, vllm_module):
7877
"""Test vLLM basic functionality."""
7978
from vllm import LLM, SamplingParams # noqa
8079

81-
assert Version(get_version("vllm")) == Version(
82-
"0.10.2"
83-
), f"vLLM version should be 0.10.2, found {Version(get_version('vllm'))}"
80+
assert Version(get_version("vllm")) == Version("0.10.2"), (
81+
f"vLLM version should be 0.10.2, found {Version(get_version('vllm'))}"
82+
)
8483

8584
print(" - vLLM core classes imported successfully")
8685

@@ -96,15 +95,15 @@ def test_sglang_functionality(self, sglang_module):
9695
)
9796
from sglang import Engine, launch_server # noqa
9897

99-
assert Version(get_version("sglang")) >= Version(
100-
"v0.5.2"
101-
), "SGLang version should be >= v0.5.2"
98+
assert Version(get_version("sglang")) >= Version("v0.5.2"), (
99+
"SGLang version should be >= v0.5.2"
100+
)
102101
print(" - SGLang imported successfully")
103102

104103
def test_transformers(self, transformers_module):
105-
assert Version(get_version("transformers")) == Version(
106-
"4.56.1"
107-
), "transformers version should be 4.56.1"
104+
assert Version(get_version("transformers")) == Version("4.56.1"), (
105+
"transformers version should be 4.56.1"
106+
)
108107
print(" - transformers imported successfully")
109108

110109
def validate_critical_dependencies(self):
@@ -121,7 +120,6 @@ def validate_critical_dependencies(self):
121120
self.test_import(
122121
"flash_attn", required=True, test_func=self.test_flash_attn_functionality
123122
)
124-
self.test_import("cugae", required=True)
125123
# Inference engines
126124
self.test_import(
127125
"sglang", required=True, test_func=self.test_sglang_functionality
@@ -175,14 +173,15 @@ def validate_optional_dependencies(self):
175173
self.test_import("seaborn", required=False)
176174
self.test_import("numba", required=False)
177175
self.test_import("nltk", required=False)
176+
self.test_import("cugae", required=False)
178177

179178
def test_te_functionality(self, _):
180179
try:
181180
import torch
182181

183-
assert Version(get_version("transformer_engine")) >= Version(
184-
"2.3.0"
185-
), "transformer_engine version must be larger than 2.3.0"
182+
assert Version(get_version("transformer_engine")) >= Version("2.3.0"), (
183+
"transformer_engine version must be larger than 2.3.0"
184+
)
186185

187186
if torch.cuda.is_available():
188187
import transformer_engine.pytorch as te
@@ -307,14 +306,14 @@ def run_validation(self):
307306

308307
# Determine overall result
309308
if self.critical_failures:
310-
print(f"\n❌ INSTALLATION VALIDATION FAILED")
309+
print("\n❌ INSTALLATION VALIDATION FAILED")
311310
print("Please check the critical failures above and ensure all required")
312311
print(
313312
"dependencies are properly installed according to the installation guide."
314313
)
315314
return False
316315
else:
317-
print(f"\n✅ INSTALLATION VALIDATION PASSED")
316+
print("\n✅ INSTALLATION VALIDATION PASSED")
318317
if self.warnings:
319318
print("Note: Some optional dependencies failed but this won't affect")
320319
print("core functionality.")

pyproject.toml

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,20 @@ classifiers = [
4141

4242
dependencies = [
4343
# Core ML/AI libraries
44-
"torch>2.0.0",
44+
"torch==2.8.0",
45+
"torchaudio",
46+
"torchvision",
47+
4548
"huggingface_hub",
4649
"datasets",
4750
"accelerate",
4851
"transformers==4.56.1",
4952
"megatron-core==0.13.1",
5053
"mbridge==0.13.0",
54+
"flashinfer-python==0.3.1",
55+
"sglang[all]==0.5.2",
56+
"flash-attn==2.8.3",
57+
"vllm==0.10.2",
5158
"peft",
5259
"qwen_agent",
5360

@@ -68,6 +75,7 @@ dependencies = [
6875
"orjson>=3.10.16",
6976
"pydantic",
7077
"PyYAML",
78+
"omegaconf==2.4.0.dev2",
7179
"hydra-core==1.4.0.dev1",
7280
"packaging",
7381
"lark",
@@ -78,6 +86,7 @@ dependencies = [
7886
"tensordict",
7987
"pybase64",
8088
"msgspec",
89+
"latex2sympy2",
8190
"openai==1.99.6",
8291
"dotenv",
8392
"json5",
@@ -89,7 +98,9 @@ dependencies = [
8998
"colorlog",
9099
"psutil",
91100
"pynvml",
92-
"swanlab[dashboard]",
101+
"nvidia-ml-py",
102+
"swanboard==0.1.9b1",
103+
"swanlab[dashboard]==0.6.12",
93104

94105
# Performance and compression
95106
"ninja",
@@ -174,6 +185,19 @@ where = ["."]
174185
include = ["areal*"]
175186
exclude = ["tests*", "realhf*", "docs*", "examples*", "evaluation*", "benchmark*"]
176187

188+
[tool.uv]
189+
override-dependencies=[
190+
"openai==1.99.6",
191+
"xgrammar==0.1.24",
192+
"outlines-core==0.1.26",
193+
]
194+
195+
[tool.uv.sources]
196+
latex2sympy2 = { path = "./evaluation/latex2sympy" }
197+
198+
[tool.uv.extra-build-dependencies]
199+
flash-attn = ["torch==2.8.0"]
200+
177201
[tool.pytest.ini_options]
178202
pythonpath = ["."]
179203
filterwarnings = [

0 commit comments

Comments
 (0)