From e91db7f370875d93269da6e4298d2b1d65f14ec4 Mon Sep 17 00:00:00 2001 From: Prachi Kushwaha Date: Sun, 26 Apr 2026 22:29:29 +0530 Subject: [PATCH 1/2] Start draft PR for get_tensor_device fix From 3256f8b52dd76fe43d4e87967da7c2149cb13750 Mon Sep 17 00:00:00 2001 From: Prachi Kushwaha Date: Mon, 27 Apr 2026 03:16:25 +0530 Subject: [PATCH 2/2] fix get_tensor_device to support nested dict/list outputs --- megatron/core/pipeline_parallel/schedules.py | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 14fc6041574..954fba0d0d3 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -219,11 +219,23 @@ def custom_backward(output, grad_output): ) -def get_tensor_device(tensor: Union[torch.Tensor, Dict[str, torch.Tensor]]): - """Get the device of a tensor or a dictionary of tensors.""" - if isinstance(tensor, dict): - return next(iter(tensor.values())).device - return tensor.device +def get_tensor_device(tensor: Union[torch.Tensor, Dict, List]): + """Get the device of a tensor or nested structure (dict/list)""" + if isinstance(tensor, torch.Tensor): + return tensor.device + + elif isinstance(tensor, dict): + if not tensor: + raise RuntimeError("Empty dict in get_tensor_device") + return get_tensor_device(next(iter(tensor.values()))) + + elif isinstance(tensor, list): + if not tensor: + raise RuntimeError("Empty list in get_tensor_device") + return get_tensor_device(tensor[0]) + + else: + raise RuntimeError(f"Unsupported type in get_tensor_device: {type(tensor)}") def forward_step_calc_loss(