Skip to content

Commit e9669a4

Browse files
feat(server): do not use device_map auto on single GPU (#362)
1 parent cfaa858 commit e9669a4

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

server/text_generation_server/models/causal_lm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,9 +468,12 @@ def __init__(
468468
model_id,
469469
revision=revision,
470470
torch_dtype=dtype,
471-
device_map="auto" if torch.cuda.is_available() else None,
471+
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
472472
load_in_8bit=quantize == "bitsandbytes",
473473
)
474+
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
475+
model = model.cuda()
476+
474477
tokenizer.pad_token_id = (
475478
model.config.pad_token_id
476479
if model.config.pad_token_id is not None

server/text_generation_server/models/seq2seq_lm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,12 @@ def __init__(
518518
model_id,
519519
revision=revision,
520520
torch_dtype=dtype,
521-
device_map="auto" if torch.cuda.is_available() else None,
521+
device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None,
522522
load_in_8bit=quantize == "bitsandbytes",
523523
)
524+
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
525+
model = model.cuda()
526+
524527
tokenizer = AutoTokenizer.from_pretrained(
525528
model_id, revision=revision, padding_side="left", truncation_side="left"
526529
)

0 commit comments

Comments
 (0)