Skip to content

Commit f70ba2e

Browse files
authored
Merge branch 'main' into fix/tpu7x-chip-counting
2 parents e92f0a2 + 739d6f6 commit f70ba2e

File tree

3 files changed

+107
-10
lines changed

3 files changed

+107
-10
lines changed

README.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111

1212
---
1313

14-
_Upcoming Events_ 🔥
15-
16-
- Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) in San Francisco!
17-
- Join us at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
18-
- Join us at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
19-
2014
_Latest News_ 🔥
2115

16+
- [Pytorch Conference](https://pytorchconference.sched.com/event/27QCh/sponsored-session-everything-everywhere-all-at-once-vllm-hardware-optionality-with-spotify-and-google-brittany-rockwell-google-shireen-kheradpey-spotify) Learn how Spotify uses vLLM with both GPUs and TPUs to drive down costs and improve user experience.
17+
- Check back soon for a recording of our session at [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco!
18+
- Check back soon for a recording of our session at [JAX DevLab on November 18th](https://rsvp.withgoogle.com/events/devlab-fall-2025) in Sunnyvale!
19+
2220
- [2025/10] [vLLM TPU: A New Unified Backend Supporting PyTorch and JAX on TPU](https://blog.vllm.ai/2025/10/16/vllm-tpu.html)
2321

2422
<details>

tests/test_envs.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,77 @@ def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
8787
assert envs.USE_MOE_EP_KERNEL is True
8888

8989

90+
def test_boolean_env_vars_string_values(monkeypatch: pytest.MonkeyPatch):
91+
"""Test that boolean env vars accept string values like 'True' and 'False'"""
92+
93+
# Test NEW_MODEL_DESIGN with string "True"
94+
monkeypatch.setenv("NEW_MODEL_DESIGN", "True")
95+
assert envs.NEW_MODEL_DESIGN is True
96+
97+
monkeypatch.setenv("NEW_MODEL_DESIGN", "true")
98+
assert envs.NEW_MODEL_DESIGN is True
99+
100+
monkeypatch.setenv("NEW_MODEL_DESIGN", "False")
101+
assert envs.NEW_MODEL_DESIGN is False
102+
103+
monkeypatch.setenv("NEW_MODEL_DESIGN", "false")
104+
assert envs.NEW_MODEL_DESIGN is False
105+
106+
# Test SKIP_JAX_PRECOMPILE with string values
107+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "True")
108+
assert envs.SKIP_JAX_PRECOMPILE is True
109+
110+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "false")
111+
assert envs.SKIP_JAX_PRECOMPILE is False
112+
113+
# Test VLLM_XLA_CHECK_RECOMPILATION with string values
114+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "TRUE")
115+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is True
116+
117+
monkeypatch.setenv("VLLM_XLA_CHECK_RECOMPILATION", "FALSE")
118+
assert envs.VLLM_XLA_CHECK_RECOMPILATION is False
119+
120+
# Test USE_MOE_EP_KERNEL with string values
121+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "true")
122+
assert envs.USE_MOE_EP_KERNEL is True
123+
124+
monkeypatch.setenv("USE_MOE_EP_KERNEL", "False")
125+
assert envs.USE_MOE_EP_KERNEL is False
126+
127+
128+
def test_boolean_env_vars_invalid_values(monkeypatch: pytest.MonkeyPatch):
129+
"""Test that boolean env vars raise errors for invalid values"""
130+
131+
# Test invalid value for NEW_MODEL_DESIGN
132+
monkeypatch.setenv("NEW_MODEL_DESIGN", "yes")
133+
with pytest.raises(
134+
ValueError,
135+
match="Invalid boolean value 'yes' for NEW_MODEL_DESIGN"):
136+
_ = envs.NEW_MODEL_DESIGN
137+
138+
monkeypatch.setenv("NEW_MODEL_DESIGN", "2")
139+
with pytest.raises(ValueError,
140+
match="Invalid boolean value '2' for NEW_MODEL_DESIGN"):
141+
_ = envs.NEW_MODEL_DESIGN
142+
143+
# Test invalid value for SKIP_JAX_PRECOMPILE
144+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "invalid")
145+
with pytest.raises(
146+
ValueError,
147+
match="Invalid boolean value 'invalid' for SKIP_JAX_PRECOMPILE"):
148+
_ = envs.SKIP_JAX_PRECOMPILE
149+
150+
151+
def test_boolean_env_vars_empty_string(monkeypatch: pytest.MonkeyPatch):
152+
"""Test that empty string returns default value"""
153+
154+
monkeypatch.setenv("NEW_MODEL_DESIGN", "")
155+
assert envs.NEW_MODEL_DESIGN is False # Should return default
156+
157+
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "")
158+
assert envs.SKIP_JAX_PRECOMPILE is False # Should return default
159+
160+
90161
def test_integer_env_vars(monkeypatch: pytest.MonkeyPatch):
91162
# Ensure clean environment for integer vars by setting to defaults
92163
monkeypatch.setenv("PYTHON_TRACER_LEVEL", "1")

