From df8767cefe515f95a2b31527ad0c88b78e148875 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 18 Sep 2025 13:43:24 -0700 Subject: [PATCH 1/5] enable applying rope offsets in backwared Signed-off-by: Sudhakar Singh --- .../common/fused_rope/fused_rope.cu | 56 ++++++++++--------- .../include/transformer_engine/fused_rope.h | 13 +++-- transformer_engine/pytorch/attention/rope.py | 7 ++- transformer_engine/pytorch/csrc/extensions.h | 1 + .../pytorch/csrc/extensions/apply_rope.cpp | 11 +++- 5 files changed, 51 insertions(+), 37 deletions(-) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index ccd0bc44c5..ef14d28f54 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -155,18 +155,19 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; - s_id_for_freqs = s_id + begin_offset; } fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -175,7 +176,7 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq template __global__ void fused_rope_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, @@ -197,17 +198,19 @@ __global__ void fused_rope_backward_kernel( cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - s_id_for_freqs = s_id; } fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -495,7 +498,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c template void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, + const float *freqs, const int *start_positions, scalar_t *input_grads, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, @@ -521,7 +524,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); @@ -590,7 +593,7 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *input_grads, const NVTE_QKV_Format qkv_format, + const Tensor &start_positions, Tensor *input_grads, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, @@ -600,6 +603,7 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);); @@ -663,18 +667,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), - *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads), - qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream); + *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size, cp_rank, + s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 610868f932..19047f463b 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] qkv_format QKV format. * \param[in] interleaved Whether to use interleaved rotary position embedding. @@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream); + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream); /*! \brief Apply rotary positional embedding to the combined QKV input tensor. * diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 139381f2dd..48efdbd085 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -145,7 +145,7 @@ def forward( cp_size, cp_rank, ) - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank @@ -156,10 +156,11 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" - freqs, cu_seqlens = ctx.saved_tensors + freqs, cu_seqlens, start_positions = ctx.saved_tensors grad_input = tex.fused_rope_backward( grad_output, freqs, + start_positions, QKVFormat[ctx.tensor_format], ctx.interleaved, cu_seqlens, @@ -167,7 +168,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.cp_rank, ) - return grad_input, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None class FusedQKVRoPEFunc(torch.autograd.Function): diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4cb05725bc..9dd154ca1a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -334,6 +334,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank); diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 064da8a670..82f8ac6956 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -163,6 +163,7 @@ std::tuple fused_qkv_rope_forward( } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank) { @@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); + } + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); @@ -208,7 +215,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, + start_positions_cu.data(), input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); @@ -246,7 +253,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, + start_positions_cu.data(), input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); From 472142c7796473284dc257c36a85281d77878893 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 18 Sep 2025 20:55:17 -0700 Subject: [PATCH 2/5] add tests for rope offsets for thd/bshd/sbhd formats Signed-off-by: Sudhakar Singh --- tests/pytorch/test_fused_rope.py | 159 +++++++++++++++---- transformer_engine/pytorch/attention/rope.py | 75 ++++----- 2 files changed, 155 insertions(+), 79 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 62d80b5529..7c775c5e5d 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -58,10 +58,6 @@ def test_fused_rope( # are with the maximum length of the rope embeddings. pytest.skip("Skipping test with margin=0 and start_positions=True") - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( @@ -102,11 +98,8 @@ def test_fused_rope( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -121,17 +114,12 @@ def test_fused_rope( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) - - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - + torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @@ -156,10 +144,6 @@ def test_fused_rope_thd( margin: int, ) -> None: - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] @@ -214,10 +198,8 @@ def test_fused_rope_thd( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -233,19 +215,132 @@ def test_fused_rope_thd( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) + assert output_fused.is_contiguous() - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) +@pytest.mark.parametrize("start_positions", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [128,256]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [2]) +@pytest.mark.parametrize("interleaved", [False, True]) +def test_unfused_rope_thd_vs_bshd( + dtype: torch.dtype, + hidden_size: int, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + """ + This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD + formats are the same. + """ + device = torch.device("cuda:0") + seqlen, max_seqlen = 16, 2048 + batch_size, head_num = 4, 256 - assert output_fused.is_contiguous() + # NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and + # that causes unexpected issues. + seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype = torch.int32) + + cu_seqlens = torch.cumsum( + torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0 + ).to(device=device, dtype=torch.int32) + + # Create a tensor in THD format + thd = torch.rand( + (cu_seqlens[-1] // cp_size, head_num, hidden_size), + dtype=dtype, + device=device, + ) + thd.requires_grad = True + + # Clone the tensor to create a tensor in BSHD format + bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach() + bshd = bshd.to(dtype=dtype, device=device) + bshd.requires_grad = True + + # Clone the tensor to create a tensor in SBHD format + sbhd = bshd.transpose(1, 0).clone().detach() + sbhd = sbhd.to(dtype=dtype, device=device) + sbhd.requires_grad = True + + # rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + # emb = rotary_pos_emb(max_seqlen) + emb = torch.randn(max_seqlen, 1, 1, hidden_size).cuda() + assert emb.is_contiguous() + + start_positions = cu_seqlens[:-1] if start_positions else None + + for cp_rank in range(cp_size): + # unfused bshd + output_unfused_bshd = apply_rotary_pos_emb( + bshd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="bshd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + loss_unfused_bshd = loss_func(output_unfused_bshd) + loss_unfused_bshd.backward() + grad_unfused_bshd = bshd.grad.detach().clone() + bshd.grad = None + + # assert torch.allclose(sbhd.transpose(1, 0), bshd) + # unfused sbhd + output_unfused_sbhd = apply_rotary_pos_emb( + sbhd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="sbhd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_sbhd = loss_func(output_unfused_sbhd) + loss_unfused_sbhd.backward() + grad_unfused_sbhd = sbhd.grad.detach().clone() + sbhd.grad = None + + # unfused thd + output_unfused_thd = apply_rotary_pos_emb( + thd.float(), + emb, + start_positions=start_positions, + tensor_format="thd", + interleaved=interleaved, + fused=False, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_thd = loss_func(output_unfused_thd) + loss_unfused_thd.backward() + grad_unfused_thd = thd.grad.detach().clone() + thd.grad = None + + torch.testing.assert_close(output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd) + torch.testing.assert_close(output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape), output_unfused_thd) + torch.testing.assert_close(grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd) + torch.testing.assert_close(grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd) + assert output_unfused_thd.is_contiguous() + assert output_unfused_bshd.is_contiguous() + assert output_unfused_sbhd.is_contiguous() @pytest.mark.parametrize("start_positions", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 48efdbd085..fdc464c312 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -275,7 +275,6 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _apply_rotary_pos_emb_base( t: torch.Tensor, freqs: torch.Tensor, - start_positions: torch.Tensor = None, tensor_format: str = "sbhd", interleaved: bool = False, ) -> torch.Tensor: @@ -288,45 +287,19 @@ def _apply_rotary_pos_emb_base( Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional embedding will be applied. freqs: torch.Tensor - Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', - with `s2 >= s` and `d2 <= d`. - start_positions: torch.Tensor, default = None. - Tokens in a sequence `i` should be applied with position encoding offset by - `start_positions[i]`. If `start_positions=None`, there's no offset. + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]` + and dtype 'float', with `s2 >= s` and `d2 <= d`. tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape `[seq, bs, ...]`. interleaved: bool, default = False Whether to use interleaved rotary position embedding. """ - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # In case `start_positions` are provided, create a staggered `freqs` tensor - # offset by the values in `start_positions`. - # `start_positions` is only supported for `cp_size=1` and inference. - if start_positions is not None: - max_offset = torch.max(start_positions) - assert ( - max_offset + cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!" - - # Stack staggered rope embeddings along the batch dimension - freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) - - # Note that from this point, `freqs` has a shape `(s,b,1,d)`. - - # Only apply the rotary embeddings up to the sequence length of the running - # input. - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - # [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, b, 1, dim] -> [b, seq, 1, dim] if tensor_format == "bshd": freqs = freqs.transpose(0, 1) + # cos/sin first then dtype conversion for better precision cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) @@ -363,7 +336,7 @@ def _get_freqs_on_this_cp_rank( ) # cp_size == 1 - return freqs + return freqs[:seqlen] def apply_rotary_pos_emb( @@ -391,7 +364,7 @@ def apply_rotary_pos_emb( qkv_formats: "thd", "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- @@ -420,22 +393,17 @@ def apply_rotary_pos_emb( cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ - - # `start_positions` is only supported for `cp_size=1` and inference. - assert not ( - cp_size > 1 and start_positions is not None - ), """start_positions != None with CP SIZE > 1 is not supported!""" - assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." + # Fused apply rope logic for THD/BSHD/SBHD formats if fused: return FusedRoPEFunc.apply( t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank ) - # Unfused THD format + # Unfused apply rope logic for THD format if tensor_format == "thd": cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -444,15 +412,17 @@ def apply_rotary_pos_emb( # `s1hd` tensors (for each sequence) and applies rotary embedding to # those sequences individually. # Note that if `start_positions` is not `None`, then for each sequence, - # it's corresponding rope offset is also supplied from `start_positions` - # individually. + # the freqs supplied are offset by the corresponding `start_positions` value. return torch.cat( [ _apply_rotary_pos_emb_base( x.unsqueeze(1), - _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), - start_positions=( - start_positions[idx : idx + 1] if start_positions is not None else None + _get_freqs_on_this_cp_rank( + freqs[start_positions[idx]:] \ + if start_positions is not None else freqs, # offset the freqs + x.size(0), + cp_size, + cp_rank, ), interleaved=interleaved, ) @@ -460,17 +430,28 @@ def apply_rotary_pos_emb( ] ).squeeze(1) - # Unfused SBHD/BSHD format + # Unfused apply rope logic for SBHD/BSHD format follows ... + if tensor_format == "sbhd": seqlen = t.size(0) elif tensor_format == "bshd": seqlen = t.size(1) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + + if start_positions is not None: + max_offset = torch.max(start_positions) + assert ( + max_offset + seqlen * cp_size <= freqs.shape[0] + ), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!" + + # Stack staggered rope embeddings along the batch dimension + freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1) + # Note that from this point, `freqs` has a shape `(s,b,1,d)`. + return _apply_rotary_pos_emb_base( t, _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), - start_positions, tensor_format, interleaved=interleaved, ) @@ -502,7 +483,7 @@ def apply_fused_qkv_rotary_pos_emb( qkv_formats: "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- From 81f40e64de2ae9a8dfe66831133e68dd7c2a4889 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 04:06:45 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fused_rope.py | 29 +++++++---- .../common/fused_rope/fused_rope.cu | 49 +++++++++---------- transformer_engine/pytorch/attention/rope.py | 5 +- .../pytorch/csrc/extensions/apply_rope.cpp | 10 ++-- 4 files changed, 52 insertions(+), 41 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 7c775c5e5d..fce8b8e136 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -223,9 +223,10 @@ def test_fused_rope_thd( torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() + @pytest.mark.parametrize("start_positions", [False, True]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [128,256]) +@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("loss_func", [_overlapping_grad]) @pytest.mark.parametrize("cp_size", [2]) @pytest.mark.parametrize("interleaved", [False, True]) @@ -247,11 +248,11 @@ def test_unfused_rope_thd_vs_bshd( # NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and # that causes unexpected issues. - seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype = torch.int32) + seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32) - cu_seqlens = torch.cumsum( - torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0 - ).to(device=device, dtype=torch.int32) + cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to( + device=device, dtype=torch.int32 + ) # Create a tensor in THD format thd = torch.rand( @@ -333,15 +334,25 @@ def test_unfused_rope_thd_vs_bshd( grad_unfused_thd = thd.grad.detach().clone() thd.grad = None - torch.testing.assert_close(output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd) - torch.testing.assert_close(output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape), output_unfused_thd) - torch.testing.assert_close(grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd) - torch.testing.assert_close(grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd) + torch.testing.assert_close( + output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd + ) + torch.testing.assert_close( + output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape), + output_unfused_thd, + ) + torch.testing.assert_close( + grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + torch.testing.assert_close( + grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) assert output_unfused_thd.is_contiguous() assert output_unfused_bshd.is_contiguous() assert output_unfused_sbhd.is_contiguous() + @pytest.mark.parametrize("start_positions", [True, False]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096]) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index ef14d28f54..597a5d3c29 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -165,8 +165,7 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq if (s_id < cur_seqlens / 2) { s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs += - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } } @@ -176,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq template __global__ void fused_rope_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, scalar_t *dst, - const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, - const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, - const int o_stride_d) { + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s, + const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block, offset_block_dst; int cur_seqlens; @@ -208,8 +207,7 @@ __global__ void fused_rope_backward_kernel( if (s_id < cur_seqlens / 2) { s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs += - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } } @@ -498,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c template void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, const int *start_positions, scalar_t *input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const float *freqs, const int *start_positions, + scalar_t *input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); @@ -524,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, - stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, - o_stride_d); + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, + s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, + o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -593,11 +591,12 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, - const Tensor &start_positions, Tensor *input_grads, const NVTE_QKV_Format qkv_format, - const bool interleaved, const int cp_size, const int cp_rank, const int s, - const int b, const int h, const int d, const int d2, - const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, cudaStream_t stream) { + const Tensor &start_positions, Tensor *input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), @@ -677,8 +676,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu using namespace transformer_engine; fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions), - convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size, cp_rank, - s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); + convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size, + cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index fdc464c312..ea4ac80a61 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -418,8 +418,9 @@ def apply_rotary_pos_emb( _apply_rotary_pos_emb_base( x.unsqueeze(1), _get_freqs_on_this_cp_rank( - freqs[start_positions[idx]:] \ - if start_positions is not None else freqs, # offset the freqs + ( + freqs[start_positions[idx] :] if start_positions is not None else freqs + ), # offset the freqs x.size(0), cp_size, cp_rank, diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 82f8ac6956..d1dcf68c3d 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -215,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - start_positions_cu.data(), input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, - max_s, b, h, d, d2, stride_t, + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; @@ -253,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - start_positions_cu.data(), input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, - h, d, d2, stride_s, stride_b, stride_h, stride_d, - at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; } From 79bc5d60ef2f87c17584e7e37b1c7ef3b2c0e43e Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 18 Sep 2025 21:10:04 -0700 Subject: [PATCH 4/5] minor fixes Signed-off-by: Sudhakar Singh --- tests/pytorch/test_fused_rope.py | 8 ++++---- transformer_engine/pytorch/attention/rope.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 7c775c5e5d..b1c8a1801b 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -226,12 +226,14 @@ def test_fused_rope_thd( @pytest.mark.parametrize("start_positions", [False, True]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [128,256]) +@pytest.mark.parametrize("rotary_percent", [1.0]) @pytest.mark.parametrize("loss_func", [_overlapping_grad]) @pytest.mark.parametrize("cp_size", [2]) @pytest.mark.parametrize("interleaved", [False, True]) def test_unfused_rope_thd_vs_bshd( dtype: torch.dtype, hidden_size: int, + rotary_percent: float, loss_func: Callable, cp_size: int, interleaved: bool, @@ -271,9 +273,8 @@ def test_unfused_rope_thd_vs_bshd( sbhd = sbhd.to(dtype=dtype, device=device) sbhd.requires_grad = True - # rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) - # emb = rotary_pos_emb(max_seqlen) - emb = torch.randn(max_seqlen, 1, 1, hidden_size).cuda() + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb = rotary_pos_emb(max_seqlen) assert emb.is_contiguous() start_positions = cu_seqlens[:-1] if start_positions else None @@ -296,7 +297,6 @@ def test_unfused_rope_thd_vs_bshd( grad_unfused_bshd = bshd.grad.detach().clone() bshd.grad = None - # assert torch.allclose(sbhd.transpose(1, 0), bshd) # unfused sbhd output_unfused_sbhd = apply_rotary_pos_emb( sbhd.float(), diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index fdc464c312..5f97771ea0 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -358,7 +358,7 @@ def apply_rotary_pos_emb( Training: qkv_formats: "thd", "bshd", "sbhd" context parallel: yes - start_positions: no + start_positions: yes interleaving: yes Inference: qkv_formats: "thd", "bshd", "sbhd" From 3504f61157d3361d77ee0c75ef03bb1c193df09a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Sep 2025 04:11:10 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fused_rope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 8bbfdee32a..1b535f6585 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -226,7 +226,7 @@ def test_fused_rope_thd( @pytest.mark.parametrize("start_positions", [False, True]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [128,256]) +@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("rotary_percent", [1.0]) @pytest.mark.parametrize("loss_func", [_overlapping_grad]) @pytest.mark.parametrize("cp_size", [2])