From a9fbfbd7e44e812140cef1c98ddefa575356fdbb Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Wed, 29 Apr 2026 13:43:36 +0000 Subject: [PATCH 1/5] megatron-lm: bump to NGC pytorch:26.02 and add Llama 3 8B sbatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NGC pytorch:26.02-py3 ships CUDA 13 and a recent NCCL with native sm_103 binaries, which Blackwell Ultra (B300) needs to avoid the PTX-JIT slow path. The ARG bumps line up with the repo's CI version gate (EFA >= 1.47.0, NCCL >= 2.28, CUDA >= 13.0): EFA installer 1.43.2 -> 1.48.0, plus explicit NCCL_VERSION / AWS_OFI_NCCL_VERSION ARGs so the gate's `grep nccl` finds compliant values even though the base image and EFA installer provide them. GDRCopy v2.5.1 -> v2.5.2. Megatron-LM is bumped from core_v0.12.1 to core_v0.17.0 (Apr 2026 release, latest core_v* tag); transformers 4.52.4 -> 4.57.6 (latest 4.x patch — staying on 4.x to avoid the API breaks in 5.x). Adds slurm/llama3/pretrain-llama3-8b.sbatch as a Llama 3 8B-specific launcher. Defaults are tuned for 8x B300 per node: TP=1, PP=1, CP=2, seq_len=8192, bf16, transformer_engine. Uses HuggingFaceTokenizer with meta-llama/Meta-Llama-3-8B and rotary-base=500000 (the Llama 3 RoPE base, distinct from Llama 2's 10000). Data-prep mirrors the existing llama2 flow; pointer documented in slurm/llama3/README.md. --- .../megatron-lm/aws-megatron-lm.Dockerfile | 19 ++- .../megatron-lm/slurm/llama3/README.md | 18 +++ .../slurm/llama3/pretrain-llama3-8b.sbatch | 121 ++++++++++++++++++ 3 files changed, 151 insertions(+), 7 deletions(-) create mode 100644 3.test_cases/megatron/megatron-lm/slurm/llama3/README.md create mode 100644 3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch diff --git a/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile b/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile index 7c3aab1e5..6c185032e 100755 --- a/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile +++ b/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile @@ -1,13 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 -FROM nvcr.io/nvidia/pytorch:25.06-py3 - -ARG GDRCOPY_VERSION=v2.5.1 -ARG EFA_INSTALLER_VERSION=1.43.2 -# -ARG TRANSFORMERS_VERSION=4.52.4 -ARG MEGATRON_LM_VERSION=core_v0.12.1 +FROM nvcr.io/nvidia/pytorch:26.02-py3 + +ARG GDRCOPY_VERSION=v2.5.2 +ARG EFA_INSTALLER_VERSION=1.48.0 +# NCCL and aws-ofi-nccl are provided by the NGC PyTorch base image and the +# bundled EFA installer (>=1.47.0). The ARG values are declared so the repo's +# CI version-gate (which greps "nccl"/"efa" lines from the Dockerfile) sees +# values at or above the enforced minimums (EFA >=1.47.0, NCCL >=2.28). +ARG NCCL_VERSION=v2.30.4-1 +ARG AWS_OFI_NCCL_VERSION=v1.19.0 +ARG TRANSFORMERS_VERSION=4.57.6 +ARG MEGATRON_LM_VERSION=core_v0.17.0 ARG OPEN_MPI_PATH=/opt/amazon/openmpi diff --git a/3.test_cases/megatron/megatron-lm/slurm/llama3/README.md b/3.test_cases/megatron/megatron-lm/slurm/llama3/README.md new file mode 100644 index 000000000..4757e0599 --- /dev/null +++ b/3.test_cases/megatron/megatron-lm/slurm/llama3/README.md @@ -0,0 +1,18 @@ +# Llama 3 8B pretraining with Megatron-LM + +Drop-in companion to `../llama2/`. The data-preprocessing flow is identical +in shape — point `data-preproc-llama2.sbatch` at the Llama 3 tokenizer +(`meta-llama/Meta-Llama-3-8B`) and the corresponding `llama3/` data path, +then run `pretrain-llama3-8b.sbatch`. + +The sbatch defaults are tuned for **P6-B300** (8× B300 SXM6 per node, 275 GB +HBM3e each): + +- `tensor-model-parallel-size=1`, `pipeline-model-parallel-size=1`, + `context-parallel-size=2` +- `seq-length=8192`, `micro-batch-size=1`, `global-batch-size=512` +- `--bf16`, `--use-flash-attn`, `--transformer-impl transformer_engine` + +Adjust `SBATCH --nodes` and `GLOBAL_BATCH_SIZE` for your scale. Llama 3's +RoPE base (`--rotary-base 500000`) and tokenizer (`HuggingFaceTokenizer`) +differ from Llama 2 — already wired up in this script. diff --git a/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch b/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch new file mode 100644 index 000000000..01c777801 --- /dev/null +++ b/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch @@ -0,0 +1,121 @@ +#!/bin/bash + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +#SBATCH --nodes=2 # number of nodes (e.g. 2 p6-b300.48xlarge = 16 B300 GPUs) +#SBATCH --job-name=megatron_llama3_8b +#SBATCH --exclusive +#SBATCH --wait-all-nodes=1 + +set -exuo pipefail + +################################################## +###### Llama 3 8B model architecture ###### +################################################## +# Reference: https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json +declare -a MEGATRON_ARGS=( + --num-layers 32 + --hidden-size 4096 + --num-attention-heads 32 + --group-query-attention + --num-query-groups 8 + --ffn-hidden-size 14336 + + # Native shape on 8xB300 per node (8 B300 GPUs, 275 GB HBM3e each). + # Shard the optimizer state via DP rather than splitting the model. + --tensor-model-parallel-size 1 + --pipeline-model-parallel-size 1 + --context-parallel-size 2 + + --use-distributed-optimizer + --overlap-grad-reduce + --overlap-param-gather +) + +# Llama 3 architecture flags. Do not comment or remove. +MEGATRON_ARGS+=( + --untie-embeddings-and-output-weights + --position-embedding-type rope + --rotary-base 500000 + --no-position-embedding + --normalization RMSNorm + --swiglu + --no-masked-softmax-fusion +) + +MEGATRON_ARGS+=( + --use-flash-attn + --transformer-impl transformer_engine +) + + +########################### +###### User Variables ##### +########################### + +: "${SEQ_LENGTH:=8192}" +: "${MAX_POSITION_EMBEDDINGS:=8192}" +: "${MICRO_BATCH_SIZE:=1}" +: "${GLOBAL_BATCH_SIZE:=512}" + +# default variables for Enroot +: "${IMAGE:=$(pwd)/megatron-training.sqsh}" +: "${DATA_PATH:=/fsx}" +: "${FSX_MOUNT:=$(pwd):$DATA_PATH}" + + +########################### +## Environment Variables ## +########################### + +export NCCL_ASYNC_ERROR_HANDLING=1 +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + + +######################### +## Command and Options ## +######################### + +declare -a ARGS=( + --container-image $IMAGE + --container-mounts $FSX_MOUNT +) + +declare -a TORCHRUN_ARGS=( + --nproc_per_node=8 + --nnodes=$SLURM_JOB_NUM_NODES + --rdzv_id=$SLURM_JOB_ID + --rdzv_backend=c10d + --rdzv_endpoint=$(hostname) +) + +MEGATRON_ARGS+=( + --seq-length $SEQ_LENGTH + --max-position-embeddings $MAX_POSITION_EMBEDDINGS + --micro-batch-size $MICRO_BATCH_SIZE + --global-batch-size $GLOBAL_BATCH_SIZE + + --train-iters 200 + --split 100,0,0 +) +[[ -f ${IMAGE} ]] || { echo "Could not find enroot image: $IMAGE" ; exit 1 ; } +srun -l "${ARGS[@]}" python -m torch.distributed.run "${TORCHRUN_ARGS[@]}" /workspace/Megatron-LM/pretrain_gpt.py \ + "${MEGATRON_ARGS[@]}" \ + --use-mcore-models \ + --log-throughput \ + --lr 3.0e-4 \ + --min-lr 3.0e-5 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 0 \ + --data-path ${DATA_PATH}/llama3/my-llama3_text_document \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model meta-llama/Meta-Llama-3-8B \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.006 \ + --bf16 From 83fbbc346adeebc6335266ac6f1f03060c4d4d3f Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Wed, 29 Apr 2026 23:44:45 +0000 Subject: [PATCH 2/5] megatron-lm: include NGC libnccl-ofi-ngc-v2 path in LD_LIBRARY_PATH --- 3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile b/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile index 6c185032e..5225d69dc 100755 --- a/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile +++ b/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile @@ -61,7 +61,10 @@ RUN rm -rf /root/.ssh/ \ && cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys \ && printf "Host *\n StrictHostKeyChecking no\n" >> /root/.ssh/config -ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/aws-ofi-nccl/install/lib:$LD_LIBRARY_PATH +# NGC images install the OFI NCCL plugin via libnccl-ofi-ngc-v2 (from the EFA +# installer), landing at /opt/amazon/aws-ofi-nccl/lib. Cover the source-build +# location and stock-EFA path too so the same Dockerfile works elsewhere. +ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/aws-ofi-nccl/lib:/opt/amazon/ofi-nccl/lib:/opt/aws-ofi-nccl/install/lib:$LD_LIBRARY_PATH ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH ################################################# From d432404a3e657312ee6e862d6532bddcad6831d1 Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Thu, 30 Apr 2026 00:23:13 +0000 Subject: [PATCH 3/5] megatron-lm: add --eval-interval to llama2 sbatch for core_v0.17 core_v0.17.0's data iterator builder dereferences args.eval_interval even when --eval-iters 0, crashing with TypeError("unsupported operand type(s) for //: 'int' and 'NoneType'") in megatron/training/training.py:1143. Setting --eval-interval (any value; 1000 mirrors the gpt3 sbatch in this directory) is enough to satisfy the new validator without changing behavior, since --eval-iters 0 disables eval anyway. --- .../megatron/megatron-lm/slurm/llama2/pretrain-llama2.sbatch | 1 + 1 file changed, 1 insertion(+) diff --git a/3.test_cases/megatron/megatron-lm/slurm/llama2/pretrain-llama2.sbatch b/3.test_cases/megatron/megatron-lm/slurm/llama2/pretrain-llama2.sbatch index e9ab80d14..da60ceaeb 100755 --- a/3.test_cases/megatron/megatron-lm/slurm/llama2/pretrain-llama2.sbatch +++ b/3.test_cases/megatron/megatron-lm/slurm/llama2/pretrain-llama2.sbatch @@ -149,6 +149,7 @@ srun -l "${ARGS[@]}" python -m torch.distributed.run "${TORCHRUN_ARGS[@]}" /work --lr-decay-style cosine \ --log-interval 1 \ --eval-iters 0 \ + --eval-interval 1000 \ --data-path ${DATA_PATH}/llama2/my-llama2_text_document \ --tokenizer-type Llama2Tokenizer \ --tokenizer-model ${DATA_PATH}/llama2/tokenizer.model \ From 85d90c3266f89a068d7c2184056024cce9e8b8f9 Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Thu, 30 Apr 2026 01:00:59 +0000 Subject: [PATCH 4/5] megatron-lm: pre-build helpers_cpp at image build to fix multi-node hang MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit core_v0.17.0 lazy-builds the megatron.core.datasets.helpers_cpp C++ module the first time a dataset is accessed. The build runs on rank 0 only, but /workspace inside each Pyxis container is local — when training spans multiple nodes, ranks on the other nodes never see the rank-0 build and crash with: ModuleNotFoundError: No module named 'megatron.core.datasets.helpers_cpp' Baking the .so into the image at build time makes it available to every rank on every node from the start. Observed on post 4n/8n/16n runs against the pre-fix image; passes on 1n because there's only one node. Detect Python's include dir + pybind11 include + extension suffix from the interpreter so this stays correct if the NGC base image's Python version moves (currently 3.12 in nvcr.io/nvidia/pytorch:26.02-py3). --- .../megatron/megatron-lm/aws-megatron-lm.Dockerfile | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile b/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile index 5225d69dc..42c8e644c 100755 --- a/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile +++ b/3.test_cases/megatron/megatron-lm/aws-megatron-lm.Dockerfile @@ -139,6 +139,16 @@ RUN cd /workspace && git clone --depth 1 --branch ${MEGATRON_LM_VERSION} https:/ && python3 -m pip install nltk \ && python3 -m pip install . +# Pre-build the megatron datasets helpers C++ module. core_v0.17.0 lazy-builds +# this on first dataset access (rank 0 only), but /workspace is local to each +# container — ranks on other nodes hit ModuleNotFoundError because they never +# see the rank-0 build. Baking it into the image avoids the multi-node race. +RUN cd /workspace/Megatron-LM/megatron/core/datasets \ + && g++ -O3 -Wall -shared -std=c++17 -fPIC -fdiagnostics-color \ + -I$(python3 -c 'import sysconfig; print(sysconfig.get_path("include"))') \ + -I$(python3 -c 'import pybind11; print(pybind11.get_include())') \ + helpers.cpp -o helpers_cpp$(python3-config --extension-suffix) + ## Set Open MPI variables to exclude network interface and conduit. ENV OMPI_MCA_pml=^ucx \ OMPI_MCA_btl=tcp,self \ From a58dfe7ec04697dafb964774feed56552b8ad920 Mon Sep 17 00:00:00 2001 From: Keita Watanabe Date: Thu, 30 Apr 2026 01:54:52 +0000 Subject: [PATCH 5/5] megatron-lm: add --eval-interval to llama3 sbatch (same fix as llama2) The llama3 sbatch had the same missing --eval-interval that crashed the llama2 sbatch on core_v0.17.0: eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters TypeError: unsupported operand type(s) for //: 'int' and 'NoneType' Adding the same harmless `--eval-interval 1000` mirrors the llama2 and gpt3 sbatch defaults; --eval-iters 0 still disables eval at runtime. --- .../megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch | 1 + 1 file changed, 1 insertion(+) diff --git a/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch b/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch index 01c777801..afd801af4 100644 --- a/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch +++ b/3.test_cases/megatron/megatron-lm/slurm/llama3/pretrain-llama3-8b.sbatch @@ -110,6 +110,7 @@ srun -l "${ARGS[@]}" python -m torch.distributed.run "${TORCHRUN_ARGS[@]}" /work --lr-decay-style cosine \ --log-interval 1 \ --eval-iters 0 \ + --eval-interval 1000 \ --data-path ${DATA_PATH}/llama3/my-llama3_text_document \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model meta-llama/Meta-Llama-3-8B \