Skip to content

Commit a300f23

Browse files
committed
fix tp
1 parent 8d0c61f commit a300f23

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/nanotron/parallel/tensor_parallel/functional.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def backward(ctx, grad_output: torch.Tensor):
414414
grad_weight = grad_output.T @ total_input
415415
grad_input = grad_output @ weight
416416
if group.size() == 1:
417-
sub_grad_input = grad_input
417+
sub_grad_input = grad_input.reshape(input_size) # [s*b, h_in] -> [s, b, h_in]
418418
else:
419419
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
420420
# We set grad_input to be contiguous in case it isn't already.

0 commit comments

Comments
 (0)