From 2142eb1b9c39962b1c764cc1610831b961703ebc Mon Sep 17 00:00:00 2001 From: Pavel Belevich Date: Thu, 22 May 2025 22:07:28 +0000 Subject: [PATCH] DeepSeek V3/R1/Prover-V2 671B SFT with LoRA --- 3.test_cases/pytorch/colossalai/README.md | 35 ++ .../pytorch/colossalai/build_docker.sh | 22 + .../pytorch/colossalai/colossalai.Dockerfile | 179 ++++++ .../deepseek-lora-finetune/.gitignore | 2 + .../deepseek-lora-finetune/README.md | 122 ++++ .../deepseek-lora-finetune/lora_eval.py | 118 ++++ .../deepseek-lora-finetune/lora_finetune.py | 532 ++++++++++++++++++ .../lora_finetune.sbatch | 72 +++ .../colossalai/gather_state_dict_fast.patch | 30 + 9 files changed, 1112 insertions(+) create mode 100644 3.test_cases/pytorch/colossalai/README.md create mode 100755 3.test_cases/pytorch/colossalai/build_docker.sh create mode 100644 3.test_cases/pytorch/colossalai/colossalai.Dockerfile create mode 100644 3.test_cases/pytorch/colossalai/deepseek-lora-finetune/.gitignore create mode 100644 3.test_cases/pytorch/colossalai/deepseek-lora-finetune/README.md create mode 100644 3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_eval.py create mode 100644 3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.py create mode 100644 3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.sbatch create mode 100644 3.test_cases/pytorch/colossalai/gather_state_dict_fast.patch diff --git a/3.test_cases/pytorch/colossalai/README.md b/3.test_cases/pytorch/colossalai/README.md new file mode 100644 index 000000000..030789df6 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/README.md @@ -0,0 +1,35 @@ +# Colossal-AI + +## Dependencies + +As of Apr 18th 2025 [commit](https://github.com/hpcaitech/ColossalAI/tree/46ed5d856b16b074325091a88e761544b3d4f9f0) ColosalAI required PyTorch 2.5.1 which official builds use CUDA 12.4. We use `nvidia/cuda:12.4.1-devel-ubuntu22.04` as the base image and install all dependencies on top of it in [colossalai.Dockerfile](colossalai.Dockerfile). + +## Build Docker Image + +Building Colossal-AI from scratch requires GPU support, you need to use Nvidia Docker Runtime as the default when doing docker build. We launch the build job on the GPU node: + +Login to AWS ECR: +```bash +export AWS_ACCESS_KEY_ID=... +export AWS_SECRET_ACCESS_KEY=... + +aws ecr get-login-password ... +``` + +Build the docker image on the GPU node and push it to the docker repo: +```bash +export DOCKER_REPO=159553542841.dkr.ecr.ap-northeast-1.amazonaws.com/belevich/colossalai +srun ./build_docker.sh +``` + +Take docker image from the docker repo: +```bash +docker pull $DOCKER_REPO:latest +``` + +Import the docker image to an enroot container(maybe remove previous created `rm ./colossalai.sqsh`): +```bash +enroot import -o ./colossalai.sqsh dockerd://$DOCKER_REPO:latest +``` + + diff --git a/3.test_cases/pytorch/colossalai/build_docker.sh b/3.test_cases/pytorch/colossalai/build_docker.sh new file mode 100755 index 000000000..185c88794 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/build_docker.sh @@ -0,0 +1,22 @@ +#! /bin/bash + +if [ -z "$SLURM_JOB_ID" ]; then + echo "Run with slurm: srun ./build_docker.sh" + exit 1 +fi + +docker build --progress=plain -f colossalai.Dockerfile -t colossalai:latest . + +if [ $? -ne 0 ]; then + echo "Failed to build docker image" + exit 1 +fi + +if [ -z "$DOCKER_REPO" ]; then + echo "DOCKER_REPO is not set" + exit 1 +fi + +docker tag colossalai:latest $DOCKER_REPO:latest + +docker push $DOCKER_REPO:latest diff --git a/3.test_cases/pytorch/colossalai/colossalai.Dockerfile b/3.test_cases/pytorch/colossalai/colossalai.Dockerfile new file mode 100644 index 000000000..8dde17c73 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/colossalai.Dockerfile @@ -0,0 +1,179 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 + +ARG GDRCOPY_VERSION=v2.4.4 +ARG EFA_INSTALLER_VERSION=1.38.1 +ARG AWS_OFI_NCCL_VERSION=v1.14.0 +ARG NCCL_VERSION=v2.26.2-1 +ARG NCCL_TESTS_VERSION=v2.14.1 + +RUN apt-get update -y && apt-get upgrade -y +RUN apt-get remove -y --allow-change-held-packages \ + ibverbs-utils \ + libibverbs-dev \ + libibverbs1 \ + libmlx5-1 \ + libnccl2 \ + libnccl-dev + +RUN rm -rf /opt/hpcx \ + && rm -rf /usr/local/mpi \ + && rm -f /etc/ld.so.conf.d/hpcx.conf \ + && ldconfig + +ENV OPAL_PREFIX= + +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ + apt-utils \ + autoconf \ + automake \ + build-essential \ + check \ + cmake \ + curl \ + debhelper \ + devscripts \ + git \ + gcc \ + gdb \ + kmod \ + libsubunit-dev \ + libtool \ + openssh-client \ + openssh-server \ + pkg-config \ + python3-distutils \ + vim +RUN apt-get purge -y cuda-compat-* + +RUN mkdir -p /var/run/sshd +RUN sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ + echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ + sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config + +ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/aws-ofi-nccl/install/lib:/usr/local/lib:$LD_LIBRARY_PATH +ENV PATH /opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH + +RUN curl https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py \ + && python3 /tmp/get-pip.py \ + && pip3 install awscli pynvml + +################################################# +## Install NVIDIA GDRCopy +## +## NOTE: if `nccl-tests` or `/opt/gdrcopy/bin/sanity -v` crashes with incompatible version, ensure +## that the cuda-compat-xx-x package is the latest. +RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ + && cd /tmp/gdrcopy \ + && make prefix=/opt/gdrcopy install + +ENV LD_LIBRARY_PATH /opt/gdrcopy/lib:$LD_LIBRARY_PATH +ENV LIBRARY_PATH /opt/gdrcopy/lib:$LIBRARY_PATH +ENV CPATH /opt/gdrcopy/include:$CPATH +ENV PATH /opt/gdrcopy/bin:$PATH + +################################################# +## Install EFA installer +RUN cd $HOME \ + && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ + && tar -xf $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ + && cd aws-efa-installer \ + && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify \ + && rm -rf $HOME/aws-efa-installer + +################################################### +## Install NCCL +RUN git clone -b ${NCCL_VERSION} https://github.com/NVIDIA/nccl.git /opt/nccl \ + && cd /opt/nccl \ + && make -j $(nproc) src.build CUDA_HOME=/usr/local/cuda \ + NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90" + +################################################### +## Install AWS-OFI-NCCL plugin +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y libhwloc-dev +#Switch from sh to bash to allow parameter expansion +SHELL ["/bin/bash", "-c"] +RUN curl -OL https://github.com/aws/aws-ofi-nccl/releases/download/${AWS_OFI_NCCL_VERSION}/aws-ofi-nccl-${AWS_OFI_NCCL_VERSION//v}.tar.gz \ + && tar -xf aws-ofi-nccl-${AWS_OFI_NCCL_VERSION//v}.tar.gz \ + && cd aws-ofi-nccl-${AWS_OFI_NCCL_VERSION//v} \ + && ./configure --prefix=/opt/aws-ofi-nccl/install \ + --with-mpi=/opt/amazon/openmpi \ + --with-libfabric=/opt/amazon/efa \ + --with-cuda=/usr/local/cuda \ + --enable-platform-aws \ + && make -j $(nproc) \ + && make install \ + && cd .. \ + && rm -rf aws-ofi-nccl-${AWS_OFI_NCCL_VERSION//v} \ + && rm aws-ofi-nccl-${AWS_OFI_NCCL_VERSION//v}.tar.gz + +SHELL ["/bin/sh", "-c"] + +################################################### +## Install NCCL-tests +RUN git clone -b ${NCCL_TESTS_VERSION} https://github.com/NVIDIA/nccl-tests.git /opt/nccl-tests \ + && cd /opt/nccl-tests \ + && make -j $(nproc) \ + MPI=1 \ + MPI_HOME=/opt/amazon/openmpi/ \ + CUDA_HOME=/usr/local/cuda \ + NCCL_HOME=/opt/nccl/build \ + NVCC_GENCODE="-gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_89,code=sm_89 -gencode=arch=compute_90,code=sm_90" + +RUN rm -rf /var/lib/apt/lists/* + +## Set Open MPI variables to exclude network interface and conduit. +ENV OMPI_MCA_pml=^ucx \ + OMPI_MCA_btl=tcp,self \ + OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent\ + OPAL_PREFIX=/opt/amazon/openmpi \ + NCCL_SOCKET_IFNAME=^docker,lo,veth + +## Turn off PMIx Error https://github.com/open-mpi/ompi/issues/7516 +ENV PMIX_MCA_gds=hash + +## Set LD_PRELOAD for NCCL library +ENV LD_PRELOAD /opt/nccl/build/lib/libnccl.so + +# Install Miniconda to not depend on the base image python +RUN mkdir -p /opt/miniconda3 \ + && curl -L https://repo.anaconda.com/miniconda/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh -o /tmp/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh \ + && bash /tmp/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh -b -f -p /opt/miniconda3 \ + && rm /tmp/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh \ + && /opt/miniconda3/bin/conda init bash + +ENV PATH="/opt/miniconda3/bin:${PATH}" + +COPY gather_state_dict_fast.patch /tmp/gather_state_dict_fast.patch + +RUN git clone https://github.com/hpcaitech/ColossalAI.git /tmp/colossalai && \ + cd /tmp/colossalai && \ + git checkout 46ed5d856b16b074325091a88e761544b3d4f9f0 && \ + git apply /tmp/gather_state_dict_fast.patch && \ + # BUILD_EXT=1 FORCE_CUDA=1 + pip install . && \ + cd applications/ColossalChat && \ + pip install . + +ENV TORCH_CUDA_ARCH_LIST="9.0a" + +# because of https://discuss.huggingface.co/t/valueerror-unable-to-avoid-copy-while-creating-an-array-as-requested/93584/5 +RUN pip install "numpy<2.0" + +# Install tensornvme from github because pipy version is totallyoutdated +RUN apt update -y && apt install -y libaio-dev && pip install -v git+https://github.com/hpcaitech/TensorNVMe.git + +# To use the fused RMSNorm kernel colossalai needs apex built from source: +RUN git clone https://github.com/NVIDIA/apex /tmp/apex && \ + cd /tmp/apex && \ + NVCC_APPEND_FLAGS="--threads 4" \ + pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 8" ./ + +# Build flash-attn +RUN MAX_JOBS=48 pip install flash-attn==2.7.4.post1 --no-cache-dir --no-deps --no-build-isolation --verbose --force-reinstall + +# Install transformers==4.52.4 for better support of DeepSeek +RUN pip install transformers==4.52.4 + +RUN pip install math-verify==0.7.0 tqdm diff --git a/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/.gitignore b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/.gitignore new file mode 100644 index 000000000..b9af70222 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/.gitignore @@ -0,0 +1,2 @@ +DeepSeek-V3 +logs \ No newline at end of file diff --git a/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/README.md b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/README.md new file mode 100644 index 000000000..096a54114 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/README.md @@ -0,0 +1,122 @@ +# DeepSeek V3/R1/Prover-V2 671B SFT with LoRA + +This example uses Colossal-AI container from the parent directory + +## Download model weights + +```bash +pip install -U "huggingface_hub[cli]" +``` + +Choose the model you want to finetune: + + - deepseek-ai/DeepSeek-V3 + - deepseek-ai/DeepSeek-V3-0324 + - deepseek-ai/DeepSeek-R1 + - deepseek-ai/DeepSeek-Prover-V2-671B + +and define model name environment variable, for example: +```bash +export MODEL_NAME="deepseek-ai/DeepSeek-R1" +``` + +Download the model weights from Hugging Face and find the model path: +```bash +huggingface-cli download $MODEL_NAME +export MODEL_PATH=`python -c "from pathlib import Path; from huggingface_hub import hf_hub_download; print(Path(hf_hub_download('$MODEL_NAME', filename='config.json')).parent)"` +export HF_HOME=${HF_HOME:-$(python -c "from pathlib import Path; from huggingface_hub import hf_hub_download; print(Path(hf_hub_download('$MODEL_NAME', filename='config.json')).parent.parent.parent.parent.parent)")} +``` + +## Convert fp8 weights to bf16 + +Since the model weights are fp8 and SFT requires bf16 weights, we use `convert_to_bf16.py` from `DeepSeek-V3` repo to convert the weights to bf16: + +Clone DeepSeek V3 repo: +```bash +git clone https://github.com/deepseek-ai/DeepSeek-V3.git +``` +Launch a job on the GPU node: +```bash +srun \ + --container-image ../colossalai.sqsh \ + --container-mounts ./:/workdir,$HF_HOME:$HF_HOME \ + python /workdir/DeepSeek-V3/inference/fp8_cast_bf16.py \ + --input-fp8-hf-path $MODEL_PATH \ + --output-bf16-hf-path /workdir/$MODEL_NAME-bf16 +``` + +Copy the model config and tokenizer files to the output directory: +```bash +cp -L $MODEL_PATH/*.json ./$MODEL_NAME-bf16/ +cp -L $MODEL_PATH/*.py ./$MODEL_NAME-bf16/ +``` + +## Launch LoRA finetuning + +```bash +sbatch lora_finetune.sbatch $MODEL_NAME AI-MO/NuminaMath-TIR train +``` +Check the logs: +```bash +tail -f -n +0 slurm-XXX.out +``` +Example output on 15 p5en nodes(DP5, PP3, EP8): +``` +Step: 3%|▎ | 6/224 [22:39<5:48:26, 95.90s/it, loss=00.794, grad_norm=0.161] +Step: 5%|▌ | 12/224 [25:24<2:14:47, 37.97s/it, loss=0.506, grad_norm=0.108] +Step: 8%|▊ | 17/224 [28:09<1:38:50, 28.65s/it, loss=0.442, grad_norm=0.124] +Step: 10%|█ | 23/224 [30:54<1:32:21, 27.57s/it, loss=0.429, grad_norm=0.0904] +Step: 13%|█▎ | 29/224 [33:16<1:34:26, 29.06s/it, loss=0.411, grad_norm=0.0404] +Step: 16%|█▌ | 34/224 [36:55<1:43:15, 32.61s/it, loss=0.383, grad_norm=0.0298] +Step: 18%|█▊ | 40/224 [40:09<1:33:32, 30.50s/it, loss=0.368, grad_norm=0.0255] +Step: 21%|██ | 46/224 [42:27<1:22:43, 27.89s/it, loss=0.367, grad_norm=0.0252] +Step: 23%|██▎ | 51/224 [45:13<1:19:52, 27.70s/it, loss=0.354, grad_norm=0.0262] +Step: 25%|██▌ | 57/224 [47:31<1:16:52, 27.62s/it, loss=0.346, grad_norm=0.0232] +Step: 28%|██▊ | 62/224 [50:16<1:14:09, 27.47s/it, loss=0.355, grad_norm=0.0211] +Step: 30%|███ | 68/224 [52:34<1:11:46, 27.61s/it, loss=0.336, grad_norm=0.0214] +Step: 33%|███▎ | 73/224 [55:36<1:14:31, 29.61s/it, loss=0.34, ggrad_norm=0.021] +Step: 35%|███▌ | 79/224 [57:57<1:07:50, 28.01s/it, loss=0.339, grad_norm=0.0212] +Step: 38%|███▊ | 84/224 [1:00:27<1:13:01, 31.30s/it, loss=0.325, grad_norm=0.0224] +Step: 40%|███▉ | 89/224 [1:03:35<1:07:18, 29.92s/it, loss=0.324, grad_norm=0.0206] +Step: 42%|████▏ | 95/224 [1:05:52<1:00:10, 27.78s/it, loss=0.338, grad_norm=0.0224] +Step: 45%|████▍ | 100/224 [1:08:08<56:34, 27.37s/it, loss=0.325, grad_norm=0.0213] +Step: 47%|████▋ | 105/224 [1:10:53<54:21, 27.41s/it, loss=0.318, grad_norm=0.0206] +Step: 49%|████▉ | 110/224 [1:13:11<51:55, 27.33s/it, loss=0.342, grad_norm=0.0208] +Step: 52%|█████▏ | 116/224 [1:15:40<51:53, 28.56s/it, loss=0.334, grad_norm=0.0214] +Step: 54%|█████▍ | 121/224 [1:18:04<48:21, 28.62s/it, loss=0.336, grad_norm=0.02] +Step: 56%|█████▋ | 126/224 [1:20:21<44:45, 27.40s/it, loss=0.33, grad_norm=0.0211] +Step: 58%|█████▊ | 131/224 [1:22:38<42:28, 27.41s/it, loss=0.326, grad_norm=0.022] +Step: 61%|██████ | 136/224 [1:25:29<46:47, 31.90s/it, loss=0.344, grad_norm=0.0233] +Step: 63%|██████▎ | 141/224 [1:28:45<38:47, 28.05s/it, loss=0.328, grad_norm=0.0218] +Step: 65%|██████▌ | 146/224 [1:30:29<35:42, 27.47s/it, loss=0.329, grad_norm=0.0218] +Step: 67%|██████▋ | 151/224 [1:32:47<33:11, 27.28s/it, loss=0.33, grad_norm=0.0208] +Step: 70%|██████▉ | 156/224 [1:35:16<32:35, 28.75s/it, loss=0.322, grad_norm=0.0216] +Step: 72%|███████▏ | 161/224 [1:37:37<30:45, 29.29s/it, loss=0.328, grad_norm=0.0238] +Step: 74%|███████▍ | 166/224 [1:39:26<26:40, 27.60s/it, loss=0.313, grad_norm=0.0236] +Step: 76%|███████▋ | 171/224 [1:41:43<24:12, 27.40s/it, loss=0.337, grad_norm=0.0435] +Step: 79%|███████▊ | 176/224 [1:44:00<21:54, 27.39s/it, loss=0.328, grad_norm=0.0222] +Step: 81%|████████ | 181/224 [1:46:24<21:01, 29.35s/it, loss=0.332, grad_norm=0.0226] +Step: 83%|████████▎ | 186/224 [1:49:02<18:43, 29.57s/it, loss=0.329, grad_norm=0.0215] +Step: 85%|████████▌ | 191/224 [1:51:18<15:45, 27.81s/it, loss=0.325, grad_norm=0.0217] +Step: 88%|████████▋ | 195/224 [1:53:42<14:02, 29.06s/it, loss=0.331, grad_norm=0.0221] +Step: 89%|████████▉ | 200/224 [1:55:59<11:04, 27.69s/it, loss=0.32, grad_norm=0.0207] +Step: 92%|█████████▏| 205/224 [1:58:00<09:17, 29.32s/it, loss=0.311, grad_norm=0.0224] +Step: 94%|█████████▍| 210/224 [2:00:16<06:27, 27.64s/it, loss=0.327, grad_norm=0.023] +Step: 96%|█████████▌| 214/224 [2:02:34<04:35, 27.57s/it, loss=0.315, grad_norm=0.0273] +Step: 98%|█████████▊| 219/224 [2:04:50<02:16, 27.38s/it, loss=0.348, grad_norm=0.0217] +Step: 100%|██████████| 224/224 [2:06:39<00:00, 27.30s/it, loss=0.325, grad_norm=0.0236] + +Start saving final model checkpoint to /workdir/deepseek-ai/DeepSeek-R1-bf16-lora +Saved final model checkpoint at epoch 0 at folder /workdir/deepseek-ai/DeepSeek-R1-bf16-lora in 63.06 seconds + +``` + +## Launch LoRA evaluation + +```bash +srun \ + --mpi=pmix --cpu-bind=none \ + --container-image ../colossalai.sqsh \ + --container-mounts ./:/workdir,$HF_HOME:$HF_HOME \ + python /workdir/lora_eval.py -m deepseek-ai/DeepSeek-R1 -d AI-MO/NuminaMath-TIR +``` diff --git a/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_eval.py b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_eval.py new file mode 100644 index 000000000..d8150fcd0 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_eval.py @@ -0,0 +1,118 @@ +import argparse +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache +from peft import PeftModel, LoraModel +from datasets import load_dataset +from coati.dataset.loader import apply_chat_template_and_mask +from typing import Optional +from math_verify import parse, verify +from tqdm import tqdm + +# https://github.com/huggingface/transformers/issues/38710 +class DynamicCacheWithGetMaxLength(DynamicCache): + def get_max_length(self) -> Optional[int]: + return self.get_max_cache_shape() + +def eval(args): + ###### + # How to Load lora Model + ###### + # 1.Load base model + base_model = AutoModelForCausalLM.from_pretrained( + args.pretrained, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True + ) + + # 2.Load lora model + if args.lora_adapter is not None: + peft_model: LoraModel = PeftModel.from_pretrained( + base_model, + args.lora_adapter, + torch_dtype=torch.bfloat16 + ) + + # 3.Merge lora model + merged_model = peft_model.merge_and_unload() + else: + merged_model = base_model + + # 4.Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained, + trust_remote_code=True, + pad_token="<|endoftext|>" + ) + + # 5.Save merged lora model + if args.merged_model_path is not None: + merged_model.save_pretrained( + args.merged_model_path, + safe_serialization=True + ) + tokenizer.save_pretrained(args.merged_model_path) + + print(f"Load dataset: {args.dataset}") + dataset = load_dataset(args.dataset, split=args.dataset_split) + + all = correct = 0 + for sample in tqdm(dataset): + problem = sample["problem"] + print(f"{problem=}") + + solution = sample["solution"] + print(f"{solution=}") + + inputs = tokenizer.apply_chat_template([{"role": "user", "content": problem}], tokenize=True, return_dict=True, return_tensors="pt") + inputs = inputs.to(merged_model.device) + print(f"{inputs=}") + + outputs = merged_model.generate(**inputs, max_new_tokens=2048, past_key_values=DynamicCacheWithGetMaxLength()) + print(f"{outputs=}") + + completion = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"{completion=}") + + parsed_gold_solution = parse(solution) + parsed_completion = parse(completion) + result = verify(parsed_gold_solution, parsed_completion) + print(f"{parsed_gold_solution=}") + print(f"{parsed_completion=}") + print(f"{result=}") + all += 1 + if result: + correct += 1 + print("="*100) + + print(f"{all=}") + print(f"{correct=}") + print(f"{correct/all=}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Basic evaluation information. + parser.add_argument( + "-m", + "--pretrained", + type=str, + required=True, + help="Path or name of the pre-trained model", + ) + parser.add_argument( + "--lora_adapter", + type=str, + required=False, + help="Path of the LoRA adapter", + ) + parser.add_argument( + "--merged_model_path", + type=str, + required=False, + help="Path to save the merged model", + ) + parser.add_argument("-d", "--dataset", type=str, required=False, help="Dataset for training.") + parser.add_argument("--dataset_split", type=str, default="test", help="Dataset split to use.") + args = parser.parse_args() + eval(args) diff --git a/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.py b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.py new file mode 100644 index 000000000..8f2b1a3f5 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.py @@ -0,0 +1,532 @@ +#!/usr/bin/env python3 + +# Copyright ColossalAI +# Modified from https://github.com/hpcaitech/ColossalAI/blob/main/applications/ColossalChat/examples/training_scripts/lora_finetune.py + +# -*- coding: utf-8 -*- +""" +Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task. +""" + +import argparse +import json +import os +import resource +import time +from contextlib import nullcontext +from types import MethodType +import torch.distributed.distributed_c10d as d10d +from datetime import timedelta +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +from coati.dataset.loader import apply_chat_template_and_mask +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer +from datasets import load_dataset + +import colossalai +from colossalai.accelerator import get_accelerator +from colossalai.booster import Booster +from colossalai.booster.plugin import ( + GeminiPlugin, + HybridParallelPlugin, + LowLevelZeroPlugin, + MoeHybridParallelPlugin, + Plugin, + TorchDDPPlugin, +) +from colossalai.cluster import DistCoordinator +from colossalai.lazy import LazyInitContext +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device + + +def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor: + loss = loss.data + group = getattr(plugin, "dp_group", None) + dist.all_reduce(loss, group=group) + return loss / dist.get_world_size(group) + + +def apply_chat_template( + tokenizer: PreTrainedTokenizer, + chat: List[Dict[str, str]], +): + tokens = [] + assistant_mask = [] + for i, msg in enumerate(chat): + msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True) + # remove unexpected bos token + if i > 0 and msg_tokens[0] == tokenizer.bos_token_id: + msg_tokens = msg_tokens[1:] + tokens.extend(msg_tokens) + if msg["role"] == "assistant": + assistant_mask.extend([True] * len(msg_tokens)) + else: + assistant_mask.extend([False] * len(msg_tokens)) + return { + "tokens": tokens, + "assistant_mask": assistant_mask, + } + + +def pad( + tokenizer: PreTrainedTokenizer, + tokens: List[int], + assistant_mask: List[bool], + max_length: Optional[int] = None, + padding: bool = True, + truncation: bool = True, + ignore_idx: int = -100, +) -> Dict[str, torch.Tensor]: + attention_mask = [1] * len(tokens) + if max_length is not None: + if padding and len(tokens) < max_length: + to_pad = max_length - len(tokens) + if tokenizer.padding_side == "right": + tokens.extend([tokenizer.pad_token_id] * to_pad) + assistant_mask.extend([False] * to_pad) + attention_mask.extend([0] * to_pad) + else: + tokens = [tokenizer.pad_token_id] * to_pad + tokens + assistant_mask = [False] * to_pad + assistant_mask + attention_mask = [0] * to_pad + attention_mask + if truncation and len(tokens) > max_length: + tokens = tokens[:max_length] + assistant_mask = assistant_mask[:max_length] + attention_mask = attention_mask[:max_length] + input_ids = torch.tensor(tokens, dtype=torch.long) + attention_mask = torch.tensor(attention_mask, dtype=torch.long) + labels = input_ids.clone() + labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + } + + +def train(args) -> None: + # ============================== + # Initialize Distributed Training + # ============================== + d10d.default_pg_nccl_timeout = timedelta(hours=1) + colossalai.launch_from_torch() + accelerator = get_accelerator() + coordinator = DistCoordinator() + + # ============================== + # Initialize Booster + # ============================== + if args.plugin == "ddp": + plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False) + elif args.plugin == "gemini": + plugin = GeminiPlugin( + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_fused_normalization=get_accelerator().is_available(), + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "gemini_auto": + plugin = GeminiPlugin( + precision=args.mixed_precision, + placement_policy="auto", + initial_scale=2**16, + max_norm=args.grad_clip, + enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_fused_normalization=get_accelerator().is_available(), + enable_flash_attention=args.use_flash_attn, + ) + elif args.plugin == "zero2": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + max_norm=args.grad_clip, + ) + elif args.plugin == "zero2_cpu": + plugin = LowLevelZeroPlugin( + stage=2, + precision=args.mixed_precision, + initial_scale=2**16, + cpu_offload=True, + max_norm=args.grad_clip, + ) + elif args.plugin == "3d": + plugin = HybridParallelPlugin( + tp_size=args.tp, + pp_size=args.pp, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + zero_stage=args.zero_stage, + enable_flash_attention=args.use_flash_attn, + enable_fused_normalization=get_accelerator().is_available(), + enable_sequence_parallelism=args.enable_sequence_parallelism, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, + max_norm=args.grad_clip, + precision=args.mixed_precision, + microbatch_size=args.microbatch_size, + ) + elif args.plugin == "moe": + plugin = MoeHybridParallelPlugin( + ep_size=args.ep, + tp_size=args.tp, + pp_size=args.pp, + zero_stage=args.zero_stage, + sp_size=args.sp, + sequence_parallelism_mode=args.sp_mode, + enable_sequence_parallelism=args.sp > 1, + enable_fused_normalization=get_accelerator().is_available(), + enable_flash_attention=args.use_flash_attn, + max_norm=args.grad_clip, + precision=args.mixed_precision, + microbatch_size=args.microbatch_size, + ) + else: + raise ValueError(f"Unknown plugin {args.plugin}") + + booster = Booster(plugin=plugin) + + def is_master(): + if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1: + return coordinator.rank == coordinator.world_size - 1 + return coordinator.is_master() + + # ============================== + # Initialize Tensorboard and Save Config + # ============================== + if is_master(): + if args.tensorboard_dir is not None: + from torch.utils.tensorboard import SummaryWriter + + os.makedirs(args.tensorboard_dir, exist_ok=True) + writer = SummaryWriter(args.tensorboard_dir) + + with open(args.config_file, "w") as f: + json.dump(args.__dict__, f, indent=4) + + # ====================================================== + # Initialize Tokenizer, Dataset, Collator and Dataloader + # ====================================================== + tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True) + + coordinator.print_on_master( + f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}" + ) + + coordinator.print_on_master(f"Load dataset: {args.dataset} split: {args.dataset_split}") + dataset = load_dataset(args.dataset, split=args.dataset_split) + # dataset = dataset.map(lambda sample: apply_chat_template_and_mask(tokenizer, sample["messages"], args.max_length), remove_columns=['problem', 'solution', 'messages']) + dataset = dataset.map(lambda sample: apply_chat_template(tokenizer, sample["messages"]), remove_columns=['problem', 'solution', 'messages']) + dataset = dataset.filter(lambda sample: len(sample["tokens"]) <= args.max_length) + dataset = dataset.map(lambda sample: pad(tokenizer, **sample, max_length=args.max_length)) + dataset.set_format(type="torch", columns=["input_ids", "labels", "attention_mask"]) + + dataloader = plugin.prepare_dataloader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + ) + + coordinator.print_on_master( + f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + + # ====================================================== + # Initialize Model, Objective, Optimizer and LR Scheduler + # ====================================================== + # When training the ChatGLM model, LoRA and gradient checkpointing are incompatible. + init_ctx = ( + LazyInitContext(default_device=get_current_device()) + if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) + else nullcontext() + ) + attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2" + + config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True) + + with init_ctx: + # from_pretrained is not compatible with LoRA, we load pretrained weights later. + # model = AutoModelForCausalLM.from_pretrained( + # args.pretrained, + # torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + # trust_remote_code=True, + # attn_implementation=attn_impl, + # ) + model = AutoModelForCausalLM.from_config( + config, + trust_remote_code=True, + attn_implementation=attn_impl, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + ) + + if args.lora_rank > 0: + if model.__class__.__name__.startswith("DeepseekV3"): + lora_config = LoraConfig( + task_type="CAUSAL_LM", + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=["gate_proj", "up_proj", "down_proj"], + ) + else: + lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha) + model = booster.enable_lora(model, lora_config=lora_config) + + # this is essential, otherwise the grad checkpoint will not work. + model.train() + + if args.use_grad_checkpoint: + model.gradient_checkpointing_enable() + coordinator.print_on_master(msg="Gradient checkpointing enabled successfully") + if model.config.__class__.__name__.startswith("DeepseekV3"): + model.config.use_cache = False + model.eval() + # enable grad for moe layers + for m in model.modules(): + if m.__class__.__name__ == "DeepseekV3MoE": + m.moe_infer = MethodType(m.moe_infer.__wrapped__, m) + + model_numel = sum(p.numel() for p in model.parameters()) + coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B") + + optimizer = HybridAdam( + model_params=model.parameters(), + lr=args.lr, + betas=(0.9, 0.95), + weight_decay=args.weight_decay, + adamw_mode=True, + ) + + if args.warmup_steps is None: + args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps)) + coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}") + + lr_scheduler = CosineAnnealingWarmupLR( + optimizer=optimizer, + total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps), + warmup_steps=args.warmup_steps, + eta_min=0.1 * args.lr, + ) + + # Flash attention will be disabled because it does NOT support fp32. + default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16 + torch.set_default_dtype(default_dtype) + model, optimizer, _, dataloader, lr_scheduler = booster.boost( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + dataloader=dataloader, + ) + + torch.set_default_dtype(torch.float) + coordinator.print_on_master("Model loading started") + start_time = time.time() + booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8) + coordinator.print_on_master(f"Model loaded in {time.time() - start_time:.2f} seconds") + + coordinator.print_on_master( + f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" + ) + coordinator.print_on_master( + f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB" + ) + + start_epoch = 0 + start_step = 0 + + num_steps_per_epoch = len(dataloader) // args.accumulation_steps + + for epoch in range(start_epoch, args.num_epochs): + dataloader.sampler.set_epoch(epoch=epoch) + if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1: + data_iter = iter(dataloader) + step_bar = tqdm( + range(len(dataloader)), + desc="Step", + disable=not is_master(), + ) + for step in step_bar: + outputs = booster.execute_pipeline( + data_iter, + model, + criterion=lambda outputs, inputs: outputs[0], + optimizer=optimizer, + return_loss=True, + ) + loss = outputs["loss"] + if booster.plugin.stage_manager.is_last_stage(): + global_loss = all_reduce_mean(loss, plugin) + + optimizer.step() + + if booster.plugin.stage_manager.is_last_stage(): + grad_norm = optimizer.get_grad_norm() + step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm}) + + if args.tensorboard_dir is not None and is_master(): + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step) + writer.add_scalar( + tag="Learning Rate", + scalar_value=lr_scheduler.get_last_lr()[0], + global_step=global_step, + ) + writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step) + + lr_scheduler.step() + optimizer.zero_grad() + + else: + pbar = tqdm( + dataloader, + desc=f"Epoch {epoch}", + disable=not is_master(), + initial=start_step // args.accumulation_steps, + ) + total_loss = torch.tensor(0.0, device=get_current_device()) + for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps): + batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} + + batch_output = model(**batch) + + loss = batch_output.loss / args.accumulation_steps + total_loss.add_(loss.data) + + booster.backward(loss=loss, optimizer=optimizer) + + if (step + 1) % args.accumulation_steps == 0: + all_reduce_mean(total_loss, plugin) + + optimizer.step() + + grad_norm = optimizer.get_grad_norm() + pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm}) + if args.tensorboard_dir is not None and is_master(): + global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps + writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step) + writer.add_scalar( + tag="Learning Rate", + scalar_value=lr_scheduler.get_last_lr()[0], + global_step=global_step, + ) + writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step) + + lr_scheduler.step() + optimizer.zero_grad() + + total_loss.fill_(0.0) + + # Delete cache. + # del batch, batch_labels, batch_output, loss + accelerator.empty_cache() + + # Final save. + coordinator.print_on_master(f"Start saving final model checkpoint to {args.save_dir}") + start_time = time.time() + if args.lora_rank > 0: + booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora")) + else: + booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True) + coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir} in {time.time() - start_time:.2f} seconds") + + coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB") + + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Basic training information. + parser.add_argument( + "-m", + "--pretrained", + type=str, + required=True, + help="Path or name of the pre-trained model", + ) + parser.add_argument("-d", "--dataset", type=str, required=True, help="Dataset for training.") + parser.add_argument("--dataset_split", type=str, default="train", help="Dataset split to use.") + parser.add_argument( + "-p", + "--plugin", + type=str, + default="zero2", + choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"], + help="Choose which plugin to use", + ) + parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory") + parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory") + parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file") + # Training parameters + parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs") + parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps") + parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process") + parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate") + parser.add_argument("--max_length", type=int, default=8192, help="Model max length") + parser.add_argument( + "--mixed_precision", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Mixed precision", + ) + parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay") + parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps") + parser.add_argument( + "-g", + "--use_grad_checkpoint", + action="store_true", + default=False, + help="Use gradient checkpointing", + ) + parser.add_argument( + "-f", + "--use_flash_attn", + action="store_true", + default=False, + help="Use flash-attention", + ) + + # Additional arguments for 3d plugin. + parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.") + parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.") + parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.") + parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.") + parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2]) + parser.add_argument( + "--sp_mode", + type=str, + default="split_gather", + choices=["split_gather", "ring", "all_to_all"], + help="SP mode, used for 3d plugin.", + ) + parser.add_argument( + "--enable_sequence_parallelism", + default=False, + action="store_true", + help="Whether to enable SP, used for 3d plugin.", + ) + parser.add_argument( + "--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin." + ) + parser.add_argument( + "--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin." + ) + parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.") + parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.") + + args = parser.parse_args() + + if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1: + raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.") + + train(args) diff --git a/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.sbatch b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.sbatch new file mode 100644 index 000000000..c6f5e7eaa --- /dev/null +++ b/3.test_cases/pytorch/colossalai/deepseek-lora-finetune/lora_finetune.sbatch @@ -0,0 +1,72 @@ +#!/bin/bash + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +#SBATCH --nodes=15 +#SBATCH --job-name=lora_ft +##SBATCH --output=logs/%x_%j.out +##SBATCH --error=logs/%x_%j.err +#SBATCH --exclusive + +set -ex; + +########################### +###### User Variables ##### +########################### + +GPUS_PER_NODE=8 + +########################### +## Environment Variables ## +########################### + +## Plenty of EFA level variables +## For G4dn and other G5, comment out all +#export FI_LOG_LEVEL=warn +# export NCCL_DEBUG=INFO +export FI_PROVIDER=efa +export FI_EFA_USE_HUGE_PAGE=0 # Set to 0 when you see os.fork() causes OSError: Cannot allocate memory. Disabling huge page causes minor performance hit. +## Switching SYNC_MEMOPS to zero can boost throughput with FSDP +## Disables CU_POINTER_ATTRIBUTE_SYNC_MEMOPS +## Reduces memory synchronizations +## https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__UNIFIED.html +export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +MODEL_NAME=$1 +MODEL_PATH=/workdir/$MODEL_NAME-bf16 +DATASET=$2 +DATASET_SPLIT=$3 +LOG_DIR=/workdir/$MODEL_NAME-bf16-logs +LORA_DIR=/workdir/$MODEL_NAME-bf16-lora + +export PYTHONFAULTHANDLER=1 +export OMP_NUM_THREADS=8 + +srun -l \ + --mpi=pmix --cpu-bind=none \ + --container-image ../colossalai.sqsh \ + --container-mounts ./:/workdir,$HF_HOME:$HF_HOME \ + torchrun \ + --nproc_per_node=$GPUS_PER_NODE \ + --nnodes=$SLURM_JOB_NUM_NODES \ + --rdzv_id=$SLURM_JOB_ID \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$(hostname) \ + /workdir/lora_finetune.py \ + --pretrained $MODEL_PATH \ + --dataset $DATASET \ + --dataset_split $DATASET_SPLIT \ + --plugin moe \ + --lr 1e-4 \ + --max_length 2048 \ + -g \ + --ep 8 \ + --pp 3 \ + --batch_size 8 \ + --lora_rank 8 \ + --lora_alpha 16 \ + --num_epochs 1 \ + --warmup_steps 8 \ + --tensorboard_dir $LOG_DIR \ + --save_dir $LORA_DIR diff --git a/3.test_cases/pytorch/colossalai/gather_state_dict_fast.patch b/3.test_cases/pytorch/colossalai/gather_state_dict_fast.patch new file mode 100644 index 000000000..927293674 --- /dev/null +++ b/3.test_cases/pytorch/colossalai/gather_state_dict_fast.patch @@ -0,0 +1,30 @@ +diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py +index 4b36dbe0..4fae11c0 100644 +--- a/colossalai/checkpoint_io/utils.py ++++ b/colossalai/checkpoint_io/utils.py +@@ -1132,18 +1132,20 @@ def gather_state_dict_fast( + if rank == dst: + returned_state_dict = state_dict.copy() + dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group) ++ ks, ops = [], [] + for i, target_metadata in enumerate(all_meta_data): + if i == dst: + continue +- ops = [] + for k, shape, dtype in target_metadata: + buffer = torch.empty(shape, dtype=dtype, device=get_current_device()) + returned_state_dict[k] = buffer ++ ks.append(k) + ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group)) +- reqs = dist.batch_isend_irecv(ops) +- for req, (k, *_) in zip(reqs, target_metadata): +- req.wait() +- returned_state_dict[k] = returned_state_dict[k].to(device) ++ reqs = dist.batch_isend_irecv(ops) ++ for req in reqs: ++ req.wait() ++ for k in ks: ++ returned_state_dict[k] = returned_state_dict[k].to(device) + return returned_state_dict + else: + dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)