diff --git a/src/accelerate/utils/operations.py b/src/accelerate/utils/operations.py index 162009e76b6..5f737344b56 100644 --- a/src/accelerate/utils/operations.py +++ b/src/accelerate/utils/operations.py @@ -652,8 +652,11 @@ def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): CannotPadNestedTensorWarning, ) return tensor - if dim >= len(tensor.shape): + if dim >= len(tensor.shape) or dim < -len(tensor.shape): return tensor + # Convert negative dimensions to non-negative + if dim < 0: + dim += len(tensor.shape) # Gather all sizes size = torch.tensor(tensor.shape, device=tensor.device)[None] diff --git a/tests/test_utils.py b/tests/test_utils.py index ed4481ed92c..cabdb55a1a6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -304,6 +304,15 @@ def test_pad_across_processes(self): nt2 = pad_across_processes(nt) assert nt is nt2 + # Basic functionality + tensor = torch.randn(4, 3, 100) + padded_tensor = pad_across_processes(tensor, dim=-1) + assert padded_tensor.shape[-1] == 100 + + # dim = -4 is out of bounds + padded_tensor = pad_across_processes(tensor, dim=-4) + assert padded_tensor is tensor + def test_slice_and_concatenate(self): # First base case: 2 processes, batch size of 1 num_processes = 2