Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for context parallelism #1299

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,12 @@ def build_train_valid_test_data_loaders(neox_args):
else:
pipe_load = True

# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model and context parallel group.
if (
pipe_load
and (neox_args.dataset_impl == "online")
and (mpu.get_model_parallel_rank() == 0)
and (mpu.get_context_parallel_rank() == 0)
):
# Can skip most of the work...
train_iters = neox_args.train_iters
Expand Down Expand Up @@ -721,11 +722,17 @@ def build_train_valid_test_data_loaders(neox_args):
# broadcast globally instead of just the model parallel group.
torch.distributed.broadcast(flags, src=0)
else:
# The same data should be used for the model parallel and context parallel groups
torch.distributed.broadcast(
flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
torch.distributed.broadcast(
flags,
mpu.get_context_parallel_src_rank(),
group=mpu.get_context_parallel_group(),
)
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()
Expand Down
16 changes: 11 additions & 5 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,20 @@ def _initialize_distributed(neox_args):
# Setup 3D topology.
pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
cp = neox_args.context_parallel_size if neox_args.context_parallel_size >= 1 else 1
assert (
neox_args.world_size % (pp * mp * cp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, cp={cp}"
assert (
neox_args.world_size % (pp * mp) == 0
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
# The data parallel ranks will be used for context parallel
# to piggy back the gradient all reduce
dp = neox_args.world_size // (pp * mp)
assert dp % cp == 0
from deepspeed.runtime.pipe.topology import ProcessTopology

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology

# this does pipe on the most outside, then data, then model.
# PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order.
topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)
topo = ProcessTopology(axes=["pipe", "data", "model"], dims=[pp, dp, mp])

# Offset base seeds for the interior pipeline stages.
# TODO: adjust last stage too once IO is improved.
Expand All @@ -186,6 +190,8 @@ def _initialize_distributed(neox_args):
else:
mpu.initialize_model_parallel(
neox_args.model_parallel_size,
neox_args.pipe_parallel_size,
neox_args.context_parallel_size,
topology=topo,
fp32_allreduce=neox_args.fp32_allreduce,
)
Expand Down
8 changes: 4 additions & 4 deletions megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
normalized_shape,
eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bclyang -- Shouldn't these remain sequence_parallel to match our previous support for megatron-style sequence parallelism (essentially just TP applied to layernorm and dropout)?

context_parallel=False,
apply_layernorm_1p=False,
mem_efficient_ln=True,
):
Expand Down Expand Up @@ -92,11 +92,11 @@ def __init__(
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
self.context_parallel = context_parallel

# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, "sequence_parallel", self.sequence_parallel)
setattr(self.bias, "sequence_parallel", self.sequence_parallel)
setattr(self.weight, "context_parallel", self.context_parallel)
setattr(self.bias, "context_parallel", self.context_parallel)

def reset_parameters(self):

Expand Down
25 changes: 24 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,30 @@ def cross_entropy(output, labels, _fp16=False):
else:
losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
loss_mask_sum = loss_mask.sum()
if mpu.get_context_parallel_world_size() > 1:
dt = loss_mask_sum.dtype
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
loss_mask_sum = loss_mask_sum.float()
torch.distributed.all_reduce(
loss_mask_sum,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_context_parallel_group(),
)
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
loss_mask_sum = loss_mask_sum.bfloat16()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
loss = loss.float()
torch.distributed.all_reduce(
loss,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_context_parallel_group(),
)
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
loss = loss.bfloat16()
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum
return loss


Expand Down
23 changes: 22 additions & 1 deletion megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
import math
import megatron.mpu as mpu


class SinusoidalPositionalEmbedding(torch.nn.Module):
Expand All @@ -37,7 +38,13 @@ def forward(self, x, seq_dim=1):

class RotaryEmbedding(torch.nn.Module):
def __init__(
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False
self,
dim,
max_seq_len,
base=10000,
precision=torch.half,
save_inv_freqs=False,
zigzag=True,
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
Expand All @@ -49,6 +56,7 @@ def __init__(
self.max_seq_len = max_seq_len
self.base = base
self.dim = dim
self.zigzag = zigzag # seq parallel zigzag

# precompute cos_cached, sin_cached in fp32
cos_cached, sin_cached, inv_freq = self._prepare_cache(
Expand All @@ -64,6 +72,19 @@ def _prepare_cache(self, seq_len, precision, base):
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim))

t = torch.arange(seq_len).type_as(inv_freq)
if mpu.get_context_parallel_world_size() > 1:
if not self.zigzag:
t_chunks = torch.chunk(t, mpu.get_context_parallel_world_size())
t = t_chunks[mpu.get_context_parallel_rank()].contiguous()
else:
t_chunks = torch.chunk(t, 2 * mpu.get_context_parallel_world_size())
t = torch.cat(
(
t_chunks[mpu.get_context_parallel_rank()],
t_chunks[-(mpu.get_context_parallel_rank() + 1)],
),
dim=0,
).contiguous()
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

Expand Down
103 changes: 101 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def __init__(
self.rope_fusion = neox_args.rope_fusion
self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.use_ring_attention = self.attention_type == "ring"
self.use_triton = (
self.use_flash_attention
and self.pos_emb == "alibi"
Expand All @@ -467,7 +468,7 @@ def __init__(
>= packaging.version.Version("2.4.0.post1")
)
)
self.sparse = self.attention_type not in ("global", "flash")
self.sparse = self.attention_type not in ("global", "flash", "ring")

if self.gqa:
assert not self.sparse
Expand Down Expand Up @@ -496,6 +497,12 @@ def __init__(
self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
self.flash_qkv_fn = flash_attn_func
self.flash_varlen_qkv_fn = flash_attn_varlen_func
elif self.use_ring_attention:
from ring_flash_attn.zigzag_ring_flash_attn import (
zigzag_ring_flash_attn_func,
)

self.ring_attn_fn = zigzag_ring_flash_attn_func
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
Expand Down Expand Up @@ -743,6 +750,96 @@ def flash_attention(self, query_layer, key_layer, value_layer):

return matmul_result

def ring_attention(self, query_layer, key_layer, value_layer):
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)

# [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn]
key_layer = key_layer.transpose(0, 1).reshape(
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
)
value_layer = value_layer.transpose(0, 1).reshape(
output_size[0], output_size[3], self.num_kv_heads_per_partition, -1
)

# [sq, b, np, hn] -> [b, sq, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(
output_size[0], output_size[2], output_size[1], -1
)

# only pass in window_size or alibi_slopes kwarg
# if we use Sliding Window Attention / AliBi.
# Flash attn defaults to (-1,-1), or
# does not have this kwarg prior to v2.3.0
extra_kwargs = (
{"window_size": (self.sliding_window_width, -1)}
if self.sliding_window_width is not None
else {}
)
if self.pos_emb == "alibi":
extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to(
query_layer.device
).to(torch.float32)

if not self.training:
batch_size = output_size[0]
max_seqlen_q = output_size[2]
max_seqlen_k = output_size[3]

cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * max_seqlen_q,
step=max_seqlen_q,
dtype=torch.int32,
device=query_layer.device,
)

cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * max_seqlen_k,
step=max_seqlen_k,
dtype=torch.int32,
device=key_layer.device,
)

q_shape = query_layer.shape
k_shape = key_layer.shape
v_shape = value_layer.shape
is_causal = max_seqlen_q == max_seqlen_k
output = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
0.0,
softmax_scale=None,
causal=is_causal,
group=mpu.get_context_parallel_group(),
**extra_kwargs,
)
output = output.reshape(q_shape)
else:
output = self.ring_attn_fn(
query_layer,
key_layer,
value_layer,
self.dropout_p if self.training else 0.0,
softmax_scale=None,
causal=True,
group=mpu.get_context_parallel_group(),
**extra_kwargs,
)

matmul_result = output
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)