tpu_inference/envs.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,34 @@ def _get_validated_env() -> str | None:
6969
return _get_validated_env
7070

7171

72+
def env_bool(env_name: str, default: bool = False) -> Callable[[], bool]:
73+
"""
74+
Accepts both numeric strings ("0", "1") and boolean strings
75+
("true", "false", "True", "False").
76+
77+
Args:
78+
env_name: Name of the environment variable
79+
default: Default boolean value if not set
80+
"""
81+
82+
def _get_bool_env() -> bool:
83+
value = os.getenv(env_name)
84+
if value is None or value == "":
85+
return default
86+
87+
value_lower = value.lower()
88+
if value_lower in ("true", "1"):
89+
return True
90+
elif value_lower in ("false", "0"):
91+
return False
92+
else:
93+
raise ValueError(
94+
f"Invalid boolean value '{value}' for {env_name}. "
95+
f"Valid options: '0', '1', 'true', 'false', 'True', 'False'.")
96+
97+
return _get_bool_env
98+
99+
72100
environment_variables: dict[str, Callable[[], Any]] = {
73101
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
74102
"JAX_PLATFORMS":
@@ -93,17 +121,17 @@ def _get_validated_env() -> str | None:
93121
lambda: os.getenv("DECODE_SLICES", ""),
94122
# Skip JAX precompilation step during initialization
95123
"SKIP_JAX_PRECOMPILE":
96-
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE") or "0")),
124+
env_bool("SKIP_JAX_PRECOMPILE", default=False),
97125
# Check for XLA recompilation during execution
98126
"VLLM_XLA_CHECK_RECOMPILATION":
99-
lambda: bool(int(os.getenv("VLLM_XLA_CHECK_RECOMPILATION") or "0")),
127+
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
100128
# Model implementation type (e.g., "flax_nnx")
101129
"MODEL_IMPL_TYPE":
102130
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
103131
["vllm", "flax_nnx", "jetpack"]),
104132
# Enable new experimental model design
105133
"NEW_MODEL_DESIGN":
106-
lambda: bool(int(os.getenv("NEW_MODEL_DESIGN") or "0")),
134+
env_bool("NEW_MODEL_DESIGN", default=False),
107135
# Directory to store phased profiling output
108136
"PHASED_PROFILING_DIR":
109137
lambda: os.getenv("PHASED_PROFILING_DIR", ""),
@@ -112,7 +140,7 @@ def _get_validated_env() -> str | None:
112140
lambda: int(os.getenv("PYTHON_TRACER_LEVEL") or "1"),
113141
# Use custom expert-parallel kernel for MoE (Mixture of Experts)
114142
"USE_MOE_EP_KERNEL":
115-
lambda: bool(int(os.getenv("USE_MOE_EP_KERNEL") or "0")),
143+
env_bool("USE_MOE_EP_KERNEL", default=False),
116144
# Number of TPU slices for multi-slice mesh
117145
"NUM_SLICES":
118146
lambda: int(os.getenv("NUM_SLICES") or "1"),

0 commit comments

Comments
 (0)