Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions gpt_oss/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
import torch
import torch.distributed as dist


def suppress_output(rank):
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print

def print(*args, **kwargs):
force = kwargs.pop('force', False)
if force:
builtin_print("rank #%d:" % rank, *args, **kwargs)
elif rank == 0:
builtin_print(*args, **kwargs)

__builtin__.print = print


def init_distributed() -> torch.device:
"""Initialize the model for distributed inference."""
# Initialize distributed inference
Expand All @@ -27,14 +23,19 @@ def init_distributed() -> torch.device:
dist.init_process_group(
backend="nccl", init_method="env://", world_size=world_size, rank=rank
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")


# Check if CUDA is available before setting device
if torch.cuda.is_available():
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
else:
device = torch.device("cpu")

# Warm up NCCL to avoid first-time latency
if world_size > 1:
x = torch.ones(1, device=device)
dist.all_reduce(x)
torch.cuda.synchronize(device)

if torch.cuda.is_available():
torch.cuda.synchronize(device)
suppress_output(rank)
return device