11import asyncio
22import os
3+ import torch
34
45from grpc import aio
56from loguru import logger
@@ -19,6 +20,10 @@ def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
1920 self .cache = cache
2021 self .model = model
2122 self .server_urls = server_urls
23+ # For some reason, inference_mode does not work well with GLOO which we use on CPU
24+ if model .device .type == "cuda" :
25+ # Force inference mode for the lifetime of TextGenerationService
26+ self ._inference_mode_raii_guard = torch ._C ._InferenceMode (True )
2227
2328 async def ServiceDiscovery (self , request , context ):
2429 return generate_pb2 .ServiceDiscoveryResponse (urls = self .server_urls )
@@ -89,7 +94,11 @@ async def serve_inner(
8994 local_url = unix_socket_template .format (uds_path , 0 )
9095 server_urls = [local_url ]
9196
92- model = get_model (model_id , revision , sharded , quantize )
97+ try :
98+ model = get_model (model_id , revision , sharded , quantize )
99+ except Exception :
100+ logger .exception ("Error when initializing model" )
101+ raise
93102
94103 server = aio .server (interceptors = [ExceptionInterceptor ()])
95104 generate_pb2_grpc .add_TextGenerationServiceServicer_to_server (
@@ -101,8 +110,11 @@ async def serve_inner(
101110 )
102111 reflection .enable_server_reflection (SERVICE_NAMES , server )
103112 server .add_insecure_port (local_url )
113+
104114 await server .start ()
115+
105116 logger .info ("Server started at {}" .format (local_url ))
117+
106118 try :
107119 await server .wait_for_termination ()
108120 except KeyboardInterrupt :
0 commit comments