Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 138 additions & 32 deletions tests/pytorch/test_fused_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def test_fused_rope(
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")

if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")

device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
Expand Down Expand Up @@ -102,11 +98,8 @@ def test_fused_rope(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)

if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()

loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
Expand All @@ -121,17 +114,12 @@ def test_fused_rope(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)

if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused)

if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)

torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()


Expand All @@ -156,10 +144,6 @@ def test_fused_rope_thd(
margin: int,
) -> None:

if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")

device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
Expand Down Expand Up @@ -214,10 +198,8 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)

if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
Expand All @@ -233,18 +215,142 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)

if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()

if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)

assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [1.0])
@pytest.mark.parametrize("loss_func", [_overlapping_grad])
@pytest.mark.parametrize("cp_size", [2])
@pytest.mark.parametrize("interleaved", [False, True])
def test_unfused_rope_thd_vs_bshd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
"""
This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD
formats are the same.
"""
device = torch.device("cuda:0")
seqlen, max_seqlen = 16, 2048
batch_size, head_num = 4, 256

# NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and
# that causes unexpected issues.
seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32)

cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to(
device=device, dtype=torch.int32
)

# Create a tensor in THD format
thd = torch.rand(
(cu_seqlens[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
thd.requires_grad = True

# Clone the tensor to create a tensor in BSHD format
bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach()
bshd = bshd.to(dtype=dtype, device=device)
bshd.requires_grad = True

# Clone the tensor to create a tensor in SBHD format
sbhd = bshd.transpose(1, 0).clone().detach()
sbhd = sbhd.to(dtype=dtype, device=device)
sbhd.requires_grad = True

rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(max_seqlen)
assert emb.is_contiguous()

start_positions = cu_seqlens[:-1] if start_positions else None

for cp_rank in range(cp_size):
# unfused bshd
output_unfused_bshd = apply_rotary_pos_emb(
bshd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="bshd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_bshd = loss_func(output_unfused_bshd)
loss_unfused_bshd.backward()
grad_unfused_bshd = bshd.grad.detach().clone()
bshd.grad = None

# unfused sbhd
output_unfused_sbhd = apply_rotary_pos_emb(
sbhd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="sbhd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)

loss_unfused_sbhd = loss_func(output_unfused_sbhd)
loss_unfused_sbhd.backward()
grad_unfused_sbhd = sbhd.grad.detach().clone()
sbhd.grad = None

# unfused thd
output_unfused_thd = apply_rotary_pos_emb(
thd.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)

loss_unfused_thd = loss_func(output_unfused_thd)
loss_unfused_thd.backward()
grad_unfused_thd = thd.grad.detach().clone()
thd.grad = None

torch.testing.assert_close(
output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd
)
torch.testing.assert_close(
output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape),
output_unfused_thd,
)
torch.testing.assert_close(
grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd
)
torch.testing.assert_close(
grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd
)

assert output_unfused_thd.is_contiguous()
assert output_unfused_bshd.is_contiguous()
assert output_unfused_sbhd.is_contiguous()


@pytest.mark.parametrize("start_positions", [True, False])
Expand Down
85 changes: 44 additions & 41 deletions transformer_engine/common/fused_rope/fused_rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,18 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens = s;
}

int s_id_for_freqs;
// Offset the RoPE embedding by start_positions if provided.
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
int s_id_for_freqs = s_id + begin_offset;

// If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order.
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
s_id_for_freqs += cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2;
}
} else {
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}

fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
Expand All @@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq

template <typename scalar_t>
__global__ void fused_rope_backward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst,
const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h,
const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h,
const int o_stride_d) {
const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions,
scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
Expand All @@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel(
cur_seqlens = s;
}

int s_id_for_freqs;
// Offset the RoPE embedding by start_positions if provided.
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
int s_id_for_freqs = s_id + begin_offset;

// If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order.
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
s_id_for_freqs += cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}

fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
Expand Down Expand Up @@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c

template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
const float *freqs, const int *start_positions,
scalar_t *input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
Expand All @@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const int o_stride_d = 1;

fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank,
s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b,
o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}

Expand Down Expand Up @@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten
}

void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
const Tensor &start_positions, Tensor *input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream););
Expand Down Expand Up @@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
}

void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
*convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions),
convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size,
cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
}

void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
Expand All @@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream);

/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
Expand Down
Loading
Loading