-
Notifications
You must be signed in to change notification settings - Fork 367
Expand file tree
/
Copy pathvllm_manager.py
More file actions
315 lines (253 loc) · 9.21 KB
/
vllm_manager.py
File metadata and controls
315 lines (253 loc) · 9.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
"""
vLLM process management for GRPO trainer.
Handles launching, monitoring, and terminating vLLM server processes
for legacy mode training.
"""
import atexit
import os
import signal
import socket
import subprocess
import time
from typing import Optional
import requests
from .config import TrainingConfig
# Global variable to keep track of the vLLM process
_vllm_process: Optional[subprocess.Popen] = None
def is_port_in_use(port: int) -> bool:
"""Check if a port is already in use."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
def kill_process_on_port(port: int, timeout: float = 5.0) -> bool:
"""
Kill any process using the specified port.
Returns True if no process was running or if it was successfully killed.
"""
if not is_port_in_use(port):
return True
print(f" Port {port} is in use, attempting to kill existing process...")
try:
# Try to find and kill the process using lsof (Linux/Mac)
result = subprocess.run(
["lsof", "-t", "-i", f":{port}"], capture_output=True, text=True, timeout=5
)
if result.stdout.strip():
pids = result.stdout.strip().split("\n")
print(f" Killing {len(pids)} processes on port {port}...")
for pid in pids:
try:
os.kill(int(pid), signal.SIGTERM)
except (ProcessLookupError, ValueError):
pass
# Wait for port to be free
start = time.time()
while time.time() - start < timeout:
if not is_port_in_use(port):
print(f" Port {port} is now free")
return True
time.sleep(0.5)
# Force kill if still running
killed_count = 0
for pid in pids:
try:
os.kill(int(pid), signal.SIGKILL)
killed_count += 1
except (ProcessLookupError, ValueError):
pass
if killed_count > 0:
print(f" Force killed {killed_count} stubborn processes")
time.sleep(1)
return not is_port_in_use(port)
except FileNotFoundError:
# lsof not available, try fuser (Linux)
try:
subprocess.run(["fuser", "-k", f"{port}/tcp"], timeout=5)
time.sleep(1)
return not is_port_in_use(port)
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
except subprocess.TimeoutExpired:
pass
print(f" WARNING: Could not kill process on port {port}")
return False
def cleanup_vllm():
"""Cleanup function to terminate vLLM on exit."""
global _vllm_process
if _vllm_process:
print("\nTerminating vLLM process...")
_vllm_process.terminate()
try:
_vllm_process.wait(timeout=5)
print("vLLM process terminated.")
except subprocess.TimeoutExpired:
print("vLLM process did not terminate gracefully, killing.")
_vllm_process.kill()
_vllm_process.wait()
print("vLLM process killed.")
_vllm_process = None
# Register cleanup on module load
atexit.register(cleanup_vllm)
def launch_vllm_server(
config: TrainingConfig,
model_path: str,
) -> Optional[subprocess.Popen]:
"""
Launch a vLLM server process using our custom vllm_api_server.py.
Uses the custom server instead of standard vLLM because:
- Streamlined API: Only /generate endpoint (provides logprobs)
- Weight bridge support: /bridge/* endpoints for shared memory mode
- LoRA hot-swap: /lora/* endpoints for adapter loading/unloading
Args:
config: Training configuration
model_path: Path to model checkpoint
Returns:
Popen process object, or None if launch failed
"""
global _vllm_process
# Check if port is in use and try to kill existing process
if is_port_in_use(config.vllm_port):
print(f" WARNING: Port {config.vllm_port} is already in use!")
if not kill_process_on_port(config.vllm_port):
print(
f" ERROR: Could not free port {config.vllm_port}. Please manually kill the process."
)
print(f" Try: lsof -i :{config.vllm_port} | grep LISTEN")
print(f" Or: pkill -f 'vllm.*{config.vllm_port}'")
return None
print(f" Successfully freed port {config.vllm_port}")
# Use our custom vllm_api_server.py
script_dir = os.path.dirname(os.path.abspath(__file__))
custom_server_path = os.path.join(script_dir, "vllm_api_server.py")
vllm_command = [
"python",
custom_server_path,
"--model",
model_path,
"--port",
str(config.vllm_port),
"--gpu-memory-utilization",
str(config.vllm_gpu_memory_utilization),
]
# Add served-model-name if using checkpoint path
if model_path != config.model_name:
vllm_command.extend(["--served-model-name", config.model_name])
print(f" Launching vLLM: {' '.join(vllm_command)}")
try:
proc = subprocess.Popen(vllm_command)
print(f" vLLM launched with PID: {proc.pid}")
# Check for immediate startup errors
try:
proc.communicate(timeout=2)
if proc.returncode is not None and proc.returncode != 0:
print(" WARNING: vLLM failed to start")
return None
except subprocess.TimeoutExpired:
print(" vLLM process started (check logs for details)")
_vllm_process = proc
return proc
except FileNotFoundError:
print(" ERROR: vLLM not found. Is it installed?")
return None
except Exception as e:
print(f" ERROR launching vLLM: {e}")
return None
def terminate_vllm_process() -> None:
"""Terminate the running vLLM process if any."""
global _vllm_process
if _vllm_process is None:
return
print(" Terminating vLLM process...")
_vllm_process.terminate()
try:
_vllm_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(" vLLM did not terminate gracefully, killing...")
_vllm_process.kill()
_vllm_process.wait()
_vllm_process = None
def check_vllm_process_health() -> None:
"""Check if vLLM process terminated unexpectedly."""
global _vllm_process
if _vllm_process is not None and _vllm_process.poll() is not None:
print(
f" WARNING: vLLM terminated unexpectedly (code: {_vllm_process.returncode})"
)
_vllm_process = None
def get_vllm_process() -> Optional[subprocess.Popen]:
"""Get the current vLLM process."""
return _vllm_process
def set_vllm_process(proc: Optional[subprocess.Popen]) -> None:
"""Set the vLLM process (for external management)."""
global _vllm_process
_vllm_process = proc
def check_vllm_health(port: int) -> bool:
"""
Check if vLLM server is healthy and responding.
Args:
port: Port the vLLM server is running on
Returns:
True if server is healthy
"""
try:
response = requests.get(f"http://localhost:{port}/health", timeout=5)
return response.status_code == 200
except Exception:
return False
def wait_for_vllm_ready(port: int, timeout: float = 120.0) -> bool:
"""
Wait for vLLM server to be ready.
Args:
port: Port the vLLM server is running on
timeout: Maximum time to wait in seconds
Returns:
True if server is ready, False if timeout
"""
print(f" Waiting for vLLM to be ready (port {port})...")
start_time = time.time()
while time.time() - start_time < timeout:
if check_vllm_health(port):
print(" vLLM is ready!")
return True
time.sleep(2)
print(f" WARNING: vLLM not ready after {timeout}s")
return False
def hotswap_lora_adapter(
adapter_name: str,
adapter_path: str,
port: int,
) -> bool:
"""
Hot-swap a LoRA adapter on a running vLLM server.
Uses the vLLM /v1/load_lora_adapter endpoint to load a new adapter
without restarting the server.
Args:
adapter_name: Name to identify the adapter
adapter_path: Path to the adapter checkpoint
port: vLLM server port
Returns:
True if hot-swap succeeded
"""
try:
# Use vLLM's native LoRA loading endpoint
response = requests.post(
f"http://localhost:{port}/v1/load_lora_adapter",
json={
"lora_name": adapter_name,
"lora_path": adapter_path,
},
timeout=30,
)
if response.status_code == 200:
print(f" [LORA] ✓ Hot-swapped adapter: {adapter_name} ({adapter_path})")
return True
else:
print(
f" [LORA] ✗ Hot-swap failed: {response.status_code} - {response.text}"
)
return False
except requests.exceptions.ConnectionError:
print(f" [LORA] ✗ Cannot connect to vLLM at port {port}")
return False
except Exception as e:
print(f" [LORA] ✗ Error during hot-swap: {e}")
return False