Skip to content

Commit c0f86f0

Browse files
committed
fix tp
1 parent 8d0c61f commit c0f86f0

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/nanotron/parallel/tensor_parallel/functional.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ def forward(
389389

390390
@staticmethod
391391
def backward(ctx, grad_output: torch.Tensor):
392+
import debugpy
393+
394+
debugpy.breakpoint()
392395
# Either allgather the inputs again or get them from context.
393396
group = ctx.group
394397
tp_recompute_allgather = ctx.tp_recompute_allgather
@@ -414,7 +417,7 @@ def backward(ctx, grad_output: torch.Tensor):
414417
grad_weight = grad_output.T @ total_input
415418
grad_input = grad_output @ weight
416419
if group.size() == 1:
417-
sub_grad_input = grad_input
420+
sub_grad_input = grad_input.reshape(input_size) # [s*b, h_in] -> [s, b, h_in]
418421
else:
419422
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
420423
# We set grad_input to be contiguous in case it isn't already.

0 commit comments

Comments
 (0)