Skip to content

Commit 4acc42a

Browse files
fix(server): better handling of inference mode (#57)
1 parent e114d87 commit 4acc42a

File tree

6 files changed

+35
-25
lines changed

6 files changed

+35
-25
lines changed

launcher/src/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ struct Args {
3838
port: u16,
3939
#[clap(default_value = "/tmp/text-generation-server", long, env)]
4040
shard_uds_path: String,
41-
#[clap(default_value = "localhost", long, env)]
41+
#[clap(default_value = "0.0.0.0", long, env)]
4242
master_addr: String,
43-
#[clap(default_value = "29500", long, env)]
43+
#[clap(default_value = "6000", long, env)]
4444
master_port: usize,
4545
#[clap(long, env)]
4646
json_output: bool,

server/text_generation/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
2929
torch.backends.cudnn.allow_tf32 = True
3030

31+
# Disable gradients
32+
torch.set_grad_enabled(False)
33+
3134

3235
def get_model(
3336
model_id: str, revision: Optional[str], sharded: bool, quantize: bool

server/text_generation/models/causal_lm.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,17 +289,12 @@ def forward(
289289
def generate_token(
290290
self, batch: CausalLMBatch
291291
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
292-
# For some reason, inference_mode does not work well with GLOO which we use on CPU
293-
context_manager = (
294-
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
292+
logits, past = self.forward(
293+
batch.input_ids,
294+
batch.attention_mask,
295+
batch.position_ids,
296+
batch.past_key_values,
295297
)
296-
with context_manager():
297-
logits, past = self.forward(
298-
batch.input_ids,
299-
batch.attention_mask,
300-
batch.position_ids,
301-
batch.past_key_values,
302-
)
303298

304299
# List of indices to cache
305300
next_batch_keep_indices = []

server/text_generation/models/seq2seq_lm.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -364,19 +364,14 @@ def forward(
364364
def generate_token(
365365
self, batch: Seq2SeqLMBatch
366366
) -> Tuple[List[Generation], Optional[Seq2SeqLMBatch]]:
367-
# For some reason, inference_mode does not work well with GLOO which we use on CPU
368-
context_manager = (
369-
torch.no_grad if self.device.type == "cpu" else torch.inference_mode
367+
logits, encoder_last_hidden_state, past = self.forward(
368+
batch.input_ids,
369+
batch.attention_mask,
370+
batch.decoder_input_ids,
371+
batch.decoder_attention_mask,
372+
batch.encoder_last_hidden_state,
373+
batch.past_key_values,
370374
)
371-
with context_manager():
372-
logits, encoder_last_hidden_state, past = self.forward(
373-
batch.input_ids,
374-
batch.attention_mask,
375-
batch.decoder_input_ids,
376-
batch.decoder_attention_mask,
377-
batch.encoder_last_hidden_state,
378-
batch.past_key_values,
379-
)
380375

381376
# List of indices to cache
382377
next_batch_keep_indices = []

server/text_generation/server.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import os
3+
import torch
34

45
from grpc import aio
56
from loguru import logger
@@ -19,6 +20,10 @@ def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
1920
self.cache = cache
2021
self.model = model
2122
self.server_urls = server_urls
23+
# For some reason, inference_mode does not work well with GLOO which we use on CPU
24+
if model.device.type == "cuda":
25+
# Force inference mode for the lifetime of TextGenerationService
26+
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
2227

2328
async def ServiceDiscovery(self, request, context):
2429
return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls)
@@ -89,7 +94,11 @@ async def serve_inner(
8994
local_url = unix_socket_template.format(uds_path, 0)
9095
server_urls = [local_url]
9196

92-
model = get_model(model_id, revision, sharded, quantize)
97+
try:
98+
model = get_model(model_id, revision, sharded, quantize)
99+
except Exception:
100+
logger.exception("Error when initializing model")
101+
raise
93102

94103
server = aio.server(interceptors=[ExceptionInterceptor()])
95104
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
@@ -101,8 +110,11 @@ async def serve_inner(
101110
)
102111
reflection.enable_server_reflection(SERVICE_NAMES, server)
103112
server.add_insecure_port(local_url)
113+
104114
await server.start()
115+
105116
logger.info("Server started at {}".format(local_url))
117+
106118
try:
107119
await server.wait_for_termination()
108120
except KeyboardInterrupt:

server/text_generation/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,14 @@ def initialize_torch_distributed():
171171
else:
172172
backend = "gloo"
173173

174+
master_ip = os.getenv("MASTER_ADDR", "0.0.0.0")
175+
master_port = os.getenv("MASTER_PORT", "6000")
176+
init_method = f"tcp://{master_ip}:{master_port}"
177+
174178
# Call the init process.
175179
torch.distributed.init_process_group(
176180
backend=backend,
181+
init_method=init_method,
177182
world_size=world_size,
178183
rank=rank,
179184
timeout=timedelta(seconds=60),

0 commit comments

Comments
 (0)