Skip to content

Commit 0e30e29

Browse files
committed
Merge branch 'main' into amangu-lora
2 parents 31b1407 + 045c9a1 commit 0e30e29

File tree

4 files changed

+1873
-24
lines changed

4 files changed

+1873
-24
lines changed

.github/workflows/unit_tests.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,7 @@ jobs:
6868
run: make install-deps
6969
- name: Run all unit tests in JetStream (jetstream/tests)
7070
run: make unit-tests
71+
- name: Run all py tests in JetStream (jetstream/tests)
72+
run: make py-tests
7173
- name: Create test coverage report
72-
run: make check-test-coverage
74+
run: make check-test-coverage

benchmarks/benchmark_serving.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -625,29 +625,37 @@ async def grpc_async_request(
625625
) -> tuple[list[int], float, float, float]:
626626
"""Send grpc synchronous request since the current grpc server is sync."""
627627
options = [("grpc.keepalive_timeout_ms", 10000)]
628-
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
629-
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
630-
request_start_time = time.perf_counter()
631-
response = stub.Decode(request)
632-
token_list = []
633-
ttft = 0
634-
ttst = 0
635-
stream_resp_cnt = 0
636-
async for resp in response:
637-
stream_resp_cnt += 1
638-
if stream_resp_cnt == 1:
639-
await prefill_quota.inc()
640-
ttft = time.perf_counter() - request_start_time
641-
if ttft > 2.0:
642-
print(datetime.now(), f"slow TTFT {ttft:.2f}", prefill_quota.value())
643-
elif stream_resp_cnt == 2:
644-
ttst = time.perf_counter() - request_start_time
645-
resp_tokens = resp.stream_content.samples[0].token_ids
646-
token_list.extend(resp_tokens)
647-
out_token_cnt.increment(len(resp_tokens))
648-
await active_req_quota.inc()
649-
req_latency = time.perf_counter() - request_start_time
650-
return token_list, ttft, ttst, req_latency
628+
# Retry connection while server is not ready.
629+
while True:
630+
try:
631+
async with grpc.aio.insecure_channel(api_url, options=options) as channel:
632+
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
633+
request_start_time = time.perf_counter()
634+
response = stub.Decode(request)
635+
token_list = []
636+
ttft = 0
637+
ttst = 0
638+
stream_resp_cnt = 0
639+
async for resp in response:
640+
stream_resp_cnt += 1
641+
if stream_resp_cnt == 1:
642+
await prefill_quota.inc()
643+
ttft = time.perf_counter() - request_start_time
644+
if ttft > 2.0:
645+
print(
646+
datetime.now(), f"slow TTFT {ttft:.2f}", prefill_quota.value()
647+
)
648+
elif stream_resp_cnt == 2:
649+
ttst = time.perf_counter() - request_start_time
650+
resp_tokens = resp.stream_content.samples[0].token_ids
651+
token_list.extend(resp_tokens)
652+
out_token_cnt.increment(len(resp_tokens))
653+
await active_req_quota.inc()
654+
req_latency = time.perf_counter() - request_start_time
655+
return token_list, ttft, ttst, req_latency
656+
except grpc.aio.AioRpcError as e:
657+
print(e)
658+
await asyncio.sleep(10)
651659

652660

653661
async def send_request(

0 commit comments

Comments
 (0)