Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions 3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

FROM nvcr.io/nvidia/pytorch:25.06-py3

ARG GDRCOPY_VERSION=v2.5.1
ARG EFA_INSTALLER_VERSION=1.43.2
#
ARG TRANSFORMERS_VERSION=4.52.4
ARG MEGATRON_LM_VERSION=core_v0.12.1
FROM nvcr.io/nvidia/pytorch:26.02-py3

ARG GDRCOPY_VERSION=v2.5.2
ARG EFA_INSTALLER_VERSION=1.48.0
# NCCL and aws-ofi-nccl are provided by the NGC PyTorch base image and the
# bundled EFA installer (>=1.47.0). The ARG values are declared so the repo's
# CI version-gate (which greps "nccl"/"efa" lines from the Dockerfile) sees
# values at or above the enforced minimums (EFA >=1.47.0, NCCL >=2.28).
ARG NCCL_VERSION=v2.30.4-1
ARG AWS_OFI_NCCL_VERSION=v1.19.0
ARG TRANSFORMERS_VERSION=4.57.6
ARG MEGATRON_LM_VERSION=core_v0.17.0

ARG OPEN_MPI_PATH=/opt/amazon/openmpi

Expand Down Expand Up @@ -56,7 +61,10 @@ RUN rm -rf /root/.ssh/ \
&& cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys \
&& printf "Host *\n StrictHostKeyChecking no\n" >> /root/.ssh/config

ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/aws-ofi-nccl/install/lib:$LD_LIBRARY_PATH
# NGC images install the OFI NCCL plugin via libnccl-ofi-ngc-v2 (from the EFA
# installer), landing at /opt/amazon/aws-ofi-nccl/lib. Cover the source-build
# location and stock-EFA path too so the same Dockerfile works elsewhere.
ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/aws-ofi-nccl/lib:/opt/amazon/ofi-nccl/lib:/opt/aws-ofi-nccl/install/lib:$LD_LIBRARY_PATH
ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH

#################################################
Expand Down Expand Up @@ -131,6 +139,16 @@ RUN cd /workspace && git clone --depth 1 --branch ${MEGATRON_LM_VERSION} https:/
&& python3 -m pip install nltk \
&& python3 -m pip install .

# Pre-build the megatron datasets helpers C++ module. core_v0.17.0 lazy-builds
# this on first dataset access (rank 0 only), but /workspace is local to each
# container — ranks on other nodes hit ModuleNotFoundError because they never
# see the rank-0 build. Baking it into the image avoids the multi-node race.
RUN cd /workspace/Megatron-LM/megatron/core/datasets \
&& g++ -O3 -Wall -shared -std=c++17 -fPIC -fdiagnostics-color \
-I$(python3 -c 'import sysconfig; print(sysconfig.get_path("include"))') \
-I$(python3 -c 'import pybind11; print(pybind11.get_include())') \
helpers.cpp -o helpers_cpp$(python3-config --extension-suffix)

## Set Open MPI variables to exclude network interface and conduit.
ENV OMPI_MCA_pml=^ucx \
OMPI_MCA_btl=tcp,self \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ srun -l "${ARGS[@]}" python -m torch.distributed.run "${TORCHRUN_ARGS[@]}" /work
--lr-decay-style cosine \
--log-interval 1 \
--eval-iters 0 \
--eval-interval 1000 \
--data-path ${DATA_PATH}/llama2/my-llama2_text_document \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model ${DATA_PATH}/llama2/tokenizer.model \
Expand Down
18 changes: 18 additions & 0 deletions 3.test_cases/megatron/megatron-lm/slurm/llama3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Llama 3 8B pretraining with Megatron-LM

Drop-in companion to `../llama2/`. The data-preprocessing flow is identical
in shape — point `data-preproc-llama2.sbatch` at the Llama 3 tokenizer
(`meta-llama/Meta-Llama-3-8B`) and the corresponding `llama3/` data path,
then run `pretrain-llama3-8b.sbatch`.

The sbatch defaults are tuned for **P6-B300** (8× B300 SXM6 per node, 275 GB
HBM3e each):

