diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 14fc6041574..954fba0d0d3 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -219,11 +219,23 @@ def custom_backward(output, grad_output): ) -def get_tensor_device(tensor: Union[torch.Tensor, Dict[str, torch.Tensor]]): - """Get the device of a tensor or a dictionary of tensors.""" - if isinstance(tensor, dict): - return next(iter(tensor.values())).device - return tensor.device +def get_tensor_device(tensor: Union[torch.Tensor, Dict, List]): + """Get the device of a tensor or nested structure (dict/list)""" + if isinstance(tensor, torch.Tensor): + return tensor.device + + elif isinstance(tensor, dict): + if not tensor: + raise RuntimeError("Empty dict in get_tensor_device") + return get_tensor_device(next(iter(tensor.values()))) + + elif isinstance(tensor, list): + if not tensor: + raise RuntimeError("Empty list in get_tensor_device") + return get_tensor_device(tensor[0]) + + else: + raise RuntimeError(f"Unsupported type in get_tensor_device: {type(tensor)}") def forward_step_calc_loss(