diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0f00b8b0ef..baa2837960 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -73,7 +73,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] + configs = ["cp_1_0", "cp_1_4", "cp_2_0", "cp_2_1", "cp_3_2", "cp_3_3"] model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} dtypes = ["bf16"] qkv_formats = ["sbhd", "thd"] @@ -224,10 +224,14 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") - if qkv_format == "thd" and cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") - if qkv_format == "thd" and "a2a" in cp_comm_type: - pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") + if qkv_format == "thd": + if cp_comm_type == "all_gather": + pytest.skip("CP implementation with KV all-gather does not support THD format yet!") + if cp_comm_type == "a2a+p2p": + pytest.skip( + "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" + " yet!" + ) if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" @@ -281,6 +285,14 @@ def test_cp_with_fused_attention( ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + + if qkv_format == "thd": + print(f"config.attn_mask_type: {config.attn_mask_type}") + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" + fp8_meta = {} fp8_meta["recipe"] = None fp8_meta["local_recipes"] = [] diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d1374e949e..dd91fd5c9a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,6 +4,7 @@ """Context Parallelism.""" import os +import itertools from typing import List, Union, Tuple import torch import transformer_engine_torch as tex @@ -258,6 +259,142 @@ def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size return x +def reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size, seq_dim=0): + """ + Reorder sequence chunks for A2A communication that happens after attention + compute. + + Args: + x: The input tensor to be reordered. + cu_seqlens: The cumulative sequence lengths of the input tensor. + cp_size: The number of ranks participating in context parallelism. + seq_dim: The dimension in which to reorder. + + Returns: + The reordered tensor. + + Example: + x: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., + 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., + 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] + cu_seqlens: [ 0, 2, 4, 6, 10] + cp_size: 4 + + Returns: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6., + 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., + 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] + + + This logic is similar to how the DualChunking is done to split the sequence + for each rank. Here, the indices of sequence chunks for all those ranks + are concatenated together. So the returned tensor ends up looking like as if + the chunks from all the ranks are concatenated together. + + e.g. [ + 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 + 1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1 + 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2 + 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 + ] + """ + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]) // total_slices_of_any_sequence + + indices = [ + ( + # 1st segment + torch.arange( + seq_start + (cp_rank * slice_size), seq_start + ((cp_rank + 1) * slice_size) + ), + # 2nd segment + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + ), + ) + for cp_rank in range(cp_size) + for slice_size, seq_start in zip(slice_sizes, cu_seqlens[:-1]) + ] + + # flatten the list of tuples to a list + indices = list(itertools.chain(*indices)) + indices = torch.cat(indices).to(device=cu_seqlens.device) + return x.index_select(seq_dim, indices) + + +def reorder_seq_chunks_after_a2a_before_attn_thd(x, cu_seqlens, seq_chunk_ids, cp_size, seq_dim=0): + """ + Reorder sequence chunks for A2A communication that happens before attention + compute. + + Args: + x: The input tensor to be reordered. + cu_seqlens: The cumulative sequence lengths of the input tensor. + seq_chunk_ids: The sequence chunk ids of the input `x` which is to be reordered. + cp_size: The number of ranks participating in context parallelism. + seq_dim: The dimension in which to reorder. + + Returns: + The reordered tensor. + + Example: + x: [ 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., 1., 6., 1., 6., + 1., 6., 2., 3., 12., 13., 2., 5., 2., 5., 2., 5., 4., 5., + 10., 11., 3., 4., 3., 4., 3., 4., 6., 7., 8., 9.] + cu_seqlens: [ 0, 8, 16, 24, 40] + seq_chunk_ids: [ 0, 2, 4, 6, 7, 5, 3, 1] + cp_size: 4 + + Returns: [ 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., + 6., 7., 0., 1., 2., 3., 4., 5., 6., 7., 0., 1., 2., 3., 4., 5., + 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.] + + Note that the input sequences (x) are arrangd after A2A communication as if DualChunked + chunks on all the ranks are concatenated together in the `seq_dim`. + + e.g. [ + 0., 7., 0., 7., 0., 7., 0., 1., 14., 15., # chunk on rank 0 + 1., 6., 1., 6., 1., 6., 2., 3., 12., 13., # chunk on rank 1 + 2., 5., 2., 5., 2., 5., 4., 5., 10., 11., # chunk on rank 2 + 3., 4., 3., 4., 3., 4., 6., 7., 8., 9. # chunk on rank 3 + ] + + Then the logic to serialize the sequences is: + 1. For every sequence segment on any rank (denoted by `start` and `end`): + 1a. For every chunk (in `chunk_id` and the total of those are twice as many as the number of CP ranks) : + 1aa. The first `cp_size` number of chunks form the first half of the whole sequence. Get those indices. + 1ab. The second `cp_size` number of chunks form the second half of the whole sequence. Get those indices. + 1b. Concatenate the indices of the first half and the second half. + 2. Reorder the entire input tensor by those indices. + """ + + max_cum_seqlen_per_cp_rank = cu_seqlens[-1] // cp_size + cu_seqlens_on_any_cp_rank = cu_seqlens // cp_size + + # Go through all the sequence segments (the sizes should be the same from all the ranks) + indices = [ + torch.arange( + # Calculate 'left' boundary + ( + start + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + if loc < cp_size + else (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + ), + # Calculate 'right' boundary + ( + (start + end) // 2 + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + if loc < cp_size + else end + max_cum_seqlen_per_cp_rank * (chunk_id // 2) + ), + ) + for start, end in zip(cu_seqlens_on_any_cp_rank[:-1], cu_seqlens_on_any_cp_rank[1:]) + for loc, chunk_id in enumerate(seq_chunk_ids) + ] + + indices = torch.cat(indices).to(device=cu_seqlens.device) + return x.index_select(seq_dim, indices) + + def flash_attn_a2a_communicate( a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], chunk_ids_for_a2a: torch.Tensor, @@ -266,8 +403,12 @@ def flash_attn_a2a_communicate( cp_group: dist_group_type, cp_stream: torch.cuda.Stream, before_attn: bool, + qkv_format: str = "bshd", + cu_seqlens: torch.Tensor = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: """A2A communication for context parallelism.""" + + assert qkv_format != "thd" or cu_seqlens is not None, "cu_seqlens is required for THD format!" a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs a2a_outputs, a2a_reqs = [None] * len(a2a_inputs), [None] * len(a2a_inputs) if before_attn: @@ -281,20 +422,33 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # reorder the sequence chunks - x = reorder_seq_chunks_for_a2a_before_attn( - x, chunk_ids_for_a2a, seq_dim, cp_size - ) - # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] - # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] - a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + if qkv_format in ["bshd", "sbhd"]: + # reorder the sequence chunks + x = reorder_seq_chunks_for_a2a_before_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size + ) + # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + a2a_outputs[i - 2] = x.view( + *x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :] + ) + else: # qkv_format == "thd" + # [cp, t, np//cp, hn] -> [cp*t, np//cp, hn] + x = x.view(-1, *x.shape[2:]) + # reorder the sequence chunks + a2a_outputs[i - 2] = reorder_seq_chunks_after_a2a_before_attn_thd( + x, cu_seqlens, chunk_ids_for_a2a, cp_size + ) + if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, h, d] -> [b, s, cp, h//cp, d] - # or [s, b, h, d] -> [s, b, cp, h//cp, d] + # [b, s, np, hn] -> [b, s, cp, np//cp, hn] + # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # or [t, np, hn] -> [t, cp, np//cp, hn] x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] - # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] + # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] + # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # or [t, cp, np//cp, hn] -> [cp, t, np//cp, hn] a2a_inputs[i] = x.movedim(-3, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -305,22 +459,30 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] - # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) - # reorder the sequence chunks - a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( - x, chunk_ids_for_a2a, seq_dim, cp_size - ) + if qkv_format in ["bshd", "sbhd"]: + # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) + # reorder the sequence chunks + a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size + ) + else: # qkv_format == "thd" + # reorder the sequence chunks + x = reorder_seq_chunks_before_a2a_after_attn_thd(x, cu_seqlens, cp_size) + # [cp*t, np//cp, hn] -> [cp, t, np//cp, hn] + a2a_inputs[i] = x.view(cp_size, -1, *x.shape[-2:]) if i > 1: with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] - # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] + # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] + # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # or [cp, t, np//cp, hn] -> [t, cp, np//cp, hn] x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] - # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] + # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] + # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # or [t, cp, np//cp, hn] -> [t, np, hn] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -3085,7 +3247,9 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" + assert ( + not padding or qkv_format == "thd" + ), f"{attn_mask_type} mask type is not supported for BSHD and SBHD!" assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -3136,11 +3300,14 @@ def forward( q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 ), "The number of attention heads needs to be divisible by CP size!" - assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - batch_dim = qkv_format.index("b") - seq_dim = qkv_format.index("s") + if qkv_format in ["bshd", "sbhd"]: + batch_dim = qkv_format.index("b") + seq_dim = qkv_format.index("s") + else: # qkv_format == "thd" + batch_dim = seq_dim = qkv_format.index("t") + assert ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" @@ -3185,7 +3352,15 @@ def forward( chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True + [q, k, v], + chunk_ids_for_a2a, + seq_dim, + cp_size, + cp_group, + cp_stream, + before_attn=True, + qkv_format=qkv_format, + cu_seqlens=cu_seqlens_q_padded, ) if softmax_type != "vanilla": softmax_offset = flash_attn_a2a_communicate_softmax_offset( @@ -3274,7 +3449,15 @@ def forward( chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) out_ = flash_attn_a2a_communicate( - out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + out_, + chunk_ids_for_a2a, + seq_dim, + cp_size, + cp_group, + cp_stream, + before_attn=False, + qkv_format=qkv_format, + cu_seqlens=cu_seqlens_q_padded, ) if use_fused_attention: @@ -3385,9 +3568,15 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + + qkv_format = ctx.qkv_format + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format causal = "causal" in ctx.attn_mask_type - seq_dim = ctx.qkv_format.index("s") + + if qkv_format in ["bshd", "sbhd"]: + seq_dim = qkv_format.index("s") + else: # qkv_format == "thd" + seq_dim = qkv_format.index("t") bwd_nominal_dtype = ctx.fwd_nominal_dtype dqkv_te_dtype = None @@ -3417,14 +3606,23 @@ def backward(ctx, dout): fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: - out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) + if qkv_format in ["bshd", "sbhd"]: + out = out.view(ctx.batch_size, -1, *out.shape[-2:]) + dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) else: dout = dout.view(*ctx.out_shape) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) dout = flash_attn_a2a_communicate( - dout, chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + dout, + chunk_ids_for_a2a, + seq_dim, + cp_size, + ctx.cp_group, + ctx.cp_stream, + before_attn=True, + qkv_format=qkv_format, + cu_seqlens=cu_seqlens_q_padded, ) flash_attn_bwd = None @@ -3441,7 +3639,7 @@ def backward(ctx, dout): fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - if ctx.qkv_format == "thd": + if qkv_format == "thd": from transformer_engine.pytorch.attention.dot_product_attention.backends import ( _flash_attn_varlen_bwd, ) @@ -3509,7 +3707,7 @@ def backward(ctx, dout): fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, - ctx.qkv_format, + qkv_format, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, @@ -3534,12 +3732,20 @@ def backward(ctx, dout): chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( - [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False + [dq, dk, dv], + chunk_ids_for_a2a, + seq_dim, + cp_size, + ctx.cp_group, + ctx.cp_stream, + before_attn=False, + qkv_format=qkv_format, + cu_seqlens=cu_seqlens_q_padded, ) - if ctx.qkv_format == "bshd": + if qkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - elif ctx.qkv_format == "sbhd": + elif qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] d_bias = None