Skip to content

Commit 1830342

Browse files
authored
Add http server to JetStream (#115)
* Add http server to JetStream * Add generate api and cleanup * Add unit tests * format & deps * type & lint * Merge refactor * fix refactor
1 parent bd6d013 commit 1830342

File tree

12 files changed

+450
-38
lines changed

12 files changed

+450
-38
lines changed

jetstream/core/server_lib.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -93,40 +93,25 @@ def wait_for_termination(self) -> None:
9393
self.stop()
9494

9595

96-
def run(
97-
port: int,
96+
def create_driver(
9897
config: Type[config_lib.ServerConfig],
9998
devices: Any,
100-
credentials: Any = grpc.insecure_server_credentials(),
101-
threads: int | None = None,
10299
jax_padding: bool = True,
103-
metrics_server_config: config_lib.MetricsServerConfig | None = None,
104-
enable_jax_profiler: bool = False,
105-
jax_profiler_port: int = 9999,
100+
metrics_collector: JetstreamMetricsCollector | None = None,
106101
enable_model_warmup: bool = False,
107-
) -> JetStreamServer:
108-
"""Runs a server with a specified config.
102+
):
103+
"""Creates a driver with a specified config.
109104
110105
Args:
111-
port: Port on which the server will be made available.
112106
config: A ServerConfig to config engine, model, device slices, etc.
113107
devices: Device objects, will be used to get engine with proper slicing.
114-
credentials: Should use grpc credentials by default.
115-
threads: Number of RPC handlers worker threads. This should be at least
116-
equal to the decoding batch size to fully saturate the decoding queue.
117108
jax_padding: The flag to enable JAX padding during tokenization.
118-
metrics_server_config: The config to enable Promethus metric server.
119-
enable_jax_profiler: The flag to enable JAX profiler server.
120-
jax_profiler_port: The port JAX profiler server (default to 9999).
109+
metrics_collector: The JetStream Promethus metric collector.
121110
enable_model_warmup: The flag to enable model server warmup with AOT.
122111
123112
Returns:
124-
JetStreamServer that wraps the grpc server and orchestrator driver.
113+
An orchestrator driver.
125114
"""
126-
127-
server_start_time = time.time()
128-
129-
logging.info("Kicking off gRPC server.")
130115
engines = config_lib.get_engines(config, devices=devices)
131116
prefill_params = [pe.load_params() for pe in engines.prefill_engines]
132117
generate_params = [ge.load_params() for ge in engines.generate_engines]
@@ -136,19 +121,6 @@ def run(
136121
len(config.prefill_slices) + len(config.generate_slices) == 0
137122
)
138123

139-
# Setup Prometheus server
140-
metrics_collector: JetstreamMetricsCollector = None
141-
if metrics_server_config and metrics_server_config.port:
142-
logging.info(
143-
"Starting Prometheus server on port %d", metrics_server_config.port
144-
)
145-
start_http_server(metrics_server_config.port)
146-
metrics_collector = JetstreamMetricsCollector()
147-
else:
148-
logging.info(
149-
"Not starting Prometheus server: --prometheus_port flag not set"
150-
)
151-
152124
prefill_engines = engines.prefill_engines + engines.interleaved_engines
153125
generate_engines = engines.generate_engines + engines.interleaved_engines
154126
prefill_params = prefill_params + shared_params
@@ -182,7 +154,7 @@ def run(
182154
traceback.print_exc()
183155
os.kill(os.getpid(), signal.SIGKILL)
184156

185-
driver = orchestrator.Driver(
157+
return orchestrator.Driver(
186158
prefill_engines=prefill_engines,
187159
generate_engines=generate_engines,
188160
prefill_params=prefill_params,
@@ -192,6 +164,56 @@ def run(
192164
metrics_collector=metrics_collector,
193165
is_ray_backend=config.is_ray_backend,
194166
)
167+
168+
169+
def run(
170+
port: int,
171+
config: Type[config_lib.ServerConfig],
172+
devices: Any,
173+
credentials: Any = grpc.insecure_server_credentials(),
174+
threads: int | None = None,
175+
jax_padding: bool = True,
176+
metrics_server_config: config_lib.MetricsServerConfig | None = None,
177+
enable_jax_profiler: bool = False,
178+
jax_profiler_port: int = 9999,
179+
enable_model_warmup: bool = False,
180+
) -> JetStreamServer:
181+
"""Runs a server with a specified config.
182+
183+
Args:
184+
port: Port on which the server will be made available.
185+
config: A ServerConfig to config engine, model, device slices, etc.
186+
devices: Device objects, will be used to get engine with proper slicing.
187+
credentials: Should use grpc credentials by default.
188+
threads: Number of RPC handlers worker threads. This should be at least
189+
equal to the decoding batch size to fully saturate the decoding queue.
190+
jax_padding: The flag to enable JAX padding during tokenization.
191+
metrics_server_config: The config to enable Promethus metric server.
192+
enable_jax_profiler: The flag to enable JAX profiler server.
193+
jax_profiler_port: The port JAX profiler server (default to 9999).
194+
enable_model_warmup: The flag to enable model server warmup with AOT.
195+
196+
Returns:
197+
JetStreamServer that wraps the grpc server and orchestrator driver.
198+
"""
199+
server_start_time = time.time()
200+
logging.info("Kicking off gRPC server.")
201+
# Setup Prometheus server
202+
metrics_collector: JetstreamMetricsCollector = None
203+
if metrics_server_config and metrics_server_config.port:
204+
logging.info(
205+
"Starting Prometheus server on port %d", metrics_server_config.port
206+
)
207+
start_http_server(metrics_server_config.port)
208+
metrics_collector = JetstreamMetricsCollector()
209+
else:
210+
logging.info(
211+
"Not starting Prometheus server: --prometheus_port flag not set"
212+
)
213+
214+
driver = create_driver(
215+
config, devices, jax_padding, metrics_collector, enable_model_warmup
216+
)
195217
# We default threads to the total number of concurrent allowed decodes,
196218
# to make sure we can fully saturate the model. Set default minimum to 64.
197219
threads = threads or max(driver.get_total_concurrent_requests(), 64)

jetstream/entrypoints/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

jetstream/entrypoints/config.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Config for JetStream Server (including engine init)."""
16+
17+
from typing import Type
18+
19+
from jetstream.core import config_lib
20+
21+
22+
def get_server_config(
23+
config_str: str,
24+
) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]:
25+
match config_str:
26+
case "InterleavedCPUTestServer":
27+
server_config = config_lib.InterleavedCPUTestServer
28+
case "CPUTestServer":
29+
server_config = config_lib.CPUTestServer
30+
case _:
31+
raise NotImplementedError
32+
return server_config
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""JetStream Http API server."""
16+
17+
import json
18+
import logging
19+
from typing import Sequence
20+
from absl import app as abslapp
21+
from absl import flags
22+
from fastapi import APIRouter, Response
23+
import fastapi
24+
from fastapi.responses import StreamingResponse
25+
from prometheus_client import start_http_server
26+
import uvicorn
27+
from google.protobuf.json_format import Parse
28+
29+
from jetstream.core import config_lib, orchestrator, server_lib
30+
from jetstream.core.metrics.prometheus import JetstreamMetricsCollector
31+
from jetstream.core.proto import jetstream_pb2
32+
from jetstream.entrypoints.config import get_server_config
33+
from jetstream.entrypoints.http.protocol import DecodeRequest
34+
from jetstream.entrypoints.http.utils import proto_to_json_generator
35+
36+
flags.DEFINE_string("host", "0.0.0.0", "server host address")
37+
flags.DEFINE_integer("port", 8080, "http server port")
38+
flags.DEFINE_string(
39+
"config",
40+
"InterleavedCPUTestServer",
41+
"available servers",
42+
)
43+
flags.DEFINE_integer(
44+
"prometheus_port",
45+
9988,
46+
"prometheus_port",
47+
)
48+
49+
llm_orchestrator: orchestrator.LLMOrchestrator
50+
51+
# Define Fast API endpoints (use llm_orchestrator to handle).
52+
router = APIRouter()
53+
54+
55+
@router.get("/")
56+
def root():
57+
"""Root path for Jetstream HTTP Server."""
58+
return Response(
59+
content=json.dumps({"message": "JetStream HTTP Server"}, indent=4),
60+
media_type="application/json",
61+
)
62+
63+
64+
@router.post("/v1/generate")
65+
async def generate(request: DecodeRequest):
66+
proto_request = Parse(request.json(), jetstream_pb2.DecodeRequest())
67+
generator = llm_orchestrator.Decode(proto_request)
68+
return StreamingResponse(
69+
content=proto_to_json_generator(generator), media_type="text/event-stream"
70+
)
71+
72+
73+
@router.get("/v1/health")
74+
async def health() -> Response:
75+
"""Health check."""
76+
response = await llm_orchestrator.HealthCheck(
77+
jetstream_pb2.HealthCheckRequest()
78+
)
79+
return Response(
80+
content=json.dumps({"is_live": str(response.is_live)}, indent=4),
81+
media_type="application/json",
82+
status_code=200,
83+
)
84+
85+
86+
def server(argv: Sequence[str]):
87+
# Init Fast API.
88+
app = fastapi.FastAPI()
89+
app.include_router(router)
90+
91+
# Init LLMOrchestrator which would be the main handler in the api endpoints.
92+
devices = server_lib.get_devices()
93+
print(f"devices: {devices}")
94+
server_config = get_server_config(flags.FLAGS.config)
95+
print(f"server_config: {server_config}")
96+
del argv
97+
98+
metrics_server_config: config_lib.MetricsServerConfig | None = None
99+
# Setup Prometheus server
100+
metrics_collector: JetstreamMetricsCollector = None
101+
if flags.FLAGS.prometheus_port != 0:
102+
metrics_server_config = config_lib.MetricsServerConfig(
103+
port=flags.FLAGS.prometheus_port
104+
)
105+
logging.info(
106+
"Starting Prometheus server on port %d", metrics_server_config.port
107+
)
108+
start_http_server(metrics_server_config.port)
109+
metrics_collector = JetstreamMetricsCollector()
110+
else:
111+
logging.info(
112+
"Not starting Prometheus server: --prometheus_port flag not set"
113+
)
114+
115+
global llm_orchestrator
116+
llm_orchestrator = orchestrator.LLMOrchestrator(
117+
driver=server_lib.create_driver(
118+
config=server_config,
119+
devices=devices,
120+
metrics_collector=metrics_collector,
121+
)
122+
)
123+
124+
# Start uvicorn http server.
125+
uvicorn.run(
126+
app, host=flags.FLAGS.host, port=flags.FLAGS.port, log_level="info"
127+
)
128+
129+
130+
if __name__ == "__main__":
131+
# Run Abseil app w flags parser.
132+
abslapp.run(server)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Http API server protocol."""
16+
17+
from pydantic import BaseModel # type: ignore
18+
19+
20+
class TextContent(BaseModel):
21+
text: str
22+
23+
24+
class TokenContent(BaseModel):
25+
token_ids: list[int]
26+
27+
28+
class DecodeRequest(BaseModel):
29+
max_tokens: int
30+
text_content: TextContent | None = None
31+
token_content: TokenContent | None = None
32+
33+
# Config to enforce the oneof behavior at runtime.
34+
class Config:
35+
extra = "forbid" # Prevent extra fields.
36+
anystr_strip_whitespace = True

jetstream/entrypoints/http/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Http API server utilities."""
16+
17+
from google.protobuf.json_format import MessageToJson
18+
19+
20+
async def proto_to_json_generator(proto_generator):
21+
"""Wraps a generator yielding Protocol Buffer messages into a generator
22+
23+
yielding JSON messages.
24+
"""
25+
async for proto_message in proto_generator:
26+
json_string = MessageToJson(proto_message)
27+
yield json_string

0 commit comments

Comments
 (0)