diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 index 016107b92c..bd34fd8ced 100644 --- a/docker/Dockerfile.rocm_MI350-5 +++ b/docker/Dockerfile.rocm_MI350-5 @@ -1,169 +1,156 @@ -#### Use the base image for ROCm 7 / gfx950 (MI355) - -# ===================================================================== -# Docker Image Version Information (Updated: Feb 5, 2026) -# ===================================================================== -# Base image: ROCm 7 with vllm pre-built for gfx950 -# Target GPU: MI355 (gfx950) -# -# Key Dependencies: -# - sglang: sglang-miles branch -# - sgl_kernel: built from selected sglang commit -# - Megatron-LM: radixark/Megatron-LM -# - TransformerEngine: commit 90c04bcdc3c109505b318f40a39680263af55edf -# - aiter: v0.1.10.post3 -# - Ray: 2.47.1 -# -# Patches: amd_patch/sglv0.5.7/ -# - megatron.patch -# - sglang.patch -# ===================================================================== - - -FROM rocm/sgl-dev:rocm7-vllm-20250904 +# 1. rlsys/miles:MI350-355-latest +# build-arg:SGLANG_IMAGE_TAG=v0.5.10-rocm720-mi35x + +ARG SGLANG_IMAGE_TAG=v0.5.10-rocm720-mi35x +FROM lmsysorg/sglang:${SGLANG_IMAGE_TAG} AS sglang SHELL ["/bin/bash", "-ceuxo", "pipefail"] -ARG MAX_JOBS=128 -ARG SGLANG_REPO=sgl-project/sglang +# ======================================== Arguments ============================================= + ARG SGLANG_BRANCH=sglang-miles ARG SGLANG_COMMIT="" + ARG MEGATRON_REPO=radixark/Megatron-LM ARG MEGATRON_BRANCH=miles-main -ARG MEGATRON_COMMIT="" -ENV MAX_JOBS=${MAX_JOBS} -# Set environment variables for gfx950 -ENV GPU_ARCH=gfx950 -ENV PYTORCH_ROCM_ARCH=gfx950 -ENV GPU_ARCH_LIST=gfx950 -ENV AMDGPU_TARGET=gfx950 +ARG MILES_COMMIT=main +ARG GPU_ARCH=gfx950 +ARG MAX_JOBS=128 -########################################### -##############1. Install AITER############# -########################################### -WORKDIR /app +ARG AITER_REPO=https://github.com/ROCm/aiter.git +ARG AITER_COMMIT=v0.1.11.post1 -RUN pip uninstall -y aiter || true -RUN rm -rf aiter -RUN git clone https://github.com/ROCm/aiter.git \ - && cd aiter \ - && git checkout v0.1.10.post3 \ - && curl -fsSL https://patch-diff.githubusercontent.com/raw/ROCm/aiter/pull/2075.patch -o /tmp/aiter-pr2075.patch \ - && git apply --3way /tmp/aiter-pr2075.patch \ - && rm -f /tmp/aiter-pr2075.patch \ - && git submodule sync --recursive \ - && git submodule update --init --recursive \ - && GPU_ARCHS=gfx950 python setup.py develop -########################################### -########################################### -########################################### - - -########################################### -####2. Install TransformerEngine for gfx950 -########################################### -WORKDIR /app - -RUN rm -rf TransformerEngine -RUN git clone https://github.com/ROCm/TransformerEngine.git \ - && cd TransformerEngine \ - && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ - && git submodule update --init --recursive +ARG RCCL_TESTS_REPO=https://github.com/ROCm/rocm-systems.git +ARG RCCL_TESTS_BRANCH=develop +ARG RCCL_TESTS_PATH=projects/rccl-tests + +ARG TRANSFORMER_ENGINE_REPO=https://github.com/ROCm/TransformerEngine.git +ARG TRANSFORMER_ENGINE_BRANCH=v2.8_rocm + +# ======================================== Setup ============================================= +WORKDIR /root/ + +ENV MAX_JOBS=${MAX_JOBS} + +# Build configuration for MI350 / gfx950. +ENV GPU_ARCH=${GPU_ARCH} +ENV PYTORCH_ROCM_ARCH=${GPU_ARCH} +ENV GPU_ARCH_LIST=${GPU_ARCH} +ENV AMDGPU_TARGET=${GPU_ARCH} + +# Transformer Engine build knobs for the v2.8_rocm branch. ENV NVTE_FRAMEWORK=pytorch -ENV NVTE_ROCM_ARCH=gfx950 +ENV NVTE_ROCM_ARCH=${GPU_ARCH} ENV NVTE_USE_HIPBLASLT=1 ENV NVTE_USE_ROCM=1 -ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" - -RUN cd TransformerEngine && pip install . -v -########################################### -########################################### -########################################### - - -######################################### -####3. Install Megatron-LM -######################################### -WORKDIR /app - -RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall - -RUN pip uninstall -y megatron-core || true -RUN rm -rf Megatron-LM -RUN git clone https://github.com/${MEGATRON_REPO}.git \ - && cd Megatron-LM \ - && git fetch origin ${MEGATRON_BRANCH} \ - && if [ -n "${MEGATRON_COMMIT}" ]; then \ - git checkout ${MEGATRON_COMMIT}; \ - else \ - git checkout FETCH_HEAD; \ - fi \ - && pip install -e . -######################################### -######################################### -######################################### - - -######################################## -############ 4. Install mbridge######### -######################################## -RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps -######################################## -######################################## -######################################## - - -######################################## -######5. Install Ray#################### -######################################## -RUN pip uninstall ray -y || true -RUN pip install "ray[data,train,tune,serve]==2.47.1" -######################################## -######################################## -######################################## - - -######################################### -###6. Install torch_memory_saver######### -######################################### -RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@64a92e1d7fb822ea4af5579c8cebb162692c531c --no-cache-dir --force-reinstall -######################################### -######################################### - - -####################################### -####7. Install Apex for ROCm########### -####################################### -WORKDIR /app - -RUN pip uninstall -y apex || true -RUN rm -rf apex -RUN git clone https://github.com/ROCm/apex.git \ - && cd apex \ - && python setup.py install -####################################### -####################################### -####################################### - - -######################################## -###8. Install miles agent framework deps -######################################## -RUN pip install pydra_config==0.0.15 -RUN pip install together -RUN pip install google-generativeai -RUN pip install tensorboard -######################################## -######################################## -######################################## - - -######################################## -###9. Set performance environment vars## -######################################## +# Keep the core package enabled and skip the extra fused-attn kernel matrix rebuild. +ENV NVTE_FUSED_ATTN=0 +ENV CMAKE_PREFIX_PATH=/opt/rocm:/opt/rocm/hip:/usr/local:/usr + +# Patch Megatron's fused-kernel init for this toolchain. +COPY docker/amd_patch/latest/megatron.patch /tmp/amd_patch/megatron.patch +COPY requirements.txt /tmp/requirements.txt + +# ======================================== Apt dependencies ============================================= + +RUN apt update +# Install build tools and diagnostics utilities. +RUN apt install -y build-essential cmake dnsutils ethtool git nvtop rsync + +# Build rccl-tests diagnostics binaries. +RUN git clone --depth 1 --branch ${RCCL_TESTS_BRANCH} ${RCCL_TESTS_REPO} /tmp/rocm-systems && \ + make -C /tmp/rocm-systems/${RCCL_TESTS_PATH} -j$(nproc) \ + HIP_HOME=/opt/rocm \ + NCCL_HOME=/opt/rocm \ + GPU_TARGETS=${GPU_ARCH} && \ + cp /tmp/rocm-systems/${RCCL_TESTS_PATH}/build/*_perf /usr/local/bin/ && \ + rm -rf /tmp/rocm-systems + +# ====================================== Python dependencies ============================================ + +# Rebuild AITER at the version paired with SGLang. +RUN pip uninstall -y aiter || true +RUN pip install flydsl==0.0.1.dev95158637 psutil pybind11 +RUN cd /sgl-workspace/aiter && \ + git remote set-url origin ${AITER_REPO} && \ + git checkout ${AITER_COMMIT} && \ + git reset --hard ${AITER_COMMIT} && \ + git clean -fdx && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + # Temporary fixes for the current ROCm 7.2 image/toolchain combination. + sed -i '459 s/if.*:/if False:/' aiter/ops/triton/attention/pa_mqa_logits.py && \ + sed -i '/c1 = torch.empty((M, D, S1 + S3), dtype=dtype, device=x.device)/i\ config = dict(config)' \ + aiter/ops/triton/gemm/fused/fused_gemm_afp4wfp4_split_cat.py && \ + GPU_ARCHS=${GPU_ARCH} pip install -e . + +# Install Transformer Engine from the requested branch. +RUN pip uninstall -y transformer-engine transformer_engine transformer_engine_torch || true +RUN rm -rf /root/TransformerEngine && \ + git clone --recursive --branch ${TRANSFORMER_ENGINE_BRANCH} ${TRANSFORMER_ENGINE_REPO} /root/TransformerEngine && \ + cd /root/TransformerEngine && \ + pip install . --no-build-isolation -v + +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +RUN GPU_ARCHS=${GPU_ARCH} BUILD_TARGET=rocm MAX_JOBS=${MAX_JOBS} \ + pip -v install flash-attn==2.8.3 --no-build-isolation + +RUN pip install flash-linear-attention==0.4.2 + +RUN rm -rf /root/Megatron-LM && \ + git clone --recursive -b ${MEGATRON_BRANCH} https://github.com/${MEGATRON_REPO}.git /root/Megatron-LM && \ + cd /root/Megatron-LM && \ + git apply /tmp/amd_patch/megatron.patch && \ + pip install -e . + +RUN pip uninstall -y sgl_kernel sglang || true +RUN cd /sgl-workspace/sglang && \ + git reset --hard && \ + git clean -fdx && \ + git fetch origin ${SGLANG_BRANCH} && \ + if [ -n "${SGLANG_COMMIT}" ]; then \ + git checkout ${SGLANG_COMMIT}; \ + else \ + git checkout FETCH_HEAD; \ + fi && \ + git submodule sync --recursive && \ + git submodule update --init --recursive && \ + cd sgl-kernel && \ + rm -f pyproject.toml && \ + mv pyproject_rocm.toml pyproject.toml && \ + AMDGPU_TARGET=${GPU_ARCH} python setup_rocm.py install && \ + cd .. && \ + rm -rf python/pyproject.toml && \ + mv python/pyproject_other.toml python/pyproject.toml && \ + pip install -e "python[all_hip]" --no-deps + +RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" + +RUN pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@d64a639 --no-cache-dir --force-reinstall +RUN pip install git+https://github.com/yushengsu-thu/Megatron-Bridge.git@merged-megatron-0.16.0rc0-miles --no-deps --no-build-isolation +RUN pip install megatron-energon --no-deps +RUN pip install multi-storage-client --no-deps + +RUN rm -rf /usr/lib/python3/dist-packages/jwt /usr/lib/python3/dist-packages/PyJWT* && \ + pip install -r /tmp/requirements.txt + +# Pin numpy 1.x for Megatron compatibility. +RUN pip install "numpy<2" + +# ====================================== Install main package ============================================ + +RUN git clone https://github.com/radixark/miles.git /root/miles && \ + cd /root/miles && \ + git checkout ${MILES_COMMIT} && \ + pip install -e . --no-deps + +# ====================================== Runtime knobs ============================================ + +# Runtime knobs consumed by the current SGLang/PyTorch stack. ENV HIP_FORCE_DEV_KERNARG=1 ENV HSA_NO_SCRATCH_RECLAIM=1 ENV SGLANG_USE_AITER=1 @@ -173,114 +160,11 @@ ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 ENV SGLANG_USE_ROCM700A=1 ENV NCCL_MIN_NCHANNELS=112 -ENV VLLM_FP8_PADDING=1 -ENV VLLM_FP8_ACT_PADDING=1 -ENV VLLM_FP8_WEIGHT_PADDING=1 -ENV VLLM_FP8_REDUCE_CONV=1 ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 -######################################## -######################################## -######################################## - -########################################### -##############Install SGLang############### -########################################### -WORKDIR /app - -# Install prerequisites -RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 - -# Clone SGLang -RUN pip uninstall -y sgl_kernel sglang || true -RUN rm -rf sglang -RUN git clone https://github.com/${SGLANG_REPO}.git \ - && cd sglang \ - && git fetch origin ${SGLANG_BRANCH} \ - && if [ -n "${SGLANG_COMMIT}" ]; then \ - git checkout ${SGLANG_COMMIT}; \ - else \ - git checkout FETCH_HEAD; \ - fi - -# Build sgl-kernel for gfx950 -RUN cd sglang/sgl-kernel \ - && rm -f pyproject.toml \ - && mv pyproject_rocm.toml pyproject.toml \ - && AMDGPU_TARGET=gfx950 python setup_rocm.py install - -# Install SGLang -RUN cd sglang \ - && rm -rf python/pyproject.toml \ - && mv python/pyproject_other.toml python/pyproject.toml \ - && pip install -e "python[all_hip]" - -# Test SGLang installation -RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" +RUN rm -rf /root/.cache/pip /root/TransformerEngine /tmp/amd_patch -RUN python -m pip cache purge -########################################### -########################################### -########################################### - - -########################################### -#### APPLY PATCHES (gfx950/MI355) ######### -########################################### - -# Copy patch from miles repo -COPY amd_patch/sglv0.5.7/megatron.patch /app/patch/megatron.patch -COPY amd_patch/sglv0.5.7/sglang.patch /app/patch/sglang.patch - -# Apply Megatron patches -RUN cd /app/Megatron-LM \ - && git apply --3way /app/patch/megatron.patch \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "Patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi \ - && pip install -e . -v - -# Apply SGLang patch -RUN cd /app/sglang \ - && git apply --3way /app/patch/sglang.patch \ - && if grep -R -n '^<<<<<<< ' .; then \ - echo "SGLang patch failed to apply cleanly. Please resolve conflicts." && \ - exit 1; \ - fi - -# Copy MOE configs for gfx950/MI355 -RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ - /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ - -type f -name '*MI300X*' 2>/dev/null | while read f; do \ - cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ - cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ -done - -########################################### -########################################### -########################################### - - -######################################## -#### Install additional packages######## -######################################## -RUN pip install sglang-router --force-reinstall -######################################## -######################################## -######################################## - - -######################################## -# Fix click/ray incompatibility with Python 3.10 -######################################## -RUN pip install click==8.2.1 -######################################## -######################################## -######################################## - - -WORKDIR /app +WORKDIR /root/ CMD ["/usr/bin/bash"] diff --git a/docker/amd_patch/latest/megatron.patch b/docker/amd_patch/latest/megatron.patch index f6efca346d..acd64149b7 100644 --- a/docker/amd_patch/latest/megatron.patch +++ b/docker/amd_patch/latest/megatron.patch @@ -1,5 +1,4 @@ diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 --- a/megatron/legacy/fused_kernels/__init__.py +++ b/megatron/legacy/fused_kernels/__init__.py @@ -3,6 +3,7 @@ @@ -10,42 +9,12 @@ index 87cceac3..ac686d74 100644 from torch.utils import cpp_extension -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" +@@ -15,6 +16,8 @@ def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') ++ if not torch.version.cuda: ++ return -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] diff --git a/docker/amd_patch/latest/sglang.patch b/docker/amd_patch/latest/sglang.patch deleted file mode 100644 index b103263070..0000000000 --- a/docker/amd_patch/latest/sglang.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -index 6e7ea07e7..73b512f51 100644 ---- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -@@ -64,6 +64,7 @@ class CustomAllreduce: - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=_MAX_CAR_SIZE, -+ enable_register_for_capturing: bool = True, - ) -> None: - """ - Args: -@@ -410,6 +411,8 @@ class CustomAllreduce: - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - if _is_hip: -+ if self.tms_cudagraph: -+ return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) -diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index c3ca1e4f3..2bb763b6a 100644 ---- a/python/sglang/srt/distributed/parallel_state.py -+++ b/python/sglang/srt/distributed/parallel_state.py -@@ -351,10 +351,12 @@ class GroupCoordinator: - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - try: -+ tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( - group=self.cpu_group, - device=self.device, -+ enable_register_for_capturing=not tms_cudagraph, - ) - except Exception as e: - logger.warning( diff --git a/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch b/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch deleted file mode 100644 index f6efca346d..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/amd_megatron_fused_kernels_init.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 ---- a/megatron/legacy/fused_kernels/__init__.py -+++ b/megatron/legacy/fused_kernels/__init__.py -@@ -3,6 +3,7 @@ - import os - import pathlib - import subprocess -+import torch - - from torch.utils import cpp_extension - -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - - def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') - -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.0rc0/megatron.patch b/docker/amd_patch/sglv0.5.0rc0/megatron.patch deleted file mode 100644 index b129959aff..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/megatron.patch +++ /dev/null @@ -1,792 +0,0 @@ -diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py -index 41c21d93d..ef80f72d6 100644 ---- a/megatron/core/dist_checkpointing/strategies/common.py -+++ b/megatron/core/dist_checkpointing/strategies/common.py -@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): - msc = MultiStorageClientFeature.import_package() - return msc.torch.load(load_path, map_location='cpu') - else: -- return torch.load(load_path, map_location='cpu') -+ return torch.load(load_path, map_location='cpu', weights_only=False) - except FileNotFoundError as e: - err_msg = f'Common file {load_path} does not exist' - if MultiStorageClientFeature.is_enabled(): -diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py -index 5a1ea308d..aa701237f 100644 ---- a/megatron/core/dist_checkpointing/strategies/torch.py -+++ b/megatron/core/dist_checkpointing/strategies/torch.py -@@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - def _validate_global_shapes(self, metadata, sharded_tensors): - for sh_ten in sharded_tensors: - if sh_ten.key not in metadata.state_dict_metadata: -- raise KeyError( -- f"{sh_ten.key} from model not in state dict:" -- f" {sorted(metadata.state_dict_metadata.keys())}" -- ) -+ # raise KeyError( -+ # f"{sh_ten.key} from model not in state dict:" -+ # f" {sorted(metadata.state_dict_metadata.keys())}" -+ # ) -+ print(f"{sh_ten.key} from model not in state dict, will skip") -+ continue - loaded_shape = metadata.state_dict_metadata[sh_ten.key].size - expected_shape = self._expected_shape(sh_ten) - if loaded_shape != expected_shape: -@@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): - tensor_metadata = self.metadata.state_dict_metadata - metadata_with_sizes = [ - (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) -- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() -+ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata - ] - try: - # Temporarily set sizes to expected shapes -@@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): - planner=MCoreLoadPlanner( - shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, - allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, -+ allow_partial_load=True, - ), - ) - -diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py -index fe26e8b43..4451f2776 100644 ---- a/megatron/core/distributed/__init__.py -+++ b/megatron/core/distributed/__init__.py -@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads - from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel - from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel - from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig -+ -+# Backward compatibility patch for FSDP module reorganization -+import sys -+import importlib.util -+ -+spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') -+if spec: -+ custom_fsdp = importlib.util.module_from_spec(spec) -+ spec.loader.exec_module(custom_fsdp) -+ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp -+ if hasattr(custom_fsdp, 'MegatronFSDP'): -+ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP -diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py -index acb93ef78..d239db4ab 100644 ---- a/megatron/core/extensions/transformer_engine.py -+++ b/megatron/core/extensions/transformer_engine.py -@@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): - ) - - for param in self.parameters(): -+ setattr(param, "parallel_mode", parallel_mode) - if is_expert: - # Reduce the gradient on the expert_data_parallel group for expert linear layers - setattr(param, "allreduce", not self.expert_parallel) -@@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): - - - if HAVE_TE and is_te_min_version("1.9.0.dev0"): -+ def ceil_div(x: int, y: int) -> int: -+ return (x + y - 1) // y -+ -+ class _FakeInt4QuantizationSTE(torch.autograd.Function): -+ @staticmethod -+ def forward(ctx, x, group_size): -+ m, n = x.shape -+ block_size_m, block_size_n = 1, group_size -+ -+ -+ m_padded = ceil_div(m, block_size_m) * block_size_m -+ n_padded = ceil_div(n, block_size_n) * block_size_n -+ -+ x_padded = torch.zeros( -+ (m_padded, n_padded), -+ dtype=x.dtype, device=x.device -+ ) -+ x_padded[:m, :n] = x -+ -+ x_view = x_padded.view( -+ m_padded // block_size_m, -+ block_size_m, -+ n_padded // block_size_n, -+ block_size_n -+ ) -+ -+ x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) -+ q_max = 7 -+ x_scale = x_max / q_max -+ -+ x_scale = x_scale.clamp(min=1e-5) -+ -+ x_div = x_view / x_scale -+ x_round = torch.round(x_div) -+ -+ x_q_clamped = x_round.clamp(-q_max, q_max) -+ -+ x_dequant_view = x_q_clamped * x_scale -+ -+ x_dequant_full = x_dequant_view.view_as(x_padded) -+ x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) -+ -+ return x_out -+ -+ @staticmethod -+ def backward(ctx, grad_output): -+ return grad_output, None -+ -+ def fake_int4_quantization_ste(x, group_size): -+ x_out = _FakeInt4QuantizationSTE.apply(x, group_size) -+ -+ if hasattr(x, 'main_grad'): -+ x_out.main_grad = x.main_grad -+ -+ return x_out - - class TEGroupedLinear(te.pytorch.GroupedLinear): - """ -@@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - _is_first_microbatch = ( - None if self.disable_parameter_transpose_cache else self.is_first_microbatch - ) -+ - out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) - self.is_first_microbatch = False - -@@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): - return out - return out, None - -+ def _get_weight_tensors(self): -+ """Get the weight tensors of the module.""" -+ weight_tensors = super()._get_weight_tensors() -+ -+ if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": -+ group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) -+ -+ weight_tensors = [ -+ fake_int4_quantization_ste(w, group_size) -+ for w in weight_tensors -+ ] -+ -+ return weight_tensors -+ - def _encode_extra_state(self, state): - # TE 2.0 changed the format of extra_state to be a byte tensor - if is_te_min_version("2.0.0"): -diff --git a/megatron/core/fusions/fused_mla_yarn_rope_apply.py b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -index 1fd5dcfae..c9aeef1f0 100644 ---- a/megatron/core/fusions/fused_mla_yarn_rope_apply.py -+++ b/megatron/core/fusions/fused_mla_yarn_rope_apply.py -@@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( - cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) - -- KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads -- kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads -- mask = kv_off < head_num * stride_kv_nheads -- k_in_off = kv_off + tl.arange(0, k_dim)[None, :] -- v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] -- k = tl.load(KV_ptr + k_in_off, mask=mask) -- v = tl.load(KV_ptr + v_in_off, mask=mask) -+ KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ k_off = ki_range * stride_kv_nheads + kj_range -+ if v_dim > 0: -+ v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ v = tl.load(KV_ptr + v_off, mask=mask_v) -+ else: -+ v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) -+ k = tl.load(KV_ptr + k_off, mask=mask_k) - -- K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads -- V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads -+ K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads -+ V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads - -- k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] -- v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] -- tl.store(K_ptr + k_out_off, k, mask=mask) -- tl.store(V_ptr + v_out_off, v, mask=mask) -+ k_out_off = ki_range * stride_k_nheads + kj_range -+ tl.store(K_ptr + k_out_off, k, mask=mask_k) -+ if v_dim > 0: -+ v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] -+ tl.store(V_ptr + v_out_off, v, mask=mask_v) - - EMB = K_POS_EMB + pid_m * stride_emb_seq - # x1 = t[..., 0::2], x2 = t[..., 1::2] -@@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( - x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) - -+ x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ mask_x = x_range < head_num - x_left_off = ( -- tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads -+ x_range * stride_k_nheads - + k_dim - + tl.arange(0, emb_dim // 2)[None, :] - ) - x_right_off = x_left_off + emb_dim // 2 -- tl.store(K_ptr + x_left_off, x_left, mask=mask) -- tl.store(K_ptr + x_right_off, x_right, mask=mask) -+ tl.store(K_ptr + x_left_off, x_left, mask=mask_x) -+ tl.store(K_ptr + x_right_off, x_right, mask=mask_x) - - - @triton.autotune( -@@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( - SIN, - emb_dim: tl.constexpr, - k_dim: tl.constexpr, -+ k_dim_ceil: tl.constexpr, - v_dim: tl.constexpr, - head_num: tl.constexpr, - batch_size, -@@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( - else: - token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) - -- dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads -- dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads -- mask = dkv_off < head_num * stride_dkv_nheads -- dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] -- dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] -- -- dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads -- dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads -- dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] -- dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -- dk = tl.load(dK_ptr + dk_in_off, mask=mask) -- dv = tl.load(dV_ptr + dv_in_off, mask=mask) -- tl.store(dKV_ptr + dk_out_off, dk, mask=mask) -- tl.store(dKV_ptr + dv_out_off, dv, mask=mask) -+ dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads -+ ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H -+ kj_range = tl.arange(0, k_dim_ceil)[None, :] -+ mask_k = (ki_range < head_num) & (kj_range < k_dim) -+ mask_v = ki_range < head_num -+ dk_out_off = ki_range * stride_dkv_nheads + kj_range -+ -+ dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads -+ dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads -+ dk_in_off = ki_range * stride_dk_nheads + kj_range -+ -+ dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) -+ tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) -+ -+ if v_dim > 0: -+ dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] -+ dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] -+ dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) -+ tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) - - if pid_head == 0: - x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) - for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): -- dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads -- x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim -+ dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads -+ x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads - mask = x_off < head_num * stride_dk_nheads - x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] - x_right_off = x_left_off + emb_dim // 2 -@@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) - o_value = kv.new_empty(total_seqlen, nheads, v_dim) -+ k_dim_ceil = triton.next_power_of_2(k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_fwd_kv_kernel[grid]( -@@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - emb_dim, - k_dim, -+ k_dim_ceil, - v_dim, - nheads, - batch_size, -@@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - - d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) - d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) -+ k_dim_ceil = triton.next_power_of_2(ctx.k_dim) - - grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) - rotary_bwd_kv_kernel[grid]( -@@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): - sin, - ctx.emb_dim, - ctx.k_dim, -+ k_dim_ceil, - ctx.v_dim, - nheads, - batch_size, -diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py -index 13d74aa52..060898a7a 100644 ---- a/megatron/core/models/common/language_module/language_module.py -+++ b/megatron/core/models/common/language_module/language_module.py -@@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): - assert ( - column_parallel_linear is not None - ), "column_parallel_linear cannot be None when not using fused linear cross entropy." -- logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) -+ # output -+ output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} -+ output_layer_buffers = dict(column_parallel_linear.named_buffers()) -+ logits, _ = torch.func.functional_call( -+ column_parallel_linear, -+ {**output_layer_params, **output_layer_buffers}, -+ (hidden,), -+ col_linear_kwargs, -+ ) - - return self.compute_language_model_loss(labels, logits) - -diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py -index e21127b87..712793853 100755 ---- a/megatron/core/models/gpt/gpt_layer_specs.py -+++ b/megatron/core/models/gpt/gpt_layer_specs.py -@@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( - use_kitchen: bool = False, - use_te_activation_func: bool = False, - fallback_to_eager_attn: bool = False, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). - -@@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( - mlp=mlp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - normalization=normalization, -+ post_self_attn_layernorm=post_self_attn_layernorm, -+ post_mlp_layernorm=post_mlp_layernorm, - ) - - -@@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( - mlp: ModuleSpec, - sharded_state_dict_keys_map: Optional[dict] = None, - normalization: Optional[str] = None, -+ post_self_attn_layernorm: bool = False, -+ post_mlp_layernorm: bool = False, - ) -> ModuleSpec: - """Helper function to get module spec for TransformerLayer""" - -@@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( - input_layernorm=input_layernorm, - self_attention=attention, - self_attn_bda=get_bias_dropout_add, -+ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, - pre_mlp_layernorm=pre_mlp_layernorm, - mlp=mlp, - mlp_bda=get_bias_dropout_add, -+ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, - sharded_state_dict_keys_map=sharded_state_dict_keys_map, - ), - ) -diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index a1230568c..1fd52f65a 100644 ---- a/megatron/core/models/gpt/gpt_model.py -+++ b/megatron/core/models/gpt/gpt_model.py -@@ -446,6 +446,7 @@ class GPTModel(LanguageModule): - *, - inference_params: Optional[BaseInferenceContext] = None, - loss_mask: Optional[Tensor] = None, -+ mtp_kwargs: Optional[dict] = {}, - ) -> Tensor: - """Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoder and finally into the post -@@ -508,6 +509,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, -+ mtp_kwargs=mtp_kwargs, - ) - - def _postprocess( -@@ -529,6 +531,7 @@ class GPTModel(LanguageModule): - runtime_gather_output=None, - extra_block_kwargs=None, - inference_context=None, -+ mtp_kwargs={}, - ): - """Postprocesses decoder hidden states to generate logits or compute loss. - -@@ -543,7 +546,8 @@ class GPTModel(LanguageModule): - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() -- if mtp_in_postprocess: -+ -+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, -@@ -563,13 +567,18 @@ class GPTModel(LanguageModule): - return hidden_states - - # Skip when mtp_num_layers is None or 0 -- if self.config.mtp_num_layers: -- mtp_labels = labels.clone() -+ if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: -+ mtp_labels = mtp_kwargs['mtp_labels'].clone() -+ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) -+ - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) - hidden_states = hidden_states_list[0] - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(mtp_labels) -+ else: -+ # Otherwise, roll the loss_mask to keep up with the mtp_labels -+ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) - for mtp_layer_number in range(self.config.mtp_num_layers): - # Calc loss for the current Multi-Token Prediction (MTP) layers. - mtp_labels, _ = roll_tensor( -@@ -595,7 +604,7 @@ class GPTModel(LanguageModule): - sequence_parallel_enabled=self.output_layer.sequence_parallel, - column_parallel_linear=self.output_layer, - col_linear_kwargs={ -- 'weight': output_weight, -+ 'weight': output_weight.detach() if output_weight else None, - 'runtime_gather_output': runtime_gather_output, - }, - ) -diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index 6e093f96f..eac21a3ea 100644 ---- a/megatron/core/optimizer/distrib_optimizer.py -+++ b/megatron/core/optimizer/distrib_optimizer.py -@@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - # TE FusedAdam will not accumulate step for empty param groups, so we need to - # align the step across param groups. - param_group["step"] = int(step) -+ if "step" in param_group and param_group["step"] is None: -+ del param_group["step"] - - # Grad scaler state. - if self.grad_scaler: -@@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): - if key == 'padding': - tensors[key] = LocalNonpersistentObject(tensors[key]) - continue -+ if key == 'step': -+ continue - assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( - tensors[key].shape, - gbuf_local_start, -diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py -index a273002b9..4f821cfd5 100644 ---- a/megatron/core/parallel_state.py -+++ b/megatron/core/parallel_state.py -@@ -11,6 +11,7 @@ from typing import Callable, List, Optional - - import numpy as np - import torch -+import torch.distributed as dist - - from .utils import GlobalMemoryBuffer, is_torch_min_version - -diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py -index ac839c21f..f18309217 100644 ---- a/megatron/core/pipeline_parallel/p2p_communication.py -+++ b/megatron/core/pipeline_parallel/p2p_communication.py -@@ -26,22 +26,22 @@ def _batched_p2p_ops( - ops = [] - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group -+ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( -- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group -+ torch.distributed.isend, tensor_send_next, next_pipeline_rank, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( -- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group -+ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, - ) - ops.append(recv_next_op) - if len(ops) > 0: -diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py -index 28cff06f5..48c9c1a25 100644 ---- a/megatron/core/transformer/moe/moe_utils.py -+++ b/megatron/core/transformer/moe/moe_utils.py -@@ -587,6 +587,9 @@ def topk_routing_with_score_function( - else: - return torch.topk(scores, k=topk, dim=1) - -+ from miles.utils.routing_replay import get_routing_replay_compute_topk -+ compute_topk = get_routing_replay_compute_topk(compute_topk) -+ - if score_function == "softmax": - if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) -diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py -index 16fc9d9af..3e95858a6 100644 ---- a/megatron/core/transformer/moe/router.py -+++ b/megatron/core/transformer/moe/router.py -@@ -201,6 +201,9 @@ class TopKRouter(Router): - self.global_tokens_per_expert = None - self.ga_steps = None - -+ from miles.utils.routing_replay import register_routing_replay -+ register_routing_replay(self) -+ - def _maintain_float32_expert_bias(self): - """ - Maintain the expert bias in float32. -diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py -index a8f4abfcd..f33f6f05e 100755 ---- a/megatron/core/transformer/multi_token_prediction.py -+++ b/megatron/core/transformer/multi_token_prediction.py -@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union - - import torch - from torch import Tensor -+import warnings - - from megatron.core import InferenceParams, parallel_state, tensor_parallel - from megatron.core.dist_checkpointing.mapping import ShardedStateDict -@@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): - cp_group=self.cp_group, - packed_seq_params=packed_seq_params, - ) -- position_ids, _ = roll_tensor( -- position_ids, -- shifts=-1, -- dims=-1, -- cp_group=self.cp_group, -- packed_seq_params=packed_seq_params, -- ) -+ if position_ids is not None: -+ position_ids, _ = roll_tensor( -+ position_ids, -+ shifts=-1, -+ dims=-1, -+ cp_group=self.cp_group, -+ packed_seq_params=packed_seq_params, -+ ) - # embedding - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) -+ decoder_input = decoder_input.detach() - -- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) -+ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) - - return input_ids, position_ids, decoder_input, hidden_states - -@@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): - return hidden_states - - def _checkpointed_forward(self, forward_func, *args, **kwargs): -+ """Wrap `forward_func` with activation checkpointing while only passing tensors. -+ -+ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so -+ that checkpoint implementations never receive them directly, avoiding save_for_backward -+ issues with non-tensor inputs. -+ """ -+ -+ # TODO(jiajun): Is there any better implementation here? -+ positional_specs = [] -+ kw_specs = [] -+ tensor_args: List[torch.Tensor] = [] -+ -+ for arg in args: -+ if torch.is_tensor(arg): -+ positional_specs.append(('tensor', len(tensor_args))) -+ tensor_args.append(arg) -+ else: -+ positional_specs.append(('const', arg)) -+ -+ for key, value in kwargs.items(): -+ if torch.is_tensor(value): -+ kw_specs.append((key, ('tensor', len(tensor_args)))) -+ tensor_args.append(value) -+ else: -+ kw_specs.append((key, ('const', value))) -+ -+ def run(*flat_tensor_args): -+ rebuilt_args = [] -+ for spec_type, payload in positional_specs: -+ if spec_type == 'tensor': -+ rebuilt_args.append(flat_tensor_args[payload]) -+ else: -+ rebuilt_args.append(payload) -+ -+ rebuilt_kwargs = {} -+ for key, (spec_type, payload) in kw_specs: -+ if spec_type == 'tensor': -+ rebuilt_kwargs[key] = flat_tensor_args[payload] -+ else: -+ rebuilt_kwargs[key] = payload -+ -+ return forward_func(*rebuilt_args, **rebuilt_kwargs) -+ -+ tensor_args_tuple = tuple(tensor_args) -+ - def checkpoint_handler(): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: -@@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), -- *args, -- **kwargs, -+ *tensor_args_tuple, - ) - else: - return tensor_parallel.checkpoint( -- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() -+ run, self.config.distribute_saved_activations, *tensor_args_tuple - ) - - if self.config.recompute_method == 'uniform': -diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index e2705bd9f..a0aa109b5 100644 ---- a/megatron/core/transformer/transformer_config.py -+++ b/megatron/core/transformer/transformer_config.py -@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): - attention_output_gate: bool = False - """Whether to apply output gate to the attention layers.""" - -+ post_self_attn_layernorm: bool = False -+ post_mlp_layernorm: bool = False -+ - test_mode: bool = False - """Whether to run real-time tests.""" - -diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py -index 3ea405770..5a42001b9 100644 ---- a/megatron/core/transformer/transformer_layer.py -+++ b/megatron/core/transformer/transformer_layer.py -@@ -223,6 +223,7 @@ class TransformerLayerSubmodules: - input_layernorm: Union[ModuleSpec, type] = IdentityOp - self_attention: Union[ModuleSpec, type] = IdentityOp - self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - - pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp - cross_attention: Union[ModuleSpec, type] = IdentityOp -@@ -231,6 +232,7 @@ class TransformerLayerSubmodules: - pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - mlp: Union[ModuleSpec, type] = IdentityOp - mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp -+ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp - - # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method - sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) -@@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - # [Module 3: BiasDropoutFusion] - self.self_attn_bda = build_module(submodules.self_attn_bda) - -+ self.post_self_attn_layernorm = build_module( -+ submodules.post_self_attn_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon, -+ ) -+ - # [Module 4: Post SelfAttention] Optional Layernorm after self-attn - self.pre_cross_attn_layernorm = build_module( - submodules.pre_cross_attn_layernorm, -@@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - - self.is_moe_layer = isinstance(self.mlp, MoELayer) - -+ self.post_mlp_layernorm = build_module( -+ submodules.post_mlp_layernorm, -+ config=self.config, -+ hidden_size=self.config.hidden_size, -+ eps=self.config.layernorm_epsilon -+ ) -+ - self.recompute_input_layernorm = False - self.recompute_pre_mlp_layernorm = False - self.recompute_mlp = False -@@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - attention_output_with_bias[0] - ) - -+ attention_output, attention_output_bias = attention_output_with_bias -+ attention_output = self.post_self_attn_layernorm(attention_output) -+ attention_output_with_bias = (attention_output, attention_output_bias) -+ - # TODO: could we move `bias_dropout_add_exec_handler` itself - # inside the module provided in the `bias_dropout_add_spec` module? - nvtx_range_push(suffix="self_attn_bda") -@@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): - else: - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - -+ mlp_output, mlp_output_bias = mlp_output_with_bias -+ mlp_output = self.post_mlp_layernorm(mlp_output) -+ mlp_output_with_bias = (mlp_output, mlp_output_bias) -+ - if self.recompute_pre_mlp_layernorm: - # discard the output of the pre-mlp layernorm and register the recompute - # as a gradient hook of mlp_output_with_bias[0] -diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index b267c8a81..83736acdc 100644 ---- a/megatron/training/arguments.py -+++ b/megatron/training/arguments.py -@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): - - kw_args['inference_sampling_seed'] = args.seed - -+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm -+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm -+ - # handle quantization config - # NOTE: Kitchen arguments are only added to the namespace when - # Kitchen library is available. -@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') -+ group.add_argument('--post-self-attn-layernorm', action='store_true', -+ help='If set, use post self attention layernorm.') -+ group.add_argument('--post-mlp-layernorm', action='store_true', -+ help='If set, use post MLP layernorm.') -+ group.add_argument('--use-gated-attention', action='store_true', -+ help='If set, use gated attention as in Qwen3Next') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' -diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py -index 13b7526ca..6c590f653 100644 ---- a/megatron/training/tokenizer/tokenizer.py -+++ b/megatron/training/tokenizer/tokenizer.py -@@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): - # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there - self._tokenizer = transformers.AutoTokenizer.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, -- trust_remote_code=trust_remote_code, -+ trust_remote_code=True, - **kwargs, - ) - self._vocab = self._tokenizer.get_vocab() diff --git a/docker/amd_patch/sglv0.5.0rc0/sglang.patch b/docker/amd_patch/sglv0.5.0rc0/sglang.patch deleted file mode 100644 index 990c2e6289..0000000000 --- a/docker/amd_patch/sglv0.5.0rc0/sglang.patch +++ /dev/null @@ -1,203 +0,0 @@ -diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py -index bdb124e51..3edf30ab1 100644 ---- a/python/sglang/srt/configs/model_config.py -+++ b/python/sglang/srt/configs/model_config.py -@@ -454,14 +454,14 @@ class ModelConfig: - ).lower() - - # Detect which checkpoint is it -- for _, method in QUANTIZATION_METHODS.items(): -- quantization_override = method.override_quantization_method( -- quant_cfg, self.quantization -- ) -- if quantization_override: -- quant_method = quantization_override -- self.quantization = quantization_override -- break -+ # for _, method in QUANTIZATION_METHODS.items(): -+ # quantization_override = method.override_quantization_method( -+ # quant_cfg, self.quantization -+ # ) -+ # if quantization_override: -+ # quant_method = quantization_override -+ # self.quantization = quantization_override -+ # break - - # Verify quantization configurations. - if self.quantization is None: -diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py -index 2dd2c75f1..f2adb18f8 100644 ---- a/python/sglang/srt/entrypoints/http_server.py -+++ b/python/sglang/srt/entrypoints/http_server.py -@@ -264,6 +264,10 @@ async def validate_json_request(raw_request: Request): - - - @app.get("/health") -+async def health(request: Request) -> Response: -+ return Response(status_code=200) -+ -+ - @app.get("/health_generate") - async def health_generate(request: Request) -> Response: - """ -diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -index 372717bf9..40665cc90 100644 ---- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -+++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py -@@ -190,6 +190,7 @@ class DeepEPBuffer: - f"Consider using --deepep-config to change the behavior." - ) - -+ num_qps_per_rank = 20 - cls._buffer = Buffer( - group, - num_nvl_bytes, -diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py -index 956264fc9..69f729336 100644 ---- a/python/sglang/srt/layers/quantization/fp8.py -+++ b/python/sglang/srt/layers/quantization/fp8.py -@@ -351,10 +351,10 @@ class Fp8LinearMethod(LinearMethodBase): - return - else: - weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data -- layer.weight = torch.nn.Parameter(weight, requires_grad=False) -- layer.weight_scale_inv = torch.nn.Parameter( -- weight_scale, requires_grad=False -- ) -+ # layer.weight = torch.nn.Parameter(weight, requires_grad=False) -+ # layer.weight_scale_inv = torch.nn.Parameter( -+ # weight_scale, requires_grad=False -+ # ) - return - - layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) -diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py -index 95a529c89..758fbfd5f 100644 ---- a/python/sglang/srt/managers/scheduler.py -+++ b/python/sglang/srt/managers/scheduler.py -@@ -1359,7 +1359,7 @@ class Scheduler( - - if memory_leak: - msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" -- raise ValueError(msg) -+ # raise ValueError(msg) - - if self.disaggregation_mode == DisaggregationMode.DECODE: - req_total_size = ( -@@ -1374,7 +1374,7 @@ class Scheduler( - f"available_size={len(self.req_to_token_pool.free_slots)}, " - f"total_size={self.req_to_token_pool.size}\n" - ) -- raise ValueError(msg) -+ # raise ValueError(msg) - - if ( - self.enable_metrics -@@ -1830,6 +1830,7 @@ class Scheduler( - deepep_mode=DeepEPMode(self.server_args.deepep_mode), - require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), - disable_overlap_schedule=self.server_args.disable_overlap_schedule, -+ offload_tags=self.offload_tags, - ) - - def handle_dp_balance_data(self, local_batch: ScheduleBatch): -@@ -1927,6 +1928,7 @@ class Scheduler( - deepep_mode: DeepEPMode, - require_mlp_tp_gather: bool, - disable_overlap_schedule: bool, -+ offload_tags: set[str], - ): - # Check if other DP workers have running batches - if local_batch is None: -@@ -1957,7 +1959,7 @@ class Scheduler( - ) - - tbo_preparer = TboDPAttentionPreparer() -- if disable_overlap_schedule: -+ if len(offload_tags) == 0 and disable_overlap_schedule: - group = tp_group.device_group - device = tp_group.device - else: -diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py -index 58220b1d6..3c3d081a8 100644 ---- a/python/sglang/srt/managers/tokenizer_manager.py -+++ b/python/sglang/srt/managers/tokenizer_manager.py -@@ -1044,10 +1044,15 @@ class TokenizerManager: - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() -- assert ( -- self.server_args.dp_size == 1 -- ), "dp_size must be 1 for init parameter update group" -- result = (await self.init_weights_update_group_communicator(obj))[0] -+ results = await self.init_weights_update_group_communicator(obj) -+ if self.server_args.dp_size == 1: -+ result = results[0] -+ return result.success, result.message -+ else: -+ all_success = all([r.success for r in results]) -+ all_message = [r.message for r in results] -+ all_message = " | ".join(all_message) -+ return all_success, all_message - return result.success, result.message - - async def update_weights_from_distributed( -@@ -1056,9 +1061,6 @@ class TokenizerManager: - request: Optional[fastapi.Request] = None, - ) -> Tuple[bool, str]: - self.auto_create_handle_loop() -- assert ( -- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention -- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed" - - if obj.abort_all_requests: - self.abort_request(abort_all=True) -@@ -1066,8 +1068,15 @@ class TokenizerManager: - # This means that weight sync - # cannot run while requests are in progress. - async with self.model_update_lock.writer_lock: -- result = (await self.update_weights_from_distributed_communicator(obj))[0] -- return result.success, result.message -+ results = await self.update_weights_from_distributed_communicator(obj) -+ if self.server_args.dp_size == 1: -+ result = results[0] -+ return result.success, result.message -+ else: -+ all_success = all([r.success for r in results]) -+ all_message = [r.message for r in results] -+ all_message = " | ".join(all_message) -+ return all_success, all_message - - async def update_weights_from_tensor( - self, -diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py -index 5222bff0a..ff0bbc62a 100644 ---- a/python/sglang/srt/model_executor/model_runner.py -+++ b/python/sglang/srt/model_executor/model_runner.py -@@ -22,6 +22,7 @@ import os - import time - from dataclasses import dataclass - from typing import List, Optional, Tuple, Union -+from contextlib import nullcontext - - import torch - import torch.distributed as dist -@@ -675,7 +676,7 @@ class ModelRunner: - monkey_patch_vllm_parallel_state() - monkey_patch_isinstance_for_vllm_base_layer() - -- with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS): -+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS) if not self.is_draft_worker else nullcontext(): - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, -diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py -index e0f0b373d..a18ac10f1 100644 ---- a/python/sglang/srt/models/glm4_moe.py -+++ b/python/sglang/srt/models/glm4_moe.py -@@ -1108,5 +1108,4 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): - ) - weight_loader(param, loaded_weight) - -- - EntryClass = [Glm4MoeForCausalLM] diff --git a/docker/amd_patch/sglv0.5.10/megatron.patch b/docker/amd_patch/sglv0.5.10/megatron.patch new file mode 100644 index 0000000000..acd64149b7 --- /dev/null +++ b/docker/amd_patch/sglv0.5.10/megatron.patch @@ -0,0 +1,20 @@ +diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py +--- a/megatron/legacy/fused_kernels/__init__.py ++++ b/megatron/legacy/fused_kernels/__init__.py +@@ -3,6 +3,7 @@ + import os + import pathlib + import subprocess ++import torch + + from torch.utils import cpp_extension + +@@ -15,6 +16,8 @@ + + + def load(args): ++ if not torch.version.cuda: ++ return + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] diff --git a/docker/amd_patch/sglv0.5.7/megatron.patch b/docker/amd_patch/sglv0.5.7/megatron.patch deleted file mode 100644 index f6efca346d..0000000000 --- a/docker/amd_patch/sglv0.5.7/megatron.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/megatron/legacy/fused_kernels/__init__.py b/megatron/legacy/fused_kernels/__init__.py -index 87cceac3..ac686d74 100644 ---- a/megatron/legacy/fused_kernels/__init__.py -+++ b/megatron/legacy/fused_kernels/__init__.py -@@ -3,6 +3,7 @@ - import os - import pathlib - import subprocess -+import torch - - from torch.utils import cpp_extension - -@@ -15,23 +16,23 @@ os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - - def load(args): -- -- # Check if cuda 11 is installed for compute capability 8.0 -- cc_flag = [] -- _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -- cpp_extension.CUDA_HOME -- ) -- if int(bare_metal_major) >= 11: -- cc_flag.append('-gencode') -- cc_flag.append('arch=compute_80,code=sm_80') -- if int(bare_metal_minor) >= 8: -+ if torch.cuda.is_available() and torch.version.cuda: -+ # Check if cuda 11 is installed for compute capability 8.0 -+ cc_flag = [] -+ _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( -+ cpp_extension.CUDA_HOME -+ ) -+ if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') -- cc_flag.append('arch=compute_90,code=sm_90') -+ cc_flag.append('arch=compute_80,code=sm_80') -+ if int(bare_metal_minor) >= 8: -+ cc_flag.append('-gencode') -+ cc_flag.append('arch=compute_90,code=sm_90') - -- # Build path -- srcpath = pathlib.Path(__file__).parent.absolute() -- buildpath = srcpath / "build" -- _create_build_dir(buildpath) -+ # Build path -+ srcpath = pathlib.Path(__file__).parent.absolute() -+ buildpath = srcpath / "build" -+ _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): diff --git a/docker/amd_patch/sglv0.5.7/sglang.patch b/docker/amd_patch/sglv0.5.7/sglang.patch deleted file mode 100644 index b103263070..0000000000 --- a/docker/amd_patch/sglv0.5.7/sglang.patch +++ /dev/null @@ -1,38 +0,0 @@ -diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -index 6e7ea07e7..73b512f51 100644 ---- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py -@@ -64,6 +64,7 @@ class CustomAllreduce: - group: ProcessGroup, - device: Union[int, str, torch.device], - max_size=_MAX_CAR_SIZE, -+ enable_register_for_capturing: bool = True, - ) -> None: - """ - Args: -@@ -410,6 +411,8 @@ class CustomAllreduce: - if self._IS_CAPTURING: - if torch.cuda.is_current_stream_capturing(): - if _is_hip: -+ if self.tms_cudagraph: -+ return self.all_reduce_unreg(input) - return self.all_reduce_reg(input) - else: - return self.all_reduce(input, registered=not self.tms_cudagraph) -diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py -index c3ca1e4f3..2bb763b6a 100644 ---- a/python/sglang/srt/distributed/parallel_state.py -+++ b/python/sglang/srt/distributed/parallel_state.py -@@ -351,10 +351,12 @@ class GroupCoordinator: - if use_custom_allreduce and self.world_size > 1: - # Initialize a custom fast all-reduce implementation. - try: -+ tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() - CAClass = dispatch_custom_allreduce() - self.ca_comm = CAClass( - group=self.cpu_group, - device=self.device, -+ enable_register_for_capturing=not tms_cudagraph, - ) - except Exception as e: - logger.warning( diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/amd/run-qwen3-4B-amd.sh similarity index 67% rename from scripts/run-qwen3-30B-A3B.sh rename to scripts/amd/run-qwen3-4B-amd.sh index 19bc70927d..bc6d4d40c0 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/amd/run-qwen3-4B-amd.sh @@ -9,30 +9,34 @@ pkill -9 python sleep 3 pkill -9 ray pkill -9 python -pkill -9 redis set -ex +# keep Ray from blanking HIP/CUDA visibility for the job entrypoint. +export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} +export RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES:-"1"} + # will prevent ray from buffering stdout/stderr export PYTHONBUFFERED=16 -NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -if [ "$NVLINK_COUNT" -gt 0 ]; then - HAS_NVLINK=1 -else - HAS_NVLINK=0 +if [[ -n "${HIP_VISIBLE_DEVICES:-}" ]]; then + export CUDA_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES}" +fi + +NUM_GPUS=${NUM_GPUS:-8} +if [[ -n "${CUDA_VISIBLE_DEVICES:-}" ]]; then + IFS=',' read -r -a visible_gpu_ids <<< "${CUDA_VISIBLE_DEVICES}" + NUM_GPUS=${#visible_gpu_ids[@]} fi -echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-30B-A3B.sh" +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." &>/dev/null && pwd)" +source "${SCRIPT_DIR}/models/qwen3-4B.sh" CKPT_ARGS=( - --hf-checkpoint /root/Qwen3-30B-A3B - #--hf-checkpoint /root/Qwen3-30B-A3B-FP8 - --ref-load /root/Qwen3-30B-A3B_torch_dist - --load /root/Qwen3-30B-A3B_miles/ - --save /root/Qwen3-30B-A3B_miles/ + --hf-checkpoint /root/Qwen3-4B + --ref-load /root/Qwen3-4B_torch_dist + --load /root/Qwen3-4B_miles/ + --save /root/Qwen3-4B_miles/ --save-interval 20 ) @@ -48,7 +52,6 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-response-len 8192 --rollout-temperature 1 - --global-batch-size 256 --balance-data ) @@ -62,11 +65,11 @@ EVAL_ARGS=( ) PERF_ARGS=( - --tensor-model-parallel-size 4 + --tensor-model-parallel-size 2 --sequence-parallel --pipeline-model-parallel-size 1 --context-parallel-size 1 - --expert-model-parallel-size 8 + --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 --recompute-granularity full @@ -75,7 +78,7 @@ PERF_ARGS=( # --micro-batch-size 1 --use-dynamic-batch-size - --max-tokens-per-gpu 20480 + --max-tokens-per-gpu 9216 ) GRPO_ARGS=( @@ -95,23 +98,18 @@ OPTIMIZER_ARGS=( --weight-decay 0.1 --adam-beta1 0.9 --adam-beta2 0.98 - - --optimizer-cpu-offload - --overlap-cpu-optimizer-d2h-h2d - --use-precision-aware-optimizer ) WANDB_ARGS=( - #--use-wandb + # --use-wandb # --wandb-project miles-dev - # --wandb-group qwen3-30B-A3B-test + # --wandb-group qwen3-4B-test # --wandb-key ${WANDB_KEY} ) SGLANG_ARGS=( - --rollout-num-gpus-per-engine 8 + --rollout-num-gpus-per-engine 2 --sglang-mem-fraction-static 0.7 - --sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256) ) MISC_ARGS=( @@ -127,14 +125,13 @@ MISC_ARGS=( # launch the master node of ray in container export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 +ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 # Build the runtime environment JSON with proper variable substitution RUNTIME_ENV_JSON="{ \"env_vars\": { \"PYTHONPATH\": \"/root/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", - \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" } }" @@ -142,7 +139,7 @@ ray job submit --address="http://127.0.0.1:8265" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ -- python3 train.py \ --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ + --actor-num-gpus-per-node ${NUM_GPUS} \ --colocate \ ${MODEL_ARGS[@]} \ ${CKPT_ARGS[@]} \ diff --git a/scripts/run-llama3.2-3B-Instruct-amd.sh b/scripts/run-llama3.2-3B-Instruct-amd.sh deleted file mode 100644 index eb5d5709ce..0000000000 --- a/scripts/run-llama3.2-3B-Instruct-amd.sh +++ /dev/null @@ -1,180 +0,0 @@ -#!/bin/bash - -# hf download meta-llama/Llama-3.2-3B-Instruct --local-dir /root/Llama-3.2-3B-Instruct - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - -set -euxo pipefail - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -# NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) -# if [ "$NVLINK_COUNT" -gt 0 ]; then -# HAS_NVLINK=1 -# else -# HAS_NVLINK=0 -# fi -# echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/llama3.2-3B-Instruct-amd.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Llama-3.2-3B-Instruct - --ref-load ${MODEL_DIR}/Llama-3.2-3B-Instruct_torch_dist - --load ${MODEL_DIR}/Llama-3.2-3B-Instruct_miles/ - --save ${MODEL_DIR}/Llama-3.2-3B-Instruct_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type math - --num-epoch 1 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 16384 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 10 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 8 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-dev - # --wandb-group llama3.2-3B - # --wandb-key ${WANDB_API_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.4 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - ################### -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - -# Build the runtime environment JSON with proper variable substitution -RUNTIME_ENV_JSON="{ - \"env_vars\": { - \"PYTHONPATH\": \"/workspace/Megatron-LM/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" - } -}" - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="${RUNTIME_ENV_JSON}" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} - - -####clear after training - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python \ No newline at end of file diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh deleted file mode 100755 index 44257cc77f..0000000000 --- a/scripts/run-qwen3-4B-amd.sh +++ /dev/null @@ -1,161 +0,0 @@ -#!/bin/bash - -# for rerun the task -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - -set -euxo pipefail - - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/root}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/root}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/root}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-4B.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Qwen3-4B - --ref-load ${MODEL_DIR}/Qwen3-4B_torch_dist - --load ${MODEL_DIR}/Qwen3-4B_miles/ - --save ${MODEL_DIR}/Qwen3-4B_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - # --use-wandb - # --wandb-project miles-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 -) - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash - ################### -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - - -# Dynamically detect Megatron-LM installation path -MEGATRON_LM_PATH=$(python3 -c "import megatron; import os; print(os.path.dirname(os.path.dirname(megatron.__file__)))" 2>/dev/null || echo "/app/Megatron-LM") - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json="{ - \"env_vars\": { - \"PYTHONPATH\": \"${MEGATRON_LM_PATH}/\", - \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" - } - }" \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} diff --git a/scripts/run-qwen3-8B-amd.sh b/scripts/run-qwen3-8B-amd.sh deleted file mode 100644 index 979ffa18e0..0000000000 --- a/scripts/run-qwen3-8B-amd.sh +++ /dev/null @@ -1,194 +0,0 @@ -#!/bin/bash - - -# bash scripts/run-qwen3-4B-amd.sh - - -####clear before training -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - -set -euxo pipefail - - -### AMD Support ### -MILES_DIR="${MILES_DIR:-/home/yushensu/projects/miles}" # Default path if not set in environment -export MILES_DIR - -MODEL_DIR="${MODEL_DIR:-/home/yushensu/projects/model}" # Default path if not set in environment -export MODEL_DIR - -DATA_DIR="${DATA_DIR:-/home/yushensu/projects/data}" # Default path if not set in environment -export DATA_DIR - -# For AMD GPU -export RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=${RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES:-"1"} # Must set to 1 -export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"} #You can choose which gpus to use -#################### - - -# will prevent ray from buffering stdout/stderr -export PYTHONBUFFERED=16 - -# Current Model convert script on AMD GPU has some issue, please download the converted model from here: https://huggingface.co/zyzshishui0627/models - -SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" -source "${SCRIPT_DIR}/models/qwen3-8B.sh" - -CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/Qwen3-8B - #--hf-checkpoint /root/Qwen3-4B-FP8 - --ref-load ${MODEL_DIR}/Qwen3-8B_torch_dist - # --ref-load ${MODEL_DIR}/Qwen3-8B_torch_dist_amd_new - --load ${MODEL_DIR}/Qwen3-8B_miles/ - --save ${MODEL_DIR}/Qwen3-8B_miles/ - --save-interval 20 -) - -ROLLOUT_ARGS=( - --prompt-data ${DATA_DIR}/dapo-math-17k/dapo-math-17k.jsonl - --input-key prompt - --label-key label - --apply-chat-template - --rollout-shuffle - --rm-type deepscaler - --num-rollout 3000 - --rollout-batch-size 32 - --n-samples-per-prompt 8 - --rollout-max-response-len 8192 - --rollout-temperature 1 - - --global-batch-size 256 - --balance-data -) - -EVAL_ARGS=( - --eval-interval 20 - --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl - --n-samples-per-eval-prompt 16 - --eval-max-response-len 16384 - --eval-top-p 1 -) - -PERF_ARGS=( - --tensor-model-parallel-size 2 - --sequence-parallel - --pipeline-model-parallel-size 1 - --context-parallel-size 1 - --expert-model-parallel-size 1 - --expert-tensor-parallel-size 1 - - --recompute-granularity full - --recompute-method uniform - --recompute-num-layers 1 - - # --micro-batch-size 1 - --use-dynamic-batch-size - --max-tokens-per-gpu 9216 -) - -GRPO_ARGS=( - --advantage-estimator grpo - --use-kl-loss - --kl-loss-coef 0.00 - --kl-loss-type low_var_kl - --entropy-coef 0.00 - --eps-clip 0.2 - --eps-clip-high 0.28 -) - -OPTIMIZER_ARGS=( - --optimizer adam - --lr 1e-6 - --lr-decay-style constant - --weight-decay 0.1 - --adam-beta1 0.9 - --adam-beta2 0.98 -) - -WANDB_ARGS=( - #--use-wandb - # --wandb-project miles-dev - # --wandb-group qwen3-4B-test - # --wandb-key ${WANDB_KEY} -) - -SGLANG_ARGS=( - --rollout-num-gpus-per-engine 2 - --sglang-mem-fraction-static 0.7 -) -#################### - - -MISC_ARGS=( - # default dropout in megatron is 0.1 - --attention-dropout 0.0 - --hidden-dropout 0.0 - # should be good for model performance - --accumulate-allreduce-grads-in-fp32 - --attention-softmax-in-fp32 - # need to comment this when using model with MLA - --attention-backend flash -) - -# launch the master node of ray in container -export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} - -NUM_GPUS=$(echo ${HIP_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) -ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 - - -# "PYTHONPATH": "/workspace/Megatron-LM/", -MEGATRON_LM_PATH=$(pip list | grep megatron-core | awk '{print $NF}') - -ray job submit --address="http://127.0.0.1:8265" \ - --runtime-env-json='{ - "env_vars": { - "PYTHONPATH": "/workspace/Megatron-LM/", - "CUDA_DEVICE_MAX_CONNECTIONS": "1" - } - }' \ - -- python3 train.py \ - --actor-num-nodes 1 \ - --actor-num-gpus-per-node 8 \ - --colocate \ - ${MODEL_ARGS[@]} \ - ${CKPT_ARGS[@]} \ - ${ROLLOUT_ARGS[@]} \ - ${OPTIMIZER_ARGS[@]} \ - ${GRPO_ARGS[@]} \ - ${WANDB_ARGS[@]} \ - ${PERF_ARGS[@]} \ - ${EVAL_ARGS[@]} \ - ${SGLANG_ARGS[@]} \ - ${MISC_ARGS[@]} - - - -####clear after training - -pkill -9 sglang -sleep 3 -ray stop --force -pkill -9 ray -pkill -9 python -sleep 3 -pkill -9 ray -pkill -9 python - - - - - - - - - -