return matmul_result

def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
Expand Down Expand Up @@ -818,7 +915,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
value_layer = value_layer.view(*new_kv_shape)

# if not using Flash attention, we repeat K/V heads to match Q head counts
if not self.use_flash_attention:
if not (self.use_flash_attention or self.use_ring_attention):
key_layer = torch.repeat_interleave(
key_layer,
repeats=int(
Expand Down Expand Up @@ -929,6 +1026,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None):

if self.use_flash_attention:
context_layer = self.flash_attention(query_layer, key_layer, value_layer)
elif self.use_ring_attention:
context_layer = self.ring_attention(query_layer, key_layer, value_layer)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
Expand Down
4 changes: 2 additions & 2 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,14 @@ def reduce_weight_grads_from_model_parallel_region(input_):

# Bf16 convert
dt = input_.dtype
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
input_ = input_.float()

# All-reduce.
dist.all_reduce(input_, group=mpu.get_model_parallel_group())

# Bf16 convert
if dt == torch.bfloat16 and mpu.get_fp32_allreduce():
if dt == torch.bfloat16 and mpu.initialize.get_fp32_allreduce():
input_ = input_.bfloat16()

return input_
Expand Down
7 changes: 7 additions & 0 deletions megatron/mpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@

from .utils import divide
from .utils import split_tensor_along_last_dim
from .data import zigzag_data
from .initialize import (
get_context_parallel_group,
get_context_parallel_rank,
get_context_parallel_world_size,
get_context_parallel_src_rank,
)
Loading
Loading