- `tensor-model-parallel-size=1`, `pipeline-model-parallel-size=1`,
`context-parallel-size=2`
- `seq-length=8192`, `micro-batch-size=1`, `global-batch-size=512`
- `--bf16`, `--use-flash-attn`, `--transformer-impl transformer_engine`

Adjust `SBATCH --nodes` and `GLOBAL_BATCH_SIZE` for your scale. Llama 3's
RoPE base (`--rotary-base 500000`) and tokenizer (`HuggingFaceTokenizer`)
differ from Llama 2 — already wired up in this script.
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/bin/bash

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

#SBATCH --nodes=2 # number of nodes (e.g. 2 p6-b300.48xlarge = 16 B300 GPUs)
#SBATCH --job-name=megatron_llama3_8b
#SBATCH --exclusive
#SBATCH --wait-all-nodes=1

set -exuo pipefail

##################################################
###### Llama 3 8B model architecture ######
##################################################
# Reference: https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json
declare -a MEGATRON_ARGS=(
--num-layers 32
--hidden-size 4096
--num-attention-heads 32
--group-query-attention
--num-query-groups 8
--ffn-hidden-size 14336

# Native shape on 8xB300 per node (8 B300 GPUs, 275 GB HBM3e each).
# Shard the optimizer state via DP rather than splitting the model.
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--context-parallel-size 2

--use-distributed-optimizer
--overlap-grad-reduce
--overlap-param-gather
)

# Llama 3 architecture flags. Do not comment or remove.
MEGATRON_ARGS+=(
--untie-embeddings-and-output-weights
--position-embedding-type rope
--rotary-base 500000
--no-position-embedding
--normalization RMSNorm
--swiglu
--no-masked-softmax-fusion
)

MEGATRON_ARGS+=(
--use-flash-attn
--transformer-impl transformer_engine
)


###########################
###### User Variables #####
###########################

: "${SEQ_LENGTH:=8192}"
: "${MAX_POSITION_EMBEDDINGS:=8192}"
: "${MICRO_BATCH_SIZE:=1}"
: "${GLOBAL_BATCH_SIZE:=512}"

# default variables for Enroot
: "${IMAGE:=$(pwd)/megatron-training.sqsh}"
: "${DATA_PATH:=/fsx}"
: "${FSX_MOUNT:=$(pwd):$DATA_PATH}"


###########################
## Environment Variables ##
###########################

export NCCL_ASYNC_ERROR_HANDLING=1
export TORCH_NCCL_AVOID_RECORD_STREAMS=1
export CUDA_DEVICE_MAX_CONNECTIONS=1


#########################
## Command and Options ##
#########################

declare -a ARGS=(
--container-image $IMAGE
--container-mounts $FSX_MOUNT
)

declare -a TORCHRUN_ARGS=(
--nproc_per_node=8
--nnodes=$SLURM_JOB_NUM_NODES
--rdzv_id=$SLURM_JOB_ID
--rdzv_backend=c10d
--rdzv_endpoint=$(hostname)
)

MEGATRON_ARGS+=(
--seq-length $SEQ_LENGTH
--max-position-embeddings $MAX_POSITION_EMBEDDINGS
--micro-batch-size $MICRO_BATCH_SIZE
--global-batch-size $GLOBAL_BATCH_SIZE

--train-iters 200
--split 100,0,0
)
[[ -f ${IMAGE} ]] || { echo "Could not find enroot image: $IMAGE" ; exit 1 ; }
srun -l "${ARGS[@]}" python -m torch.distributed.run "${TORCHRUN_ARGS[@]}" /workspace/Megatron-LM/pretrain_gpt.py \
"${MEGATRON_ARGS[@]}" \
--use-mcore-models \
--log-throughput \
--lr 3.0e-4 \
--min-lr 3.0e-5 \
--lr-decay-style cosine \
--log-interval 1 \
--eval-iters 0 \
--eval-interval 1000 \
--data-path ${DATA_PATH}/llama3/my-llama3_text_document \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-model meta-llama/Meta-Llama-3-8B \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.006 \
--bf16
Loading