diff --git a/.github/workflows/cpu-torch-latest.yml b/.github/workflows/cpu-torch-latest.yml index 6496d7e35065..75507da4b347 100644 --- a/.github/workflows/cpu-torch-latest.yml +++ b/.github/workflows/cpu-torch-latest.yml @@ -42,7 +42,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 981c276 + # git checkout 981c276 git rev-parse --short HEAD pip install . @@ -59,5 +59,5 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.6" - HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.6" + HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.7" + HF_HOME=/tmp/hf_home/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.7" diff --git a/.github/workflows/hpu-gaudi2-nightly.yml b/.github/workflows/hpu-gaudi2-nightly.yml index c0576360cd61..3490567f95b8 100644 --- a/.github/workflows/hpu-gaudi2-nightly.yml +++ b/.github/workflows/hpu-gaudi2-nightly.yml @@ -21,7 +21,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest + image: vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice @@ -45,6 +45,8 @@ jobs: test_zero_leaf_module.py test_zero_offloadpp.py test_zero_tiled.py + test_autotp_training.py + test_ulysses.py # Steps represent a sequence of tasks that will be executed as part of the job steps: diff --git a/.github/workflows/hpu-gaudi2.yml b/.github/workflows/hpu-gaudi2.yml index 48730442686c..441b254b4762 100644 --- a/.github/workflows/hpu-gaudi2.yml +++ b/.github/workflows/hpu-gaudi2.yml @@ -39,7 +39,7 @@ jobs: # The type of runner that the job will run on runs-on: [self-hosted, intel, gaudi2] container: - image: vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest + image: vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest ports: - 80 options: --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice @@ -94,6 +94,8 @@ jobs: test_zero_nesting_init.py test_zeropp.py (test_zero.py and (TestZero3ParamPartitioningLargeParam or TestZero3ParamPartitioningLargeParam)) + (test_linear.py and (TestLoRALinear or TestBasicLinear)) + (test_ctx.py and TestEngine) # Steps represent a sequence of tasks that will be executed as part of the job steps: @@ -112,7 +114,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 981c276 + # git checkout 981c276 git rev-parse --short HEAD pip install . diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 77a2661d08a6..fa21785ae188 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -23,7 +23,7 @@ jobs: unit-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:24.09-py3 + image: nvcr.io/nvidia/pytorch:24.12-py3 ports: - 80 options: --gpus all --shm-size "8G" @@ -43,7 +43,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if you need to use an older transformers version temporarily in case of breakage - git checkout 981c276 + # git checkout 981c276 git rev-parse --short HEAD python -m pip install . - name: Install deepspeed @@ -58,8 +58,8 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.5" --cuda_ver="12" - python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.5" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2' unit/ --torch_ver="2.6" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'inference_v2_ops' unit/ --torch_ver="2.6" --cuda_ver="12" - name: MII unit tests run: | BRANCH="main" diff --git a/.github/workflows/nv-flash-attn.yml b/.github/workflows/nv-flash-attn.yml index 591969fbd986..8b3d46dfa4a8 100644 --- a/.github/workflows/nv-flash-attn.yml +++ b/.github/workflows/nv-flash-attn.yml @@ -18,7 +18,7 @@ jobs: unit-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:24.09-py3 + image: nvcr.io/nvidia/pytorch:24.12-py3 ports: - 80 options: --gpus all --shm-size "8G" @@ -53,7 +53,7 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - python -m pytest --color=yes --durations=0 --verbose -rF unit/sequence_parallelism/test_ulysses.py --torch_ver="2.5" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF unit/sequence_parallelism/test_ulysses.py --torch_ver="2.6" --cuda_ver="12" - name: Open GitHub issue if nightly CI fails if: ${{ failure() && (github.event_name == 'schedule') }} uses: JasonEtco/create-an-issue@v2 diff --git a/.github/workflows/nv-human-eval.yml b/.github/workflows/nv-human-eval.yml index 3f59c42f697e..56cbfa767126 100644 --- a/.github/workflows/nv-human-eval.yml +++ b/.github/workflows/nv-human-eval.yml @@ -11,7 +11,7 @@ jobs: unit-tests: runs-on: [self-hosted, nvidia, a6000] container: - image: nvcr.io/nvidia/pytorch:24.09-py3 + image: nvcr.io/nvidia/pytorch:24.12-py3 ports: - 80 options: --gpus all --shm-size "8G" @@ -50,4 +50,4 @@ jobs: run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch cd tests - python -m pytest --color=yes --durations=0 --verbose -rF -m 'evaluation' -k "test_human_eval" unit/ --torch_ver="2.5" --cuda_ver="12" + python -m pytest --color=yes --durations=0 --verbose -rF -m 'evaluation' -k "test_human_eval" unit/ --torch_ver="2.6" --cuda_ver="12" diff --git a/.github/workflows/nv-pre-compile-ops.yml b/.github/workflows/nv-pre-compile-ops.yml index fc810bc190d0..53e2aad85a6b 100644 --- a/.github/workflows/nv-pre-compile-ops.yml +++ b/.github/workflows/nv-pre-compile-ops.yml @@ -36,7 +36,7 @@ jobs: #python -c "import torch; print('CUDA available:', torch.cuda.is_available())" - name: Compile DeepSpeed Ops run: | - DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install . + DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_GDS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 DS_BUILD_DEEP_COMPILE=0 pip3 install . - name: DS Report run: | ds_report diff --git a/.github/workflows/nv-torch-latest-v100.yml b/.github/workflows/nv-torch-latest-v100.yml index eba35ba7210a..ab6da53acf23 100644 --- a/.github/workflows/nv-torch-latest-v100.yml +++ b/.github/workflows/nv-torch-latest-v100.yml @@ -44,7 +44,7 @@ jobs: - name: Install deepspeed run: | - pip install .[dev,1bit,autotuning] + pip install .[dev,1bit,autotuning,deepcompile] ds_report - name: Python environment diff --git a/.github/workflows/nv-torch-nightly-v100.yml b/.github/workflows/nv-torch-nightly-v100.yml index 0013ed3f276f..34ac3e5ba514 100644 --- a/.github/workflows/nv-torch-nightly-v100.yml +++ b/.github/workflows/nv-torch-nightly-v100.yml @@ -37,7 +37,7 @@ jobs: git clone https://github.com/huggingface/transformers cd transformers # if needed switch to the last known good SHA until transformers@master is fixed - git checkout 981c276 + # git checkout 981c276 git rev-parse --short HEAD pip install . diff --git a/.github/workflows/setup-venv/action.yml b/.github/workflows/setup-venv/action.yml index 9a88e0651860..af7913290b71 100644 --- a/.github/workflows/setup-venv/action.yml +++ b/.github/workflows/setup-venv/action.yml @@ -6,7 +6,9 @@ runs: - id: update-env run: | sudo apt-get update - sudo apt-get install -y libaio-dev + # Temporary disable nvme UTs + # sudo apt-get install -y libaio-dev + sudo apt remove -y libaio-dev python -m pip install --user --upgrade pip python -m pip install --user --upgrade virtualenv shell: bash diff --git a/.github/workflows/xpu-max1100.yml b/.github/workflows/xpu-max1100.yml index 2d84f8f60571..b78091cfaec1 100644 --- a/.github/workflows/xpu-max1100.yml +++ b/.github/workflows/xpu-max1100.yml @@ -36,7 +36,7 @@ jobs: unit-tests: runs-on: [self-hosted, intel, xpu] container: - image: intel/oneapi-basekit:2025.0.1-0-devel-ubuntu24.04 + image: intel/oneapi-basekit:2025.0.2-0-devel-ubuntu22.04 ports: - 80 options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL @@ -47,20 +47,16 @@ jobs: shell: bash run: | apt-get update - apt-get install clinfo libaio-dev python3-pip python3.12-venv -y - python3 -m venv ~/ds_env - source ~/ds_env/bin/activate - pip install torch==2.5.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torch/ - pip install intel-extension-for-pytorch==2.5.10+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/intel-extension-for-pytorch/ - pip install oneccl_bind_pt==2.5.0+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/oneccl-bind-pt/ - pip install torchvision==0.20.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/torchvision/ - pip install py-cpuinfo numpy + apt-get install -y python3.11 python3.11-dev python3-pip clinfo libaio-dev + pip install --upgrade pip + pip install py-cpuinfo + pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/xpu + pip install intel-extension-for-pytorch==2.7.10+xpu oneccl_bind_pt==2.7.0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us pip install .[dev,autotuning] - name: Check container state shell: bash run: | - source ~/ds_env/bin/activate ldd --version ds_report python3 -c "import torch; print('torch:', torch.__version__, torch)" @@ -71,8 +67,9 @@ jobs: - name: Unit tests shell: bash run: | - source ~/ds_env/bin/activate cd tests/unit + export FI_PROVIDER="tcp" + export I_MPI_SHM=off pytest --verbose accelerator/* pytest --verbose autotuning/* pytest --verbose checkpoint/test_reshape_checkpoint.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bfc22afb5359..b03a498144a4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,6 +19,12 @@ If a formatting test fails, it will fix the modified code in place and abort the `git commit`. After looking over the changes, you can `git add ` and then repeat the previous `git commit` command. +You can also run: +``` +make format +``` +which will do the same as above, and it'll also automatically build a `venv` python environment if you +don't already have one, which will isolate the requirements of this project from requirements of other projects. ## Testing DeepSpeed tracks two types of tests: unit tests and more costly model convergence tests. @@ -38,6 +44,11 @@ You can also provide the `-v` flag to `pytest` to see additional information abo tests. Note that [pytest-forked](https://github.com/pytest-dev/pytest-forked) and the `--forked` flag are required to test CUDA functionality in distributed tests. +You can also run: +``` +make test +``` + ### Model Tests To execute model tests, first [install DeepSpeed](#installation). The [DeepSpeedExamples](https://github.com/deepspeedai/DeepSpeedExamples/) repository is cloned @@ -48,16 +59,15 @@ pytest run_sanity_check.py ``` Note that the `--forked` flag is not necessary for the model tests. -## Contributor License Agreement -This project welcomes contributions and suggestions. Most contributions require you to -agree to a Contributor License Agreement (CLA) declaring that you have the right to, and -actually do, grant us the rights to use your contribution. For details, visit -https://cla.opensource.microsoft.com. +## Developer Certificate of Origin +This project welcomes contributions and suggestions. All contributions to deepspeedai projects +require commits to be signed off with a [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) +(DCO) declaring that you have the right to, and actually do, grant us the rights to use your contribution. + +When you submit a pull request, the DCO app will check for the presence of signed commits. +Information about how this check works is here: https://github.com/dcoapp/app?tab=readme-ov-file#how-it-works -When you submit a pull request, a CLA bot will automatically determine whether you need -to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply -follow the instructions provided by the bot. You will only need to do this once across -all repos using our CLA. +To sign commits, you will need to include `-s` when running `git commit`. For example, `git commit -s -m "Commit message"`. One note, creating PRs via the GitHub interface do not appear to include this option. If you forget this, clicking on the failing check in your PR will point you to commands you can run to rebase and sign previous commits. ## Code of Conduct This project has adopted the [Microsoft Open Source Code of diff --git a/Makefile b/Makefile new file mode 100644 index 000000000000..8756897ebedf --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +# usage: make help + +.PHONY: help test format +.DEFAULT_GOAL := help + +help: ## this help + @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[0-9a-zA-Z_-]+:.*?##/ { printf " \033[36m%-22s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) + echo $(MAKEFILE_LIST) + +test: ## run tests + pytest --forked tests/unit/ + +format: ## fix formatting + @if [ ! -d "venv" ]; then \ + python -m venv venv; \ + . venv/bin/activate; \ + pip install pre-commit -U; \ + pre-commit clean; \ + pre-commit uninstall; \ + pre-commit install; \ + deactivate; \ + fi + . venv/bin/activate && pre-commit run --files $$(git diff --name-only master) && deactivate diff --git a/README.md b/README.md index 6922e55b4144..643d22385d13 100755 --- a/README.md +++ b/README.md @@ -15,32 +15,22 @@ ## Latest News DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat). - +* [2025/04] [DeepCompile: Unlocking Compiler Optimization for Distributed Training](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepcompile/README.md) +* [2025/03] [DeepSpeed-AutoTP: Automatic Tensor Parallel Training of Hugging Face models](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/huggingface-tp/README.md) * [2024/12] [Ulysses-Offload: Democratizing Long Context LLM Training ](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/ulysses-offload/README.md) * [2024/12] [DeepSpeed-Domino: Communication-Free LLM Training Engine](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-domino/README.md) * [2024/08] [DeepSpeed on Windows](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/windows/08-2024/README.md) [[日本語](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/windows/08-2024/japanese/README.md)] [[中文](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/windows/08-2024/chinese/README.md)] -* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-gds/README.md) [[日本語](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-gds/japanese/README.md)] -* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md) [[中文](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ucp/chinese/README.md)] [[日本語](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)] -* [2024/03] [DeepSpeed-FP6:The power of FP6-Centric Serving for Large Language Models](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024) [[English](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)] -* [2024/01] [DeepSpeed-FastGen: Introducing Mixtral, Phi-2, and Falcon support with major performance and feature enhancements.](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen/2024-01-19) -* [2023/11] [Llama 2 Inference on 4th Gen Intel® Xeon® Scalable Processor with DeepSpeed](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/intel-inference) [[Intel version]](https://www.intel.com/content/www/us/en/developer/articles/technical/xllama-2-on-xeon-scalable-processor-with-deepspeed.html) -* [2023/11] [DeepSpeed ZeRO-Offload++: 6x Higher Training Throughput via Collaborative CPU/GPU Twin-Flow](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-offloadpp) -* [2023/11] [DeepSpeed-FastGen: High-throughput Text Generation for LLMs via MII and DeepSpeed-Inference](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen) [[English](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen)] [[中文](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen/chinese/README.md)] [[日本語](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fastgen/japanese/README.md)] -* [2023/10] [DeepSpeed-VisualChat: Improve Your Chat Experience with Multi-Round Multi-Image Inputs](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md) [[English](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-visualchat/10-03-2023/README.md)] [[中文](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Chinese.md)] [[日本語](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-visualchat/10-03-2023/README-Japanese.md)] -* [2023/09] Announcing the DeepSpeed4Science Initiative: Enabling large-scale scientific discovery through sophisticated AI system technologies [[Tutorials](https://www.deepspeed.ai/deepspeed4science/)] [[White paper](https://arxiv.org/abs/2310.04610)] [[Blog](https://www.microsoft.com/en-us/research/blog/announcing-the-deepspeed4science-initiative-enabling-large-scale-scientific-discovery-through-sophisticated-ai-system-technologies/)] [[中文](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed4science/chinese/README.md)] [[日本語](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed4science/japanese/README.md)] -
More news
diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index de711f73144e..4b3d89e6cd34 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -229,10 +229,17 @@ def is_bf16_supported(self): return True def is_fp16_supported(self): - return False + try: + if torch.ops.mkldnn._is_mkldnn_fp16_supported(): + return True + except: + return False def supported_dtypes(self): - return [torch.float, torch.bfloat16] + supported_dtypes = [torch.float, torch.bfloat16] + if self.is_fp16_supported(): + supported_dtypes.append(torch.float16) + return supported_dtypes # Graph operations def create_graph(self): diff --git a/accelerator/hpu_accelerator.py b/accelerator/hpu_accelerator.py index b46351f8ca43..cae1fa0383cb 100644 --- a/accelerator/hpu_accelerator.py +++ b/accelerator/hpu_accelerator.py @@ -23,6 +23,8 @@ def __init__(self): import habana_frameworks.torch.hpu as hpu self.hpu = hpu torch.use_deterministic_algorithms(True) + # TODO: remove this WA when memory mapping break is resolved. + torch.utils.deterministic.fill_uninitialized_memory = False except ImportError as e: raise ValueError( f"HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.") diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py index 9c1713ca4c6c..7b271739a5a3 100644 --- a/accelerator/real_accelerator.py +++ b/accelerator/real_accelerator.py @@ -131,11 +131,17 @@ def get_accelerator(): if accelerator_name is None: try: import intel_extension_for_pytorch as ipex - if ipex._C._has_xpu(): accelerator_name = "xpu" except ImportError as e: - pass + import torch + if hasattr(torch, 'xpu'): + if torch.xpu.is_available(): + accelerator_name = "xpu" + else: + pass + else: + pass if accelerator_name is None: try: import torch_npu # noqa: F401,F811 # type: ignore diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index ad8a10710bf2..fcce2502c428 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -5,19 +5,33 @@ import torch from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator -import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore -import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore import functools - import importlib import inspect +try: + import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore + oneccl_imported_p = True +except ImportError as e: + oneccl_imported_p = False + +try: + import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore + ipex_imported_p = True +except ImportError as e: + ipex_imported_p = False + class XPU_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'xpu' self._communication_backend_name = 'ccl' + if oneccl_imported_p: + self._communication_backend_name = 'ccl' + else: + # changed to xccl if not using torch-CCL on XPU device + self._communication_backend_name = 'xccl' self._compile_backend = "inductor" self.aligned_tensors = [] self.class_dict = None @@ -26,11 +40,14 @@ def is_synchronized_device(self): return False def use_host_timers(self): - # WA XPU event will be consolidated in 2.6 - if ipex.__version__ < '2.6': - return True - else: + if not ipex_imported_p: return self.is_synchronized_device() + else: + # WA XPU event will be consolidated in 2.6 + if ipex.__version__ < '2.6': + return True + else: + return self.is_synchronized_device() def resolves_data_dependency(self): return self.is_synchronized_device() @@ -290,10 +307,13 @@ def get_op_builder(self, class_name): return self.class_dict['NotImplementedBuilder'] def build_extension(self): - try: - from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension - except ImportError: - from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension + if ipex_imported_p: + try: + from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension + except ImportError: + from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension + else: + from torch.utils.cpp_extension import DpcppBuildExtension return DpcppBuildExtension def export_envs(self): diff --git a/blogs/deepcompile/README.md b/blogs/deepcompile/README.md new file mode 100644 index 000000000000..7fca2cba1fa1 --- /dev/null +++ b/blogs/deepcompile/README.md @@ -0,0 +1,174 @@ +
+ +# DeepCompile: Unlocking Compiler Optimization for Distributed Training + +
+ +# Introduction + +
+ + + +
+ +Distributed training has become essential for scaling today’s massive deep learning models. While deep learning compilers like PyTorch compiler dramatically improved single-GPU training performance through optimizations like kernel fusion and operator scheduling, they fall short when it comes to distributed workloads. +Existing distributed training frameworks such as DeepSpeed and FSDP have made large-scale model training feasible through advanced parallelization strategies. While powerful, their optimizations are implemented at the PyTorch framework level, which limits the ability to apply compiler-style techniques like dependency analysis or operator scheduling. + +DeepCompile addresses this gap by enabling compiler-level optimizations for distributed training. It takes a standard single-GPU model implementation and transforms it into an optimized multi-GPU training graph without requiring changes to the model code. Unlike existing approaches, DeepCompile automatically applies parameter sharding, communication scheduling, and memory-aware execution at the compiler IR level, enabling global analysis and optimization that are difficult to express in traditional frameworks. Furthermore, during training, DeepCompile employs profile-guided optimization techniques to dynamically tune these parallelization strategies and improve training performance. + +Our evaluation demonstrates that DeepCompile improves training performance over ZeRO-3 baselines, achieving up to 1.5x speedup when sufficient GPU resources are available, and up to 7x speedup in GPU-constrained settings that require offloading. DeepCompile is available in DeepSpeed versions >= [0.16.6](https://github.com/deepspeedai/DeepSpeed/releases/tag/v0.16.6). As it is under active development, we recommend using the latest version of DeepSpeed or installing from source to access the most recent updates and bug fixes. + +# Design Overview + +DeepCompile extends the capabilities of deep learning compilers to support distributed training. It starts from a standard single-GPU model implementation, such as those available on the Hugging Face model hub, and automatically transforms it by inserting necessary distributed training operations such as parameter sharding and communication primitives. Users are not required to embed any distributed logic into the model code. + +The process begins by compiling the model into an intermediate representation (IR), which forms a computation graph. DeepCompile then applies a sequence of *optimization passes*, each responsible for a specific transformation of the computation graph or a targeted performance improvement, to incrementally introduce distributed behavior and optimize the graph. These include operations such as all-gather for sharded parameters or offloading of optimizer states, all while preserving the original computation semantics (Fig. 1). + +
+ + + +*Figure 1: Workflow of compilation and optimization with DeepCompile.* + +
+ +At its core, DeepCompile builds on two key capabilities: + +- **Automatic parallelization**: DeepCompile allows optimization passes to rewrite the single-GPU computation graph into a distributed multi-GPU version, incorporating strategies such as ZeRO, FSDP, and more. This eliminates the need for manual implementation of distributed training logic, drastically reducing engineering effort. +- **Profile-guided performance tuning**: At runtime, DeepCompile collects profiling data such as operator-level memory usage and execution latency. It uses this information to dynamically schedule computation and communication operators. This enables effects such as an improved overlap between communication and computation, and an avoidance of memory bottlenecks. Fine-grained tuning through these optimization passes often leads to better performance than even manually engineered implementations. + +Figure 2 illustrates the optimization cycle employed by DeepCompile. After the initial computation graph is generated by the compiler, DeepCompile profiles its behavior by measuring operator execution time, communication overhead, and memory usage throughout the forward and backward passes. + +
+ + + +*Figure 2. Optimization cycle.* + +
+ +Based on the collected profiling data, DeepCompile applies a sequence of optimization passes. These passes modify the computation graph by inserting, removing, or reordering operators to improve overall efficiency. The modified graph is then re-profiled, and this cycle of profiling and optimization is repeated. + +Once a stable set of optimizations has been applied, the graph is deployed for the remaining training iterations. During execution, memory usage and other runtime characteristics may change. In such cases, DeepCompile can resume the profiling and optimization cycle according to the predefined schedule of passes, allowing the graph to adapt and maintain high performance. + +# Optimizations + +DeepCompile is designed as a general compiler framework for applying and optimizing a wide range of parallelization strategies. In the following, we describe several optimizations that have been implemented as optimization passes within DeepCompile. + +## ZeRO3 + +As an initial step, we have used DeepCompile to implement and enhance ZeRO-3-style optimizations at the compiler level. ZeRO-3 partitions model parameters, gradients, and optimizer states across devices, reducing memory usage and enabling large-scale training. + +In conventional ZeRO-3 implementations, operations such as all-gather, reduce-scatter, and buffer release are typically inserted using Python hooks at runtime. DeepCompile replaces this approach by injecting these operations directly into the computation graph during compilation. This allows the compiler to determine their placement precisely, guided by both the static structure of the graph and runtime profiling information. + +One of the key optimizations is **proactive prefetching**, which launches all-gather operations earlier in the computation based on memory usage profiling. This reordering increases the overlap between communication and computation thereby improving throughput, while avoiding OOMs. In addition, small communication operations are often fused to reduce launch latency and improve efficiency. + +Another optimization is **selective unsharding**, which keeps certain parameters in an unsharded form during the forward and backward passes when memory conditions permit. This reduces the frequency of all-gather operations and avoids redundant communication, particularly in scenarios where gradient accumulation is enabled. + +## Offloading + +DeepCompile also supports **adaptive offloading**, which offloads optimizer states to reduce GPU memory pressure. Unlike approaches that offload all the optimizer states, adaptive offloading identifies only the portions that exceed the memory limit—such as momentum and variance used by the Adam optimizer—and schedules data transfers to overlap with computation. This selective and asynchronous strategy minimizes overhead and enables efficient training even in memory-constrained environments. + +## ZeRO1 + +ZeRO-1 differs from ZeRO-3 in that it shards only the optimizer states across devices, while keeping parameters and gradients fully replicated. This approach reduces memory usage with minimal changes to computation flow, making it a lightweight alternative for certain training scenarios. +DeepCompile implements ZeRO-1-style optimization by inserting reduce-scatter operations directly into the computation graph. By avoiding Python-level hooks, this graph-level integration reduces overhead and improves execution efficiency. + +# Performance Improvements + +## ZeRO-3 + +We evaluated DeepCompile on Llama-3-70B and Mixtral 8x7B using parameter sharding on top of Hugging Face model implementations. +Figure 3 shows training throughput (TFLOPs/GPU) across different gradient accumulation steps, using 32 H100 GPUs with a sequence length of 1024. +We compare DeepCompile against two DeepSpeed ZeRO-3 baselines: (i) an eager-mode version without compiler support (labelled ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (labelled ZeRO3+Compile). For DeepCompile, we enabled both proactive prefetching and selective unsharding to demonstrate the combined effect of these optimization passes. + +
+ +*Figure 3. Achieved throughputs for ZeRO3 training of Llama-3 70B and Mixtral 8x7B models.* + +
+Across both models, DeepCompile consistently delivers higher throughput. The benefit becomes more pronounced at higher accumulation steps, where the reduced frequency of parameter updates makes selective unsharding more effective. DeepCompile with proactive prefetching and selective unsharding achieves up to 1.28× speedup over ZeRO-3 on Llama-3-70B and 1.54× on Mixtral 8x7B. + +Meanwhile, enabling the PyTorch compiler with ZeRO-3, i.e., ZeRO3+Compile introduces minor overheads in some settings. This is because ZeRO-3 includes many conditional branches for runtime features such as prefetching. When the compiler encounters branches that cannot be statically resolved, it splits the computation into multiple graph segments. These fragmented segments can reduce optimization opportunities and introduce additional overheads during execution. + +## Offloading + +Training models as large as Llama-3 70B with ZeRO-3 typically requires 32 GPUs with 80GB of memory. +DeepSpeed addresses this challenge by offering offloading capabilities, which transfer optimizer states and optionally model parameters to CPU memory to reduce GPU memory usage. DeepCompile also supports offloading through a dedicated optimization pass, but with a few key differences in design. + +Unlike the traditional approach of offloading both optimizer computation and memory, DeepCompile offloads only optimizer memory (e.g., momentum, variance, and master weights of Adam optimizer) while the optimizer computation remains on GPU. DeepCompile profiles memory usage during both forward and backward passes to identify when offloading is necessary, and transfers only the required data. This fine-grained approach avoids unnecessary overhead and helps maintain high computational throughput. +Furthermore, DeepCompile overlaps data transfers with computation whenever possible, dynamically adjusting the timing based on observed memory usage patterns. This asynchronous behavior is a crucial aspect of DeepCompile’s offloading strategy, allowing it to reduce GPU memory pressure without stalling execution. + +We evaluated DeepCompile's offloading using Llama-3 70B on 16xH100-80GB (half the required GPU counts) and present the results in Figure 4. + +
+ + + +*Figure 4. Achieved throughput of optimizer offloading for Llama-3 70B on 16x80GB GPUs* + +
+ +We compare against two ZeRO-3 offloading baselines: (i) an eager-mode version without compiler support (ZeRO3+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO3+Compile). As shown by the results, DeepCompile significantly improves offloading efficiency and provides up to 7× speedup over ZeRO3+Eager. In contrast, we see that ZeRO3+Compile achieves similar performance as ZeRO3+Eager. + + +## ZeRO-1 + +We also evaluated DeepCompile with ZeRO-1 using the Llama-3-8B model. We compare DeepCompile against two ZeRO-1 baselines: (i) an eager-mode version without compiler support (ZeRO1+Eager), and (ii) a compiled version using PyTorch compiler (ZeRO1+Compile). In our experiment with 8 GPUs and a batch size of 2, DeepCompile achieved consistent throughput improvements across different sequence lengths, as shown in Figure 5. + +
+ + + +*Figure 5. Achieved throughput of ZeRO-1 training of Llama-3 8B* + +
+ +The most significant speedup was observed with batch size 1 and sequence length 512, where DeepCompile outperformed ZeRO1+Eager by up to 1.9×, and ZeRO1+Compile by up to 2.5×. + +While compiler-based approaches can be effective for large batch sizes and long sequences by replacing suboptimal operations with more efficient kernels, they may also introduce overheads in ZeRO-1-style training in the form of *graph breaks* around the communication operations. These overheads become more pronounced with smaller batch sizes and sequence lengths, thus hurting performance compared to the non-compiled execution. In contrast, DeepCompile inserts communication operators directly into the computation graph during compilation, avoiding graph fragmentation and minimizing associated overhead. This makes DeepCompile more robust to small-scale workloads, while still benefiting from compiler-level optimizations. + +## Additional Results and Analysis + +Please refer to our [arXiv paper](https://arxiv.org/abs/2504.09983) for additional results, such as detailed comparisons across different batch sizes, sequence lengths, and memory usage. + +# Looking Ahead + +DeepCompile brings the power of compiler-based optimizations to distributed deep learning. By transforming computation graphs and applying profile-guided optimization passes, it enables more efficient training without requiring changes to model code. + +This release is just the beginning. We’re actively working on expanding the set of optimization passes and improving integration with a broader range of distributed training strategies. Future directions include automated parallelization (sequence/tensor parallelisms), smarter memory management, and dynamic adaptation to runtime behavior. + +We invite the community to try DeepCompile, explore its capabilities, and contribute to its evolution. Let’s build the next generation of scalable deep learning together. + +# Acknowledgments + +We would like to thank everyone who supported this project. + +This project would not have been possible without the PyTorch Compiler—a platform that is not only powerful and flexible, but also a pleasure to work with. We are especially grateful to the developers and researchers behind PyTorch Compiler for making such an excellent foundation available to the community. + +# Contributors + +This project is the result of a close collaboration between Microsoft and the University of Virginia. The contributors are: Masahiro Tanaka, Du Li, and Umesh Chand, Olatunji Ruwase (Microsoft); and Ali Zafar and Haiying Shen (University of Virginia). + +# Appendix + +## Examples and Benchmarks + +Our DeepSpeedExamples repository provides [example code](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/benchmarks/deepcompile) to enable DeepCompile. + +## Optimization Passes + +The following optimization passes are currently available in DeepCompile: + +- All-gather & reduce-scatter insertion (ZeRO3) +- Proactive prefetching (ZeRO3) +- Selective unsharding (ZeRO3) +- Reduce-scatter insertion (ZeRO1) +- Adaptive offloading + +We used the following combinations of passes in the experiments presented above: + +- Improved communication scheduling for ZeRO-3: All-gather & reduce-scatter → Proactive prefetching → Selective unsharding +- Offloading optimizer states for ZeRO3: Adding all-gather & reduce-scatter → Adaptive offloading +- Reduced overhead and improved overlap for ZeRO-1: Adding reduce-scatter diff --git a/blogs/deepcompile/media/opt_loop.png b/blogs/deepcompile/media/opt_loop.png new file mode 100644 index 000000000000..a3a4ca33a684 Binary files /dev/null and b/blogs/deepcompile/media/opt_loop.png differ diff --git a/blogs/deepcompile/media/perf_offload.png b/blogs/deepcompile/media/perf_offload.png new file mode 100644 index 000000000000..1506f20bc133 Binary files /dev/null and b/blogs/deepcompile/media/perf_offload.png differ diff --git a/blogs/deepcompile/media/perf_summary.png b/blogs/deepcompile/media/perf_summary.png new file mode 100644 index 000000000000..798ff54acb7d Binary files /dev/null and b/blogs/deepcompile/media/perf_summary.png differ diff --git a/blogs/deepcompile/media/perf_zero1.png b/blogs/deepcompile/media/perf_zero1.png new file mode 100644 index 000000000000..a7256919f9a5 Binary files /dev/null and b/blogs/deepcompile/media/perf_zero1.png differ diff --git a/blogs/deepcompile/media/perf_zero3.png b/blogs/deepcompile/media/perf_zero3.png new file mode 100644 index 000000000000..a93e929312a3 Binary files /dev/null and b/blogs/deepcompile/media/perf_zero3.png differ diff --git a/blogs/deepcompile/media/workflow.png b/blogs/deepcompile/media/workflow.png new file mode 100644 index 000000000000..72a358408099 Binary files /dev/null and b/blogs/deepcompile/media/workflow.png differ diff --git a/blogs/deepspeed-gds/README.md b/blogs/deepspeed-gds/README.md index 536b6f984af0..29bfdd842ee5 100644 --- a/blogs/deepspeed-gds/README.md +++ b/blogs/deepspeed-gds/README.md @@ -47,7 +47,7 @@ We used three benchmarking tools for our evaluations. The first is fio, the popu ## High-Performance I/O with CPU Buffers via NVMe Scaling -Our first set of microbenchmark evaluations used fio and ds\_io to measure the performance of transferring 1GB data between NVMe and CPU memory. We configure fio to use the libaio backend for these experiments1. The results are summarized in Figure 1, from which we make two observations. First, DeepNVMe demonstrates high performance as it roughly matches fio, despite being more representative of DL applications. Second, DeepNVMe scales I/O performance almost linearly with available NVMe bandwidth, achieving rates of 10GB/sec reads and 5GB/sec writes. +Our first set of microbenchmark evaluations used fio and ds\_io to measure the performance of transferring 1GB data between NVMe and CPU memory. We configure fio to use the libaio backend for these experiments. The results are summarized in Figure 1, from which we make two observations. First, DeepNVMe demonstrates high performance as it roughly matches fio, despite being more representative of DL applications. Second, DeepNVMe scales I/O performance almost linearly with available NVMe bandwidth, achieving rates of 10GB/sec reads and 5GB/sec writes. @@ -85,4 +85,4 @@ In this blog post, we introduced DeepNVMe, an I/O optimization technology create # Acknowlegements -This work is the result of a deep collaboration between Microsoft and NVIDIA. The contributors include Joe Mayer, Martin Cai, and Olatunji Ruwase from Microsoft; Kiran Modukuri, Vahid Noormofidi, Sourab Gupta, and Sandeep Joshi from Nivida. +This work is the result of a deep collaboration between Microsoft and NVIDIA. The contributors include Joe Mayer, Martin Cai, and Olatunji Ruwase from Microsoft; Kiran Modukuri, Vahid Noormofidi, Sourab Gupta, and Sandeep Joshi from Nvidia. diff --git a/blogs/huggingface-tp/README.md b/blogs/huggingface-tp/README.md new file mode 100644 index 000000000000..44469f1818c3 --- /dev/null +++ b/blogs/huggingface-tp/README.md @@ -0,0 +1,242 @@ +
+ +# Automatic Tensor Parallel (AutoTP) Training of Hugging Face models + +
+ + +# Introduction + +Tensor parallelism (TP) is an important memory optimization for training large-scale deep learning models. Despite the popularity of training Hugging Face (HF) [models](https://huggingface.co/models), the model scaling options for **[HF trainer](https://huggingface.co/docs/transformers/main_classes/trainer)** was previously limited to sharded data parallelism through [ZeRO](https://huggingface.co/docs/accelerate/usage_guides/deepspeed)/[FSDP](https://huggingface.co/docs/accelerate/usage_guides/fsdp). While ZeRO3 offers superior memory efficiency, it incurs significant communication costs. ZeRO (1/2) has lower communication overhead, but in the case of very large models, it cannot be used directly due to memory limitations. Therefore, combining TP with ZeRO (1/2) offers more balanced options for memory and performance. Moreover, through TP, we can alleviate the batch scaling limitations imposed by ZeRO/FSDP. + +We are pleased to announce that DeepSpeed now provides native automatic tensor parallel training for Hugging Face (HF) transformers. This new feature builds on DeepSpeed's [AutoTP](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) mechanism, which was previously restricted to inference. AutoTP training can be combined with ZeRO to unlock unprecented efficiency benefits for HF model post-training, including: + +**1**. Model scaling with lower communication costs than FSDP/ZeRO3 (e.g., use AutoTP + ZeRO1 to achieve ZeRO3 memory savings). + +**2**. Batch size scaling for faster training and increased throughput. + +**3**. Context length scaling to enable new application scenarios. + +We have integrated AutoTP training with ZeRO1 & ZeRO2, with ZeRO3 integration on the way. AutoTP training is available in DeepSpeed versions >= 0.16.4 + +# Batch Scaling with AutoTP Training + ZeRO +The following is a batch scaling experiment of Llama3 8B training conducted on [Gaudi2 Accelerator](https://www.intel.com/content/www/us/en/products/details/processors/ai-accelerators/gaudi.html). + + +
+ + + + +*Figure 1. Batch scaling experiment on Gaudi2, showing throughput performance improvements from 2 to 4 cards by combining AutoTP and ZeRO. The used mbs is the max possible value with the given config. A higher speedup indicates better performance.* + +
+ + + +
+ + + + +*Figure 2. Model training with AutoTP + ZeRO* + +
+ + +Figure 2 illustrates the basic flowchart, The division of TP and ZeRO is implemented through the AutoTP parser and ZeRO Wrapper in [Accelerate](https://github.com/huggingface/accelerate.git). Besides, The TP-based dataloader and save mechanism are both supported in DeepSpeed and Accelerate. + +# Usage + + + +Although we evaluated AutoTP training with Llama2 & Llama3 models in this blog, we expect compatibility with other Hugging Face models, especially [those](https://www.deepspeed.ai/tutorials/automatic-tensor-parallelism/) previously validated with AutoTP inference. + +**Requirements** +- `deepspeed >= 0.16.4` +- `transformers >= 4.50.1` +- `accelerate >= 1.6.0` + + **Enable TP training** + +Similar to ZeRO, AutoTP training is enabled using the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/) by specifying ```[tensor_parallel][autotp_size]```. +``` + "ZeRO_optimization": { + "stage": 1, + "gather_16bit_weights_on_model_save": true, + ... + }, + "tensor_parallel":{ + "autotp_size": 4 + }, +``` + +The parallel configuration follows this logic: + + +``` +tp_size = auto_tp_size +dp_size = num_gpus / tp_size +``` + +Note that the global_batch_size (gbs) changes with different TP settings: +``` +gbs (only dp) = per_device_batch_size * n_gpus * gradient_accumulation_steps + +gbs (dp with tp) = per_device_batch_size * n_gpus / tp_size * gradient_accumulation_steps +``` + + + + + + + + **Save Model** + + + + +Saving checkpoints and model files is fully compatible with HF transformers. The [trainer.save_model()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_model) method saves the original model. Ensure ```gather_16bit_weights_on_model_save``` is set to ```true```in the [deepspeed configuration file](https://www.deepspeed.ai/docs/config-json/). +```gather_16bit_weights_on_model_save=true in config. + "ZeRO_optimization": { + ... + "gather_16bit_weights_on_model_save": true, + }, +``` + +``` +trainer.save_model(your_saved_path) +``` +Models saved this way can be directly used for HF format inference without intermediate transformations. + + + + **Saving Checkpoints and Resuming** + + + +Saving Checkpoints remains compatible with HF transformers. Use [trainer.save_state()](https://huggingface.co/docs/transformers/v4.49.0/en/main_classes/trainer#transformers.Trainer.save_state) or set the save interval for automatic saving, which can be used to resume training. +``` +trainer.train(resume_from_checkpoint="your_saved_path/checkpoint-1200") +``` + +# Example +We validated AutoTP training using supervised finetune training (SFT) task: [stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca). The original benchmark model used in this project is Llama2-7B. The example code is also available [here](https://github.com/deepspeedai/DeepSpeedExamples/tree/master/training/tensor_parallel) + + +**Training Loss curve** + + + +The following loss curves depict SFT training, where gbs is uniformly set to 32, and other configurations match the default experiment settings from ([stanford_alpaca](https://github.com/tatsu-lab/stanford_alpaca)). The loss curves are largely consistent across the following setups: + + - ZeRO3 + - TP + disable ZeRO + - ZeRO1 and ZeRO1 + AutoTP + - ZeRO2 and ZeRO2 + AutoTP + + + + + +
+ + + + +*Figure 3. Loss curve of ZeRO3 stage training (gbs=32, dp=8)* + +
+
+ + + +*Figure 4. Loss curve of AutoTP training (gbs=32, tp=8)* +
+ +
+ + + +*Figure 5. Loss curve of AutoTP + ZeRO1 training (gbs=32, dp=2, tp=4)* +
+ + +
+ + + +*Figure 6. Loss curve of AutoTP + ZeRO2 training (gbs=32, dp=2, tp=4)* + + +
+ + + **Resuming Training** + + + We tested recovery training curves from step 1200 in AutoTP + ZeRO1 and AutoTP + ZeRO2, which align with the original training curves. + +
+ + + +*Figure 7. AutoTP + ZeRO1 resuming training* + + + +*Figure 8. AutoTP + ZeRO2 resuming training* + +
+ + + + **Model Evaluation** + + + We conducted inference evaluations for the [MMLU task](https://github.com/EleutherAI/lm-evaluation-harness). + In MMLU, the scores for AutoTP + ZeRO1 and ZeRO1, as well as AutoTP + ZeRO2 and ZeRO2, are consistent, showing a fixed improvement over the pre-training model before SFT. + + +
+ + +| Groups | Version | Filter | n-shot | Metric | Model before SFT | ZeRO1 DP8 training | ZeRO1 TP4 DP2 training | ZeRO2 DP8 training | ZeRO2 TP4DP2 training | +|--------|---------|--------|--------|--------|-----------------------|--------------------|------------------------|--------------------|------------------------| +| mmlu | 2 | none | | acc | 0.4185 ± 0.0041 | 0.4472 ± 0.0041 | 0.4444 ± 0.0041 | 0.4543 ± 0.0041 | 0.4529 ± 0.0041 | +| - humanities | 2 | none | | acc | 0.3979 ± 0.0069 | 0.4185 ± 0.0070 | 0.4145 ± 0.0069 | 0.4274 ± 0.0070 | 0.4272 ± 0.0070 | +| - other | 2 | none | | acc | 0.4712 ± 0.0089 | 0.5249 ± 0.0087 | 0.5182 ± 0.0088 | 0.5282 ± 0.0087 | 0.5269 ± 0.0087 | +| - social sciences | 2 | none | | acc | 0.4742 ± 0.0089 | 0.5070 ± 0.0089 | 0.5083 ± 0.0088 | 0.5151 ± 0.0088 | 0.5115 ± 0.0089 | +| - stem | 2 | none | | acc | 0.3428 ± 0.0084 | 0.3549 ± 0.0084 | 0.3539 ± 0.0084 | 0.3622 ± 0.0084 | 0.3609 ± 0.0084 | + +*Table 1. MMLU score with Llama2-7B inference* + +
+ + + + + +# Miscellaneous + +If users define their own dataloader, please ensure data consistency within ```deepspeed.utils.groups.get_tensor_model_parallel_group()```. DeepSpeed provides basic validation functions to assist with this. + +Furthermore, if users are not using transformers library, you can replace the ```TensorParallel_Layer``` layer and its subclasses as needed. See ```prepare_tp_model``` function in ```unit/model_parallelism/test_autotp_training.py```. Users can also define different shard and gather for subclasses of ```TensorParallel_Layer.``` + + + + + +# Ongoing Work +- **Optimization**: Communication/Activation optimization. +- **Usability**: Support [Transformers TP plan](https://github.com/huggingface/transformers/blob/336dc69d63d56f232a183a3e7f52790429b871ef/src/transformers/models/llama/configuration_llama.py#L145), decouple AutoTP parser and more model testing, + + +Theoretically, features supported by ZeRO should also be supported, though extensive testing is pending. + +Welcome bug reports, enhancement, and additional model training examples. + + +# Contributors +This work was made possible through a deep collaboration between Intel and Microsoft. The contributors include Mingzhi Liu, Guokai Ma, Kiefer Kuah, Yejing Lai, Kurt Chen, Yejun Guo, Guangxin Xu, Xiaofei Feng, and Yang Wang from Intel; Guanhua Wang and Olatunji Ruwase from Microsoft. diff --git a/blogs/huggingface-tp/media/batchscale.png b/blogs/huggingface-tp/media/batchscale.png new file mode 100644 index 000000000000..37a6eeeade9e Binary files /dev/null and b/blogs/huggingface-tp/media/batchscale.png differ diff --git a/blogs/huggingface-tp/media/flowchart.png b/blogs/huggingface-tp/media/flowchart.png new file mode 100644 index 000000000000..b7115df8c213 Binary files /dev/null and b/blogs/huggingface-tp/media/flowchart.png differ diff --git a/blogs/huggingface-tp/media/tp8.png b/blogs/huggingface-tp/media/tp8.png new file mode 100644 index 000000000000..0ae6e925eef1 Binary files /dev/null and b/blogs/huggingface-tp/media/tp8.png differ diff --git a/blogs/huggingface-tp/media/tpzero1.png b/blogs/huggingface-tp/media/tpzero1.png new file mode 100644 index 000000000000..b7f21a9e2a5f Binary files /dev/null and b/blogs/huggingface-tp/media/tpzero1.png differ diff --git a/blogs/huggingface-tp/media/tpzero2.png b/blogs/huggingface-tp/media/tpzero2.png new file mode 100644 index 000000000000..7a3333990835 Binary files /dev/null and b/blogs/huggingface-tp/media/tpzero2.png differ diff --git a/blogs/huggingface-tp/media/zero1tpload.png b/blogs/huggingface-tp/media/zero1tpload.png new file mode 100644 index 000000000000..9af5622f908d Binary files /dev/null and b/blogs/huggingface-tp/media/zero1tpload.png differ diff --git a/blogs/huggingface-tp/media/zero2tpload.png b/blogs/huggingface-tp/media/zero2tpload.png new file mode 100644 index 000000000000..69f002abf474 Binary files /dev/null and b/blogs/huggingface-tp/media/zero2tpload.png differ diff --git a/blogs/huggingface-tp/media/zero3.png b/blogs/huggingface-tp/media/zero3.png new file mode 100644 index 000000000000..62e6eb712151 Binary files /dev/null and b/blogs/huggingface-tp/media/zero3.png differ diff --git a/blogs/windows/08-2024/chinese/README.md b/blogs/windows/08-2024/chinese/README.md index 5d62705df3ae..78b9b6213d89 100644 --- a/blogs/windows/08-2024/chinese/README.md +++ b/blogs/windows/08-2024/chinese/README.md @@ -97,7 +97,7 @@ DeepSpeed可以通过两种方式在Windows系统上安装。较为简单的方 # 总结 -使得DeepSpeed,一个流行的深度学习框架,能够原生运行在最流行的操作系统 Windows 上,是让每个人和组织从当前的人工智能革命中受益的重要一步。在这篇博客中,我们分享了我们为实现这一目标所取得的早期成果。尽管 DeepSpeed 对 Windows 的支持仍在继续开发中,我们希望上述结果已经能够对我们的用户有实用价值,并且鼓舞他们。我们接下来的工作计划涵盖多GPU支持、权重量化以及性能优化。 +让流行的深度学习框架 DeepSpeed 能够在最流行的操作系统 Windows 上原生运行,是让每个人和每个组织都能从正在进行的人工智能革命中受益的关键一步。在这篇博客中,我们分享了我们为实现这一目标所取得的早期成果。尽管 DeepSpeed 对 Windows 的支持仍在继续开发中,我们希望上述结果已经能够对我们的用户有实用价值,并且鼓舞他们。我们接下来的工作计划涵盖多GPU支持、权重量化以及性能优化。 # 致谢 这给项目的完成得益于现任和前任 DeepSpeed 成员的大力合作,包括 Costin Eseanu、Logan Adams、Elton Zheng、Reza Yazdani Aminabadi、Martin Cai 和 Olatunji Ruwase。我们还要感谢那些及时提出此项需求、提供关键的临时解决方法、部分解决方案和建设性反馈的 DeepSpeed 用户,最重要的是,他们始终与我们同行. diff --git a/build_win.bat b/build_win.bat index 627694dbe8a0..64ba99633d50 100644 --- a/build_win.bat +++ b/build_win.bat @@ -10,6 +10,7 @@ set DS_BUILD_FP_QUANTIZER=0 set DS_BUILD_GDS=0 set DS_BUILD_RAGGED_DEVICE_OPS=0 set DS_BUILD_SPARSE_ATTN=0 +set DS_BUILD_DEEP_COMPILE=0 python -m build --wheel --no-isolation diff --git a/csrc/compile/deepcompile.cpp b/csrc/compile/deepcompile.cpp new file mode 100644 index 000000000000..2eca0a33262e --- /dev/null +++ b/csrc/compile/deepcompile.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#define USE_C10D_NCCL + +namespace dc { + +std::shared_ptr param_registry; +std::unordered_map> executors; +std::shared_ptr reduce_buckets = nullptr; + +c10::intrusive_ptr process_group = nullptr; +c10::intrusive_ptr symm_mem = nullptr; +ncclComm_t nccl_comm; +bool use_symm_mem; +bool clone_custom_op_output; +bool profile = false; +bool pre_div_reduce = true; + +bool sync_before_reduce; // for debugging +bool sync_after_reduce; // for debugging +bool sync_before_allgather; // for debugging +bool sync_after_allgather; // for debugging + +std::vector sizes_to_int_vector(at::IntArrayRef sizes) +{ + std::vector result; + for (int i = 0; i < sizes.size(); i++) { result.push_back(sizes[i]); } + return result; +} + +void enable_profiling(bool enable) { profile = enable; } + +bool is_profiling() { return profile; } + +c10::intrusive_ptr getSymmMemWorkspace(int64_t size) +{ + c10::Device device = c10::Device(c10::kCUDA, c10::cuda::current_device()); + std::vector sizes = {size}; + std::vector strides = {1}; + at::Tensor sym_mem_ws = c10d::symmetric_memory::empty_strided_p2p( + {size}, {1}, c10::ScalarType::Byte, device, process_group->getGroupName(), std::nullopt); + return c10d::symmetric_memory::rendezvous(sym_mem_ws); +} + +void lazy_init_symm_memory() +{ + if (use_symm_mem && !symm_mem) { + int64_t max_param_size = 0; + for (const auto& it : param_registry->getParams()) { + int64_t size = it.second.getDSTensor().numel() * it.second.getDSTensor().element_size(); + if (size > max_param_size) { max_param_size = size; } + } + symm_mem = getSymmMemWorkspace(max_param_size); + } +} + +ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type) +{ + switch (scalar_type) { + case at::kFloat: return ncclFloat; + case at::kHalf: return ncclHalf; + case at::kDouble: return ncclDouble; + case at::kBFloat16: return ncclBfloat16; + case at::kLong: return ncclInt64; + case at::kInt: return ncclInt; + case at::kChar: return ncclInt8; + default: throw std::runtime_error("Unsupported scalar type"); + } +} + +void reset() +{ + executors.clear(); + // We keep the buckets for memory estimation + // reduce_buckets->clear(); +} + +void cleanup() +{ + reset(); + + ncclCommDestroy(nccl_comm); + process_group = nullptr; + symm_mem = nullptr; +} + +at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id) +{ + if (sync_before_reduce) { c10::cuda::device_synchronize(); } + + assert(hasKey(executors, graph_id)); + if (!profile) { executors[graph_id]->reduceGrad(grad_tensor, ds_id); } + + if (sync_after_reduce) { c10::cuda::device_synchronize(); } + + return at::Tensor(); +} + +at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id) +{ + return at::Tensor(); +} + +void free_tensors(std::vector tensors) +{ + int64_t THRESHOLD = 10 * 1024 * 1024; + + if (!profile) { + for (auto& tensor : tensors) { + if (tensor.is_cuda() && tensor.numel() > THRESHOLD) { + tensor.record_stream(at::cuda::getCurrentCUDAStream()); + tensor.set_data(torch::empty({0}, tensor.options())); + } + } + } +} + +void free_tensors_meta(std::vector tensors) {} + +void init(c10::intrusive_ptr pg, + int64_t initial_reduce_bucket_size, + bool enable_double_buffer, + bool _use_symm_mem, + bool _clone_custom_op_output, + bool _sync_before_reduce, + bool _sync_after_reduce, + bool _sync_before_allgather, + bool _sync_after_allgather) +{ + process_group = pg; + + ncclUniqueId ncclID; + ncclGetUniqueId(&ncclID); + + // ProcessGroup doesn't have an API to get the CUDA stream for comm calls. + // So we create a NCCL communicator and call NCCL APIs directly. + auto vec = std::vector(reinterpret_cast(&ncclID), + reinterpret_cast(&ncclID) + NCCL_UNIQUE_ID_BYTES); + auto device = torch::Device(torch::kCUDA); + at::Tensor tensor = torch::from_blob(vec.data(), {static_cast(vec.size())}, torch::kUInt8) + .to(torch::Device(torch::kCUDA)); + std::vector bcast_input = {tensor}; + + process_group->broadcast(bcast_input, c10d::BroadcastOptions())->wait(); + + // create a new nccl communicator + std::memcpy(&ncclID, tensor.to(torch::Device(torch::kCPU)).data_ptr(), NCCL_UNIQUE_ID_BYTES); + ncclCommInitRank(&nccl_comm, process_group->getSize(), ncclID, process_group->getRank()); + + param_registry = std::make_shared(); + reduce_buckets = std::make_shared(initial_reduce_bucket_size, + enable_double_buffer); + use_symm_mem = _use_symm_mem; + clone_custom_op_output = _clone_custom_op_output; + + sync_before_reduce = _sync_before_reduce; + sync_after_reduce = _sync_after_reduce; + sync_before_allgather = _sync_before_allgather; + sync_after_allgather = _sync_after_allgather; +} + +void start_forward() +{ + lazy_init_symm_memory(); + for (auto& it : executors) { it.second->startForward(); } +} + +void end_forward() +{ + for (auto& it : executors) { it.second->endForward(); } +} + +void start_backward(bool update) +{ + for (auto& it : executors) { it.second->startBackward(update); } +} + +// We don't call this +// void end_backward(bool update) +// { +// } + +} // namespace dc diff --git a/csrc/compile/init.cpp b/csrc/compile/init.cpp new file mode 100644 index 000000000000..ca2538b5f2b1 --- /dev/null +++ b/csrc/compile/init.cpp @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" +#include "z1.h" +#include "z3.h" + +TORCH_LIBRARY(dc, m) +{ + m.def("allgather_param(Tensor a, int graph_id, int id) -> Tensor"); + m.def("prefetch_params_fused(int graph_id, Tensor[] params, int[] ids) -> ()"); + m.def("wait_allgather(Tensor a, int graph_id, int id) -> Tensor"); + m.def("release_param(Tensor a, int graph_id, int id, int n_users) -> Tensor"); + m.def("reduce_grad(Tensor a, int graph_id, int id) -> Tensor"); + m.def("free_tensors(Tensor[] a) -> ()"); + m.def("offload_tensor(Tensor a, int id, int id) -> Tensor"); + m.def("reload_tensor(Tensor a, int id, int id) -> Tensor"); + m.def("wait_offload(Tensor a, int id, int id) -> Tensor"); + m.def("wait_reload(Tensor a, int id, int id) -> Tensor"); + m.def("offload_parameter(Tensor a, int id, int id) -> ()"); + m.def("reload_parameter(Tensor a, int id, int id) -> ()"); + + m.def("test_call(Tensor a) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(dc, CPU, m) +{ + m.impl("allgather_param", &dc::allgather_param); + m.impl("prefetch_params_fused", &dc::prefetch_params_fused); + m.impl("wait_allgather", &dc::wait_allgather); + m.impl("release_param", &dc::release_param); + m.impl("reduce_grad", &dc::reduce_grad); + m.impl("free_tensors", &dc::free_tensors); + m.impl("offload_tensor", &dc::offload_tensor); + m.impl("reload_tensor", &dc::reload_tensor); + m.impl("wait_offload", &dc::wait_offload); + m.impl("wait_reload", &dc::wait_reload); + m.impl("offload_parameter", &dc::offload_parameter); + m.impl("reload_parameter", &dc::reload_parameter); + + m.impl("test_call", &dc::test_call); +} + +TORCH_LIBRARY_IMPL(dc, CUDA, m) +{ + m.impl("allgather_param", &dc::allgather_param); + m.impl("prefetch_params_fused", &dc::prefetch_params_fused); + m.impl("wait_allgather", &dc::wait_allgather); + m.impl("release_param", &dc::release_param); + m.impl("reduce_grad", &dc::reduce_grad); + m.impl("free_tensors", &dc::free_tensors); + m.impl("offload_tensor", &dc::offload_tensor); + m.impl("reload_tensor", &dc::reload_tensor); + m.impl("wait_offload", &dc::wait_offload); + m.impl("wait_reload", &dc::wait_reload); + m.impl("offload_parameter", &dc::offload_parameter); + m.impl("reload_parameter", &dc::reload_parameter); + + m.impl("test_call", &dc::test_call); +} + +TORCH_LIBRARY_IMPL(dc, Meta, m) +{ + m.impl("allgather_param", &dc::allgather_param_meta); + m.impl("prefetch_params_fused", &dc::prefetch_params_fused_meta); + m.impl("release_param", &dc::release_param_meta); + m.impl("wait_allgather", &dc::wait_allgather_meta); + m.impl("reduce_grad", &dc::reduce_grad_meta); + m.impl("free_tensors", &dc::free_tensors_meta); + m.impl("reload_parameter", &dc::reload_parameter_meta); + m.impl("offload_parameter", &dc::offload_parameter_meta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("set_persistent", &dc::set_persistent, "Set persistent flag for a parameter"); + m.def("enable_profiling", &dc::enable_profiling, "Enable profiling"); + m.def("is_profiling", &dc::is_profiling, "Check if profiling is enabled"); + m.def("init", &dc::init, "Set the process group"); + m.def("cleanup", &dc::cleanup, "Cleanup the process group"); + m.def("register_z1_param", &dc::register_z1_param, "Register a parameter"); + m.def("register_graph_z1", + &dc::register_graph_z1, + "Register graph with a list of ds parameter ids"); + m.def("register_z3_param", &dc::register_z3_param, "Register a parameter"); + m.def("register_graph_z3", + &dc::register_graph_z3, + "Register graph with a list of ds parameter ids"); + m.def("start_forward", &dc::start_forward, "Start forward pass"); + m.def("end_forward", &dc::end_forward, "End forward pass"); + m.def("start_backward", &dc::start_backward, "Start backward pass"); + // m.def("end_backward", &dc::end_backward, "End backward pass"); + m.def("cleanup", &dc::cleanup, "Clean up DeepCompile"); + m.def("reset", &dc::reset, "Reset the state"); + m.def("invalidate_gathered_param", &dc::invalidate_gathered_param, "Invalidate gathered param"); + m.def("clear_all_gathered_params", &dc::clear_all_gathered_params, "Clear all gathered params"); +} diff --git a/csrc/compile/util.cpp b/csrc/compile/util.cpp new file mode 100644 index 000000000000..948338028059 --- /dev/null +++ b/csrc/compile/util.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#include + +namespace dc { + +std::string tensorToString(const at::Tensor& t, size_t max_elem, size_t max_str_len) +{ + auto t_cpu = t.flatten() + .slice(0, 0, std::min((int64_t)max_elem, t.numel())) + .to(c10::Device(c10::kCPU), false, true); + + size_t size = std::min(max_elem, productDim(t.sizes())); + + if (t.scalar_type() == c10::ScalarType::Half || t.scalar_type() == c10::ScalarType::BFloat16) { + auto float_ten = t_cpu.to(c10::ScalarType::Float, false, true).contiguous(); + return tensorPtrToString((float*)float_ten.data_ptr(), size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Float) { + return tensorPtrToString((float*)t_cpu.data_ptr(), size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Double) { + return tensorPtrToString((double*)t_cpu.data_ptr(), size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Int) { + int* ptr = static_cast(t_cpu.data_ptr()); + return tensorPtrToString(ptr, size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Long) { + long* ptr = static_cast(t_cpu.data_ptr()); + return tensorPtrToString(ptr, size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Byte) { + unsigned char* ptr = static_cast(t_cpu.data_ptr()); + std::vector vec; + vec.reserve(size); + for (size_t i = 0; i < size; i++) { + vec.push_back(*ptr); + ptr++; + } + return tensorPtrToString(&vec[0], size, max_str_len); + } else if (t.scalar_type() == c10::ScalarType::Bool) { + bool* ptr = static_cast(t_cpu.data_ptr()); + std::vector vec; + vec.reserve(size); + for (size_t i = 0; i < size; i++) { + vec.push_back(*ptr); + ptr++; + } + return tensorPtrToString(&vec[0], size, max_str_len); + } + std::stringstream ss; + ss << "Failed to convert tensor to string. Invalid type of tensor: " + << toString(t.scalar_type()); + throw std::invalid_argument(ss.str()); +} + +std::string tensorPtrToString(void* ptr, + size_t size, + c10::ScalarType datatype, + size_t max_elem, + size_t max_str_len) +{ + int64_t elem_size = std::min((size_t)max_elem, size); + + if (datatype == c10::ScalarType::Long) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Int) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Double) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Float) { + return tensorPtrToString(static_cast(ptr), elem_size, max_str_len); + } else if (datatype == c10::ScalarType::Half || datatype == c10::ScalarType::BFloat16) { + const auto ten = torch::from_blob(ptr, {(int64_t)elem_size}, datatype); + auto float_ten = ten.to(c10::ScalarType::Float, false, true).contiguous(); + return tensorPtrToString((float*)float_ten.data_ptr(), elem_size, max_str_len); + } + std::stringstream ss; + ss << "Failed to convert tensor ptr to string. Invalid type of tensor: " << toString(datatype); + throw std::invalid_argument(ss.str()); +} + +std::string tensorDimToString(const at::Tensor& t) +{ + const auto dim = t.sizes(); + return join_as_str(dim); +} +} // namespace dc diff --git a/csrc/compile/z1.cpp b/csrc/compile/z1.cpp new file mode 100644 index 000000000000..1fc90839862d --- /dev/null +++ b/csrc/compile/z1.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "z1.h" +#include "deepcompile.h" + +#define USE_C10D_NCCL + +#include +#include +#include +#include +#include +#include + +#include + +namespace dc { + +class Z1CustomOpExecutor : public CustomOpExecutor { +public: + Z1CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + bool pre_div_reduce) + : CustomOpExecutor(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + rs_stream, + copy_stream, + pre_div_reduce) + { + } + ~Z1CustomOpExecutor() {} + + void endBackward() override + { + if (param_updated_) { + for (auto& it : has_acc_grad_) { it.second = false; } + } + } + + void flushReduceBucket(at::ScalarType scalar_type) override + { + int rank = process_group_->getRank(); + + if (!hasKey(reduce_tasks_, scalar_type)) { return; } + + int64_t tmp_recv_numel = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); + copy_done_event->block(rs_stream_); + } + + ncclGroupStart(); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg; + if (pre_div_reduce_) { + at::cuda::CUDAStreamGuard guard(rs_stream_); + t.getSendBuf().div_(process_group_->getSize()); + } + + // inplace + ncclResult_t result = ncclAllReduce(t.getSendBuf().data_ptr(), + t.getSendBuf().data_ptr(), + t.getSendBuf().numel(), + get_nccl_data_type(scalar_type), + op, + nccl_comm_, + rs_stream_); + if (result != ncclSuccess) { throw std::runtime_error("NCCL AllReduce failed"); } + } + ncclGroupEnd(); + + { + at::cuda::CUDAStreamGuard guard(rs_stream_); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + bool acc_grad = has_acc_grad_.at(t.getDSId()); + auto param = param_registry_->getParam(t.getDSId()); + auto grad_buf = param.getGradBuffer().flatten(); + + if (grad_buf.numel() == 0) { continue; } + + int64_t offset = param.getOffset(); + auto recv_buf = t.getSendBuf().flatten().index( + {torch::indexing::Slice(offset, offset + grad_buf.numel())}); + if (acc_grad) { + grad_buf.add_(recv_buf); + } else { + grad_buf.copy_(recv_buf); + } + has_acc_grad_[t.getDSId()] = true; + } + } + + reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_); + + // Not very sure if this is necessary + // Want to prevent grad tensor from being released before the copy is done + auto comp_stream = at::cuda::getCurrentCUDAStream(); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); + copy_done_event->block(comp_stream); + } + reduce_tasks_[scalar_type].clear(); + } +}; + +static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); + +void register_graph_z1(long graph_id, const std::vector& ds_ids) +{ + executors[graph_id] = std::make_shared(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + rs_stream, + copy_stream, + pre_div_reduce); +} + +void register_z1_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + int64_t offset) +{ + param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, false, offset, false); +} + +} // namespace dc diff --git a/csrc/compile/z1.h b/csrc/compile/z1.h new file mode 100644 index 000000000000..a2f100565eba --- /dev/null +++ b/csrc/compile/z1.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#pragma once + +namespace dc { + +void register_graph_z1(long graph_id, const std::vector& ds_ids); +void register_z1_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + int64_t offset); +} // namespace dc diff --git a/csrc/compile/z3.cpp b/csrc/compile/z3.cpp new file mode 100644 index 000000000000..523bcf2c04b4 --- /dev/null +++ b/csrc/compile/z3.cpp @@ -0,0 +1,544 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "z3.h" +#include "deepcompile.h" + +#define USE_C10D_NCCL + +#include +#include +#include +#include +#include +#include + +#include + +namespace dc { + +const size_t TIMEOUT_SYMMETRIC_MEMORY_BARRIER = 60000; + +class Z3CustomOpExecutor : public CustomOpExecutor { +public: + Z3CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream ag_stream, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + at::cuda::CUDAStream offload_stream, + at::cuda::CUDAStream reload_stream, + bool pre_div_reduce) + : CustomOpExecutor(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + rs_stream, + copy_stream, + pre_div_reduce), + ag_stream_(ag_stream), + offload_stream_(offload_stream), + reload_stream_(reload_stream) + { + for (long ds_id : ds_ids_) { + ag_comm_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + ag_comp_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + + param_use_count_[ds_id] = 0; + } + } + ~Z3CustomOpExecutor() {} + + void endBackward() override + { + if (param_updated_) { + for (auto& it : has_acc_grad_) { + it.second = false; + param_registry_->setValid(it.first, false); + } + } + + for (auto& it : reload_buffers_) { + it.second.record_stream(at::cuda::getCurrentCUDAStream()); + } + reload_buffers_.clear(); + } + + void launchAllGather(at::Tensor output_buf, + long ds_id, + c10::intrusive_ptr symm_mem) + { + const DSParam& param = param_registry_->getParam(ds_id); + const at::Tensor& ds_tensor = param.getDSTensor(); + + if (symm_mem == nullptr) { + ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(), + output_buf.data_ptr(), + ds_tensor.numel(), + get_nccl_data_type(ds_tensor.scalar_type()), + nccl_comm_, + ag_stream_); + + if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); } + } else { + at::cuda::CUDAStreamGuard guard(ag_stream_); + int world_size = process_group_->getSize(); + int rank = process_group_->getRank(); + + at::Tensor local_buf = + symm_mem->get_buffer(rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); + local_buf.copy_(ds_tensor, true); + + symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); + auto chunks = output_buf.flatten().chunk(world_size); + for (int step = 0; step < world_size; step++) { + int remote_rank = (rank - step + world_size) % world_size; + auto src_buf = symm_mem->get_buffer( + remote_rank, ds_tensor.sizes(), ds_tensor.scalar_type(), 0); + chunks[remote_rank].copy_(src_buf.flatten(), true); + } + symm_mem->barrier(0, TIMEOUT_SYMMETRIC_MEMORY_BARRIER); + } + + param_registry_->registerGatheredParam(ds_id, output_buf); + param_registry_->setValid(ds_id, true); + } + + at::Tensor allgatherParam(long ds_id, + c10::intrusive_ptr symm_mem) + { + if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); } + + const DSParam& param = param_registry_->getParam(ds_id); + const at::Tensor& ds_tensor = param.getDSTensor(); + at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id) + ? param_registry_->getGatheredParam(ds_id) + : torch::empty(param.getShape(), ds_tensor.options()); + + assert(hasKey(ag_comp_done_events_, ds_id)); + ag_comp_done_events_[ds_id]->record(); + ag_comp_done_events_[ds_id]->block(ag_stream_); + + launchAllGather(output_buf, ds_id, symm_mem); + + ag_comm_done_events_[ds_id]->record(ag_stream_); + return output_buf; + } + + void prefetchParamsFused(std::vector ds_ids, + c10::intrusive_ptr symm_mem) + { + std::vector invalid_ds_ids; + for (const auto& ds_id : ds_ids) { + if (!param_registry_->isValid(ds_id)) { invalid_ds_ids.push_back(ds_id); } + } + + std::unordered_map output_bufs; + for (long ds_id : invalid_ds_ids) { + const DSParam& param = param_registry_->getParam(ds_id); + if (param_registry_->hasGatheredParam(ds_id)) { + output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id); + } else { + output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options()); + } + } + + for (long ds_id : invalid_ds_ids) { + ag_comp_done_events_[ds_id]->record(); + ag_comp_done_events_[ds_id]->block(ag_stream_); + } + + ncclGroupStart(); + for (long ds_id : invalid_ds_ids) { + assert(hasKey(output_bufs, ds_id)); + launchAllGather(output_bufs.at(ds_id), ds_id, symm_mem); + } + ncclGroupEnd(); + + for (long ds_id : invalid_ds_ids) { ag_comm_done_events_[ds_id]->record(ag_stream_); } + } + + void releaseParam(long ds_id, long n_users) + { + const DSParam& param = param_registry_->getParam(ds_id); + + assert(hasKey(param_use_count_, ds_id)); + if (param_use_count_[ds_id] == 0) { param_use_count_[ds_id] = n_users; } + param_use_count_[ds_id]--; + + if (param_use_count_[ds_id] == 0 && !param.isPersistent()) { + at::Tensor gathered_param = param_registry_->getGatheredParam(ds_id); + + if (gathered_param.defined()) { // gathered param is undefined while profiling + const auto options = gathered_param.options(); + at::Tensor empty_buffer = torch::empty({0}, options); + gathered_param.set_data(empty_buffer); + } + + param_registry_->unregisterGatheredParam(ds_id); + } + } + + at::Tensor waitAllgather(at::Tensor v, long ds_id) + { + assert(hasKey(ag_comm_done_events_, ds_id)); + ag_comm_done_events_[ds_id]->block(at::cuda::getCurrentCUDAStream()); + return v; + } + + void flushReduceBucket(at::ScalarType scalar_type) override + { + if (!hasKey(reduce_tasks_, scalar_type)) { return; } + + int64_t tmp_recv_numel = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); + copy_done_event->block(rs_stream_); + + if (has_acc_grad_.at(t.getDSId())) { + tmp_recv_numel += param_registry_->getParam(t.getDSId()).getGradBuffer().numel(); + } + } + + at::Tensor tmp_recv_buf = at::Tensor(); + if (tmp_recv_numel > 0) { + at::cuda::CUDAStreamGuard guard(rs_stream_); + tmp_recv_buf = torch::empty({tmp_recv_numel}, + at::TensorOptions().dtype(scalar_type).device(at::kCUDA)); + } + + ncclGroupStart(); + int64_t offset = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); + + bool acc_grad = has_acc_grad_.at(t.getDSId()); + + if (acc_grad) { + recv_buf = + tmp_recv_buf.index({torch::indexing::Slice(offset, offset + recv_buf.numel())}); + } + + ncclRedOp_t op = pre_div_reduce_ ? ncclSum : ncclAvg; + if (pre_div_reduce_) { + at::cuda::CUDAStreamGuard guard(rs_stream_); + t.getSendBuf().div_(process_group_->getSize()); + } + ncclResult_t result = ncclReduceScatter(t.getSendBuf().data_ptr(), + recv_buf.data_ptr(), + recv_buf.numel(), + get_nccl_data_type(scalar_type), + op, + nccl_comm_, + rs_stream_); + if (result != ncclSuccess) { throw std::runtime_error("NCCL ReduceScatter failed"); } + + if (acc_grad) { offset += recv_buf.numel(); } + } + ncclGroupEnd(); + + { + at::cuda::CUDAStreamGuard guard(rs_stream_); + int64_t offset = 0; + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + bool acc_grad = has_acc_grad_.at(t.getDSId()); + + if (acc_grad) { + auto recv_buf = param_registry_->getParam(t.getDSId()).getGradBuffer(); + recv_buf.add_(tmp_recv_buf.index( + {torch::indexing::Slice(offset, offset + recv_buf.numel())})); + offset += recv_buf.numel(); + } + has_acc_grad_[t.getDSId()] = true; + } + } + + reduce_buckets_->swap(scalar_type, rs_stream_, copy_stream_); + + // Not very sure if this is necessary + // Want to prevent grad tensor from being released before the copy is done + auto comp_stream = at::cuda::getCurrentCUDAStream(); + for (const ReduceTask& t : reduce_tasks_.at(scalar_type)) { + auto copy_done_event = rs_copy_done_events_.at(t.getDSId()); + copy_done_event->block(comp_stream); + } + reduce_tasks_[scalar_type].clear(); + + if (tmp_recv_numel > 0) { tmp_recv_buf.record_stream(rs_stream_); } + } + + at::Tensor offloadTensor(at::Tensor tensor, long id) + { + if (!hasKey(offload_events_, id)) { + offload_events_[id] = std::make_shared(cudaEventDisableTiming); + offload_comp_done_events_[id] = + std::make_shared(cudaEventDisableTiming); + + const auto options = at::TensorOptions().pinned_memory(true).device(torch::kCPU); + offload_buffers_[id] = at::empty_like(tensor, options); + } + + offload_comp_done_events_[id]->record(); + offload_comp_done_events_[id]->block(offload_stream_); + { + at::cuda::CUDAStreamGuard guard(offload_stream_); + offload_buffers_.at(id).copy_(tensor, true); + } + + tensor.record_stream(offload_stream_); + + offload_events_[id]->record(offload_stream_); + assert(hasKey(offload_buffers_, id)); + return offload_buffers_.at(id); + } + + at::Tensor reloadTensor(at::Tensor tensor, long id) + { + if (!hasKey(reload_events_, id)) { + reload_events_[id] = std::make_shared(cudaEventDisableTiming); + } + + assert(hasKey(offload_buffers_, id)); + offload_events_[id]->block(reload_stream_); + + at::Tensor ten; + { + at::cuda::CUDAStreamGuard guard(reload_stream_); + + assert(hasKey(offload_buffers_, id)); + at::Tensor buf = offload_buffers_.at(id); + const auto options = at::TensorOptions().device(torch::kCUDA); + ten = at::empty_like(buf, options); + ten.copy_(buf, true); + + reload_buffers_[id] = ten; + } + + reload_events_[id]->record(reload_stream_); + return ten; + } + + at::Tensor waitOffload(at::Tensor tensor, long id) + { + assert(hasKey(offload_events_, id)); + offload_events_[id]->block(at::cuda::getCurrentCUDAStream()); + + assert(hasKey(offload_buffers_, id)); + return offload_buffers_.at(id); + } + + at::Tensor waitReload(at::Tensor tensor, long id) + { + assert(hasKey(reload_events_, id)); + reload_events_[id]->block(at::cuda::getCurrentCUDAStream()); + + assert(hasKey(reload_buffers_, id)); + auto ten = reload_buffers_.at(id); + + // We can't release here because the tensor is still being used + // We will need "freeReloadedTensor" after the last user of the tensor to call + // ".record_stream". As it is a bit complicated, we clear the buffer and do at the end of + // the backward pass for now. reload_buffers_.erase(id); + return ten; + } + + void offloadParameter(at::Tensor tensor, long ds_id) { param_registry_->offload(ds_id); } + void reloadParameter(at::Tensor tensor, long ds_id) { param_registry_->reload(ds_id); } + + bool hasReloadBuffer(long id) { return hasKey(reload_buffers_, id); } + + bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); } + +private: + at::cuda::CUDAStream ag_stream_; + at::cuda::CUDAStream offload_stream_; + at::cuda::CUDAStream reload_stream_; + + std::unordered_map> ag_comp_done_events_; + std::unordered_map> ag_comm_done_events_; + + std::unordered_map> offload_events_; + std::unordered_map> offload_comp_done_events_; + std::unordered_map> reload_events_; + std::unordered_map offload_buffers_; + std::unordered_map reload_buffers_; + + std::unordered_map param_use_count_; +}; + +static at::cuda::CUDAStream ag_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream rs_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream copy_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream offload_stream = at::cuda::getStreamFromPool(true); +static at::cuda::CUDAStream reload_stream = at::cuda::getStreamFromPool(true); + +void register_graph_z3(long graph_id, const std::vector& ds_ids) +{ + executors[graph_id] = std::make_shared(process_group, + param_registry, + reduce_buckets, + ds_ids, + nccl_comm, + ag_stream, + rs_stream, + copy_stream, + offload_stream, + reload_stream, + pre_div_reduce); +} + +void register_z3_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool persistent) +{ + param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent); + if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); } +} + +at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + + if (sync_before_allgather) { c10::cuda::device_synchronize(); } + auto ret = executor->allgatherParam(ds_id, symm_mem); + if (sync_after_allgather) { c10::cuda::device_synchronize(); } + return ret; +} + +void set_persistent(long ds_id) +{ + param_registry->setPersistent(ds_id, true); + + // Allocate buffer here + // Memory fragmentation will be more severe if we allocate in forward/backward + for (auto& it : executors) { + if (it.second->hasParam(ds_id)) { + auto executor = getExecutor(it.first, executors); + executor->allgatherParam(ds_id, symm_mem); + } + } +} + +void prefetch_params_fused(long graph_id, + const std::vector params, + const std::vector& ds_ids) +{ + auto executor = getExecutor(graph_id, executors); + executor->prefetchParamsFused(ds_ids, symm_mem); +} + +void prefetch_params_fused_meta(long graph_id, + const std::vector params, + const std::vector& ds_ids) +{ +} + +// for profiling +void invalidate_gathered_param(long ds_id) +{ + const DSParam& param = param_registry->getParam(ds_id); + if (param.isPersistent()) { return; } + + param_registry->unregisterGatheredParam(ds_id); + param_registry->registerGatheredParam(ds_id, at::Tensor()); +} + +void clear_all_gathered_params() +{ + for (const auto& it : param_registry->getParams()) { + long ds_id = it.first; + const DSParam& param = param_registry->getParam(ds_id); + if (param.isPersistent()) { continue; } + if (param_registry->hasGatheredParam(ds_id)) { + param_registry->unregisterGatheredParam(ds_id); + } + } +} + +at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id) +{ + const DSParam& param = param_registry->getParam(ds_id); + auto options = param.getDSTensor().options().device(c10::kMeta); + at::Tensor output_buf = torch::empty(param.getShape(), options); + return output_buf; +} + +at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users) +{ + auto executor = getExecutor(graph_id, executors); + executor->releaseParam(ds_id, n_users); + + if (clone_custom_op_output) { return dummy.clone(); } + return dummy; +} + +at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users) +{ + return dummy; +} + +at::Tensor wait_allgather(at::Tensor v, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->waitAllgather(v, ds_id); + return v; +} + +at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id) { return v; } + +at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + return executor->offloadTensor(tensor, id); +} + +at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + return executor->reloadTensor(tensor, id); +} + +at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + return executor->waitOffload(tensor, id); +} + +at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id) +{ + auto executor = getExecutor(graph_id, executors); + if (profile && !executor->hasReloadBuffer(id)) { return tensor; } + return executor->waitReload(tensor, id); +} + +at::Tensor test_call(at::Tensor a) +{ + std::cout << "test_call" << std::endl; + return a; +} + +void reload_parameter(at::Tensor tensor, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->reloadParameter(tensor, ds_id); +} + +void offload_parameter(at::Tensor tensor, long graph_id, long ds_id) +{ + auto executor = getExecutor(graph_id, executors); + executor->offloadParameter(tensor, ds_id); +} +void reload_parameter_meta(at::Tensor param_tensor, long graph_id, long ds_id) {} +void offload_parameter_meta(at::Tensor tensor, long graph_id, long ds_id) {} + +} // namespace dc diff --git a/csrc/compile/z3.h b/csrc/compile/z3.h new file mode 100644 index 000000000000..1031f0c84f7c --- /dev/null +++ b/csrc/compile/z3.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "deepcompile.h" + +#pragma once + +namespace dc { + +void register_graph_z3(long graph_id, const std::vector& ds_ids); +void register_graph_ops_z3(long graph_id, + const std::vector& op_names, + const std::vector& n_args); +void register_bwd_graph_ops_z3(long graph_id, + const std::vector& op_names, + const std::vector& n_args); +void register_z3_param(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool persistent); +at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id); +void set_persistent(long ds_id); +void prefetch_params_fused(long graph_id, + const std::vector params, + const std::vector& ds_ids); +void prefetch_params_fused_meta(long graph_id, + const std::vector params, + const std::vector& ds_ids); +// for profiling +void invalidate_gathered_param(long ds_id); +void clear_all_gathered_params(); +at::Tensor allgather_param_meta(at::Tensor param_tensor, long graph_id, long ds_id); +at::Tensor release_param(at::Tensor dummy, long graph_id, long ds_id, long n_users); +at::Tensor release_param_meta(at::Tensor dummy, long graph_id, long ds_id, long n_users); +at::Tensor wait_allgather(at::Tensor v, long graph_id, const long ds_id); +at::Tensor wait_allgather_meta(at::Tensor v, long graph_id, long ds_id); +at::Tensor offload_tensor(at::Tensor tensor, long graph_id, long id); +at::Tensor reload_tensor(at::Tensor tensor, long graph_id, long id); +at::Tensor wait_offload(at::Tensor tensor, long graph_id, long id); +at::Tensor wait_reload(at::Tensor tensor, long graph_id, long id); +void reload_parameter(at::Tensor tensor, long graph_id, long id); +void offload_parameter(at::Tensor tensor, long graph_id, long id); +void reload_parameter_meta(at::Tensor tensor, long graph_id, long id); +void offload_parameter_meta(at::Tensor tensor, long graph_id, long id); +} // namespace dc diff --git a/csrc/fp_quantizer/fp_quantize.cpp b/csrc/fp_quantizer/fp_quantize.cpp index 903d84270d32..1a887b50e1a3 100644 --- a/csrc/fp_quantizer/fp_quantize.cpp +++ b/csrc/fp_quantizer/fp_quantize.cpp @@ -24,7 +24,6 @@ at::Tensor quantize(torch::Tensor& out, torch::Tensor& val, - torch::Tensor& scale, int group_size, int stochastic_rounding, int q_bits, @@ -60,7 +59,6 @@ at::Tensor quantize(torch::Tensor& out, void dequantize(torch::Tensor& val, torch::Tensor& val_q, - torch::Tensor& scale, int group_size, int q_mantisa_bits, int q_exponent_bits) diff --git a/csrc/includes/deepcompile.h b/csrc/includes/deepcompile.h new file mode 100644 index 000000000000..57c8fb89104e --- /dev/null +++ b/csrc/includes/deepcompile.h @@ -0,0 +1,576 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#define NOMINMAX // Windows idiosyncrasy + // https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c + +#define USE_C10D_NCCL + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace dc { + +template +static bool hasKey(const std::unordered_map& map, const K& key) +{ + return map.find(key) != map.end(); +} + +template +inline std::string to_string(const T& v) +{ + std::stringstream ss; + ss << v; + return ss.str(); +} + +template +size_t productDim(const L& dim) +{ + size_t prod = 1; + for (auto d : dim) { prod *= d; } + return prod; +} + +template +std::string join_as_str(const T& v, const char* delim = ",", const size_t maxlen = 0) +{ + std::stringstream ss; + + if (!v.empty()) { + auto it = v.begin(); + ss << to_string(*it); + it++; + for (; it != v.end(); ++it) { + if (delim) ss << delim; + ss << to_string(*it); + } + } + + std::string s = ss.str(); + if (maxlen > 0 && s.length() > maxlen) { s = s.substr(0, maxlen) + " ..."; } + + return "[" + s + "]"; +} + +template +std::string tensorPtrToString(T* ptr, size_t size, size_t str_len = 100) +{ + std::vector vals; + for (size_t i = 0; i < size; i++) { + vals.push_back(*ptr); + ptr++; + } + return join_as_str(vals, ",", str_len); +} + +std::string tensorPtrToString(void* ptr, + size_t size, + c10::ScalarType datatype, + size_t max_elem = 20, + size_t max_str_len = 100); + +std::string tensorToString(const at::Tensor& t, size_t max_elem = 20, size_t max_str_len = 100); + +std::string tensorDimToString(const at::Tensor& t); + +at::Tensor test_call(at::Tensor param); + +extern c10::intrusive_ptr process_group; +extern c10::intrusive_ptr symm_mem; +extern ncclComm_t nccl_comm; +extern bool use_symm_mem; +extern bool clone_custom_op_output; +extern bool profile; +extern bool pre_div_reduce; + +extern bool sync_before_reduce; // for debugging +extern bool sync_after_reduce; // for debugging +extern bool sync_before_allgather; // for debugging +extern bool sync_after_allgather; // for debugging + +std::vector sizes_to_int_vector(at::IntArrayRef sizes); +void enable_profiling(bool enable); +bool is_profiling(); + +c10::intrusive_ptr getSymmMemWorkspace(int64_t size); +void lazy_init_symm_memory(); +ncclDataType_t get_nccl_data_type(at::ScalarType scalar_type); +void cleanup(); + +class ReduceTask { +public: + ReduceTask(long ds_id, at::Tensor grad, at::Tensor send_buf) + : ds_id_(ds_id), grad_(std::move(grad)), send_buf_(std::move(send_buf)) + { + } + + long getDSId() const { return ds_id_; } + at::Tensor getSendBuf() const { return send_buf_; } + +private: + long ds_id_; + at::Tensor grad_; + at::Tensor send_buf_; +}; + +class ReduceBucket { +public: + ReduceBucket(int64_t size, at::ScalarType scalar_type) : size_(size), scalar_type_(scalar_type) + { + buffer_ = torch::empty({size}, at::TensorOptions().dtype(scalar_type).device(at::kCUDA)); + offset_ = 0; + } + + int64_t getSize() const { return size_; } + int64_t getOffset() const { return offset_; } + at::Tensor getBuffer() const { return buffer_; } + at::ScalarType getScalarType() const { return scalar_type_; } + + void reserve(int64_t size) + { + if (size > size_) { + buffer_ = + torch::empty({size}, at::TensorOptions().dtype(scalar_type_).device(at::kCUDA)); + size_ = size; + } + } + + at::Tensor allocate(int64_t numel) + { + if (offset_ + numel > size_) { + throw std::runtime_error("Buffer size exceeds the reduce bucket size"); + } + + at::Tensor result = buffer_.index({torch::indexing::Slice(offset_, offset_ + numel)}); + offset_ += numel; + return result; + } + + bool shouldFlush(int64_t numel) { return offset_ > 0 && offset_ + numel > size_; } + + void reset() { offset_ = 0; } + +private: + int64_t size_; + int64_t offset_; + at::Tensor buffer_; + at::ScalarType scalar_type_; +}; + +class DoubleBufferedReduceBucket { +public: + DoubleBufferedReduceBucket(int64_t initial_bucket_size, bool enable_double_buffer) + : initial_bucket_size_(initial_bucket_size), enable_double_buffer_(enable_double_buffer) + { + } + + void swap(at::ScalarType scalar_type, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream) + { + assert(hasKey(current_buffer_, scalar_type)); + assert(hasKey(current_buffer_events_, scalar_type)); + + current_buffer_.at(scalar_type)->reset(); + current_buffer_events_.at(scalar_type)->record(rs_stream); + + if (enable_double_buffer_) { + assert(hasKey(shadow_buffer_, scalar_type)); + assert(hasKey(shadow_buffer_events_, scalar_type)); + + auto tmp = current_buffer_.at(scalar_type); + current_buffer_[scalar_type] = shadow_buffer_.at(scalar_type); + shadow_buffer_[scalar_type] = tmp; + + auto tmp_event = current_buffer_events_.at(scalar_type); + current_buffer_events_[scalar_type] = shadow_buffer_events_.at(scalar_type); + shadow_buffer_events_[scalar_type] = tmp_event; + } + } + + std::shared_ptr getBuffer(at::ScalarType scalar_type) + { + if (!hasKey(current_buffer_, scalar_type)) { + current_buffer_[scalar_type] = + std::make_shared(initial_bucket_size_, scalar_type); + current_buffer_events_[scalar_type] = + std::make_shared(cudaEventDisableTiming); + + if (enable_double_buffer_) { + shadow_buffer_[scalar_type] = + std::make_shared(initial_bucket_size_, scalar_type); + shadow_buffer_events_[scalar_type] = + std::make_shared(cudaEventDisableTiming); + } + } + + return current_buffer_.at(scalar_type); + } + + std::shared_ptr getEvent(at::ScalarType scalar_type) + { + assert(hasKey(current_buffer_events_, scalar_type)); + return current_buffer_events_.at(scalar_type); + } + + void clear() + { + current_buffer_.clear(); + shadow_buffer_.clear(); + current_buffer_events_.clear(); + shadow_buffer_events_.clear(); + } + +private: + int64_t initial_bucket_size_; + bool enable_double_buffer_; + std::unordered_map> current_buffer_; + std::unordered_map> shadow_buffer_; + std::unordered_map> current_buffer_events_; + std::unordered_map> shadow_buffer_events_; +}; + +class DSParam { +public: + DSParam(long id, + std::vector ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool partitioned, + int64_t offset, // for Z1 + bool persistent // for Z3 + ) + : id_(id), + shape_(std::move(ds_shape)), + ds_tensor_(ds_tensor), + grad_buffer_(grad_buffer), + partitioned_(partitioned), + offset_(offset), + persistent_(persistent), + offload_stream_(at::cuda::getStreamFromPool()), + reload_stream_(at::cuda::getStreamFromPool()) + { + } + + long getId() const { return id_; } + std::vector getShape() const { return shape_; } + at::Tensor getDSTensor() const + { + // If the reload event exists and is complete, return the reloaded tensor (if defined) + if (reload_done_event_) { + if (!reload_done_event_->query()) { + reload_done_event_->block(at::cuda::getCurrentCUDAStream()); + } + if (ds_reload_tensor_.defined()) { return ds_reload_tensor_; } + } + // Otherwise, if an offload event exists, wait for it to complete + if (offload_done_event_) { + if (!offload_done_event_->query()) { + offload_done_event_->block(at::cuda::getCurrentCUDAStream()); + } + } + return ds_tensor_; + } + at::Tensor getGradBuffer() const { return grad_buffer_; } + bool isPartitioned() const { return partitioned_; } + int64_t getOffset() const { return offset_; } + void setPersistent(bool persistent) { persistent_ = persistent; } + bool isPersistent() const { return persistent_; } + + void offload() + { + // If a reloaded tensor exists, offload its data back to ds_tensor_ + if (ds_reload_tensor_.defined()) { + auto comp_stream = at::cuda::getCurrentCUDAStream(); + comp_done_event_ = std::make_shared(cudaEventDisableTiming); + // Record completion and wait on the offload stream + comp_done_event_->record(comp_stream); + comp_done_event_->block(offload_stream_); + offload_done_event_ = std::make_shared(cudaEventDisableTiming); + + { + at::cuda::CUDAStreamGuard guard(offload_stream_); + ds_tensor_.copy_(ds_reload_tensor_, /*non_blocking=*/true); + ds_reload_tensor_.reset(); // Clear the reloaded tensor + offload_done_event_->record(offload_stream_); + } + // Reset the reload event to indicate that no valid reload is present. + if (reload_done_event_) { reload_done_event_.reset(); } + } + } + + void reload() + { + // Reload only if the current ds_tensor_ is on CPU + if (ds_tensor_.device().is_cpu()) { + auto comp_stream = at::cuda::getCurrentCUDAStream(); + comp_done_event_ = std::make_shared(cudaEventDisableTiming); + // Record and wait on the reload stream + comp_done_event_->record(comp_stream); + comp_done_event_->block(reload_stream_); + reload_done_event_ = std::make_shared(cudaEventDisableTiming); + + { + at::cuda::CUDAStreamGuard guard(reload_stream_); + ds_reload_tensor_ = + at::empty_like(ds_tensor_, ds_tensor_.options().device(torch::kCUDA)); + ds_reload_tensor_.copy_(ds_tensor_, /*non_blocking=*/true); + reload_done_event_->record(reload_stream_); + } + // Reset offload_done_event if it exists to clear any stale offload state. + if (offload_done_event_) { offload_done_event_.reset(); } + } + } + +private: + long id_; + std::vector shape_; + at::Tensor ds_tensor_; + at::Tensor ds_reload_tensor_; + at::Tensor grad_buffer_; + bool partitioned_; + int64_t offset_; // for Z1 + bool persistent_; // for Z3 + mutable bool is_reloaded = false; + + at::cuda::CUDAStream offload_stream_; + at::cuda::CUDAStream reload_stream_; + std::shared_ptr comp_done_event_; + std::shared_ptr offload_done_event_; + std::shared_ptr reload_done_event_; +}; + +class DSParamRegistry { +public: + DSParamRegistry() {} + ~DSParamRegistry() {} + + void registerParam(long ds_id, + const std::vector& ds_shape, + at::Tensor ds_tensor, + at::Tensor grad_buffer, + bool partitioned, + int64_t offset, // for Z1 + bool persistent // for Z3 + ) + { + grad_buffer.zero_(); + params_.emplace( + ds_id, + DSParam(ds_id, ds_shape, ds_tensor, grad_buffer, partitioned, offset, persistent)); + valid_[ds_id] = false; + } + + void registerGatheredParam(long ds_id, at::Tensor ds_tensor) + { + gathered_params_.emplace(ds_id, ds_tensor); + } + + void unregisterGatheredParam(long ds_id) + { + assert(hasKey(gathered_params_, ds_id)); + gathered_params_.erase(ds_id); + valid_[ds_id] = false; + } + + const std::unordered_map& getParams() const { return params_; } + + const DSParam& getParam(long ds_id) const { return params_.at(ds_id); } + const size_t getNumParams() const { return params_.size(); } + const at::Tensor& getGatheredParam(long ds_id) const + { + assert(hasKey(gathered_params_, ds_id)); + return gathered_params_.at(ds_id); + } + bool hasGatheredParam(long ds_id) const { return hasKey(gathered_params_, ds_id); } + void setPersistent(long ds_id, bool persistent) { params_.at(ds_id).setPersistent(persistent); } + void offload(long ds_id) { params_.at(ds_id).offload(); } + void reload(long ds_id) { params_.at(ds_id).reload(); } + + void setValid(long ds_id, bool valid) { valid_[ds_id] = valid; } + bool isValid(long ds_id) const + { + assert(hasKey(valid_, ds_id)); + return valid_.at(ds_id); + } + +private: + std::unordered_map params_; + std::unordered_map gathered_params_; + std::unordered_map valid_; +}; + +class CustomOpExecutor { +public: + CustomOpExecutor(c10::intrusive_ptr process_group, + std::shared_ptr param_registry, + std::shared_ptr reduce_buckets, + std::vector ds_ids, + ncclComm_t nccl_comm, + at::cuda::CUDAStream rs_stream, + at::cuda::CUDAStream copy_stream, + bool pre_div_reduce) + : process_group_(process_group), + param_registry_(std::move(param_registry)), + reduce_buckets_(std::move(reduce_buckets)), + ds_ids_(std::move(ds_ids)), + nccl_comm_(nccl_comm), + rs_stream_(rs_stream), + copy_stream_(copy_stream), + pre_div_reduce_(pre_div_reduce) + { + for (long ds_id : ds_ids_) { + has_acc_grad_[ds_id] = false; + + rs_comp_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + rs_copy_done_events_[ds_id] = + std::make_shared(cudaEventDisableTiming); + } + reduce_counter_ = ds_ids_.size(); + } + ~CustomOpExecutor() {} + + virtual void startForward() {} + + virtual void endForward() {} + + virtual void startBackward(bool update) { param_updated_ = update; } + + virtual void endBackward() {} + + at::Tensor reduceGrad(at::Tensor grad_tensor, long ds_id) + { + int world_size = process_group_->getSize(); + const DSParam& param = param_registry_->getParam(ds_id); + const auto scalar_type = grad_tensor.scalar_type(); + std::shared_ptr reduce_bucket = reduce_buckets_->getBuffer(scalar_type); + + auto comp_stream = at::cuda::getCurrentCUDAStream(); + + if (reduce_bucket->shouldFlush(grad_tensor.numel())) { + int rank = process_group_->getRank(); + + flushReduceBucket(scalar_type); + + // reduce_bucket is swapped in flushReduceBucket if double buffering is enabled + reduce_bucket = reduce_buckets_->getBuffer(scalar_type); + } + + if (grad_tensor.numel() > reduce_bucket->getSize()) { + // extend buckets + at::cuda::stream_synchronize(rs_stream_); + reduce_bucket->reserve(grad_tensor.numel()); + } + + at::Tensor reduce_in_buffer = reduce_bucket->allocate(grad_tensor.numel()); + + // This ensures the order of reduce_scatter -> copy + // Without this block, copy may start while reduce_scatter is still running + reduce_buckets_->getEvent(scalar_type)->block(comp_stream); + auto copy_src = grad_tensor.contiguous().view({-1}).detach(); + // keep references to copy src + reduce_tasks_[scalar_type].emplace_back(ds_id, copy_src, reduce_in_buffer); + + // computation must be done before copy + rs_comp_done_events_[ds_id]->record(comp_stream); + rs_comp_done_events_[ds_id]->block(copy_stream_); + { + at::cuda::CUDAStreamGuard guard(copy_stream_); + reduce_in_buffer.copy_(copy_src, true); + rs_copy_done_events_[ds_id]->record(copy_stream_); + } + + reduce_counter_--; + + if (reduce_counter_ == 0) { + flushAllReduceBuckets(); + + reduce_counter_ = ds_ids_.size(); + + // This synchronization ensures all of reduce calls are done before optimizer's step. + at::cuda::stream_synchronize(rs_stream_); + + endBackward(); + } + + return at::Tensor(); + } + + bool hasParam(long ds_id) const { return hasKey(has_acc_grad_, ds_id); } + +protected: + c10::intrusive_ptr process_group_; + std::shared_ptr param_registry_; + std::shared_ptr reduce_buckets_; + std::vector ds_ids_; + ncclComm_t nccl_comm_; + at::cuda::CUDAStream rs_stream_; + at::cuda::CUDAStream copy_stream_; + + std::unordered_map> rs_comp_done_events_; + std::unordered_map> rs_copy_done_events_; + + size_t reduce_counter_ = 0; + bool param_updated_ = false; + std::unordered_map> reduce_tasks_; + std::unordered_map has_acc_grad_; + bool pre_div_reduce_; + + virtual void flushReduceBucket(at::ScalarType scalar_type) = 0; + + void flushAllReduceBuckets() + { + for (const auto& it : reduce_tasks_) { flushReduceBucket(it.first); } + } +}; + +template +std::shared_ptr getExecutor(long graph_id, + const std::unordered_map>& executors) +{ + assert(hasKey(executors, graph_id)); + if (auto executor = std::dynamic_pointer_cast(executors.at(graph_id))) { return executor; } + throw std::runtime_error("Invalid executor type"); +} + +extern std::shared_ptr param_registry; +extern std::unordered_map> executors; +extern std::shared_ptr reduce_buckets; + +at::Tensor reduce_grad(at::Tensor grad_tensor, long graph_id, long ds_id); +at::Tensor reduce_grad_meta(at::Tensor grad_tensor, long graph_id, long ds_id); +void free_tensors(std::vector tensors); +void free_tensors_meta(std::vector tensors); + +void init(c10::intrusive_ptr pg, + int64_t initial_reduce_bucket_size, + bool enable_double_buffer, + bool _use_symm_mem, + bool _clone_custom_op_output, + bool _sync_before_reduce, + bool _sync_after_reduce, + bool _sync_before_allgather, + bool _sync_after_allgather); +void reset(); +void cleanup(); + +void start_forward(); +void end_forward(); +void start_backward(bool update); + +} // namespace dc diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index dc7ff4d1e7c0..97857bc3f70b 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -13,8 +13,14 @@ namespace cg = cooperative_groups; // only used to avoid compilation error due to lack of definition. #ifndef BF16_AVAILABLE +#if defined(__CUDA_BF16_H__) +static_assert(sizeof(__nv_bfloat162) == sizeof(__half2), + "CUDA's __nv_bfloat162 doesn't match __half2 size"); +#else +// Fallback to simple typedef only if CUDA doesn't provide it using __nv_bfloat162 = __half2; #endif +#endif inline __device__ float gelu(const float x) { diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu index 8bc5a94e16ee..e7624363021e 100644 --- a/csrc/transformer/inference/csrc/transform.cu +++ b/csrc/transformer/inference/csrc/transform.cu @@ -12,8 +12,14 @@ namespace cg = cooperative_groups; // only used to avoid compilation error due to lack of definition. #ifndef BF16_AVAILABLE +#if defined(__CUDA_BF16_H__) +static_assert(sizeof(__nv_bfloat162) == sizeof(__half2), + "CUDA's __nv_bfloat162 doesn't match __half2 size"); +#else +// Fallback to simple typedef only if CUDA doesn't provide it using __nv_bfloat162 = __half2; #endif +#endif // Bias add diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index fd1f421b8954..c7e01a0ac969 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -366,7 +366,7 @@ def init_inference(model, config=None, **kwargs): return engine -def tp_model_init(model, tp_size, dtype): +def tp_model_init(model, tp_size, dtype, config=None, **kwargs): """ Initialize the model for tensor parallelism. @@ -379,8 +379,9 @@ def tp_model_init(model, tp_size, dtype): torch.nn.Module: The initialized model with tensor parallelism. """ # avoid re-entry - assert not hasattr( - model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed." + if hasattr(model, 'ds_autotp_parsed'): + logger.warning("ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed.") + return set_autotp_mode(training=True) diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 800bc6078c3a..5a5fc6c42442 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -621,6 +621,17 @@ def initialize_mesh_device(mesh_shape, mesh_dim_names): return mesh_device +def enable_symm_mem_for_group(group_name: str): + global cdb + assert cdb is not None and cdb.is_initialized( + ), 'DeepSpeed backend not set, please initialize it using init_process_group()' + + if hasattr(cdb, 'enable_symm_mem_for_group'): + cdb.enable_symm_mem_for_group(group_name) + else: + raise RuntimeError(f"Backend {cdb.name} does not support symmetric memory initialization") + + # Main DeepSpeed Comms. public API. def init_distributed(dist_backend=None, auto_mpi_discovery=True, diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index efa0640fb87b..d955d56f5e2f 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -409,6 +409,13 @@ def init_device_mesh(self, mesh_shape, mesh_dim_names): mesh_shape, mesh_dim_names=mesh_dim_names) + def enable_symm_mem_for_group(self, group_name): + if not required_torch_version(min_version=2.5): + raise RuntimeError(f"Torch version must be 2.5 or higher to use symmetric memory. " + f"Current version: {torch.__version__}") + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + return enable_symm_mem_for_group(group_name) + # This will become a light-weight wrapper around torch.distributed functions # TODO: create some example to show how this wrapper can help profile communication diff --git a/deepspeed/compile/__init__.py b/deepspeed/compile/__init__.py new file mode 100644 index 000000000000..208299fb8c50 --- /dev/null +++ b/deepspeed/compile/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team diff --git a/deepspeed/compile/backend.py b/deepspeed/compile/backend.py new file mode 100644 index 000000000000..ee33447aaf4a --- /dev/null +++ b/deepspeed/compile/backend.py @@ -0,0 +1,279 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Dict, List, Callable +import time +import gc + +import torch +from torch.fx import Graph, GraphModule + +try: + import torch.utils._pytree as pytree + import torch._dynamo + import torch._inductor.scheduler + from functorch.compile import make_boxed_func + from torch._functorch.aot_autograd import aot_module_simplified + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + pass + +from deepspeed.accelerator import get_accelerator + +from .fx import add_free_activations +from .graph_param import DSGraphParamManager +from .profilers import ProfilingResult +from .profilers.graph_profile import MemoryProfilingInterpreter +from .patch_compiled_func import patch_compiled_func, unpatch_compiled_func, get_backward_inputs +from .util import get_input_nodes, get_activation_node_names, get_index_by_graph_id, get_deepcompile_handle, log_rank0 +from .partitioner import get_wrapped_partitioner +from .inductor import register_custom_ops, patch_create_aot_dispatcher_function + +remaining_schedule = None +next_pass_step = -1 +next_passes = None +current_passes = None + +param_manager: Dict[int, DSGraphParamManager] = {} +graph_order = [] +profiling_results: Dict[int, ProfilingResult] = {} +opt_pass_times = [] + +opt_passes = {} + +fwd_real_inputs = [] +remaining_bwd_compile_count = 0 + + +def register_compile_pass(name: str, opt_pass_fn): + opt_passes[name] = opt_pass_fn + + +def init_schedule(schedule): + + assert isinstance(schedule, list), f"schedule should be a list, but got {type(schedule)}" + + for step, passes in schedule: + assert isinstance(step, int), f"Each step in schedule should be an integer, but got {type(step)}" + assert isinstance(passes, list), f"Passes at a certain step should be a list, but got {type(passes)}" + + global remaining_schedule + remaining_schedule = schedule + + +def launch_compile_passes(global_steps: int): + global next_pass_step, next_passes + + if len(remaining_schedule) > 0 and global_steps == remaining_schedule[0][0]: + _, next_passes = remaining_schedule.pop(0) + log_rank0(f"Launching compile passes: global_steps={global_steps} passes={next_passes}", True) + + torch._dynamo.reset() + get_deepcompile_handle().reset() + patch_compiled_func() + graph_order.clear() + profiling_results.clear() + param_manager.clear() + + +def set_time_and_tensor_size(graph_id, graph: Graph, mem, bwd, profiling_results): + node_time = [] + tensor_sizes = [] + + for n in graph.nodes: + node_time.append((n.name, n.meta["device_time"] if "device_time" in n.meta else 0.0, + n.meta["wall_time"] if "wall_time" in n.meta else 0.0)) + tensor_sizes.append((n.name, n.meta["tensor_size"] if "tensor_size" in n.meta else 0)) + + if bwd: + profiling_results[graph_id].bwd_graph = graph + profiling_results[graph_id].bwd_time = node_time + profiling_results[graph_id].bwd_tensor_sizes = tensor_sizes + profiling_results[graph_id].bwd_mem = mem + else: + profiling_results[graph_id].fwd_graph = graph + profiling_results[graph_id].fwd_time = node_time + profiling_results[graph_id].fwd_tensor_sizes = tensor_sizes + profiling_results[graph_id].fwd_mem = mem + + +def run_opt_passes(opt_passes: List[Callable], + gm: GraphModule, + graph_id: int, + graph_order: List[int], + profiling_results, + create_inputs_fn, + mem_budget: float, + param_manager, + bwd: bool, + debug_log=False) -> None: + + with unset_fake_temporarily(): + get_accelerator().synchronize() + gc.collect() + get_accelerator().empty_cache() + + for i, opt_pass_fn in enumerate(opt_passes): + log_rank0(f"Running opt pass {i} for graph {graph_id}. bwd={bwd}", enable=debug_log) + + gm_new = opt_pass_fn(gm, graph_id, graph_order, profiling_results, create_inputs_fn, mem_budget, param_manager, + bwd) + if gm_new is not None: + gm = gm_new + gm.graph.lint() + gm.recompile() + + mem_prof = MemoryProfilingInterpreter(gm, debug_log=debug_log) + mem_prof.run(*create_inputs_fn()) + mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] + + set_time_and_tensor_size(graph_id, gm.graph, mem, bwd, profiling_results) + + with unset_fake_temporarily(): + get_accelerator().synchronize() + gc.collect() + get_accelerator().empty_cache() + + +def make_backend(backend, compile_kwargs={}, free_activation=False, debug_log=False): + + register_custom_ops() + + def backend_fn(gm: GraphModule, real_inputs): + graph_id = id(gm.graph) + needs_backward = pytree.tree_any(lambda x: x.requires_grad if torch.is_tensor(x) else False, real_inputs) + + global graph_order + graph_order.append((graph_id, needs_backward)) + + z3_partition = any(hasattr(v, "ds_id") for v in real_inputs) + if z3_partition: + param_indices = [(i, input_val.ds_id, input_val.ds_shape) for i, input_val in enumerate(real_inputs) + if isinstance(input_val, torch.nn.Parameter)] + else: + assert all(hasattr(v, "param_id") for v in real_inputs + if isinstance(v, torch.nn.Parameter)), "All param inputs should have param_id" + param_indices = [(i, input_val.param_id, input_val.shape) for i, input_val in enumerate(real_inputs) + if isinstance(input_val, torch.nn.Parameter)] + + global fwd_real_inputs + fwd_real_inputs.append(real_inputs) + + global profiling_results + if graph_id not in profiling_results: + profiling_results[graph_id] = ProfilingResult() + profiling_results[graph_id].param_indices = param_indices + profiling_results[graph_id].needs_backward = needs_backward + + def make_fw_graph(gm, sample_inputs): + time_start = time.time() + graph_index = len(graph_order) - 1 + real_inputs = fwd_real_inputs.pop(0) + + param_manager[graph_id] = DSGraphParamManager(gm.graph, real_inputs, param_indices) + + real_inputs_with_rng = real_inputs + sample_inputs[len(real_inputs):] + run_opt_passes( + opt_passes=next_passes, + gm=gm, + graph_id=graph_id, + graph_order=graph_order, + profiling_results=profiling_results, + create_inputs_fn=lambda: real_inputs_with_rng, + mem_budget=.0, # unused + param_manager=param_manager, + bwd=False, + debug_log=debug_log) + + if needs_backward: + global remaining_bwd_compile_count + remaining_bwd_compile_count += 1 + + opt_pass_times.append(("fwd", graph_index, graph_id, time.time() - time_start)) + + log_rank0( + f"Fwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", + enable=debug_log) + + return gm.graph + + def make_bw_graph(gm, sample_inputs): + time_start = time.time() + + graph_index = get_index_by_graph_id(graph_order, graph_id) + log_rank0( + f"Bwd start {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", + enable=debug_log) + + bwd_inputs_stack = get_backward_inputs() + + if len(bwd_inputs_stack) == 0: + # dynamo calls bw compiler ahead of time when symints are saved for backward. See the details for aot_dispatch_autograd in jit_compile_runtime_wrappers. + # As we currently use actually bwd input values in bw compiler, we return None to skip the compilation there. + # This would need be handled properly in the future. + return None + + bwd_real_inputs = bwd_inputs_stack.pop() + run_opt_passes( + opt_passes=next_passes, + gm=gm, + graph_id=graph_id, + graph_order=graph_order, + profiling_results=profiling_results, + create_inputs_fn=lambda: tuple(bwd_real_inputs), + mem_budget=.0, # unused + param_manager=param_manager, + bwd=True, + debug_log=debug_log) + + # assert graph_id in param_manager, f"Graph {graph_id} not found in param_manager" + + if free_activation: + param_nodes_bw, _ = param_manager[graph_id].get_bwd_mapping(gm.graph) + param_names = [n.name for n in param_nodes_bw] + non_param_input_names = [n.name for n in get_input_nodes(gm.graph) if n.name not in param_names] + add_free_activations(graph_id, gm.graph, + get_activation_node_names(gm.graph, param_nodes_bw, non_param_input_names)) + + global remaining_bwd_compile_count + remaining_bwd_compile_count -= 1 + if remaining_bwd_compile_count == 0: + unpatch_compiled_func() + + log_rank0( + f"Bwd end {graph_index} graph_id={graph_id} alloc_mem={get_accelerator().memory_allocated()} graph={gm.graph}", + enable=debug_log) + + gm.recompile() + opt_pass_times.append(("bwd", graph_index, graph_id, time.time() - time_start)) + + return gm.graph + + if backend == "eager": + + def make_compiler_fn(make_graph_fn): + + def compiler_fn(gm, sample_inputs): + return None if make_graph_fn(gm, sample_inputs) is None else make_boxed_func(gm.forward) + + return compiler_fn + + aot_mod = aot_module_simplified(gm, + real_inputs, + fw_compiler=make_compiler_fn(make_fw_graph), + bw_compiler=make_compiler_fn(make_bw_graph), + partition_fn=get_wrapped_partitioner(param_indices)) + return torch._dynamo.optimize(**compile_kwargs)(aot_mod) + elif backend == "inductor": + patch_create_aot_dispatcher_function(graph_id, z3_partition, make_fw_graph, make_bw_graph, real_inputs, + param_indices, param_manager) + from .partitioner import get_wrapped_choose_saved_values_set + torch._functorch.partitioners.choose_saved_values_set = get_wrapped_choose_saved_values_set(param_indices) + + return torch._inductor.compile(gm, real_inputs) + + raise ValueError(f"Unsupported backend {backend}") + + return backend_fn diff --git a/deepspeed/compile/config.py b/deepspeed/compile/config.py new file mode 100644 index 000000000000..d88458fc594e --- /dev/null +++ b/deepspeed/compile/config.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from deepspeed.runtime.config_utils import DeepSpeedConfigModel + + +class CompileConfig(DeepSpeedConfigModel): + """ Configure compile settings """ + + deepcompile: bool = False + """ Turn on/off the DeepCompile mode """ + + free_activation: bool = False + """ Turn on/off the free activation mode """ + + offload_activation: bool = False + """ Turn on/off the activation offloading """ + + offload_opt_states: bool = False + """ Turn on/off the optimizer states offloading """ + + double_buffer: bool = True + """ Turn on/off the double buffering """ + + symmetric_memory: bool = False + """ Turn on/off the symmetric memory """ + + debug_log: bool = False + """ Turn on/off the graph dumping """ + + offload_parameters: bool = False + """ Turn on/off the parameter offloading """ + + sync_before_reduce: bool = False + """ Turn on/off the sync before reduce """ + + sync_after_reduce: bool = False + """ Turn on/off the sync after reduce """ + + sync_before_allgather: bool = False + """ Turn on/off the sync before allgather """ + + sync_after_allgather: bool = False + """ Turn on/off the sync after allgather """ diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py new file mode 100644 index 000000000000..3506aef7062b --- /dev/null +++ b/deepspeed/compile/fx.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Callable, Any, List +from collections import defaultdict + +import torch +from torch.fx import Node, Graph + +from .util import get_last_uses + + +def get_output_node(graph: Graph): + for v in graph.nodes: + if v.target == "output": + return v + raise ValueError("No output node found") + + +def move_primals_to_head(graph: Graph): + + # Move primals to the head of the graph + primals = [n for n in graph.nodes if n.op == "placeholder"] + non_primals = [n for n in graph.nodes if n.op != "placeholder"] + all_nodes = primals + non_primals + + new_graph = Graph() + env = {} + for node in all_nodes: + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + new_graph.lint() + + return new_graph + + +def add_args_process(graph: Graph, + node: Node, + fn: Callable[..., Any], + extra_args: List[int] = [], + name=None, + meta={}) -> List[Node]: + # Apply fn to all args of node + new_nodes = [] + with graph.inserting_before(node): + target_args = [arg for arg in node.args if isinstance(arg, Node)] + + for arg in target_args: + new_node = graph.create_node('call_function', fn, (arg, ) + tuple(extra_args), name=name) + for k, v in meta.items(): + new_node.meta[k] = v + node.replace_input_with(arg, new_node) + new_nodes.append(new_node) + + return new_nodes + + +def add_postprocess(graph: Graph, + node: Node, + fn: Callable[..., Any], + extra_args: List[int] = [], + name=None, + meta={}) -> Node: + # https://github.com/pytorch/examples/blob/main/fx/wrap_output_dynamically.py + with graph.inserting_after(node): + args = (node, ) + for a in extra_args: # To add ds_id + args += (a, ) + + node_users = node.users.keys() + new_node = graph.create_node('call_function', fn, args, {}, name=name) + users = {} + for u in node_users: + if u != new_node: + users[u] = (node, new_node) + for u, (old_in, new_in) in users.items(): + u.replace_input_with(old_in, new_in) + + for k, v in meta.items(): + new_node.meta[k] = v + + return new_node + + +def _make_node_meta(node: Node, ds_id: int, comm: bool): + meta = {"param_name": node.name, "ds_id": ds_id, "comm": comm} + if "tensor_meta" in node.meta: + meta["tensor_meta"] = node.meta["tensor_meta"] + return meta + + +def add_free_activations(graph_id: int, graph: Graph, activation_node_names: List[str]): + node_to_last_use, _ = get_last_uses(graph) + activation_nodes_set = set([n for n in graph.nodes if n.op == "placeholder" and n.name in activation_node_names]) + + offload_id_to_node = {} + node_to_wait_reload = {} + for node in graph.nodes: + if node.target == torch.ops.dc.reload_tensor.default: + offload_act = node.args[0] + # node_to_offload_id[offload_act] = node.args[2] + offload_id_to_node[node.args[2]] = offload_act + elif node.target == torch.ops.dc.wait_reload.default: + offload_id = node.args[2] + node_to_wait_reload[offload_id_to_node[offload_id]] = node + + activation_nodes_set = set(node_to_wait_reload[n] if n in node_to_wait_reload else n for n in activation_nodes_set) + + last_user_to_uses = defaultdict(list) + for node, last_user in node_to_last_use.items(): + last_user_to_uses[last_user].append(node) + + def _should_free(node: Node) -> bool: + if not hasattr(node, "meta"): + return False + if not "tensor_meta" in node.meta: + return False + return True + + def free_tensors(tensors: List[torch.Tensor]): + for a in tensors: + if a.numel() > 10_000_000: + a.data = torch.empty([0], device=a.device, dtype=a.dtype) + + for last_user, used_nodes in last_user_to_uses.items(): + activation_args = [an for an in used_nodes if an in activation_nodes_set and _should_free(an)] + + if len(activation_args) == 0: + continue + + node_name = f"free_activations_{[n.name for n in used_nodes]}" + with graph.inserting_after(last_user): + args = (activation_args, ) + graph.create_node('call_function', torch.ops.dc.free_tensors.default, args, {}, name=node_name) + + # Python version for debugging + # graph.create_node('call_function', free_tensors, args, {}, name=node_name) diff --git a/deepspeed/compile/graph_param.py b/deepspeed/compile/graph_param.py new file mode 100644 index 000000000000..445af97374f0 --- /dev/null +++ b/deepspeed/compile/graph_param.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple +from functools import reduce + +import torch +from torch.fx import Graph, Node + +from .fx import get_output_node +from .util import get_param_nodes + + +@dataclass +class DSGraphParam: + name: str + shape: torch.Size + dtype: torch.dtype + device: torch.device + node: Node + allgather_node: Node + release_node: Node + param: torch.Tensor + numel: int = field(init=False) + + def __post_init__(self): + self.numel = reduce(lambda x, y: x * y, self.shape) + + +class DSGraphParamManager: + + def __init__(self, fw_graph: Graph, sample_inputs: Any, index_to_ds_ids: List[Tuple[int, int, int]]): + self._fw_graph = fw_graph + self._bw_graph = None + self._params: Dict[str, DSGraphParam] = {} + self._param_name_to_grad: Dict[str, Node] = {} + self._ds_ids: Dict[str, int] = {} + + param_nodes = get_param_nodes(fw_graph, index_to_ds_ids) + self._param_names = [pn.name for pn in param_nodes] + self._param_indices = [i for i, _, _ in index_to_ds_ids] + + param_inputs = [sample_inputs[i] for i, _, _ in index_to_ds_ids] + ds_ids = [ds_id for _, ds_id, _ in index_to_ds_ids] + ds_shapes = [ds_shape for _, _, ds_shape in index_to_ds_ids] + + for pn, pi, ds_id, ds_shape in zip(param_nodes, param_inputs, ds_ids, ds_shapes): + self._params[pn.name] = DSGraphParam(name=pn.name, + shape=ds_shape, + dtype=pi.dtype, + device=pi.device, + node=pn, + allgather_node=None, + release_node=None, + param=pi) + self._ds_ids[pn.name] = ds_id + + def get_bwd_mapping(self, bw_graph: Graph): + self._bw_graph = bw_graph + + output_node = get_output_node(bw_graph) + param_nodes_bw = [n for n in self._bw_graph.nodes if n.name in self.param_names] + grad_outputs = [output_node.args[0][i] for i in self._param_indices] + param_name_to_grad = {param_name: grad for param_name, grad in zip(self.param_names, grad_outputs)} + return param_nodes_bw, param_name_to_grad + + @property + def param_names(self) -> List[str]: + return self._param_names + + @property + def params(self) -> Dict[str, DSGraphParam]: + return self._params + + @property + def ds_ids(self) -> Dict[str, int]: + return self._ds_ids + + def get_grad_name(self, param_name) -> str: + assert self._param_name_to_grad is not None, "Backward graph is not added yet" + return self._param_name_to_grad[param_name] diff --git a/deepspeed/compile/inductor.py b/deepspeed/compile/inductor.py new file mode 100644 index 000000000000..7ba1c17ee43b --- /dev/null +++ b/deepspeed/compile/inductor.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +try: + import torch.utils._pytree as pytree + from torch._functorch.aot_autograd import create_aot_dispatcher_function + from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs + from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode + from torch._inductor.virtualized import V + from torch._inductor.scheduler import Scheduler + + original_create_aot_dispatcher_function = create_aot_dispatcher_function +except ImportError: + pass + +from .util import get_input_nodes +from .graph_param import DSGraphParamManager + + +def patch_compiler(original_compiler, dc_compiler, z3_partition: bool, graph_id, graph_param_manager, bwd: bool): + + def wrapped_compiler(gm, fake_inputs): + mod_graph = dc_compiler(gm, fake_inputs) + + # For symint case + if mod_graph is None: + return None + + if z3_partition: + # Inductor validates input size estimated by the first trace, where ds tensor is materialized. + # We need to patch the input tensors to avoid the validation error. + patched_inputs = [] + if bwd: + param_nodes_bw, _ = graph_param_manager[graph_id].get_bwd_mapping(gm.graph) + param_names = [n.name for n in param_nodes_bw] + else: + param_names = graph_param_manager[graph_id].param_names + input_nodes = get_input_nodes(gm.graph) + + for in_node, in_v in zip(input_nodes, fake_inputs): + ds_param = in_node.name in param_names + if ds_param: + from torch._subclasses.fake_tensor import is_fake + from torch._dynamo.utils import to_fake_tensor + assert is_fake(in_v), f"Input {in_v} should be fake tensor" + patched_inputs.append( + to_fake_tensor(torch.empty([0], dtype=in_v.dtype, device=in_v.device), in_v.fake_mode)) + else: + patched_inputs.append(in_v) + + patched_inputs = tuple(patched_inputs) + else: + patched_inputs = fake_inputs + + return original_compiler(gm, patched_inputs) + + return wrapped_compiler + + +def wrap_partition_fn(partition_fn, real_inputs, param_indices): + + def wrapped_partition_fn(*args, **kwargs): + + fw_module, bw_module = partition_fn(*args, **kwargs) + + # get parameter names + pm = DSGraphParamManager(fw_module.graph, real_inputs, param_indices) + + def fix_placeholder_meta(graph): + for n in graph.nodes: + if n.op == "placeholder" and n.name in pm.param_names: + n.meta["val"] = torch.empty([0], dtype=n.meta["val"].dtype, device=n.meta["val"].device) + + fix_placeholder_meta(fw_module.graph) + fix_placeholder_meta(bw_module.graph) + + return fw_module, bw_module + + return wrapped_partition_fn + + +def patch_create_aot_dispatcher_function(graph_id: int, z3_partition: bool, make_fw_graph, make_bw_graph, real_inputs, + param_indices, param_manager): + + from torch._dynamo.backends.common import AotAutograd + import functools + + def patch_aotautograd(): + # Unpatch if it was already patched + if hasattr(AotAutograd, "__original_init"): + AotAutograd.__init__ = AotAutograd.__original_init + + original_init = AotAutograd.__init__ + + @functools.wraps(original_init) + def patched_init(self, **kwargs): + kwargs["fw_compiler"] = patch_compiler(kwargs["fw_compiler"], + make_fw_graph, + z3_partition, + graph_id, + param_manager, + bwd=False) + kwargs["bw_compiler"] = patch_compiler(kwargs["bw_compiler"], + make_bw_graph, + z3_partition, + graph_id, + param_manager, + bwd=True) + kwargs["inference_compiler"] = kwargs["fw_compiler"] + + if z3_partition: + kwargs["partition_fn"] = wrap_partition_fn(kwargs["partition_fn"], real_inputs, param_indices) + + original_init(self, **kwargs) + + AotAutograd.__original_init = original_init + AotAutograd.__init__ = patched_init + + patch_aotautograd() + + +def register_custom_ops(): + + def fallback_handler_no_reuse(kernel, + never_reuse_input, + never_reuse_output, + force_free_input, + add_to_fallback_set=True): + if add_to_fallback_set: + fallbacks.add(kernel) + + def handler(*args, **kwargs): + + def wrap_tensors(x): + out = TensorBox.create(x) if isinstance(x, torch._inductor.ir.IRNode) else x + if out is not None and never_reuse_output: + V.graph.never_reuse_buffers.add(out.get_name()) + return out + + class CustomDCKernel(FallbackKernel): + + def __init__(self, op, *args, **kwargs): + super().__init__(op, *args, **kwargs) + + def add_to_never_reuse(x): + if isinstance(x, IRNode): + assert hasattr(x, "get_name"), f"x doesn't have get_name {x.__class__}" + V.graph.never_reuse_buffers.add(x.get_name()) + + if never_reuse_input: + pytree.tree_map(add_to_never_reuse, args) + + def get_var_name_for_arg(self, arg: str): + if arg.isidentifier(): + return arg + + import re + match = re.match(r"reinterpret_tensor\((\w+),", arg) + if match: + return match.group(1) + return None + + def codegen(self, wrapper): + if not force_free_input: + return super().codegen(wrapper) + + kernel = self.op_overload + self.codegen_comment(wrapper) + args = [*self.codegen_args(), *self.codegen_kwargs()] + + V.graph.wrapper_code.generate_fallback_kernel(self, args) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) + + var_name = self.get_var_name_for_arg(args[0]) + if var_name: + wrapper.writeline(f"{var_name} = None") + + self.codegen_unbacked_symbol_defs(wrapper) + + kernel_cls = CustomDCKernel if force_free_input else FallbackKernel + return pytree.tree_map(wrap_tensors, kernel_cls.create(kernel, *args, **kwargs)) + + return handler + + def register_fallback_no_reuse(op_overload, + never_reuse_input=False, + never_reuse_output=False, + force_free_input=False): + add_needs_realized_inputs(op_overload) + return register_lowering(op_overload, type_promotion_kind=None)(fallback_handler_no_reuse( + op_overload, + never_reuse_input=never_reuse_input, + never_reuse_output=never_reuse_output, + force_free_input=force_free_input)) + + # Inductor tries to reuse output buffer when possible. We need to disable this behavior for some custom ops. + # -> It seems that memory region is still reused in some cases. So we clone the inputs for some ops. + register_fallback_no_reuse(torch.ops.dc.allgather_param.default, never_reuse_input=False, never_reuse_output=True) + register_fallback_no_reuse(torch.ops.dc.wait_allgather.default, never_reuse_input=True, never_reuse_output=True) + register_fallback_no_reuse(torch.ops.dc.release_param.default, never_reuse_input=True, never_reuse_output=False) + register_fallback_no_reuse(torch.ops.dc.reduce_grad.default, + never_reuse_input=True, + never_reuse_output=True, + force_free_input=True) + register_fallback_no_reuse(torch.ops.dc.free_tensors.default, never_reuse_input=True, never_reuse_output=True) + + if not hasattr(Scheduler, "is_dc_patched") or not Scheduler.is_dc_patched: + Scheduler.is_dc_patched = True + Scheduler.dead_node_elimination = lambda _: None diff --git a/deepspeed/compile/init_z1.py b/deepspeed/compile/init_z1.py new file mode 100644 index 000000000000..2591e9db8e01 --- /dev/null +++ b/deepspeed/compile/init_z1.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy + +import torch + +from deepspeed.accelerator import get_accelerator +from .passes import zero1_compile, zero3_compile +from .backend import make_backend, launch_compile_passes, init_schedule +from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor + +WARMUP = 5 + + +def init_z1(engine, backend, compile_config, compile_kwargs, schedule=None): + + optimizer = engine.optimizer + optimizer.contiguous_gradients = False # Avoid creating unnecessary buffer + for hook in optimizer._grad_acc_hooks: + hook.remove() + optimizer._grad_acc_hooks.clear() + + dc = get_deepcompile_handle() + dc.init(engine.data_parallel_group, + engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory, + is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, False, + False) + + grad_buffer = {} + + for i, group in enumerate(optimizer.bit16_groups): + + grad_buffer[i] = optimizer.get_flat_partition(optimizer.params_in_partition[i], + optimizer.first_offset[i], + optimizer.partition_size[i], + dtype=optimizer.gradient_accumulation_dtype, + device=get_accelerator().current_device_name(), + return_tensor_list=True) + grad_buffer[i] = [p.clone().detach() for p in grad_buffer[i]] # Maybe not necessary + + index_in_partition = 0 + first_in_partition = True + for p in group: + param_id = optimizer.get_param_id(p) + p.param_id = param_id + in_partition = optimizer.is_param_in_current_partition[param_id] + + if in_partition: + buf = grad_buffer[i][index_in_partition] + offset = optimizer.first_offset[i] if first_in_partition else 0 + # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf={buf.shape} partition_offset={offset}") + dc.register_z1_param(p.param_id, p.shape, p, buf, int(offset)) + index_in_partition += 1 + first_in_partition = False + else: + # print(f"[r{dist.get_rank()}] Registering group {i} param {param_id} in_partition={in_partition} p={p.shape} buf=None") + dc.register_z1_param(p.param_id, p.shape, p, torch.empty([0], dtype=p.dtype, device=p.device), 0) + + def set_grad_buffer(): + optimizer.averaged_gradients = copy.copy(grad_buffer) + + add_pre_backward_hook(set_grad_buffer) + + if schedule is None: + schedule = [] + schedule.append((0, [zero1_compile.add_z1_reduce])) + else: + for opt in schedule: + # avoid typical misconfiguration + if zero3_compile.add_z3_gather_release in opt[1]: + raise ValueError("A pass for ZeRO3 is not specified though ZeRO1 is enabled") + + init_schedule(schedule) + + engine.launch_compile_passes = launch_compile_passes + return make_backend(backend, + compile_kwargs=compile_kwargs, + free_activation=False, + debug_log=compile_config.debug_log) diff --git a/deepspeed/compile/init_z3.py b/deepspeed/compile/init_z3.py new file mode 100644 index 000000000000..f05de840de03 --- /dev/null +++ b/deepspeed/compile/init_z3.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.partition_parameters import InsertPostInitMethodToModuleSubClasses + +from .passes import zero3_compile, prefetch, selective_gather, offload_parameters +from .backend import make_backend, launch_compile_passes, init_schedule +from .patch_fake_tensor import patch_fake_tensor +from .util import get_deepcompile_handle, add_pre_backward_hook, is_backend_inductor + +WARMUP = 5 + + +def init_z3(engine, backend, compile_config, compile_kwargs, schedule=None): + + optimizer = engine.optimizer + if optimizer is not None and hasattr(optimizer, '_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer'): + optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer = None + get_accelerator().empty_cache() + + dc = get_deepcompile_handle() + dc.init(engine.data_parallel_group, + engine.zero_reduce_bucket_size(), compile_config.double_buffer, compile_config.symmetric_memory, + is_backend_inductor(backend), compile_config.sync_before_reduce, compile_config.sync_after_reduce, + compile_config.sync_before_allgather, compile_config.sync_after_allgather) + + # Unset hooks + for m in engine.module.modules(): + m._parameters = m._original_parameters + optimizer.parameter_offload._remove_module_hooks() + + for hook in optimizer._grad_acc_hooks: + hook.remove() + optimizer._grad_acc_hooks.clear() + + # Unpatch linear + if hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"): + torch.nn.functional.linear = InsertPostInitMethodToModuleSubClasses.linear_bk + + if compile_config.symmetric_memory: + group_name = engine.data_parallel_group.group_name + dist.enable_symm_mem_for_group(group_name) + + for p in engine.module.parameters(): + grad_buffer = optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[p.ds_id] + + # Disable persistent param + p.ds_persist = False + dc.register_z3_param(p.ds_id, p.ds_shape, p.ds_tensor, grad_buffer, p.ds_persist) + + def set_grad_buffer(): + for i, sub_group in enumerate(optimizer.fp16_groups): + optimizer.averaged_gradients[i] = [ + optimizer._DeepSpeedZeroOptimizer_Stage3__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group + ] + + add_pre_backward_hook(set_grad_buffer) + + if schedule is None: + schedule = [] + if (compile_config.offload_parameters): + schedule.append((0, [zero3_compile.add_z3_gather_release, offload_parameters.offload_parameter_fwd])) + else: + schedule.append((0, [zero3_compile.add_z3_gather_release])) + schedule.append( + (WARMUP, + [zero3_compile.add_z3_gather_release, prefetch.schedule_prefetch, selective_gather.selective_gather])) + + init_schedule(schedule) + + # offloading opt states need additional setup + from .passes.offload_adam_states import move_opt_states, move_opt_states_sync, init_offload_opt_states + for _, passes in schedule: + if move_opt_states in passes or move_opt_states_sync in passes: + init_offload_opt_states(optimizer, dc) + + engine.launch_compile_passes = launch_compile_passes + + patch_fake_tensor() + free_activation = compile_config.free_activation and not is_backend_inductor(backend) + + torch._inductor.config.size_asserts = False + + return make_backend(backend, + compile_kwargs=compile_kwargs, + free_activation=free_activation, + debug_log=compile_config.debug_log) diff --git a/deepspeed/compile/list_schedule.py b/deepspeed/compile/list_schedule.py new file mode 100644 index 000000000000..8df84498a581 --- /dev/null +++ b/deepspeed/compile/list_schedule.py @@ -0,0 +1,431 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from collections import defaultdict +from typing import List, Dict +from copy import copy +from dataclasses import dataclass + +import torch +from torch.fx import Graph, Node +from torch.fx.node import map_arg + +try: + from torch.utils._pytree import tree_iter +except ImportError: + pass + +from .util import get_last_uses, is_release_node +from .fx import get_output_node + + +def make_graph_from_schedule(scheduled: List[Node]): + new_graph = Graph() + env = {} + for node in scheduled: + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + + return new_graph + + +def get_original_args_num(node: Node): + if node.name.startswith("allgather_ds_param") \ + or node.name.startswith("release_ds_param") \ + or node.name.startswith("wait_allgather_ds_param") \ + or node.name.startswith("reduce_ds_param"): + return 1 + + return len(node.args) + + +def flat_nodes_in_args(args: List[Node]): + return [a for a in tree_iter(args) if isinstance(a, Node)] + + +def filter_args(node: Node): + args = node.args[:get_original_args_num(node)] + return flat_nodes_in_args(args) + + +def init_schedule(graph: Graph): + mem_table = create_mem_table(graph) + remaining_users = defaultdict(set) + user_to_producer = {} + + scheduled = [] + unscheduled = [] + edges = defaultdict(list) + for node in graph.nodes: + filtered_args = filter_args(node) + # print(f"Node: {node} args: {node.args}") + if len(filtered_args) == 0: + scheduled.append(node) + + remaining_users[node] = set(node.users.keys()) + for user in node.users.keys(): + user_to_producer[user] = node + else: + unscheduled.append(node) + for a in filtered_args: + for elem_a in tree_iter(a): + if isinstance(elem_a, Node): + if node not in edges[elem_a]: + edges[elem_a].append(node) + + return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer + + +def get_runnable_nodes(scheduled: List[Node], unscheduled: List[Node]): + scheduled = set(scheduled) + return [node for node in unscheduled if all(arg in scheduled for arg in filter_args(node))] + + +def choose_next_node(scheduled: List[Node], unscheduled: List[Node], mem_table: Dict[str, int]): + runnable_nodes = get_runnable_nodes(scheduled, unscheduled) + + # sort by memory usage + runnable_nodes = sorted(runnable_nodes, key=lambda n: mem_table[n.name]) + return runnable_nodes[0] + + +def create_mem_table(graph: Graph) -> Dict[str, int]: + mem_table = {} + for node in graph.nodes: + if node.name.startswith("allgather_ds_param"): + mem_table[node.name] = node.meta["tensor_size"] + elif node.name.startswith("release_ds_param") or node.name.startswith("reduce_ds_param"): + mem_table[node.name] = -node.meta["tensor_size"] + else: + mem_table[node.name] = 0 + + return mem_table + + +def list_schedule(graph: Graph) -> Graph: + + scheduled, unscheduled, mem_table = init_schedule(graph) + + while len(unscheduled) > 0: + next_node = choose_next_node(scheduled, unscheduled, mem_table) + scheduled.append(next_node) + unscheduled.remove(next_node) + + return make_graph_from_schedule(scheduled) + + +############################### + + +def get_new_runnable_nodes_with(scheduled: List[Node], edges: Dict[Node, List[Node]], new_scheduled: Node): + scheduled = set(scheduled) + new_runnables = [] + for node in edges[new_scheduled]: + if all(arg in scheduled for arg in filter_args(node) if arg != new_scheduled): + new_runnables.append(node) + + return new_runnables + + +def _do_schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]], + non_ag_runnable: List[Node]): + + while len(non_ag_runnable) > 0: + next_node = non_ag_runnable.pop() + + new_runnables = get_new_runnable_nodes_with(scheduled, edges, next_node) + non_ag_runnable += [n for n in new_runnables if not n.name.startswith("allgather_ds_param")] + + scheduled.append(next_node) + unscheduled.remove(next_node) + + return scheduled, unscheduled + + +def schedule_without_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]]): + runnable = get_runnable_nodes(scheduled, unscheduled) + non_ag_runnable = [n for n in runnable if not n.name.startswith("allgather_ds_param")] + + tmp_scheduled = copy(scheduled) + tmp_unscheduled = copy(unscheduled) + + return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable) + + +def try_schedule_with_new_allgather(scheduled: List[Node], unscheduled: List[Node], edges: Dict[Node, List[Node]], + new_scheduled: Node): + new_runnables = get_new_runnable_nodes_with(scheduled, edges, new_scheduled) + non_ag_runnable = [n for n in new_runnables if not n.name.startswith("allgather_ds_param")] + + tmp_scheduled = copy(scheduled) + tmp_unscheduled = copy(unscheduled) + + tmp_scheduled.append(new_scheduled) + tmp_unscheduled.remove(new_scheduled) + + return _do_schedule_without_allgather(tmp_scheduled, tmp_unscheduled, edges, non_ag_runnable) + + +def simple_prefetch(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph: + + scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule(graph) + tmp_scheduled, tmp_unscheduled = schedule_without_allgather(scheduled, unscheduled, edges) + + while len(tmp_unscheduled) > 0: + + runnable = get_runnable_nodes(tmp_scheduled, tmp_unscheduled) + ag_with_unblock_time = [] + + for ag_node in runnable: + ag_scheduled, ag_unscheduled = try_schedule_with_new_allgather(tmp_scheduled, tmp_unscheduled, edges, + ag_node) + unblock_time = sum(n.meta["device_time"] for n in ag_scheduled[len(tmp_scheduled) + 1:]) + ag_with_unblock_time.append((ag_node, unblock_time, ag_scheduled, ag_unscheduled)) + + ag_with_unblock_time = sorted(ag_with_unblock_time, key=lambda x: x[1], reverse=True) + best_ag_node = ag_with_unblock_time[0][0] + best_ag_scheduled = ag_with_unblock_time[0][2] + + no_ag_runnables = tmp_scheduled[len(scheduled):] + after_ag_runnables = best_ag_scheduled[len(tmp_scheduled) + 1:] + + scheduled.append(best_ag_node) + unscheduled.remove(best_ag_node) + for n in no_ag_runnables: + scheduled.append(n) + unscheduled.remove(n) + + tmp_scheduled = copy(scheduled) + tmp_unscheduled = copy(unscheduled) + for n in after_ag_runnables: + tmp_scheduled.append(n) + tmp_unscheduled.remove(n) + + return make_graph_from_schedule(tmp_scheduled) + + +############################### + + +def init_schedule_with_placeholders(graph: Graph): + mem_table = create_mem_table(graph) + remaining_users = defaultdict(set) + user_to_producer = {} + + scheduled = [] + unscheduled = [] + edges = defaultdict(list) + for node in graph.nodes: + if node.op == 'placeholder': + scheduled.append(node) + + remaining_users[node] = set(node.users.keys()) + for user in node.users.keys(): + user_to_producer[user] = node + else: + unscheduled.append(node) + + return scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer + + +def get_node_requirements(target_node: Node, scheduled: List[Node]): + scheduled = set(scheduled) + visited = set() + ordered_nodes = [] + + def dfs(node: Node): + if node in scheduled: + return + if node in visited: + return + visited.add(node) + + args = [] + + def register_arg(n: Node): + args.append(n) + + map_arg(node.args, register_arg) + + for arg in args: + dfs(arg) + ordered_nodes.append(node) + + dfs(target_node) + + return ordered_nodes + + +@dataclass +class AllgatherTask: + node: Node + allgather_cost: float + free_cost: float + allgathered_mem: int + allgather_acc_mem: int + free_acc_mem: int + last_use: Node + n_scheduled_ags: int + schedule_until_ag: List[Node] + schedule_until_free: List[Node] + + +def fast_free_schedule(graph: Graph, available_mem: int, output_size: int, debug_log: bool) -> Graph: + node_to_last_use, user_to_last_uses = get_last_uses(graph) + + # check tensor size + for node in graph.nodes: + if "tensor_size" not in node.meta: + # Our profiler may not visit all nodes because of the control flow. + node.meta["tensor_size"] = 0 + + scheduled, unscheduled, edges, mem_table, remaining_users, user_to_producer = init_schedule_with_placeholders( + graph) + + unscheduled_ags = [n for n in unscheduled if n.target == torch.ops.dc.allgather_param.default] + + release_nodes = defaultdict(list) + for n in unscheduled: + if is_release_node(n): + release_nodes[n.args[2]].append(n) + + ag_nodes_in_path = {} + for ag_node in unscheduled_ags: + last_use = node_to_last_use[ag_node] + required_nodes = get_node_requirements(last_use, scheduled) + ag_nodes_in_path[ag_node] = set(n for n in required_nodes if n.target == torch.ops.dc.allgather_param.default) + + reduce_nodes = [n for n in unscheduled if n.target == torch.ops.dc.reduce_grad.default] + ag_nodes_in_path_to_reduce_nodes = {} + for reduce_node in reduce_nodes: + ag_nodes_in_path_to_reduce_nodes[reduce_node] = set(n for n in get_node_requirements(reduce_node, scheduled) + if n.target == torch.ops.dc.allgather_param.default) + + output_nodes = [ + n for n in get_output_node(graph).args[0] + if isinstance(n, Node) and n.target != torch.ops.dc.reduce_grad.default + ] + ag_nodes_in_path_to_output_nodes = {} + for output_node in output_nodes: + ag_nodes_in_path_to_output_nodes[output_node] = set(n for n in get_node_requirements(output_node, scheduled) + if n.target == torch.ops.dc.allgather_param.default) + + while len(unscheduled_ags) > 0: + + ag_nodes_count = {ag_node: len(nodes) for ag_node, nodes in ag_nodes_in_path.items()} + count_list = sorted(set(ag_nodes_count.values())) + + runnable_ags = [] + for ag_count in count_list: + + target_unscheduled_ags = [ag for ag in unscheduled_ags if ag_nodes_count[ag] == ag_count] + + for node in target_unscheduled_ags: + ds_id = node.args[2] + + schedule_until_ag = get_node_requirements(node, scheduled) + if schedule_until_ag is None: + continue + + last_use = node_to_last_use[node] + + diff_required_nodes = get_node_requirements(last_use, scheduled + schedule_until_ag) + + allgather_cost = sum(n.meta["device_time"] for n in schedule_until_ag) + free_cost = sum(n.meta["device_time"] for n in diff_required_nodes) + allgathered_mem = node.meta["tensor_size"] + allgather_acc_mem = sum(n.meta["tensor_size"] for n in schedule_until_ag + if n.target == torch.ops.dc.allgather_param.default) + free_acc_mem = sum(n.meta["tensor_size"] for n in diff_required_nodes + if n.target == torch.ops.dc.allgather_param.default) + + schedule_until_free = schedule_until_ag + diff_required_nodes + for release_node in release_nodes[ds_id]: + if release_node not in schedule_until_free: + schedule_until_free.append(release_node) + + n_scheduled_ags = len( + [n for n in schedule_until_free if n.target == torch.ops.dc.allgather_param.default]) + + task = AllgatherTask(node, allgather_cost, free_cost, allgathered_mem, allgather_acc_mem, free_acc_mem, + last_use, n_scheduled_ags, schedule_until_ag, schedule_until_free) + + # print(f" ag_count {ag_count} allgather runnable {i}: {node} last_use: {node_to_last_use[node]} t: {t2-t1:.2f}") + runnable_ags.append(task) + + if len(runnable_ags) > 0: + break + + assert len(runnable_ags) > 0, "No runnable allgather nodes" + + # Criteria of the choice: + # We want to choose allgather that does not require additional allgather until releasing the param. + # When we can find such a node, free_acc_mem will be zero. In that case, we choose the one with the smallest cost until free to minimize the period of occupying memory for the gathered param. + # If there is no such node, we choose the one with the smallest free_cost to minimize the period of occupying memory for the gathered param. + ags_with_no_additional_ag = [ag for ag in runnable_ags if ag.free_acc_mem == 0] + if len(ags_with_no_additional_ag) > 0: + sorted_ags = sorted(runnable_ags, key=lambda x: x.free_cost) + next_ag = sorted_ags[0] + nodes_to_schedule = next_ag.schedule_until_free + else: + # sorted_ags = sorted(runnable_ags, key=lambda x: x.allgathered_mem) + sorted_ags = sorted(runnable_ags, key=lambda x: x.free_acc_mem) + next_ag = sorted_ags[0] + nodes_to_schedule = next_ag.schedule_until_ag + + # print(f" next_ag {next_ag}") + for n in nodes_to_schedule: + scheduled.append(n) + unscheduled.remove(n) + + unscheduled_ags.remove(next_ag.node) + + ag_nodes_in_path.pop(next_ag.node) + for ag_node, nodes in ag_nodes_in_path.items(): + if next_ag.node in nodes: + nodes.remove(next_ag.node) + + # Schedule reduce nodes when possible to free memory earlier + reduces_to_schedule = [] + for reduce_node in reduce_nodes: + if next_ag.node in ag_nodes_in_path_to_reduce_nodes[reduce_node]: + ag_nodes_in_path_to_reduce_nodes[reduce_node].remove(next_ag.node) + if len(ag_nodes_in_path_to_reduce_nodes[reduce_node]) == 0: + reduces_to_schedule.append(reduce_node) + + for n in reduces_to_schedule: + need_to_schedule = get_node_requirements(n, scheduled) + for nn in need_to_schedule: + scheduled.append(nn) + unscheduled.remove(nn) + + # Do the same for output nodes + outputs_to_schedule = [] + for output_node in output_nodes: + if next_ag.node in ag_nodes_in_path_to_output_nodes[output_node]: + ag_nodes_in_path_to_output_nodes[output_node].remove(next_ag.node) + if len(ag_nodes_in_path_to_output_nodes[output_node]) == 0: + outputs_to_schedule.append(output_node) + + for n in outputs_to_schedule: + need_to_schedule = get_node_requirements(n, scheduled) + for nn in need_to_schedule: + scheduled.append(nn) + unscheduled.remove(nn) + + # print(f"After ag scheduled: scheduled: {scheduled}") + + scheduled_set = set(scheduled) + for node in graph.nodes: + if node in scheduled_set: + continue + scheduled.append(node) + unscheduled.remove(node) + + assert len(unscheduled) == 0, f"There are unscheduled nodes: {unscheduled}" + + ret_graph = make_graph_from_schedule(scheduled) + ret_graph.lint() + return ret_graph diff --git a/deepspeed/compile/partitioner.py b/deepspeed/compile/partitioner.py new file mode 100644 index 000000000000..f60170d0b56b --- /dev/null +++ b/deepspeed/compile/partitioner.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This file was copied from PyTorch and modified for DeepSpeed. + +from typing import Tuple, List +import operator + +import torch +from torch.fx import GraphModule, Graph, Node + +try: + from torch._functorch.partitioners import is_sym_node, _is_primal, _is_fwd_seed_offset, _extract_fwd_bwd_outputs, _extract_graph_with_inputs_outputs, _extract_fwd_bwd_modules, has_recomputable_ops, min_cut_rematerialization_partition, choose_saved_values_set +except ImportError: + pass + +from .util import get_no_copy_ops + +_recompute_ops = {torch.ops.aten.t.default} + + +def _find_recompute_nodes(graph: Graph, ds_param_node: Node) -> List[Node]: + """ + Given a graph and a node that represents a parameter that was allgathered, + find all nodes that use the parameter and require recomputation. + """ + no_copy_ops = get_no_copy_ops() + recompute_nodes = set() + for node in graph.nodes: + if node.target in no_copy_ops: + if ds_param_node in node.args: + recompute_nodes.add(node) + if any(a in recompute_nodes for a in node.args): + recompute_nodes.add(node) + + return recompute_nodes + + +def _get_values_from_ds_params(joint_graph, param_indices): + primal_inputs = list(filter(_is_primal, joint_graph.nodes)) + ds_param_inputs = [primal_inputs[arg_idx] for arg_idx, _, _ in param_indices] + + no_copy_ops = get_no_copy_ops() + + ds_param_inputs = set(ds_param_inputs) + ds_param_users = {} + + for node in joint_graph.nodes: + if node.target in no_copy_ops and any((a in ds_param_inputs or a in ds_param_users) for a in node.args): + for a in node.args: + if a in ds_param_inputs: + ds_param_users[node] = a + elif a in ds_param_users: + ds_param_users[node] = ds_param_users[a] + + return ds_param_users + + +def get_wrapped_choose_saved_values_set(param_indices: List[Tuple[int, int, torch.Size]]): + + def ds_choose_saved_values_set(joint_graph: torch.fx.Graph, node_info, memory_budget=1) -> List[Node]: + saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget) + ds_param_users = _get_values_from_ds_params(joint_graph, param_indices) + + new_saved_values = [] + for v in saved_values: + if v in ds_param_users: + ds_val = ds_param_users[v] + if ds_val not in new_saved_values: + new_saved_values.append(ds_val) + else: + new_saved_values.append(v) + + return new_saved_values + + return ds_choose_saved_values_set + + +def get_wrapped_partitioner(param_indices: List[Tuple[int, int, torch.Size]]): + + def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *, + num_fwd_outputs) -> Tuple[GraphModule, GraphModule]: + """ + This is basically the same as the default_partition function, but + it doesn't save the gathered params and values computed from them. + """ + if has_recomputable_ops(joint_module): + return min_cut_rematerialization_partition(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) + forward_only_graph = _extract_graph_with_inputs_outputs(joint_module.graph, inputs, fwd_outputs, "forward") + forward_node_names = {node.name for node in forward_only_graph.nodes if node.op != "output"} + saved_values = [] + saved_sym_nodes = [] + + fwd_inputs = list(filter(_is_primal, forward_only_graph.nodes)) + ds_param_inputs = [fwd_inputs[arg_idx] for arg_idx, _, _ in param_indices] + ds_param_input_names = {node.name for node in ds_param_inputs} + + ds_param_recompute_nodes = set() + + for node in joint_module.graph.nodes: + if node.name not in forward_node_names: + continue + + if is_sym_node(node): + # Symints must be kept separate from tensors so that PythonFunction only calls + # save_for_backward on tensors and stashes symints in autograd .ctx + saved_sym_nodes.append(node) + elif "tensor_meta" not in node.meta and node.op == "call_function": + # Since we can't save tuple of tensor values, we need to flatten out what we're saving + users = node.users + assert all(user.target == operator.getitem for user in users) + saved_values.extend(users) + else: + backward_usages = [n for n in node.users if n.name not in forward_node_names] + + if "tensor_meta" in node.meta and all(is_sym_node(n) for n in backward_usages): + # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, + # and not the actual tensor data, + # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. + # + # Note that saving the tensor could also cause compilation problems: + # If the user mutated an input in the forward and uses its sizes/strides in the backward, + # then we would be obligated to clone the input before saving it to appease autograd. + # (This is how we originally found this bug). + saved_sym_nodes.extend(backward_usages) + + if node.name in ds_param_input_names: + saved_values.append(node) + recompute_nodes = _find_recompute_nodes(joint_module.graph, node) + recompute_nodes = [n for n in recompute_nodes if n.name in forward_node_names] + for recompute_node in recompute_nodes: + ds_param_recompute_nodes.add(recompute_node) + + if len(recompute_nodes) > 0: + saved_values.append(node) + else: + if node not in ds_param_recompute_nodes: + saved_values.append(node) + saved_values = list(dict.fromkeys(saved_values).keys()) + saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) + + f_gm, b_gm = _extract_fwd_bwd_modules( + joint_module, + saved_values, + saved_sym_nodes=saved_sym_nodes, + num_fwd_outputs=num_fwd_outputs, + ) + + return f_gm, b_gm + + return partition_recompute_ds_params diff --git a/deepspeed/compile/passes/__init__.py b/deepspeed/compile/passes/__init__.py new file mode 100644 index 000000000000..620e99147647 --- /dev/null +++ b/deepspeed/compile/passes/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..profilers.graph_profile import MemoryProfilingInterpreter + +import deepspeed.comm as dist + + +def run_opt_passes(nz3, + graph_index, + graph_id, + gm, + create_inputs_fn, + opt_passes, + graph_order, + profiling_results, + param_manager, + bwd, + debug_log=False): + profile = profiling_results[graph_id] + rank = dist.get_rank() + + for i, opt_pass in enumerate(opt_passes): + + opt_pass_fn, mem_budget = opt_pass + + graph = opt_pass_fn(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd) + graph.lint() + gm.graph = graph + gm.recompile() + + if debug_log: + print(f"Prefetching enabled for {'bwd' if bwd else 'fwd'} graph_id={graph_id} {graph}") + + mem_prof = MemoryProfilingInterpreter(nz3, gm) + mem_prof.run(*create_inputs_fn()) + if debug_log and rank == 0: + mem_prof.dump(f"mem_prof_r{rank}_{'bwd' if bwd else 'fwd'}_{graph_index}_{graph_id}_pass_{i}.csv") + + mem = [(name, current_alloc, delta, peak) for name, current_alloc, delta, peak in mem_prof.mem_record] + if bwd: + profile.bwd_mem = mem + else: + profile.fwd_mem = mem + + return gm diff --git a/deepspeed/compile/passes/offload_activation.py b/deepspeed/compile/passes/offload_activation.py new file mode 100644 index 000000000000..496e7351f218 --- /dev/null +++ b/deepspeed/compile/passes/offload_activation.py @@ -0,0 +1,116 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Dict, Set, Tuple +import random +from collections import defaultdict + +import torch +from torch.fx import Graph, Node + +from ..fx import get_output_node, move_primals_to_head +from ..graph_param import DSGraphParamManager + +value_to_id: Dict[int, Dict[str, int]] = defaultdict(dict) +used_ids: Set[int] = set() + + +def get_random_id() -> int: + + def _gen(): + # generate random int + return random.randint(10000, 2**31) + + global used_ids + v = _gen() + while v in used_ids: + v = _gen() + used_ids.add(v) + return v + + +def _should_offload(node: Node) -> bool: + if not hasattr(node, "meta"): + return False + if not "tensor_meta" in node.meta: + return False + + return True + + +def offload_activation_fwd(graph: Graph, graph_id: int, nodes_to_offload_with_names: List[Tuple[str, Node]], + graph_order: List[int], mem_budget: float, param_manager: DSGraphParamManager) -> Graph: + param_names = set(param_manager.param_names) + + import copy + cl_graph = copy.deepcopy(graph) + cl_graph.erase_node(get_output_node(cl_graph)) + + global value_to_id + for name, node in nodes_to_offload_with_names: + if node.name in param_names: + continue + + if not _should_offload(node): + continue + + val_id = get_random_id() + with graph.inserting_after(node): + offload_node = graph.create_node('call_function', + torch.ops.dc.offload_tensor.default, (node, graph_id, val_id), {}, + name=f"offload_{node.name}_{val_id}") + with graph.inserting_after(offload_node): + wait_node = graph.create_node('call_function', + torch.ops.dc.wait_offload.default, (offload_node, graph_id, val_id), {}, + name=f"wait_copy_{node.name}_{val_id}") + + output_node = get_output_node(graph) + output_node.replace_input_with(node, wait_node) + + value_to_id[graph_id][name] = val_id + + graph = move_primals_to_head(graph) + + graph.lint() + return graph + + +def reload_activation_bwd(graph: Graph, graph_id: int, graph_order: List[int], mem_budget: float, + param_manager: DSGraphParamManager) -> Graph: + + graph_value_to_id = value_to_id[graph_id] + name_to_node = {n.name: n for n in graph.nodes} + act_nodes = [name_to_node[n] for n in graph_value_to_id.keys()] + + node_to_first_user = {} + for act in act_nodes: + for node in graph.nodes: + if act in node.args: + node_to_first_user[act] = node + break + + for node in act_nodes: + val_id = graph_value_to_id[node.name] + + with graph.inserting_before(node_to_first_user[node]): + reload_node = graph.create_node('call_function', + torch.ops.dc.reload_tensor.default, (node, graph_id, val_id), {}, + name=f"reload_{node.name}_{val_id}") + with graph.inserting_after(reload_node): + wait_node = graph.create_node('call_function', + torch.ops.dc.wait_reload.default, (reload_node, graph_id, val_id), {}, + name=f"wait_copy_{node.name}_{val_id}") + + # replace all uses of node with wait_node + users = {} + for u in node.users.keys(): + if u != reload_node: + users[u] = (node, wait_node) + for u, (old_in, new_in) in users.items(): + u.replace_input_with(old_in, new_in) + + graph = move_primals_to_head(graph) + graph.lint() + return graph diff --git a/deepspeed/compile/passes/offload_adam_states.py b/deepspeed/compile/passes/offload_adam_states.py new file mode 100644 index 000000000000..458d07f39d16 --- /dev/null +++ b/deepspeed/compile/passes/offload_adam_states.py @@ -0,0 +1,546 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import copy +from typing import List + +import torch +from torch.fx import Graph, GraphModule + +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.offload_states import _make_offload_state_key + +try: + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + # Unsupported torch version + pass + +from ..profilers import ProfilingResult +from ..graph_param import DSGraphParamManager +from ..fx import move_primals_to_head + +import deepspeed.comm as dist + +NAME = "offload_adam_states" + + +def print_r0(msg): + if dist.get_rank() == 0: + print(msg) + + +MARGIN = 0.2 + +copy_stream = None +offload_event = None +reload_event = None + +offload_key_events = {} +reload_key_events = {} + +max_memory = 0 + + +def lazy_init(): + global copy_stream + global offload_event + global reload_event + + if copy_stream is None: + + copy_stream = get_accelerator().Stream() + offload_event = get_accelerator().Event() + reload_event = get_accelerator().Event() + + +optimizer = None +device = None +nz3 = None + + +def move_key(state, key, key_event=None): + offload_buf_key = _make_offload_state_key(key) + if offload_buf_key not in state: + state[offload_buf_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device="cpu")) + + if key not in state: + return + + with get_accelerator().stream(copy_stream): + state[offload_buf_key].copy_(state[key], non_blocking=True) + + if key_event is None: + offload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_back_key(state, key, key_event=None): + + with get_accelerator().stream(copy_stream): + state[key] = torch.empty_like(state[_make_offload_state_key(key)], device=device) + state[key].copy_(state[_make_offload_state_key(key)], non_blocking=True) + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_hp_param(src_tensor, dest_buf, key_event=None): + with get_accelerator().stream(copy_stream): + dest_buf.copy_(src_tensor, non_blocking=True) + src_tensor.data = dest_buf + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def move_back_hp_param(src_tensor, dest_buf, key_event=None): + with get_accelerator().stream(copy_stream): + dest_buf.data = torch.empty_like(src_tensor, device=device) + dest_buf.copy_(src_tensor, non_blocking=True) + + if key_event is None: + reload_event.record(stream=copy_stream) + else: + key_event.record(stream=copy_stream) + + +def offload_adam_states_sync(): + + with unset_fake_temporarily(): + + if not hasattr(optimizer, "hp_params_pin_buffers"): + optimizer.hp_params_pin_buffers = [ + get_accelerator().pin_memory(torch.empty_like(t, device="cpu")) + for t in optimizer.fp32_partitioned_groups_flat + ] + + for i, (k, state) in enumerate(optimizer.state.items()): + if "exp_avg" in state: + move_key(state, "exp_avg") + if "exp_avg_sq" in state: + move_key(state, "exp_avg_sq") + + for _, state in optimizer.state.items(): + if "exp_avg" in state: + del state["exp_avg"] + if "exp_avg_sq" in state: + del state["exp_avg_sq"] + + for src_tensor, dest_buf in zip(optimizer.fp32_partitioned_groups_flat, optimizer.hp_params_pin_buffers): + move_hp_param(src_tensor, dest_buf) + + get_accelerator().synchronize() + + +def reload_adam_states_sync(): + + with unset_fake_temporarily(): + # print_r0("Reloading Adam states") + + for _, state in optimizer.state.items(): + if _make_offload_state_key("exp_avg") in state: + move_back_key(state, "exp_avg") + if _make_offload_state_key("exp_avg_sq") in state: + move_back_key(state, "exp_avg_sq") + + for src, dest in zip(optimizer.hp_params_pin_buffers, optimizer.fp32_partitioned_groups_flat): + move_back_hp_param(src, dest) + + get_accelerator().synchronize() + + +def sync_offload_states(event=None): + if nz3.is_profiling(): + offload_adam_states_sync() + else: + if event is None: + offload_event.wait(copy_stream) + else: + event.wait(copy_stream) + + +def sync_reload_states(event=None): + if nz3.is_profiling(): + reload_adam_states_sync() + else: + if event is None: + reload_event.wait(copy_stream) + else: + event.wait(copy_stream) + + +def make_offload_task(task): + + def run_offload_task(): + # if not nz3.is_profiling(): + # print_r0(f"run_offload_task {task[0]} {task[2]} {task[3]} {task[4]}") + + if offload_key_events.get(task[1]) is None: + offload_key_events[task[1]] = get_accelerator().Event() + + if task[2] == "hp_param": + move_hp_param(task[1][0], task[1][1], offload_key_events[task[1][0]]) + else: + assert task[1] in optimizer.state, f"State {task[1]} not found in optimizer" + state = optimizer.state[task[1]] + # if offload_key_events.get(task[1]) is None: + # offload_key_events[task[1]] = get_accelerator().Event() + move_key(state, task[2], offload_key_events[task[1]]) + + return run_offload_task + + +def make_offload_sync(task): + + def run_offload_sync(): + # if not nz3.is_profiling(): + event = offload_key_events[task[1]] + event.synchronize() + + if task[2] != "hp_param": + state = optimizer.state[task[1]] + key = task[2] + if key in state: + del state[key] + # print_r0(f"run_offload_sync {task[0]} {task[2]} alloc_mem={get_accelerator().memory_allocated()}") + + return run_offload_sync + + +def make_reload_task(task): + + def run_reload_task(): + if not nz3.is_profiling(): + if reload_key_events.get(task[1]) is None: + reload_key_events[task[1]] = get_accelerator().Event() + + if task[2] == "hp_param": + move_back_hp_param(task[1][1], task[1][0], reload_key_events[task[1]]) + else: + state = optimizer.state[task[1]] + # print_r0(f"run_reload_task {task[0]} {task[2]} {task[3]} {task[4]}") + move_back_key(state, task[2], reload_key_events[task[1]]) + + return run_reload_task + + +def update_max_memory(name): + + global max_memory + mem = get_accelerator().max_memory_allocated() + max_memory = max(max_memory, mem) + + +def empty_cache(): + get_accelerator().empty_cache() + + +offload_tasks = [] +offload_tasks_remaining = [] +offload_tasks_scheduled = [] +reload_task_remaining = [] +total_reload_mem = 0 + + +def offload_opt_states_inc(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph: + + to_remove = [] + for node in graph.nodes: + if node.op == 'call_function' and \ + node.target in [offload_adam_states_sync, sync_offload_states, reload_adam_states_sync, sync_reload_states, update_max_memory]: + to_remove.append(node) + + for node in to_remove: + graph.erase_node(node) + + accelerator = get_accelerator() + total_mem = accelerator.total_memory() * (1 - MARGIN) + print_r0(f"offload_opt_states_inc start graph {graph_id} bwd={bwd} max_memory={max_memory} total_mem={total_mem}") + + mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem + mem_dict = {name: peak for name, alloc_mem, delta, peak in mem} + + current_peak_mem = 0 + peak_mem = {} + + ordered_node = reversed(graph.nodes) if bwd else graph.nodes + for node in ordered_node: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if mem_dict[node.name] > current_peak_mem: + current_peak_mem = mem_dict[node.name] + peak_mem[node.name] = current_peak_mem + + # fwd_max_mem = max(m[3] for m in prof.fwd_mem) + # bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0 + # peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem) + + global offload_tasks_remaining, reload_tasks_remaining, offload_tasks_scheduled + + if not bwd: + is_first_graph = graph_id == graph_order[0][0] + # print_r0( + # f"offload_opt_states_inc start graph {graph_id} graph_order {graph_order} fwd is_first_graph {is_first_graph}" + # ) + + # At the beginning of the first graph, we schedule offload tasks to launch all offloading + if is_first_graph: + # print_r0( + # f"offload_opt_states_inc fwd before reload graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}" + # ) + + with unset_fake_temporarily(): + offload_adam_states_sync() + reload_adam_states_sync() + sync_reload_states() + + reload_size = 0 + + for i, ((k, state), hp_param, hp_param_cpu) in enumerate( + zip(optimizer.state.items(), optimizer.fp32_partitioned_groups_flat, + optimizer.hp_params_pin_buffers)): + # print_r0( + # f"Checking key for offloading {i} {k.shape} has_key {_make_offload_state_key('exp_avg') in state}") + + if _make_offload_state_key("exp_avg") in state: + key = _make_offload_state_key("exp_avg") + size = state[key].numel() * state[key].element_size() + + # if total_mem < max_memory + reload_size + size: + offload_tasks.append( + (i, k, "exp_avg", state[key].numel() * state[key].element_size(), state[key].dtype)) + # print_r0( + # f"Offloading task {i} exp_avg reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}" + # ) + + if _make_offload_state_key("exp_avg_sq") in state: + key = _make_offload_state_key("exp_avg_sq") + size = state[key].numel() * state[key].element_size() + + # if total_mem < max_memory + reload_size + size: + offload_tasks.append( + (i, k, "exp_avg_sq", state[key].numel() * state[key].element_size(), state[key].dtype)) + # print_r0( + # f"Offloading task {i} exp_avg_sq reload_size={reload_size} size={size} estimated_mem={max_memory + reload_size + size}" + # ) + + hp_param_size = hp_param.numel() * hp_param.element_size() + # if total_mem < max_memory + reload_size + hp_param_size: + offload_tasks.append((i, (hp_param, hp_param_cpu), "hp_param", + hp_param.numel() * hp_param.element_size(), hp_param.dtype)) + # print_r0( + # f"Offloading task {i} hp_param reload_size={reload_size} size={hp_param_size} estimated_mem={max_memory + reload_size + hp_param_size}" + # ) + + # print_r0(f"offload_opt_states_inc fwd graph {graph_id} allocated_mem={get_accelerator().memory_allocated()}") + + for node in graph.nodes: + # print_r0(f"checking sync node insert node: {node.name}") + + if node.name not in peak_mem \ + or node.op == 'placeholder' \ + or "offload_opt_" in node.name: + continue + + to_offload = [] + optim_size = sum([task[3] for task in offload_tasks]) + + # print_r0( + # f" optim_size: {optim_size} total_mem: {total_mem} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}" + # ) + while total_mem - peak_mem[node.name] - optim_size < 0: + if len(offload_tasks) == 0: + break + + task = offload_tasks.pop(0) + to_offload.append(task) + optim_size = sum([task[3] for task in offload_tasks]) + # print_r0( + # f" scheduled task {task[0]} {task[2]} {task[3]} optim_size: {optim_size} peak_mem: {peak_mem[node.name]} available: {total_mem - peak_mem[node.name] - optim_size} #tasks={len(offload_tasks)}" + # ) + + for task in to_offload: + with graph.inserting_before(node): + graph.create_node('call_function', + make_offload_sync(task), (), {}, + name=f"offload_opt_sync_{task[0]}_{task[2]}") + print_r0(f"Inserting fwd offload_opt_sync_{task[0]}_{task[2]}") + offload_tasks_scheduled.append(task) + + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op != 'placeholder': + print_r0(f"Inserting all offload tasks before {node.name}") + for task in offload_tasks_scheduled: + name = f"offload_opt_{task[0]}_{task[2]}" + with graph.inserting_before(node): + offload_node = graph.create_node('call_function', make_offload_task(task), (), {}, name=name) + break + + # print_r0(f"offload_opt_states_inc finish graph {graph_id} fwd graph {graph}") + print_r0(f"offload_opt_states_inc finish graph {graph_id}") + else: + + graph_order_with_backward = [g[0] for g in graph_order if g[1]] + is_first_graph = graph_id == graph_order_with_backward[-1] + is_last_graph = graph_id == graph_order_with_backward[0] + + # print_r0( + # f"offload_opt_states_inc bwd graph {graph_id} graph_order_with_backward {graph_order_with_backward} is_first_graph {is_first_graph} is_last_graph {is_last_graph}" + # ) + + if is_first_graph: + inserted_sync = False + for node in graph.nodes: + if node.op != 'placeholder' and not inserted_sync: + # print(f"Inserting offload_sync before {node.name}") + with graph.inserting_before(node): + graph.create_node('call_function', empty_cache, (), {}, name="empty_cache") + + inserted_sync = True + reload_tasks_remaining = copy.copy(offload_tasks_scheduled) + + global total_reload_mem + for node in graph.nodes: + if node.name not in peak_mem \ + or node.op == 'placeholder' \ + or node.op == 'output' \ + or "offload_opt_sync_" in node.name: + continue + + if len(reload_tasks_remaining) > 0: + task = reload_tasks_remaining[0] + next_reload_mem = task[3] + + insert_pos = node + while total_mem > peak_mem[node.name] + total_reload_mem + next_reload_mem: + expected_mem = peak_mem[node.name] + total_reload_mem + print_r0( + f" Inserting reload_opt reload_opt_{task[0]}_{task[2]} after {insert_pos.name} next_inc={next_reload_mem} peak_mem[{node.name}]={peak_mem[node.name]} inc_total={total_reload_mem} expected_mem={expected_mem}" + ) + + with graph.inserting_after(insert_pos): + insert_pos = graph.create_node('call_function', + make_reload_task(task), (), {}, + name=f"reload_opt_{task[0]}_{task[2]}") + + total_reload_mem += next_reload_mem + reload_tasks_remaining.pop(0) + if len(reload_tasks_remaining) == 0: + break + + task = reload_tasks_remaining[0] + next_reload_mem = task[3] + + # prev_node = node + + if is_last_graph: + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op == 'output': + for task in reload_tasks_remaining: + with graph.inserting_before(node): + graph.create_node('call_function', + make_reload_task(task), (), {}, + name=f"reload_opt_{task[0]}_{task[2]}") + + sync_fn = lambda: copy_stream.synchronize() + with graph.inserting_before(node): + graph.create_node('call_function', sync_fn, (), {}, name="sync_offload_copy_stream") + + print_r0( + f"offload_opt_states_inc graph {graph_id} graph_order {graph_order} bwd is_first_graph {is_first_graph} is_last_graph {is_last_graph}" + ) + + return graph + + +def add_record_max_mem_nodes(graph: Graph): + + nodes = list(graph.nodes) + for node in nodes: + if node.op == "output" or node.op == "placeholder": + continue + + with graph.inserting_after(node): + name = f"update_max_memory_{node.name}" + graph.create_node('call_function', update_max_memory, (name, ), {}, name=name) + + +def insert_offload_opt_states(graph: Graph, graph_id: int, graph_order: List[int], profiling_results: ProfilingResult, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> Graph: + + if bwd: + graph_order_with_backward = [g[0] for g in graph_order if g[1]] + is_last_graph = graph_id == graph_order_with_backward[0] + + inserted_reload = False + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op == 'output' and not inserted_reload and is_last_graph: + # print(f"Inserting reload_opt before {node.name}") + with graph.inserting_before(node): + graph.create_node('call_function', reload_adam_states_sync, (), {}, name="reload_opt") + inserted_reload = True + + # add_record_max_mem_nodes(graph) + + else: + is_first_graph = graph_id == graph_order[0][0] + + graph = move_primals_to_head(graph) + + inserted_offload = False + for node in graph.nodes: + # print(f"Node: {node.name} mem: {mem_dict[node.name]}") + if node.op != 'placeholder' and not inserted_offload and is_first_graph: + print(f"Inserting offload_opt before {node.name}") + with graph.inserting_before(node): + graph.create_node('call_function', offload_adam_states_sync, (), {}, name="offload_opt") + inserted_offload = True + + add_record_max_mem_nodes(graph) + + return graph + + +def move_opt_states(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + gm.graph = offload_opt_states_inc(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, + bwd) + return gm + + +def move_opt_states_sync(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + gm.graph = insert_offload_opt_states(gm.graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, + bwd) + return gm + + +def offload_adam_states_for_init(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, + create_inputs_fn, mem_budget: float, param_manager: DSGraphParamManager, + bwd: bool) -> GraphModule: + if not bwd and graph_id == graph_order[0][0]: + with unset_fake_temporarily(): + offload_adam_states_sync() + # returns None, and profiling will be skipped + + +def init_offload_opt_states(adam_optimizer, _nz3): + lazy_init() + + global optimizer + optimizer = adam_optimizer + global device + device = torch.device(get_accelerator().current_device()) + global nz3 + nz3 = _nz3 diff --git a/deepspeed/compile/passes/offload_parameters.py b/deepspeed/compile/passes/offload_parameters.py new file mode 100644 index 000000000000..29468f4970d9 --- /dev/null +++ b/deepspeed/compile/passes/offload_parameters.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import torch +from torch.fx import Node, GraphModule +from deepspeed.compile.util import get_last_uses +from ..graph_param import DSGraphParamManager + + +def add_offload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int): + new_node = None + with gm.graph.inserting_after(node): + args = (node, ) + for a in [graph_id, ds_id]: # To add ds_id + args += (a, ) + new_node = gm.graph.create_node('call_function', + torch.ops.dc.offload_parameter.default, + args, {}, + name="offload_parameter") + + return new_node + + +def add_reload_parameter(graph_id: int, gm: GraphModule, node: Node, ds_id: int): + new_node = None + with gm.graph.inserting_after(node): + args = (node, ) + for a in [graph_id, ds_id]: # To add ds_id + args += (a, ) + new_node = gm.graph.create_node('call_function', + torch.ops.dc.reload_parameter.default, + args, {}, + name=f"reload_parameter") + return new_node + + +def get_ds_id(node: Node): + assert node.target == torch.ops.dc.allgather_param.default + return node.args[2] + + +def offload_parameter_fwd(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + node_to_last_use, user_to_last_uses = get_last_uses(gm.graph) + for node in gm.graph.nodes: + if (isinstance(node, Node) and node.target == torch.ops.dc.allgather_param.default): + add_reload_parameter(graph_id, gm, node.args[0], get_ds_id(node)) + add_offload_parameter(graph_id, gm, node_to_last_use[node], get_ds_id(node)) + gm.graph.lint() + return gm diff --git a/deepspeed/compile/passes/prefetch.py b/deepspeed/compile/passes/prefetch.py new file mode 100644 index 000000000000..ce0d721f8d58 --- /dev/null +++ b/deepspeed/compile/passes/prefetch.py @@ -0,0 +1,174 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import torch +from torch.fx import Graph, Node, GraphModule + +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist + +from ..profilers.comm_profile import create_predictor +from ..graph_param import DSGraphParamManager + +NAME = "prefetch" + +FUSE_FACTOR = 0.8 +MARGIN = 0.1 +MAX_FUSE_SIZE = 1e9 +MAX_BUFFERED_SIZE = 4e9 + +run_prefetch_pass = False + + +def print_rank_0(message): + if dist.get_rank() == 0: + print(message) + + +def get_ds_id(node: Node): + assert node.target == torch.ops.dc.allgather_param.default + return node.args[2] + + +def schedule_prefetch(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + + max_mem = get_accelerator().total_memory() * (1 - MARGIN) + vals_to_bcast = torch.tensor([max_mem], device=torch.device(get_accelerator().current_device())) + dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN) + max_mem = vals_to_bcast[0].item() + + mem = profiling_results[graph_id].bwd_mem if bwd else profiling_results[graph_id].fwd_mem + op_time = profiling_results[graph_id].bwd_time if bwd else profiling_results[graph_id].fwd_time + tensor_sizes = profiling_results[graph_id].bwd_tensor_sizes if bwd else profiling_results[graph_id].fwd_tensor_sizes + + mem_dict = {name: (alloc_mem, peak) for name, alloc_mem, delta, peak in mem} + time_dict = {name: (device_time, wall_time) for name, device_time, wall_time in op_time} + tensor_size_dict = {name: size for name, size in tensor_sizes} + + graph = gm.graph + total_param_size = sum( + [tensor_size_dict[n.name] for n in graph.nodes if n.target == torch.ops.dc.allgather_param.default]) + + print_rank_0( + f"schedule_prefetch graph_id={graph_id} max_mem={max_mem} available_memory={get_accelerator().available_memory()} memory_allocated={get_accelerator().memory_allocated()} max_allocated={get_accelerator().max_memory_allocated()} total_param_size={total_param_size} margin={MARGIN}" + ) + + # Fill missing values + prev_mem = 0 + prev_peak = 0 + for node in graph.nodes: + if node.name in mem_dict: + prev_mem = mem_dict[node.name][0] + prev_peak = mem_dict[node.name][1] + else: + print_rank_0(f"node {node.name} not in mem_dict") + mem_dict[node.name] = (prev_mem, prev_peak) + + comm_predictor = create_predictor() + + order_rev = list(reversed(graph.nodes)) + new_order_rev = [] + prefetch_ags = [] + prefetch_ag_groups = [] + ag_tensor_size_sum = 0 + for i, node in enumerate(order_rev): + # print_rank_0( + # f"Checking node reverse order {node.name} {node.target} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}" + # ) + + if node.op != "placeholder": + assert i < len(order_rev) - 1 + assert node.name in mem_dict + next_node = order_rev[i + 1] + next_alloc_mem, next_peak = mem_dict[next_node.name] + + # Free up memory + while next_peak + ag_tensor_size_sum > max_mem or ag_tensor_size_sum > MAX_BUFFERED_SIZE: + if len(prefetch_ag_groups) > 0: + # launch prefetch + fused_ag_nodes = prefetch_ag_groups.pop(0) + total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in fused_ag_nodes]) + ag_tensor_size_sum -= total_ag_tensor_size + new_order_rev.append(fused_ag_nodes) + assert len(fused_ag_nodes) > 0 + # print_rank_0( + # f"Free up memory fused_ag_nodes={fused_ag_nodes} next_alloc_mem={next_alloc_mem} total_ag_tensor_size={total_ag_tensor_size} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}" + # ) + elif len(prefetch_ags) > 0: + prefetch_ag_groups.append(prefetch_ags) + prefetch_ags = [] + # print_rank_0( + # f"Free up memory prefetch_ags={prefetch_ag_groups} next_alloc_mem={next_alloc_mem} ag_tensor_size_sum={ag_tensor_size_sum} max_mem={max_mem}" + # ) + else: + break + + if node.target == torch.ops.dc.allgather_param.default: + + current_ag_size = sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags]) + pred_time_current = comm_predictor(current_ag_size) + pred_time_next = comm_predictor(tensor_size_dict[node.name]) + pred_time_fused = comm_predictor(current_ag_size + tensor_size_dict[node.name]) + + do_fuse = max(pred_time_current, pred_time_next) * 1.2 > pred_time_fused and ( + current_ag_size + tensor_size_dict[node.name]) < MAX_FUSE_SIZE + # print_rank_0( + # f"found allgather_param do_fuse={do_fuse} current_ag_size={current_ag_size} tensor_size_dict[node.name]={tensor_size_dict[node.name]} pred_time_current={pred_time_current} pred_time_next={pred_time_next} pred_time_fused={pred_time_fused}" + # ) + + if len(prefetch_ags) > 0 and not do_fuse: + # stop fusing here + prefetch_ag_groups.append(prefetch_ags) + prefetch_ags = [] + # print_rank_0( + # f"stop fusing prefetch_ags={prefetch_ag_groups} ag_tensor_size_sum={ag_tensor_size_sum}") + # else: + # print_rank_0( + # f"continue fusing ag_tensor_size_sum={ag_tensor_size_sum} ag_size={tensor_size_dict[node.name]} prefetch_ags={prefetch_ags} prefetch_ag_groups={prefetch_ag_groups}" + # ) + prefetch_ags.append(node) + ag_tensor_size_sum += tensor_size_dict[node.name] + + new_order_rev.append(node) + + if (node.op != "placeholder" + and node.target != torch.ops.dc.reload_parameter) and order_rev[i + 1].op == "placeholder": + for ag_group in prefetch_ag_groups: + assert len(ag_group) > 0 + new_order_rev.append(ag_group) + total_ag_tensor_size = sum([tensor_size_dict[ag_node.name] for ag_node in ag_group]) + ag_tensor_size_sum -= total_ag_tensor_size + if len(prefetch_ags) > 0: + new_order_rev.append(prefetch_ags) + ag_tensor_size_sum -= sum([tensor_size_dict[ag_node.name] for ag_node in prefetch_ags]) + assert ag_tensor_size_sum == 0 + + # print_rank_0( + # f"node={node} next_alloc_mem={next_alloc_mem} pending_ags={len(prefetch_ags)} ag_tensor_size_sum={ag_tensor_size_sum}" + # ) + + assert ag_tensor_size_sum >= 0 + + new_graph = Graph() + env = {} + for node in reversed(new_order_rev): + if isinstance(node, Node): + #print(f"reconstruct {node.name} {node.target}") + new_node = new_graph.node_copy(node, lambda n: env[n.name]) + env[node.name] = new_node + else: + param_nodes = [ag_node.args[0] for ag_node in node] + param_nodes_copy = [env[param_node.name] for param_node in param_nodes] + + ds_ids = [get_ds_id(ag_node) for ag_node in node] + new_graph.call_function(torch.ops.dc.prefetch_params_fused.default, + args=(graph_id, param_nodes_copy, ds_ids)) + new_graph.lint() + gm.graph = new_graph + + return gm diff --git a/deepspeed/compile/passes/selective_gather.py b/deepspeed/compile/passes/selective_gather.py new file mode 100644 index 000000000000..83306872ce06 --- /dev/null +++ b/deepspeed/compile/passes/selective_gather.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from collections import defaultdict +from typing import List + +import torch +from torch.fx import GraphModule + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +from ..util import get_deepcompile_handle +from ..graph_param import DSGraphParamManager + +NAME = "selective_gather" + +max_alloc_mem = 0 +last_optimize_step = 0 + + +def selective_gather(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager: DSGraphParamManager, bwd: bool) -> GraphModule: + + if not bwd: + return gm + + last_backward_graph_id = None + for g_id, needs_bwd in graph_order: + if needs_bwd: + last_backward_graph_id = g_id + break + + # Run only on the last backward graph + if last_backward_graph_id is None or graph_id != last_backward_graph_id: + return gm + + peak_mem = 0 + for graph_id, prof in profiling_results.items(): + # Use peak memory + fwd_max_mem = max(m[3] for m in prof.fwd_mem) + bwd_max_mem = max(m[3] for m in prof.bwd_mem) if len(prof.bwd_mem) > 0 else 0 + peak_mem = max(peak_mem, fwd_max_mem, bwd_max_mem) + if dist.get_rank() == 0: + print( + f"selective_gather graph_id={graph_id} max_mem={peak_mem} fwd_max_mem={fwd_max_mem} bwd_max_mem={bwd_max_mem}" + ) + + persistent_ds_ids = set() + for graph_id, pm in param_manager.items(): + for name, ds_param in pm.params.items(): + if ds_param.param.ds_persist: + persistent_ds_ids.add(pm.ds_ids[name]) + + ds_id_to_size = {} + ds_id_to_time = defaultdict(float) + ds_id_to_prof_dtime = defaultdict(float) + ds_id_to_prof_wtime = defaultdict(float) + + for graph_id, pm in param_manager.items(): + params = pm.params + for param_name, param in params.items(): + ds_id = pm.ds_ids[param_name] + ds_id_to_size[ds_id] = param.numel * param.dtype.itemsize + + profile = profiling_results[graph_id] + for n in profile.fwd_graph.nodes: + if n.target == torch.ops.dc.allgather_param.default: + assert "tensor_size" in n.meta + ds_id_to_size[n.args[2]] = n.meta["tensor_size"] + assert "device_time" in n.meta + ds_id_to_time[n.args[2]] += n.meta["device_time"] + + ds_id_to_prof_dtime[n.args[2]] = n.meta["device_time"] + ds_id_to_prof_wtime[n.args[2]] = n.meta["wall_time"] + + if profile.bwd_graph is not None: + for n in profile.bwd_graph.nodes: + if n.target == torch.ops.dc.allgather_param.default: + assert "tensor_size" in n.meta + ds_id_to_size[n.args[2]] = n.meta["tensor_size"] + assert "device_time" in n.meta + ds_id_to_time[n.args[2]] += n.meta["device_time"] + + ds_ids = [ds_id for ds_id in ds_id_to_size if ds_id not in persistent_ds_ids] + ds_ids.sort(key=lambda ds_id: ds_id_to_time[ds_id] / ds_id_to_size[ds_id], reverse=True) + + # print(f"ds_id_to_size={ds_id_to_size}") + # print(f"ds_id_to_time={ds_id_to_time}") + + # if dist.get_rank() == 0: + # for ds_id in ds_ids: + # dtime_in_sec = ds_id_to_prof_dtime[ds_id] + # wtime_in_sec = ds_id_to_prof_wtime[ds_id] + # size_in_mb = ds_id_to_size[ds_id] / 1024 / 1024 + # print( + # f"ds_id={ds_id} time_per_size={ds_id_to_time[ds_id] / ds_id_to_size[ds_id]:.5f} dtime={dtime_in_sec:.3f} wtime={wtime_in_sec:.3f} size={size_in_mb:.2f}MB bw={size_in_mb/dtime_in_sec:.2f}MB/s" + # ) + + sorted_ds_ids = {ds_id: ds_id_to_size[ds_id] for ds_id in ds_ids} + + accelerator = get_accelerator() + total_mem = accelerator.total_memory() + vals_to_bcast = torch.tensor([total_mem], device=torch.device(get_accelerator().current_device())) + dist.all_reduce(vals_to_bcast, dist.ReduceOp.MIN) + total_mem = vals_to_bcast[0].item() + + MEM_MARGIN = 0.1 + available_mem = total_mem * (1 - MEM_MARGIN) - peak_mem + + if dist.get_rank() == 0: + print( + f"selective_gather max_mem={peak_mem} total_mem={total_mem} MEM_MARGIN={MEM_MARGIN} available_mem={available_mem}" + ) + + ds_id_to_param = {} + for g_id, g_pm in param_manager.items(): + for name, ds_param in g_pm.params.items(): + ds_id_to_param[g_pm.ds_ids[name]] = ds_param.param + + persistent_mem = 0 + nz3 = get_deepcompile_handle() + for ds_id, size in sorted_ds_ids.items(): + if persistent_mem + size > available_mem: + break + persistent_mem += size + + param_obj = ds_id_to_param[ds_id] + + nz3.set_persistent(ds_id) + if dist.get_rank() == 0: + print(f"Set persistent: {ds_id} size: {size} persistent_mem: {persistent_mem} shape: {param_obj.ds_shape}") + + return gm + + +# def make_selective_gather(z3_optimizer, nz3): + +# def selective_gather_wrapper(graph: Graph, graph_id: int, graph_order: List[int], profiling_results, +# mem_budget: float, param_manager, bwd: bool) -> Graph: +# return selective_gather(graph, graph_id, graph_order, profiling_results, mem_budget, param_manager, bwd, +# z3_optimizer, nz3) + +# return selective_gather_wrapper diff --git a/deepspeed/compile/passes/zero1_compile.py b/deepspeed/compile/passes/zero1_compile.py new file mode 100644 index 000000000000..fb331cc6bca3 --- /dev/null +++ b/deepspeed/compile/passes/zero1_compile.py @@ -0,0 +1,55 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List + +import torch +from torch.fx import GraphModule + +from ..util import get_deepcompile_handle +from ..fx import add_postprocess, move_primals_to_head, _make_node_meta + +NAME = "zero1_compile" + + +def add_z1_reduce_fw(gm: GraphModule, graph_id: int, profiling_results, param_manager) -> GraphModule: + + dc = get_deepcompile_handle() + param_indices = profiling_results[graph_id].param_indices + dc.register_graph_z1(graph_id, [v[1] for v in param_indices]) # Need this before profiling + + return gm + + +def add_z1_reduce_bw(gm: GraphModule, graph_id: int, param_manager) -> GraphModule: + + graph = gm.graph + pm = param_manager[graph_id] + _, param_name_to_grad = pm.get_bwd_mapping(graph) + + for param_name in pm.param_names: + + grad_node = param_name_to_grad[param_name] + + assert param_name in pm.ds_ids, f"param_name={param_name} not in ds_ids" + ds_id = pm.ds_ids[param_name] + + new_node = add_postprocess(graph, + grad_node, + torch.ops.dc.reduce_grad.default, + extra_args=[graph_id, ds_id], + name=f"reduce_param_{param_name}", + meta=_make_node_meta(grad_node, param_name, True)) + new_node.meta["val"] = None + + gm.graph = move_primals_to_head(graph) + return gm + + +def add_z1_reduce(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager, bwd: bool) -> GraphModule: + if bwd: + return add_z1_reduce_bw(gm, graph_id, param_manager) + return add_z1_reduce_fw(gm, graph_id, profiling_results, param_manager) diff --git a/deepspeed/compile/passes/zero3_compile.py b/deepspeed/compile/passes/zero3_compile.py new file mode 100644 index 000000000000..1fe420081bd0 --- /dev/null +++ b/deepspeed/compile/passes/zero3_compile.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import gc +from typing import List, Dict + +import torch +from torch.fx import Graph, Node, GraphModule + +from ..util import get_input_nodes, get_param_nodes, get_index_by_graph_id, get_deepcompile_handle, get_real_uses +from ..fx import add_postprocess, _make_node_meta, get_output_node, move_primals_to_head +from ..profilers.graph_profile import ProfilingInterpreter +from ..list_schedule import fast_free_schedule + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + +NAME = "zero3_compile" + + +def add_allgather(graph_id: int, graph: Graph, node: Node, ds_id: int): + new_ag_node = add_postprocess(graph, + node, + torch.ops.dc.allgather_param.default, + extra_args=[graph_id, ds_id], + name=f"allgather_ds_param_{node.target}_{ds_id}", + meta=_make_node_meta(node, ds_id, True)) + new_ag_node.meta["val"] = node.meta["val"] + + # Set the previous node back to output + # We don't want to change the output node to allgather + output_node = get_output_node(graph) + output_node.replace_input_with(new_ag_node, node) + + # Add wait as well + new_wait_node = add_postprocess(graph, + new_ag_node, + torch.ops.dc.wait_allgather.default, + extra_args=[graph_id, ds_id], + name=f"wait_allgather_ds_param__{node.target}_{ds_id}", + meta=_make_node_meta(node, ds_id, False)) + new_wait_node.meta["val"] = node.meta["val"] + + return new_ag_node + + +def add_release(graph_id: int, graph: Graph, node: Node, release_node: Node, ds_id: int, n_users: int): + new_node = add_postprocess(graph, + node, + torch.ops.dc.release_param.default, + extra_args=[graph_id, ds_id, n_users], + name=f"release_ds_param_{release_node.target}_{node.name}_{ds_id}", + meta=_make_node_meta(node, ds_id, False)) + new_node.meta["val"] = None + + +def add_reduce(graph_id: int, graph: Graph, grad_node: Node, param_name: str, ds_id: int): + new_node = add_postprocess(graph, + grad_node, + torch.ops.dc.reduce_grad.default, + extra_args=[graph_id, ds_id], + name=f"reduce_ds_param_{param_name}", + meta=_make_node_meta(grad_node, ds_id, True)) + new_node.meta["val"] = None + + +def add_gather_and_release(graph_id: int, graph: Graph, param_manager, param_nodes: List[Node]) -> Graph: + + node_to_uses = get_real_uses(graph) + for pn in param_nodes: + add_allgather(graph_id, graph, pn, param_manager.ds_ids[pn.name]) + ds_id = param_manager.ds_ids[pn.name] + users = node_to_uses[pn] + for user in users: + add_release(graph_id, graph, user, pn, ds_id, len(users)) + + return move_primals_to_head(graph) + + +def add_gather_and_reduce(graph_id: int, graph: Graph, param_manager, param_nodes_bw: List[Node], + param_name_to_grad: Dict[str, Node]) -> Graph: + + add_gather_and_release(graph_id, graph, param_manager, param_nodes_bw) + + for param_name in param_manager.param_names: + add_reduce(graph_id, graph, param_name_to_grad[param_name], param_name, param_manager.ds_ids[param_name]) + + return move_primals_to_head(graph) + + +def add_z3_gather_release_fw(gm: GraphModule, + graph_id: int, + graph_order: List[int], + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) -> GraphModule: + + nz3 = get_deepcompile_handle() + + real_inputs = create_inputs_fn() + param_indices = profiling_results[graph_id].param_indices + + gm.graph = add_gather_and_release(graph_id, gm.graph, param_manager[graph_id], + get_param_nodes(gm.graph, param_indices)) + + nz3.register_graph_z3(graph_id, [v[1] for v in param_indices]) # Need this before profiling + + profiler = ProfilingInterpreter(gm, debug_log=debug_log) + profiler.run(*real_inputs) + del profiler + gc.collect() + get_accelerator().empty_cache() + + rank = dist.get_rank() + graph_index = get_index_by_graph_id(graph_order, graph_id) + if rank == 0 and debug_log: + print(f"Fwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") + + for n in gm.graph.nodes: + is_ds_param = n.name in param_manager[graph_id].ds_ids + if "val" in n.meta and is_ds_param: + # Used for Inductor's validation + n.meta["val"] = torch.empty([0], dtype=n.meta['val'].dtype, device=n.meta['val'].device) + + gm.graph = fast_free_schedule( + gm.graph, + get_accelerator().available_memory(), + 0, # unused + debug_log=debug_log) + + if rank == 0 and debug_log: + print(f"Fwd after scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") + + return gm + + +def add_z3_gather_release_bw(gm: GraphModule, + graph_id: int, + graph_order: List[int], + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) -> GraphModule: + + param_nodes_bw, param_name_to_grad = param_manager[graph_id].get_bwd_mapping(gm.graph) + gm.graph = add_gather_and_reduce(graph_id, gm.graph, param_manager[graph_id], param_nodes_bw, param_name_to_grad) + + input_nodes = get_input_nodes(gm.graph) + real_inputs = create_inputs_fn() + assert len(input_nodes) == len(real_inputs), f"Expected {len(real_inputs)} inputs, got {len(input_nodes)}" + + real_outputs = ProfilingInterpreter(gm, debug_log=debug_log).run(*real_inputs) + + del real_outputs + gc.collect() + get_accelerator().empty_cache() + + rank = dist.get_rank() + graph_index = get_index_by_graph_id(graph_order, graph_id) + if rank == 0 and debug_log: + print(f"Bwd before scheduling graph {graph_index} graph_id={graph_id} {gm.graph}") + + gm.graph = fast_free_schedule( + gm.graph, + get_accelerator().available_memory(), + 0, # unused + debug_log=debug_log) + + return gm + + +def add_z3_gather_release(gm: GraphModule, graph_id: int, graph_order: List[int], profiling_results, create_inputs_fn, + mem_budget: float, param_manager, bwd: bool) -> GraphModule: + if bwd: + return add_z3_gather_release_bw(gm, + graph_id, + graph_order, + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) + return add_z3_gather_release_fw(gm, + graph_id, + graph_order, + profiling_results, + create_inputs_fn, + param_manager, + debug_log=False) diff --git a/deepspeed/compile/patch_compiled_func.py b/deepspeed/compile/patch_compiled_func.py new file mode 100644 index 000000000000..c77d529a64ac --- /dev/null +++ b/deepspeed/compile/patch_compiled_func.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from deepspeed.utils.torch import required_torch_version + +backward_inputs = [] + +enabled_patched_func = False +original_grad_fn = None +base_meta = type(torch.autograd.Function) + +if required_torch_version(min_version=2.7): + + class FunctionMeta(base_meta): + + def __new__(cls, name, bases, dct): + if name == "CompiledFunction": + original_backward_impl = dct.get("_backward_impl") + + def wrapped_backward_impl(ctx, all_args): + assert original_backward_impl is not None + + if enabled_patched_func: + backward_inputs.append(all_args) + wrapped_backward_impl.owner_class.compiled_bw = None + + return original_backward_impl(ctx, all_args) + + wrapped_backward_impl.owner_class = None + dct["_backward_impl"] = staticmethod(wrapped_backward_impl) + new_class = super().__new__(cls, name, bases, dct) + wrapped_backward_impl.owner_class = new_class + + return new_class + + return super().__new__(cls, name, bases, dct) + +elif required_torch_version(min_version=2.6): + + class FunctionMeta(base_meta): + + def __new__(cls, name, bases, dct): + if name == "CompiledFunction": + original_backward_prologue = dct.get("_backward_prologue") + + def wrapped_backward_prologue(ctx, *grad_outputs): + assert original_backward_prologue is not None + + all_args = original_backward_prologue(ctx, *grad_outputs) + if enabled_patched_func: + backward_inputs.append(all_args) + wrapped_backward_prologue.owner_class.compiled_bw = None + + return all_args + + wrapped_backward_prologue.owner_class = None + dct["_backward_prologue"] = staticmethod(wrapped_backward_prologue) + new_class = super().__new__(cls, name, bases, dct) + wrapped_backward_prologue.owner_class = new_class + + return new_class + + return super().__new__(cls, name, bases, dct) + + +def patch_compiled_func(): + + global enabled_patched_func + enabled_patched_func = True + + class PatchedFunction(torch.autograd.Function, metaclass=FunctionMeta): + pass + + global original_grad_fn + original_grad_fn = torch.autograd.Function + torch.autograd.Function = PatchedFunction + + return backward_inputs + + +def unpatch_compiled_func(): + global enabled_patched_func + enabled_patched_func = False + + global original_grad_fn + torch.autograd.Function = original_grad_fn + + +def get_backward_inputs(): + return backward_inputs diff --git a/deepspeed/compile/patch_fake_tensor.py b/deepspeed/compile/patch_fake_tensor.py new file mode 100644 index 000000000000..1924e8861921 --- /dev/null +++ b/deepspeed/compile/patch_fake_tensor.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +try: + from torch._subclasses import FakeTensorMode + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch._dynamo.variables.builder import wrap_to_fake_tensor_and_record +except ImportError: + # Unsupported torch version + pass + + +def wrap_if_ds_param(t): + if hasattr(t, 'ds_id'): + data = torch.rand(t.ds_shape, + dtype=t.dtype, + layout=t.layout, + device=t.device, + pin_memory=t.is_pinned(), + requires_grad=t.requires_grad) + if isinstance(t, torch.nn.Parameter): + t = torch.nn.Parameter(data, requires_grad=t.requires_grad) + else: + t = data + return t + + +def patch_fake_tensor(): + # dynamo tracer uses wrap_to_fake_tensor_and_record + # Wrapping FakeTensorMode.from_tensor is not sufficient as dynamo generates SymbolicContext before calling from_tensor + original_wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record + + def wrap_to_fake_tensor_and_record_wrapper(t, *args, **kwargs): + dummy_tensor = wrap_if_ds_param(t) + ret = original_wrap_to_fake_tensor_and_record(dummy_tensor, *args, **kwargs) + if tracing_context := torch._guards.TracingContext.try_get(): + tracing_context.tensor_to_context[t] = tracing_context.tensor_to_context.pop(dummy_tensor) + return ret + + torch._dynamo.variables.builder.wrap_to_fake_tensor_and_record = wrap_to_fake_tensor_and_record_wrapper + + # aot_module_simplified uses fake_mode.from_tensor to process inputs + original_from_tensor = FakeTensorMode.from_tensor + + def from_tensor_wrapper(self, t, *args, **kwargs): + with unset_fake_temporarily(): + return original_from_tensor(self, wrap_if_ds_param(t), *args, **kwargs) + + FakeTensorMode.from_tensor = from_tensor_wrapper diff --git a/deepspeed/compile/profilers/__init__.py b/deepspeed/compile/profilers/__init__.py new file mode 100644 index 000000000000..7adb54f11872 --- /dev/null +++ b/deepspeed/compile/profilers/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import List, Tuple +from dataclasses import dataclass, field + +from torch.fx import Graph + + +@dataclass +class ProfilingResult: + fwd_graph: Graph = None + bwd_graph: Graph = None + needs_backward: bool = False + fwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) # name, current_alloc, delta, peak + bwd_mem: List[Tuple[str, int, int, int]] = field(default_factory=list) + fwd_time: List[Tuple[str, int, int]] = field(default_factory=list) # name, device_time, wall_time + bwd_time: List[Tuple[str, int, int]] = field(default_factory=list) + fwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list) # name, size + bwd_tensor_sizes: List[Tuple[str, int]] = field(default_factory=list) + param_indices: List[Tuple[int, int, Tuple[int, ...]]] = field(default_factory=list) # index, ds_id, ds_shape diff --git a/deepspeed/compile/profilers/comm_profile.py b/deepspeed/compile/profilers/comm_profile.py new file mode 100644 index 000000000000..18bd517c1e8f --- /dev/null +++ b/deepspeed/compile/profilers/comm_profile.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import os +import torch + +try: + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + # Unsupported torch version + pass + +import deepspeed +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator + + +def sync_all(): + get_accelerator().synchronize() + dist.barrier() + + +def get_bw(comm_op, size, duration): + n = dist.get_world_size() + tput = 0 + busbw = 0 + + if duration == 0: + raise ValueError("Error. Duration is 0.") + + if comm_op == "all_to_all": + tput = (size / duration) + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_gather": + size *= n + tput = (size / duration) + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_reduce": + tput = (size * 2 / duration) + busbw = (size / duration) * (2 * (n - 1) / n) + elif comm_op == "pt2pt" or comm_op == "broadcast": + tput = (size / duration) + busbw = tput + else: + raise ValueError("wrong comm_op specified") + + return tput, busbw + + +# Run all_gather and print metrics +def timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op): + sync_all() + # Warmups, establish connections, etc. + for i in range(warmup): + dist.all_gather_into_tensor(output, input, async_op=async_op) + sync_all() + + # time the actual comm op trials times and average it + start_event.record() + for i in range(trials): + dist.all_gather_into_tensor(output, input, async_op=async_op) + end_event.record() + sync_all() + duration = start_event.elapsed_time(end_event) / 1000 + + # maintain and clean performance data + avg_duration = duration / trials + size = input.element_size() * input.nelement() * dist.get_world_size() + # tput, busbw = get_bw('all_gather', size, avg_duration) + + avg_duration_ten = torch.tensor([avg_duration], device=device) + if dist.get_world_size() > 1: + dist.all_reduce(avg_duration_ten, dist.ReduceOp.AVG) + + return size, avg_duration_ten.item() + + +def run_all_gather(device, dtype, maxsize, warmup=5, trials=10, async_op=False): + + # Prepare benchmark header + global_rank = dist.get_rank() + world_size = dist.get_world_size() + + start_event = get_accelerator().Event(enable_timing=True) + end_event = get_accelerator().Event(enable_timing=True) + + # Create list of message sizes + M_LIST = [] + for x in (2**p for p in range(1, maxsize)): + m = x // world_size + if m > 0: + M_LIST.append(m) + + results = [(0, 0)] + sync_all() + # loop over various tensor sizes + for M in M_LIST: + global_rank = dist.get_rank() + try: + mat = torch.ones(M, dtype=dtype, device=device) + sync_all() + input = ((mat.mul_(float(global_rank))).view(-1)) + # Delete original mat to avoid OOM + del mat + get_accelerator().empty_cache() + output = torch.zeros(input.nelement() * world_size, dtype=dtype, device=device) + except RuntimeError as e: + if 'out of memory' in str(e): + if dist.get_rank() == 0: + print('WARNING: Ran out of GPU memory. Exiting comm op.') + sync_all() + break + else: + raise e + sync_all() + results.append(timed_all_gather(device, input, output, start_event, end_event, warmup, trials, async_op)) + + return results + + +profile_results = None + + +def create_predictor(): + global profile_results + if profile_results is None: + with unset_fake_temporarily(): + device = get_accelerator().current_device() + profile_results = run_all_gather(device, torch.bfloat16, 31) + if dist.get_rank() == 0: + for size, avg_duration in profile_results: + print(f"size: {size}, avg_duration: {avg_duration}") + + # Extract size and avg_duration from results + sizes = [result[0] for result in profile_results] + durations = [result[1] for result in profile_results] + + try: + from scipy.interpolate import interp1d + except ImportError: + raise RuntimeError("Please install scipy to use communication profiler in DeepCompile") + + predictor = interp1d(sizes, durations, kind='linear', fill_value="extrapolate") + + def f(size): + if size == 0: + return 0 + return predictor(size) + + # Create an interpolation function + return f + + +if __name__ == "__main__": + local_rank = int(os.environ['LOCAL_RANK']) + get_accelerator().set_device(local_rank) + print(f"local_rank={local_rank}") + + deepspeed.init_distributed(dist_backend='nccl') + + # Create predictor function + predictor = create_predictor() + + # Predict time for a specific data size + example_size = 1e9 + predicted_time = predictor(example_size) + print(f"Predicted time for size {example_size}: {predicted_time:.6f} seconds") + + dist.destroy_process_group() diff --git a/deepspeed/compile/profilers/graph_profile.py b/deepspeed/compile/profilers/graph_profile.py new file mode 100644 index 000000000000..6cb3d83e485a --- /dev/null +++ b/deepspeed/compile/profilers/graph_profile.py @@ -0,0 +1,295 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import time +from typing import Any, Tuple, Dict +import statistics + +import torch +from torch.fx import GraphModule, Interpreter +from torch.fx.node import map_aggregate + +try: + from torch.utils._pytree import tree_all, tree_leaves + from torch._subclasses.fake_tensor import unset_fake_temporarily, is_fake +except ImportError: + # Unsupported torch version + pass + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from ..util import is_comm_op, is_release_node, get_deepcompile_handle + + +def _all_real_if_tensor(args): + return tree_all(lambda x: not torch.is_tensor(x) or not is_fake(x), args) + + +def _to(v, device): + if torch.is_tensor(v): + with unset_fake_temporarily(): + return v.to(device) + return v + + +def _args_to_key(v): + + def _tensor_to_key(v) -> str: + if torch.is_tensor(v): + if v.numel() == 1: + return f"{v.dtype}{v.device}{v.item()}" + else: + return f"{v.dtype}{v.device}{v.shape}" + return str(v) + + return map_aggregate(v, _tensor_to_key) + + +def _node_size(out): + return sum([v.element_size() * v.numel() for v in tree_leaves(out) if torch.is_tensor(v)]) + + +def _get_mem_usage_out_of_torch(): + + adjust = 0 + try: + import pynvml + pynvml.nvmlInit() + + current_dev_id = get_accelerator().current_device() + handle = pynvml.nvmlDeviceGetHandleByIndex(current_dev_id) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + + torch_alloc = get_accelerator().memory_allocated() + adjust = info.used - torch_alloc + except: + # pynvml not available + pass + + return adjust + + +# https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html +class ProfilingInterpreter(Interpreter): + + def __init__(self, gm: GraphModule, iteration: int = 10, warmup: int = 5, debug_log=False): + super().__init__(gm) + + self.nz3 = get_deepcompile_handle() + + assert iteration > 0 + assert warmup >= 0 + self.iteration = iteration + self.warmup = warmup + self.device = torch.device(get_accelerator().current_device()) + self.cache: Dict[Tuple, Any] = {} + self.distributed = dist.is_initialized() + self.allgather_mem: Dict[int, int] = {} + self.debug_log = debug_log + self.mem_usage_out_of_torch = 0 + + def run(self, *args) -> Any: + """Run the graph with profiling enabled. + + args: inputs to the graph. Tensors in the inpusts must be real tensors, not fake tensors. args can contain ds parameters. + returns: The output of the graph. Tensor in the output is real tensors. + """ + try: + assert _all_real_if_tensor(args), "Inputs must be real tensors" + self.nz3.enable_profiling(True) + + with unset_fake_temporarily(): + with get_accelerator().random().fork_rng(devices=[self.device]): + self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() + return_val = super().run(*args) + except Exception as e: + msg = e.msg if "msg" in dir(e) else str(e) + print(f"Profiling error {msg}") + finally: + self.nz3.clear_all_gathered_params() + self.nz3.enable_profiling(False) + return return_val + + def run_node(self, n: torch.fx.Node) -> Any: + + if n.op in {"placeholder", "output"}: + n.meta["device_time"] = 0.0 + n.meta["wall_time"] = 0.0 + n.meta["alloc_mem"] = 0 + n.meta["max_memory"] = 0 + n.meta["tensor_size"] = _node_size(n) + return super().run_node(n) + + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + def rebuild_param_if_necessary(v): + if hasattr(v, "ds_id"): + v.all_gather(param_list=[v]) + return v + + args = map_aggregate(args, lambda x: rebuild_param_if_necessary(x)) + + args = map_aggregate(args, lambda x: _to(x, self.device)) + kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) + + cache_key = (n.target, _args_to_key(args), _args_to_key(kwargs)) + cache_hit = cache_key in self.cache + + cache_hit_flag = torch.tensor([0 if cache_hit else 1], device=self.device, dtype=torch.int) + if self.distributed: + dist.all_reduce(cache_hit_flag, dist.ReduceOp.SUM) + cache_hit = cache_hit_flag.item() == 0 + + if cache_hit: + device_time, wall_time, alloc_mem, max_mem, tensor_size = self.cache[cache_key] + n.meta["device_time"] = device_time + n.meta["wall_time"] = wall_time + n.meta["alloc_mem"] = alloc_mem + n.meta["max_mem"] = max_mem + n.meta["tensor_size"] = tensor_size + + is_release_op = is_release_node(n) + run_only_once = cache_hit or is_release_op + iteration = 1 if run_only_once else self.iteration + accelerator = get_accelerator() + start_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] + end_events = [accelerator.Event(enable_timing=True) for _ in range(iteration)] + + get_accelerator().reset_peak_memory_stats() + alloc_mem_start = get_accelerator().memory_allocated() + max_mem_start = get_accelerator().max_memory_allocated() + + if not run_only_once: + for i in range(self.warmup): + out = getattr(self, n.op)(n.target, args, kwargs) + + if is_comm_op(n): + assert self.distributed, f"Distributed environment is not initialized but comm operator {n.name} {n.target} is used." + dist.barrier() + + start = time.time() + for i in range(iteration): + start_events[i].record() + out = getattr(self, n.op)(n.target, args, kwargs) + end_events[i].record() + accelerator.synchronize() + walltime_sum = time.time() - start + + if is_comm_op(n): + dist.barrier() + + alloc_mem = get_accelerator().memory_allocated() - alloc_mem_start + self.mem_usage_out_of_torch + max_memory = get_accelerator().max_memory_allocated() - max_mem_start + self.mem_usage_out_of_torch + tensor_size = _node_size(out) + + def partition_param_if_necessary(v): + if hasattr(v, "ds_id") and not v.ds_persist: + v.partition(param_list=[v], has_been_updated=False) + return v + + args = map_aggregate(args, lambda x: partition_param_if_necessary(x)) + + if not cache_hit: + device_time = statistics.mean([s.elapsed_time(e) for s, e in zip(start_events, end_events)]) + wall_time = walltime_sum / iteration * 1000 + + with unset_fake_temporarily(): + vals_to_bcast = torch.tensor([device_time, wall_time, alloc_mem, max_memory, tensor_size], + device=self.device) + if self.distributed: + dist.all_reduce(vals_to_bcast, dist.ReduceOp.AVG) + n.meta["device_time"] = vals_to_bcast[0].item() + n.meta["wall_time"] = vals_to_bcast[1].item() + n.meta["alloc_mem"] = int(vals_to_bcast[2].item()) + n.meta["max_mem"] = int(vals_to_bcast[3].item()) + n.meta["tensor_size"] = int(vals_to_bcast[4].item()) + self.cache[cache_key] = (n.meta["device_time"], n.meta["wall_time"], n.meta["alloc_mem"], + n.meta["max_mem"], n.meta["tensor_size"]) + + if is_release_op: + n.meta["alloc_mem"] = -self.allgather_mem.get(args[2], 0) + + if dist.get_rank() == 0 and self.debug_log: + print( + f"{n.target} {n.meta['device_time']:.2f}ms {n.meta['wall_time']:.2f}ms alloc_mem={n.meta['alloc_mem'] / 1024 / 1024:.2f}MB max_mem={n.meta['max_mem'] / 1024 / 1024:.2f}MB tensor_size={n.meta['tensor_size']}" + ) + + if n.target == torch.ops.dc.allgather_param.default: + out = args[0] + assert hasattr(out, "ds_id") + if not out.ds_persist: + self.nz3.invalidate_gathered_param(args[2]) + self.allgather_mem[out.ds_id] = n.meta["alloc_mem"] + + return out + + +class MemoryProfilingInterpreter(Interpreter): + + def __init__(self, gm: GraphModule, debug_log=False): + super().__init__(gm) + self.nz3 = get_deepcompile_handle() + self.device = torch.device(get_accelerator().current_device()) + self.mem_record = [] + self.last_alloc = get_accelerator().memory_allocated() + + self.node_counter = 0 + self.node_num = len(gm.graph.nodes) + self.debug_log = debug_log + + def run(self, *args) -> Any: + try: + assert _all_real_if_tensor(args), "Inputs must be real tensors" + self.nz3.enable_profiling(True) + self.mem_usage_out_of_torch = _get_mem_usage_out_of_torch() + + with unset_fake_temporarily(): + with get_accelerator().random().fork_rng(devices=[self.device]): + return_val = super().run(*args) + except Exception as e: + print(f"MemoryProfiling error {e}") + finally: + self.nz3.enable_profiling(False) + + return return_val + + def run_node(self, n: torch.fx.Node) -> Any: + get_accelerator().reset_peak_memory_stats() + + if n.op in {"placeholder", "output"}: + ret = super().run_node(n) + else: + args, kwargs = self.fetch_args_kwargs_from_env(n) + args = map_aggregate(args, lambda x: _to(x, self.device)) + kwargs = map_aggregate(kwargs, lambda x: _to(x, self.device)) + ret = getattr(self, n.op)(n.target, args, kwargs) + + del args, kwargs + + current_alloc = get_accelerator().memory_allocated() + self.mem_usage_out_of_torch + max_alloc = get_accelerator().max_memory_allocated() + self.mem_usage_out_of_torch + vals_to_bcast = torch.tensor([current_alloc, max_alloc], device=self.device) + dist.all_reduce(vals_to_bcast, dist.ReduceOp.MAX) + current_alloc = vals_to_bcast[0].item() + max_alloc = vals_to_bcast[1].item() + + self.mem_record.append((n.name, current_alloc, current_alloc - self.last_alloc, max_alloc)) + + self.node_counter += 1 + if self.debug_log and dist.get_rank() == 0: + print( + f"Mem prof Node {self.node_counter}/{self.node_num} {n.name} memory {current_alloc / 1024 / 1024:.2f}MB delta {(current_alloc - self.last_alloc) / 1024 / 1024:.2f}MB" + ) + + self.last_alloc = current_alloc + + return ret + + def dump(self, path): + import pandas as pd + df = pd.DataFrame(self.mem_record, columns=["node", "memory", "delta", "max_mem"]) + df.to_csv(path, index=False) diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py new file mode 100644 index 000000000000..fdeb5f4347a9 --- /dev/null +++ b/deepspeed/compile/util.py @@ -0,0 +1,429 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import functools +import operator +from typing import List, Tuple, Dict +from collections import defaultdict + +import torch +from torch.fx import Node, Graph +from torch.fx.node import map_aggregate, Argument, map_arg + +try: + from torch._subclasses.fake_tensor import unset_fake_temporarily +except ImportError: + # Unsupported torch version + pass + +import deepspeed.comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.utils.torch import required_torch_version +from deepspeed.ops.op_builder.dc import DeepCompileBuilder + + +def is_deepcompile_supported() -> bool: + return required_torch_version(min_version=2.6, max_version=2.7) and get_accelerator().device_name() == "cuda" + + +dc_handle = None + +if is_deepcompile_supported(): + sym_size_ops = { + operator.ge, + operator.le, + operator.eq, + operator.ne, + operator.gt, + operator.lt, + torch.ops.aten.sym_size.int, + operator.getitem, + } + + +def get_deepcompile_handle(): + global dc_handle + if dc_handle is None: + dc_handle = DeepCompileBuilder().load() + return dc_handle + + +def is_backend_inductor(backend): + return backend == "inductor" + + +backward_started = False +pre_backward_hooks = [] + + +def add_pre_backward_hook(hook): + pre_backward_hooks.append(hook) + + +def deepcompile_backward_prologue(is_gradient_accumulation_boundary): + + for hook in pre_backward_hooks: + hook() + + dc = get_deepcompile_handle() + dc.start_backward(is_gradient_accumulation_boundary) + + +def log_rank0(msg: str, enable: bool = False): + if dist.get_rank() == 0 and enable: + print(msg) + + +def get_no_copy_ops(): + # Need to compile custom ops + get_deepcompile_handle() + return { + torch.ops.aten.t.default, torch.ops.aten.view.default, torch.ops.aten.detach.default, + torch.ops.aten.permute.default, torch.ops.dc.wait_allgather.default + } + + +def get_input_nodes(graph: Graph) -> List[Node]: + return [n for n in graph.nodes if n.op == "placeholder"] + + +def get_param_nodes(graph: Graph, index_to_ds_ids: List[Tuple[int, int]]) -> List[Node]: + all_input_nodes = get_input_nodes(graph) + return [all_input_nodes[i] for i, _, _ in index_to_ds_ids] + + +def is_comm_op(node: Node) -> bool: + return "comm" in node.meta and node.meta["comm"] + + +def exclude_from_act_offload(node: Node) -> bool: + return node.target in sym_size_ops + + +def dtype_to_elem_size(dtype: torch.dtype) -> int: + if dtype == torch.float32: + elem_size = 4 + elif dtype == torch.float64: + elem_size = 8 + elif dtype == torch.float16: + elem_size = 2 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + return elem_size + + +def tensor_meta_size(tensor_meta) -> int: + numel = 1 if len(tensor_meta.shape) == 0 else functools.reduce(operator.mul, tensor_meta.shape) + + dtype = tensor_meta.dtype + if dtype == torch.float32: + elem_size = 4 + elif dtype == torch.float64 or dtype == torch.int64: + elem_size = 8 + elif dtype == torch.float16 or dtype == torch.bfloat16: + elem_size = 2 + elif dtype == torch.bool: + elem_size = 1 + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + return numel * elem_size + + +class NodeValueOffloadHelper: + + def __init__(self, device): + self.device = device + self.env_values: Dict[str, Argument] = {} + self.original_device: Dict[torch.Tensor, torch.device] = {} + + def _to_cpu(self, v): + if torch.is_tensor(v): + with unset_fake_temporarily(): + device = v.device + offloaded = v.to('cpu').detach() + self.original_device[offloaded] = device + return offloaded + return v + + def _from_cpu(self, v): + if torch.is_tensor(v) and v in self.original_device: + return v.to(self.original_device[v]) + return v + + def save(self, name: str, v: Argument, offload) -> None: + self.env_values[name] = map_aggregate(v, lambda x: self._to_cpu(x) if offload else x) + + def load(self, name: str) -> Argument: + return map_aggregate(self.env_values[name], lambda x: self._from_cpu(x)) + + def get_offloaded_value(self, name: str) -> Argument: + return self.env_values[name] + + def has_value(self, name: str) -> bool: + return name in self.env_values + + def clear(self) -> None: + self.env_values.clear() + self.original_device.clear() + + +def materialize_fake(v, device=None): + from torch._subclasses.fake_tensor import is_fake + + def convert(t): + if is_fake(t): + with unset_fake_temporarily(): + if t.is_floating_point(): + return torch.randn(t.shape, + dtype=t.dtype, + device=t.device if device is None else device, + layout=t.layout, + requires_grad=t.requires_grad, + pin_memory=t.is_pinned()) + else: + return torch.zeros(t.shape, + dtype=t.dtype, + device=t.device if device is None else device, + requires_grad=t.requires_grad) + + return t + + return map_aggregate(v, lambda x: convert(x)) + + +def get_last_uses(graph: Graph): + position = {node: i for i, node in enumerate(graph.nodes)} + + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + no_copy_ops = get_no_copy_ops() + + def register_last_uses(n: Node, user: Node): + update = False + known_last_use = None + + if user.target in no_copy_ops and n in node_to_last_use: + last_user = node_to_last_use[user] + last_use_position = position[last_user] + + known_last_use = node_to_last_use[n] + known_last_use_position = position[known_last_use] + update = last_use_position > known_last_use_position + + if n not in node_to_last_use or update: + if user.target in no_copy_ops: + user = node_to_last_use[user] + + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + if known_last_use: + user_to_last_uses[known_last_use].remove(n) + + for node in reversed(graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + return node_to_last_use, user_to_last_uses + + +def get_real_uses(graph: Graph): + node_to_uses: Dict[Node, List[Node]] = defaultdict(list) + no_copy_ops = get_no_copy_ops() + + def register_last_uses(n: Node, user: Node): + if user.target == "output": + return + + if user.target in no_copy_ops: + users = node_to_uses[user] + node_to_uses[n].extend(users) + else: + node_to_uses[n].append(user) + + for node in reversed(graph.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + return node_to_uses + + +def count_inflight_values(graph: Graph, file_path: str): + position = {node: i for i, node in enumerate(graph.nodes)} + + node_to_last_use, user_to_last_uses = get_last_uses(graph) + + max_inflight_size = 0 + inflight_values = set() + + # Output csv. + csv_filename = file_path + csv_data = [] + header = [ + 'Node', 'tensor_size', 'inflight_size', 'inflight_size_in_output', 'args', 'users', 'node_to_last_use', + 'lifetime', 'user_to_last_uses', 'inflight_values' + ] + csv_data.append(header) + + from .fx import get_output_node + output_node = get_output_node(graph) + values_in_output = set([n for n in output_node.args[0] if isinstance(n, Node)]) + + for node in graph.nodes: + inflight_values.add(node) + if node in user_to_last_uses: + for to_delete in user_to_last_uses[node]: + inflight_values.remove(to_delete) + + assert "tensor_size" in node.meta, f"Node {node} does not have tensor_size" + inflight_size = sum(n.meta["tensor_size"] for n in inflight_values) + inflight_size_in_output = sum(n.meta["tensor_size"] for n in inflight_values if n in values_in_output) + + lifetime = position[node_to_last_use[node]] - position[node] if node in node_to_last_use else 0 + + row = [ + node.name, node.meta["tensor_size"], inflight_size, inflight_size_in_output, + [a.name for a in node.args if isinstance(a, Node)], + list(node.users.keys()), node_to_last_use[node] if node in node_to_last_use else 'NA', lifetime, + user_to_last_uses[node] if node in user_to_last_uses else 'NA', + list(inflight_values) + ] + csv_data.append(row) + + # print( + # f"Node: {node.name} users: {list(node.users.keys())} node_to_last_use: {node_to_last_use[node] if node in node_to_last_use else 'NA'} user_to_last_uses: {user_to_last_uses[node] if node in user_to_last_uses else 'NA'} inflight_values: {inflight_values} inflight_size: {inflight_size}" + # ) + max_inflight_size = max(max_inflight_size, inflight_size) + + import csv + with open(csv_filename, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerows(csv_data) + + print(f"Max inflight size: {max_inflight_size}") + print(f"Data successfully written to {csv_filename}") + + +def get_activation_node_names(graph: Graph, param_nodes_bw: List[Node], fwd_output_names: List[str]): + + input_nodes = get_input_nodes(graph) + param_node_names = set([n.name for n in param_nodes_bw]) + + activation_node_names = [] + for in_node in input_nodes: + if in_node.name in fwd_output_names: + if in_node.name not in param_node_names: + activation_node_names.append(in_node.name) + + return activation_node_names + + +class TensorOffloadHelper(): + + def __init__(self): + self.devices = {} + self.base_tensors = {} + self.views = {} + self.arg_list = [] + self.offloaded = {} + self.non_tensor = {} + + def offload(self, argument): + + def is_base_tensor(tensor): + return torch.is_tensor(a) and not a._is_view() and not hasattr(tensor, "ds_id") + + base_tensor_ids = set() + for a in argument: + if is_base_tensor(a): + base_tensor_ids.add(id(a)) + + for a in argument: + a_id = id(a) + + if is_base_tensor(a): + # Base tensor + self.devices[a_id] = a.device + self.base_tensors[a_id] = a + # elif torch.is_tensor(a) and not hasattr(a, "ds_id") and id(a._base) in base_tensor_ids: + # # View + # self.views[a_id] = { + # "base_id": id(a._base), + # "size": a.size(), + # "stride": a.stride(), + # "offset": a.storage_offset(), + # } + else: + # other types or ds tensor + self.non_tensor[a_id] = a + + self.arg_list.append(a_id) + + for a in argument: + if is_base_tensor(a): + a.data = a.data.to("cpu") + + def reload(self, in_place): + + loaded_base_tensors = {} + for a_id in self.arg_list: + if a_id in self.base_tensors: + device = self.devices[a_id] + + if in_place: + self.base_tensors[a_id].data = self.base_tensors[a_id].to(device) + loaded_base_tensors[a_id] = self.base_tensors[a_id] + else: + loaded_base_tensors[a_id] = self.base_tensors[a_id].to(device) + + results = [] + for a_id in self.arg_list: + if a_id in self.base_tensors: + results.append(loaded_base_tensors[a_id]) + + # elif a_id in self.views: + # view_info = self.views[a_id] + # # print(f"load_args loading view {a_id} base_id={view_info['base_id']} size={view_info['size']} stride={view_info['stride']} offset={view_info['offset']}") + # base_tensor = loaded_base_tensors[view_info["base_id"]] + # view_tensor = base_tensor.as_strided( + # view_info["size"], view_info["stride"], view_info["offset"] + # ) + # results.append(view_tensor) + + elif a_id in self.non_tensor: + results.append(self.non_tensor[a_id]) + + return results + + +def add_mem_profile_nodes(graph: Graph, prefix: str): + + def show_memory(label: str): + if dist.get_rank() == 0: + print( + f"{prefix} {label} alloc_mem={get_accelerator().memory_allocated()} max_mem={get_accelerator().max_memory_allocated()}" + ) + + nodes = list(graph.nodes) + for node in nodes: + if node.op == "output": + continue + + with graph.inserting_after(node): + msg = f"Mem {node.name}" + name = f"show_memory_{node.name}" + graph.create_node('call_function', show_memory, (msg, ), {}, name=name) + + +def is_release_node(n: Node) -> bool: + return n.target == torch.ops.dc.release_param.default + + +def get_index_by_graph_id(graph_order, target_graph_id): + for index, (graph_id, _) in enumerate(graph_order): + if graph_id == target_graph_id: + return index + return -1 diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 3e0c9d8a1652..218fb02be766 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -8,6 +8,7 @@ import shutil import subprocess import warnings +import re from shlex import split from abc import ABC, abstractmethod from deepspeed.accelerator import get_accelerator @@ -34,7 +35,10 @@ def get_cmd(self, environment, active_resources): """Return the command to execute on node""" def add_export(self, key, var): - self.exports[key.strip()] = f"\"{var.strip()}\"" + var = var.strip() + if re.search(r'[^\w@%+=:,./-]', var): + var = f"\"{var}\"" + self.exports[key.strip()] = var def parse_user_args(self): return self.args.user_args diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index bea1e14fa51f..2f7daed4ef43 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -462,10 +462,8 @@ def main(args=None): if multi_node_exec and not args.no_ssh_check and not args.no_ssh: first_host = list(active_resources.keys())[0] try: - ssh_check_cmd = "ssh -o PasswordAuthentication=no " - if args.ssh_port is not None: - ssh_check_cmd += f"-p {args.ssh_port} " - ssh_check_cmd += f"{first_host} hostname" + ssh_check_cmd = ("ssh -o PasswordAuthentication=no " + + (f"-p {args.ssh_port} " if args.ssh_port is not None else "") + f"{first_host} hostname") safe_ssh_cmd = shlex.split(ssh_check_cmd) subprocess.check_call(safe_ssh_cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL) except subprocess.CalledProcessError: @@ -485,7 +483,7 @@ def main(args=None): result = subprocess.check_output(hostname_cmd) except subprocess.CalledProcessError as err: logger.error( - "Unable to detect suitable master address via `hostname -I`, please manually specify one via --master_addr" + "Unable to detect suitable master address via 'hostname -I', please manually specify one via --master_addr" ) raise err args.master_addr = result.decode('utf-8').split()[0] diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index b089ec420d47..274f56e92487 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -11,7 +11,7 @@ from typing import Optional import torch from deepspeed import comm as dist -from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer +from .layers import * from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list @@ -211,7 +211,7 @@ def __init__(self, self.orig_layer_impl = orig_layer_impl self.linear_policies = None self.conv_linear_layer = False - self.keep_module_on_host = keep_module_on_host + TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host) def in_module_list(module, module_list): for item in module_list: @@ -350,14 +350,11 @@ def _replace(self, child, name, conv_linear_layer): # and avoid any complex shard-related logic. if getattr(child, "replaced", False) == True: return - device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name() - # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some - # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy. - return_new_copy = not self.keep_module_on_host + weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip - if "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( + if "mlp.gate" == name or "q_a_proj" in name or "kv_a_proj_with_mqa" in name or name == "block_sparse_moe.gate" or ( ('mlp.shared_expert_gate' == name or 'mlp.gate' == name) and 'qwen2_moe' in str(type(self.module))): return child # For Yuan model diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 0c673225a732..60c87bcb460a 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -17,6 +17,11 @@ from copy import deepcopy from typing import Union +__all__ = [ + "TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce", + "Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer" +] + DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE DS_IS_REPLACED_MODULE = 'ds_is_replaced_module' DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel' @@ -43,26 +48,6 @@ def set_autotp_mode(training=False): DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE -def move(tensor, device): - # TODO: consider the timing of deletion - # to save host resources when DP > 1。 - - if tensor.is_meta: - # Keep tensor in meta device if tensor is meta. - return tensor - else: - # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). - # Using copy=True instead of clone() will help in case of cpu --> cpu. - # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. - cloned_tensor = tensor.to(device, copy=True) - - # free the memory of the original tensor to reduce memory peak - # Equivalent to directly deleting the tensor reference outside the function. - # see https://github.com/microsoft/DeepSpeed/pull/4353 - tensor.data = torch.empty(0, device=tensor.device) - return cloned_tensor - - class RowParallel(torch.autograd.Function): """ A custom autograd function for performing row-wise parallelism. @@ -95,6 +80,35 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, N return None, grad_output, None +class AsyncColumnParallel(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, weight, bias) -> torch.Tensor: + """ + Forward pass. + """ + ctx.use_bias = bias is not None + ctx.group = group + output = torch.matmul(input, weight.transpose(-1, -2)) + if bias is not None: + output += bias + + ctx.save_for_backward(input, weight) + + return output + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: + + input, weight = ctx.saved_tensors + grad_input = grad_output.matmul(weight) + handle = dist.all_reduce(grad_input.contiguous(), group=ctx.group, async_op=True) + grad_weight = grad_output.view(-1, grad_output.shape[-1]).t().matmul(input.view(-1, input.shape[-1])) + grad_bias = grad_output.sum(0) if ctx.use_bias else None + handle.wait() + return None, grad_input, grad_weight, grad_bias + + class ColumnParallel(torch.autograd.Function): """ Custom autograd function for column-wise parallelism. @@ -139,6 +153,16 @@ class TensorParallel_Layer(nn.Module, ABC): support_training (bool): Flag indicating whether the layer supports training (default: False). name (Optional[str]): The name of the layer, if provided. """ + ##### Initialize Parameter List ##### + + # keep_module_on_host determines whether to keep the module on the host. + # Checkpoints are first loaded to the host (sometimes directly from disk to avoid filling host memory), + # so an additional copy is unnecessary. + keep_module_on_host: bool = False + + ##### Runtime Parameter List ##### + tp_overlap_comm: bool = False + """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """ def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any): """ @@ -163,6 +187,16 @@ def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any): if kwargs.get('name') is not None: self.name = kwargs.get('name') # Set the layer name if provided. + @classmethod + def set_keep_module_on_host(cls, value: bool): + """ + Set the static variable keep_module_on_host. + + Args: + value (bool): The new value for keep_module_on_host. + """ + cls.keep_module_on_host = value + @abstractmethod def forward(self, input): """ @@ -235,6 +269,38 @@ def extra_repr(self): in_features, out_features, self.bias is not None, dtype) return extra_repr_str + def move(self, tensor): + # TODO: consider the timing of deletion + # to save host resources when DP > 1。 + + # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some + # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy. + if tensor.is_meta: + # Keep tensor in meta device if tensor is meta. + return tensor + else: + device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name() + return_new_copy = not self.__class__.keep_module_on_host + + # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). + # Using copy=True instead of clone() will help in case of cpu --> cpu. + # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. + cloned_tensor = tensor.to(device, copy=return_new_copy) + + if return_new_copy: + # free the memory of the original tensor to reduce memory peak + # Equivalent to directly deleting the tensor reference outside the function. + # see https://github.com/microsoft/DeepSpeed/pull/4353 + tensor.data = torch.empty(0, device=tensor.device) + return cloned_tensor + + +def configure_tensor_parallel_runtime(config): + runtime_keys = ['tp_overlap_comm'] + for key in runtime_keys: + if hasattr(config, key): + setattr(TensorParallel_Layer, key, getattr(config, key)) + class GatherReplacedLayerParams: """ @@ -349,7 +415,7 @@ def _tp_partition(self, params_list): return _partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() params_list[idx].data = _partition @@ -363,7 +429,7 @@ def uneven_partition(self, params_list): self.name), dim=1)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() params_list[idx].data = _partition @@ -382,11 +448,15 @@ def __init__(self, module, mp_group=None, skip_partition=False, **kwargs): self.config_tp_params(self.bias) def forward(self, input): - if getattr(self, 'mp_group', None) is not None: - input = ColumnParallel.apply(self.mp_group, input) - output = torch.matmul(input, self.weight.transpose(-1, -2)) - if self.bias is not None: - output += self.bias + if not self.__class__.tp_overlap_comm: + if getattr(self, 'mp_group', None) is not None: + input = ColumnParallel.apply(self.mp_group, input) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output += self.bias + else: + output = AsyncColumnParallel.apply(self.mp_group, input, self.weight, self.bias) + return output @torch.no_grad() @@ -414,7 +484,7 @@ def _tp_partition(self, params_list): #split bias if provide _partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() params_list[idx].data = _partition @@ -429,7 +499,7 @@ def uneven_partition(self, params_list): self.name), dim=0)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() params_list[idx].data = _partition @@ -475,7 +545,7 @@ def _tp_partition(self, params_list): _partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index) - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() params_list[idx].data = _partition @@ -492,13 +562,13 @@ def _tp_partition(self, params_list): weight, bias = params_list[0], params_list[1] _partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name), dim=1)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() weight.data = _partition if bias is not None: _partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name), dim=0)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() bias.data = _partition @@ -522,9 +592,9 @@ class Yuan_LinearLayer(LinearLayer): def _tp_partition(self, params_list): weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size, True) - params_list[0].data = move(weight, get_accelerator().current_device_name()).detach() + params_list[0].data = self.move(weight).detach() if bias is not None: - params_list[1].data = move(bias, get_accelerator().current_device_name()).detach() + params_list[1].data = self.move(bias).detach() class GateUpPack_LinearLayer(LinearLayer): @@ -532,9 +602,9 @@ class GateUpPack_LinearLayer(LinearLayer): @torch.no_grad() def _tp_partition(self, params_list): weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size) - params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach() + params_list[0].data = self.move(weight).detach() if bias is not None: - params_list[1].data = move(bias, device=get_accelerator().current_device_name()).detach() + params_list[1].data = self.move(bias).detach() class Conv_LinearALlreduce(LinearAllreduce): @@ -549,7 +619,7 @@ def _tp_partition(self, params_list): _partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name), dim=1)[self.tp_index] - _partition = move(_partition, get_accelerator().current_device_name()).detach() + _partition = self.move(_partition).detach() params_list[idx].data = _partition diff --git a/deepspeed/ops/compile/__init__.py b/deepspeed/ops/compile/__init__.py new file mode 100755 index 000000000000..e38d56359fea --- /dev/null +++ b/deepspeed/ops/compile/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from ..op_builder import DeepCompileBuilder diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py index 746e217d4194..086525cc6442 100644 --- a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -39,25 +39,19 @@ def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, str weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( (pid_n * BLOCK_SIZE_N) // quantization_group_size) - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset + ((k * BLOCK_SIZE_K * stride_bk) // quantization_group_size)) # Dequantize weight (fp8 -> bf16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (weight & 0x80).to(tl.uint16) << 8 + w = w | ((weight & 0x7f).to(tl.uint16) << 4) w = (w + 0x3C00).to(tl.uint16) - w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + w = (w.to(tl.bfloat16, bitcast=True).to(tl.float32) * scale).to(tl.bfloat16) inp_data += BLOCK_SIZE_K * stride_ak weight_data += BLOCK_SIZE_K * stride_bk - weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K - weight = tl.load(weight_data, mask=weight_mask, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), - mask=weight_mask, - other=0.0) accumulator += tl.dot(inp, w) diff --git a/deepspeed/ops/fp_quantizer/quantize.py b/deepspeed/ops/fp_quantizer/quantize.py index 69c21eaf693b..47b3b08c7e03 100644 --- a/deepspeed/ops/fp_quantizer/quantize.py +++ b/deepspeed/ops/fp_quantizer/quantize.py @@ -79,27 +79,15 @@ def quantize(self, else: assert (0), \ f"Missing {q_bits}-quantization, please add the template arguments for the kernel to support this precision!" - - # Adding (group_size - 1) is for padding - self.num_groups = (input.numel() + self.q_config.group_size - 1) // self.q_config.group_size - # group_size should be the minimal number between the defined group size and number of elements in tensor. - group_size = int(min(self.q_config.group_size, input.numel()) * q_bits) // 8 - # CUDA quantization kernel saves the scale as (fp32) inside the quantized tensor for each group - if self.cuda_impl: - group_size += 4 - # CUDA quantization kernel allocates tensors as uint8, but handles them as fp8 inside the kernel. - self.input_q = torch.ones(self.num_groups, group_size, dtype=self.q_config.q_dtype, device=input.device) - # CUDA quantization kernel attaches scales to quantized result, in python implementation it can't be done - # because they are of different types. - self.scale = torch.ones(self.num_groups, 1, device=input.device) - out = fp_quant_module.quantize(self.input_q, input, self.scale, group_size, stochastic_mode, q_bits, - q_mantisa_bits) + self.num_groups = input.numel() // self.group_size + self.input_q = torch.ones(self.num_groups, + int(self.group_size * q_bits) // 8 + 4, + dtype=torch.uint8, + device=input.device) + out = fp_quant_module.quantize(self.input_q, input, self.group_size, stochastic_mode, q_bits, q_mantisa_bits) if return_meta_tensor: - if self.cuda_impl: - data, self.scale = out.split(group_size, dim=-1) - data = data.contiguous().reshape(input.shape) - else: - data = out.contiguous().reshape(input.shape) + data, self.scale = out.split(self.group_size, dim=-1) + data = data.contiguous().reshape(input.shape) self.scale = self.scale.contiguous() del self.input_q del out @@ -111,9 +99,9 @@ def quantize(self, def to(self, *args, **kwargs): # Intermediate tensors may need to be moved to different devices - if hasattr(self, 'input_q') and self.input_q is not None: + if hasattr(self, 'input_q'): self.input_q = self.input_q.to(*args, **kwargs) - if hasattr(self, 'scale') and self.scale is not None: + if hasattr(self, 'scale'): self.scale = self.scale.to(*args, **kwargs) def get_scales(self): @@ -136,16 +124,11 @@ def dequantize(self, input_q, fp_out=None, q_bits=8, q_mantisa_bits=3, scale=Non assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None and self.cuda_impl: + if scale is not None: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() - elif scale is not None and not self.cuda_impl: - group_size = int(min(self.q_config.group_size, input_q.numel()) * q_bits) // 8 - input_q = input_q.reshape(-1, group_size) - - fp_quant_module.dequantize(fp_out, input_q, self.scale, self.q_config.group_size, q_mantisa_bits, - q_bits - q_mantisa_bits - 1) + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() + fp_quant_module.dequantize(fp_out, input_q, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out def selective_dequantize(self, @@ -174,11 +157,11 @@ def selective_dequantize(self, assert (0), \ f"Missing {q_bits}-dequantization, please add the template arguments for the kernel to support this precision!" - if scale is not None and self.cuda_impl: + if scale is not None: assert input_q.numel() == fp_out.numel(), \ f'[De-quantization Error]: quantized data should have the same size as original tensor when scale is not None!' - input_q = torch.cat([input_q.reshape(-1, self.q_config.group_size), scale], dim=-1).contiguous() + input_q = torch.cat([input_q.reshape(-1, self.group_size), scale], dim=-1).contiguous() - fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.q_config.group_size, q_mantisa_bits, + fp_quant_module.selective_dequantize(fp_out, input_q, indexes, self.group_size, q_mantisa_bits, q_bits - q_mantisa_bits - 1) return fp_out diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 78895e70df03..c9337a795aea 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -18,6 +18,7 @@ from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups from deepspeed.moe.utils import is_moe_param, is_moe_param_group from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank +from deepspeed.utils.torch import register_grad_hook from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -44,7 +45,7 @@ def __init__(self, timers=None, grad_acc_dtype=None, graph_harvesting=False, - immediate_grad_update=False, + immediate_grad_update=True, has_moe_layers=False): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) @@ -313,7 +314,7 @@ def step(self, closure=None): self.clear_hp_grads() - def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): + def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs): """Perform a backward pass and copy the low-precision gradients to the high-precision copy. @@ -323,7 +324,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg The low-precision grads are deallocated during this procedure. """ self.clear_lp_grads() - loss.backward(**bwd_kwargs) + loss.backward(retain_graph=retain_graph, **bwd_kwargs) if update_hp_grads: self.update_hp_grads(clear_lp_grads=clear_lp_grads) @@ -425,9 +426,6 @@ def update_lp_params(self): fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bf16_partitions[partition_id].data.copy_(fp32_partition.data) - # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) - # if i == 0: - # print_rank_0(f'{fp32_partition[:10]=}', force=True) all_gather_dp_groups(groups_flat=self.bf16_groups_flat, partitioned_param_groups=self.bf16_partitioned_groups, @@ -442,10 +440,12 @@ def clear_hp_grads(self): for i, group in enumerate(self.fp32_groups_gradients): self.fp32_groups_has_gradients[i] = [False] * len(group) - def clear_lp_grads(self): + def clear_lp_grads(self, set_to_none=False): # using zero_() fixed memory address for graph replay - set_to_none = False if self.graph_harvesting else True + if self.graph_harvesting: + assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None" + zero_grads_list = [] for group in self.bf16_groups: for param in group: @@ -458,6 +458,10 @@ def clear_lp_grads(self): if not set_to_none and len(zero_grads_list) > 0: torch._foreach_zero_(zero_grads_list) + def zero_grad(self, set_to_none=True): + self.clear_lp_grads(set_to_none) + self.clear_hp_grads() + def state_dict(self): state_dict = {} state_dict[CLIP_GRAD] = self.clip_grad @@ -537,20 +541,16 @@ def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx): self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False) def create_grad_acc_hooks(self): - self.grad_accs = [] for i, param_group in enumerate(self.bf16_groups): for j, param in enumerate(param_group): if param.requires_grad: def wrapper(param, i, j): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] def accumulate_hp_grads_and_remove_lp(*notneeded): self.accumulate_hp_grads_and_remove_lp(param, i, j) - self._grad_acc_hooks.append(grad_acc.register_hook(accumulate_hp_grads_and_remove_lp)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, accumulate_hp_grads_and_remove_lp)) wrapper(param, i, j) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index b6dabc161e8c..56b2e9863cad 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -31,6 +31,7 @@ from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import get_monitor_config from ..inference.config import WeightQuantConfig +from ..compile.config import CompileConfig from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -801,7 +802,6 @@ def __init__(self, config: Union[str, dict], mpu=None, mesh_device=None): def _initialize_params(self, param_dict): self.train_batch_size = get_train_batch_size(param_dict) - #print(f"beginning get_train_batch_size = {get_train_batch_size}") self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu(param_dict) self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict) self.steps_per_print = get_steps_per_print(param_dict) @@ -913,6 +913,8 @@ def _initialize_params(self, param_dict): self.weight_quantization_config = WeightQuantConfig( **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None + self.compile_config = CompileConfig(**param_dict.get('compile', {})) + self.timers_config = get_timers_config(param_dict) self.tensor_parallel_config = get_tensor_parallel_config(param_dict) diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index d5c3a1548360..54cf813fd7fd 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -61,7 +61,7 @@ def _process_deprecated_field(self, dep_field): # Get information about the deprecated field pydantic_config = self fields_set = pydantic_config.model_fields_set - kwargs = pydantic_config.model_fields[dep_field].json_schema_extra + kwargs = type(pydantic_config).model_fields[dep_field].json_schema_extra new_param_fn = kwargs.get("new_param_fn", lambda x: x) param_value = new_param_fn(getattr(pydantic_config, dep_field)) new_field = kwargs.get("new_param", "") diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 55cfa8f59c91..fa7e9cad73b8 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -128,7 +128,7 @@ # BFLOAT16 optimizer immediate gradient update BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update" -BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False +BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True ######################################### # FP16 support @@ -249,7 +249,7 @@ Optional comm data type for seq paralleism should be set as: "seq_parallel_communication_data_type": "fp32" ''' -SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_comm_data_type" +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_communication_data_type" SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32" ######################################### diff --git a/deepspeed/runtime/data_pipeline/config.py b/deepspeed/runtime/data_pipeline/config.py index 623480518925..690ce97034e4 100644 --- a/deepspeed/runtime/data_pipeline/config.py +++ b/deepspeed/runtime/data_pipeline/config.py @@ -20,7 +20,6 @@ def get_data_efficiency_config(param_dict): sub_param_dict = param_dict[DATA_EFFICIENCY] output[DATA_SAMPLING] = get_data_sampling(sub_param_dict) output[DATA_ROUTING] = get_data_routing(sub_param_dict) - return output @@ -39,15 +38,14 @@ def get_data_efficiency_seed(param_dict): def get_data_sampling(param_dict): - output = {} + sub_param_dict = param_dict.get(DATA_SAMPLING, {}) + output = copy.copy(sub_param_dict) output[DATA_SAMPLING_ENABLED] = get_data_sampling_enabled(param_dict) output[DATA_SAMPLING_NUM_EPOCHS] = get_data_sampling_num_epochs(param_dict) output[DATA_SAMPLING_NUM_WORKERS] = get_data_sampling_num_workers(param_dict) - if DATA_SAMPLING not in param_dict.keys(): - param_dict[DATA_SAMPLING] = {} - sub_param_dict = param_dict[DATA_SAMPLING] + output[DATA_SAMPLING_PIN_MEMORY] = get_data_sampling_pin_memory(param_dict) output[CURRICULUM_LEARNING] = get_curriculum_learning(sub_param_dict) - + output[DYNAMIC_BATCHING] = get_dynamic_batching(sub_param_dict) return output @@ -73,6 +71,13 @@ def get_data_sampling_num_workers(param_dict): return DATA_SAMPLING_NUM_WORKERS_DEFAULT +def get_data_sampling_pin_memory(param_dict): + if DATA_SAMPLING in param_dict.keys(): + return get_scalar_param(param_dict[DATA_SAMPLING], DATA_SAMPLING_PIN_MEMORY, DATA_SAMPLING_PIN_MEMORY_DEFAULT) + else: + return DATA_SAMPLING_PIN_MEMORY_DEFAULT + + def get_curriculum_learning(param_dict): output = {} output[CURRICULUM_LEARNING_ENABLED] = get_curriculum_learning_enabled(param_dict) @@ -87,6 +92,26 @@ def get_curriculum_learning(param_dict): return output +def get_dynamic_batching(param_dict): + output = copy.copy(param_dict.get(DYNAMIC_BATCHING, {})) + output[DYNAMIC_BATCHING_ENABLED] = bool(output.get(DYNAMIC_BATCHING_ENABLED, DYNAMIC_BATCHING_ENABLED_DEFAULT)) + output[DYNAMIC_BATCHING_LR_SCALING_METHOD] = str( + output.get(DYNAMIC_BATCHING_LR_SCALING_METHOD, DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT)) + output[DYNAMIC_BATCHING_MIN_BATCH_SIZE] = int( + output.get(DYNAMIC_BATCHING_MIN_BATCH_SIZE, DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT)) + output[DYNAMIC_BATCHING_MAX_BATCH_SIZE] = int(output[DYNAMIC_BATCHING_MAX_BATCH_SIZE]) \ + if DYNAMIC_BATCHING_MAX_BATCH_SIZE in output.keys() \ + else DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT + output[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER] = str( + output.get(DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER, DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT)) + if output[DYNAMIC_BATCHING_ENABLED]: + assert DYNAMIC_BATCHING_MAX_TOKENS in output.keys( + ), f"Dynamic batching is enabled, so {DYNAMIC_BATCHING_MAX_TOKENS} must be specified" + output[DYNAMIC_BATCHING_MAX_TOKENS] = int(output[DYNAMIC_BATCHING_MAX_TOKENS]) + output[DYNAMIC_BATCHING_VERBOSE] = bool(output.get(DYNAMIC_BATCHING_VERBOSE, False)) + return output + + def get_curriculum_learning_enabled(param_dict): if CURRICULUM_LEARNING in param_dict.keys(): return get_scalar_param(param_dict[CURRICULUM_LEARNING], CURRICULUM_LEARNING_ENABLED, diff --git a/deepspeed/runtime/data_pipeline/constants.py b/deepspeed/runtime/data_pipeline/constants.py index 1ade640e38d9..73cc69c1f606 100644 --- a/deepspeed/runtime/data_pipeline/constants.py +++ b/deepspeed/runtime/data_pipeline/constants.py @@ -22,6 +22,8 @@ DATA_SAMPLING_NUM_EPOCHS_DEFAULT = 1000 DATA_SAMPLING_NUM_WORKERS = "num_workers" DATA_SAMPLING_NUM_WORKERS_DEFAULT = 0 +DATA_SAMPLING_PIN_MEMORY = "pin_memory" +DATA_SAMPLING_PIN_MEMORY_DEFAULT = False ######################################### # Data efficiency - Data Sampling - Curriculum Learning @@ -62,6 +64,24 @@ CURRICULUM_LEARNING_DATA_CLUSTER_CURRENT_POSITION = "data_cluster_current_position" CURRICULUM_LEARNING_NP_RNG_STATE = "np_rng_state" +######################################### +# Data efficiency - Dynamic batching and LR scaling +######################################### +DYNAMIC_BATCHING = "dynamic_batching" +DYNAMIC_BATCHING_ENABLED = "enabled" +DYNAMIC_BATCHING_ENABLED_DEFAULT = False +DYNAMIC_BATCHING_METRICS_PATH = "metrics_path" +DYNAMIC_BATCHING_LR_SCALING_METHOD = "lr_scaling_method" # "linear" / "sqrt" / "none" +DYNAMIC_BATCHING_LR_SCALING_METHOD_DEFAULT = "linear" +DYNAMIC_BATCHING_MIN_BATCH_SIZE = "min_batch_size" +DYNAMIC_BATCHING_MIN_BATCH_SIZE_DEFAULT = 1 +DYNAMIC_BATCHING_MAX_BATCH_SIZE = "max_batch_size" +DYNAMIC_BATCHING_MAX_BATCH_SIZE_DEFAULT = None +DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER = "sequence_picking_order" # "random" / "seqlen" / "dataloader" +DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER_DEFAULT = "dataloader" # "random" / "seqlen" / "dataloader" +DYNAMIC_BATCHING_MAX_TOKENS = "max_tokens" +DYNAMIC_BATCHING_VERBOSE = "verbose" + ######################################### # Curriculum Learning legacy implementation ######################################### diff --git a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py index 93d351169834..f82c684ec6e2 100644 --- a/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py +++ b/deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py @@ -862,8 +862,13 @@ def test_compare_both_data_analyzers(dataset): for path in output_paths: with open(os.path.join(da.save_path, path), 'rb') as f1, \ open(os.path.join(dda.save_path, path), 'rb') as f2: - if f1.read() != f2.read(): + # if files have suffix .bin, they should be identical + if path.endswith(".bin"): + assert f1.read() == f2.read(), f"files {path} are not identical." + elif f1.read() != f2.read(): print(f"files {path} are not identical.") + dist.barrier() + dist.destroy_process_group() if __name__ == "__main__": diff --git a/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py b/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py new file mode 100644 index 000000000000..c9a39bbc53b5 --- /dev/null +++ b/deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py @@ -0,0 +1,492 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# support/questions/maintenance: github user @brunomaga or @deepspeedai/deepspeed + +import random +import torch +import os +import numpy as np +from torch.optim.lr_scheduler import LRScheduler +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from deepspeed.utils import logger +from deepspeed.runtime.pipe.engine import PipelineEngine +from deepspeed.runtime.data_pipeline.constants import * +from deepspeed.runtime.data_pipeline.data_sampling.indexed_dataset import MMapIndexedDataset +from deepspeed.runtime.data_pipeline.data_sampling.data_analyzer import DistributedDataAnalyzer +import pathlib + + +def batch_by_seqlens( + seqlens, + max_tokens, + sequence_ids_per_mb=None, + min_batch_size=1, + max_batch_size=None, + sequence_picking_order="dataloader", + effective_batch_size=1, + required_microbatches_of_same_size=False, + verbose=False, + seed=None, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. + Similar to "Attention is all you need", Section 5.1: + "sequence pairs were batched together by approximate sequence length. Each training batch + contained a set of sequence pairs containing approximately X source tokens and X target tokens" + + Arguments: + - `seqlens`: a list of difficulties (metric values) for every sample in the dataset; + - `max_tokens`: maximum cap in total difficulty in a batch; + - `min_batch_size`: smallest allowed size of a batch; + - `min_batch_size`: largest allowed size of a batch; + - `sequence_picking_order`: order in which to process samples: "dataloader" (default), "random" or "seqlen" (ascending) + - `effective_batch_size`: effective batch size; + - `required_microbatches_of_same_size`: enable if each mini-batch (in a total of `batch_size_multiple` + micro-batches per batch), should have all micro-batches with the same batch size ie the same + number of sequences. + - `verbose`: print debug information; + - `seed`: random seed for reproducibility; + + Returns: + - `microbatch_ids`: list of tuple of batch id and samples ids per microbatch + - `batch_sizes`: the effective batch size of each batch, used for to compute the scaled LR + - `batch_max_seqlens`: the max seqlen across all microbatches in a batch + """ + + assert sequence_picking_order in ["random", "seqlen", "dataloader"] + if sequence_ids_per_mb is None: + metrics = list(zip(seqlens, range(len(seqlens)))) # use all samples + else: + metrics = list(zip(np.array(seqlens)[sequence_ids_per_mb], sequence_ids_per_mb)) + + if sequence_picking_order == 'random': + metric_random = random.Random(seed) + metric_random.shuffle(metrics) + if sequence_picking_order == 'seqlen': + metrics = sorted(metrics) + + # go through metrics, warn user, and filter samples that alone exceed the max batch threshold + long_ids = [idx for val, idx in metrics if val > max_tokens] + if len(long_ids) > 0: + logger.warning(f"Data indices {long_ids} ignored as metrics exceed {max_tokens}.") + logger.info(f"Original dataset length: {len(metrics)}. New dataset length: {len(long_ids)}") + metrics = [m for m in metrics if m[1] not in long_ids] + + def is_microbatch_valid(metrics): + if min_batch_size and len(metrics) < min_batch_size: return False # insufficient sample count + if max_batch_size and len(metrics) > max_batch_size: return False # too many samples + if sum([m[0] for m in metrics]) > max_tokens: return False # exceeds max + return True + + # go through all samples and pack then in microbatches of metric sums below the threshold + # `required_microbatches_of_same_size` means all minibatches in a batch must be of equal size + equal_size_multiple = effective_batch_size if required_microbatches_of_same_size else 1 + microbatches = [] + batch_init = 0 + while batch_init < len(metrics): + + # we iterate over possible effective batch sizes (groups of microbatches of same size) + valid_batch_end = batch_init + for batch_end in range(batch_init + equal_size_multiple, len(metrics), equal_size_multiple): + + # attempt effective batch + batch = metrics[batch_init:batch_end] + + # pick interleaved samples for each microbatch to help with load balancing + # (in the ordered use case), and to replicate what the distributed sampler does. + mbs = [batch[b::equal_size_multiple] for b in range(equal_size_multiple)] + + # if they are all valid micro-batches, keep them until you find longer mbatches, if any + is_batch_valid = all([is_microbatch_valid(mb) for mb in mbs]) + if is_batch_valid: + valid_batch_end = batch_end + + if batch_init == valid_batch_end: break # last batch is not valid (size zero), so we are done + batch = metrics[batch_init:valid_batch_end] + mbs = [batch[b::equal_size_multiple] for b in range(equal_size_multiple)] + batch_init += sum([len(l) for l in mbs]) + microbatches += mbs + + # make sure we give the same number of (micro-)batches to each dataloader by trimming the dataset + assert len(microbatches) >= effective_batch_size, "not enough datapoints to create a single sample per dataloader" + microbatches = microbatches[:len(microbatches) - len(microbatches) % effective_batch_size] + + #compute the effective batch size for each microbatch. + batch_sizes, batch_max_seqlens, microbatch_ids = [], [], [] + for rank in range(0, len(microbatches), effective_batch_size): + batch_id = rank // effective_batch_size + mbs = microbatches[rank:rank + effective_batch_size] + # compute the number of samples (not tokens) in this batch (not microbatch) + n_sequences = sum([len(mb) for mb in mbs]) + # compute the longest sequence (as number of tokens) in this batch (not microbatch) + sequence_ids_per_mb = [[m[1] for m in metrics] for metrics in mbs] + sequence_lens_per_mb = [[m[0] for m in metrics] for metrics in mbs] + batch_max_seqlen = max([max(seqlens) for seqlens in sequence_lens_per_mb]) + batch_and_mb_ids = zip([batch_id] * effective_batch_size, sequence_ids_per_mb) + batch_sizes.append(n_sequences) + batch_max_seqlens.append(batch_max_seqlen) + microbatch_ids += batch_and_mb_ids + if verbose: + n_tokens_per_mb = [sum([m[0] for m in mb]) for mb in mbs] + n_sequences_per_mb = [len(mb) for mb in mbs] + assert all([n <= max_tokens for n in n_tokens_per_mb]), "size of microbatch exceeds max tokens" + logger.info( + f"Batch id {batch_id} contains in total {len(mbs)} microbatches or {n_sequences} sequences. "\ + f"n_sequences per microbatch {n_sequences_per_mb}. "\ + f"n_tokens per microbatch {n_tokens_per_mb}. "\ + f"sequence ids per microbatch: {sequence_ids_per_mb}. "\ + f"sequence lengths per microbatch: {sequence_lens_per_mb}.") + + # return the sample ids of each microbatch, and the batch sizes + assert len(batch_sizes) == len(microbatch_ids) // effective_batch_size + return microbatch_ids, batch_sizes, batch_max_seqlens + + +def scale_lr(base_batch_size, batch_size, base_lr=1, method="linear"): + """ given a reference lr and batch_size, compute the new LR for a given batch size """ + if method == "linear": + # Linear Scaling Rule: "When the minibatch size is multiplied by k, multiply the learning + # rate by k" (Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour, Goyal et al) + return base_lr * batch_size / base_batch_size + if method == "sqrt": + # Square Root scaling: "when multiplying the batch size by k, multiply the learning rate + # by √k, to keep the variance in the gradient expectation constant" + # (A. Krizhevsky. One weird trick for parallelizing convolutional neural networks) + return base_lr * torch.sqrt(batch_size / base_batch_size) + elif method == None or method.upper() == "NONE": + return base_lr + raise ValueError("Unknown scaling method: {}".format(method)) + + +def dataloader_for_variable_batch_size( + dataset, + microbatch_ids, + batch_max_seqlens, + dataloader_rank=0, + dataloader_batch_size=1, + dataloader_num_replicas=1, + dataloader_collate_fn=None, + dataloader_num_workers=2, + dataloader_pin_memory=False, + required_microbatches_of_same_seqlen=False, + sample_padding_fn=None, +): + + # equidistantly distribute the microbatches across the replicas in an interleaved fashion. + sampler = DistributedSampler( + dataset=microbatch_ids, + num_replicas=dataloader_num_replicas, + rank=dataloader_rank, + shuffle=False, + drop_last=False, + ) + + # collate function wraps user-defined collate function to the variable batch data + def collate_fn_wrapper(list_microbatch_ids): + # each batch is a list of sample ids that fill up to the max tokens per batch + # we return the collated batch of all dataset samples of all input batches. + batch = [] + for batch_id, microbatch_ids in list_microbatch_ids: + batch_data = [dataset[idx] for idx in microbatch_ids] + if required_microbatches_of_same_seqlen: + assert sample_padding_fn is not None, \ + "padding dataloader_padding_fn must be provided if required_microbatches_of_same_seqlen is True" + max_seqlen = batch_max_seqlens[batch_id] + assert all([len(sample) <= max_seqlen for sample in batch_data]), \ + "some samples are longer than the computed max seqlen for the batch those samples belong to" + batch_data = [sample_padding_fn(sample, max_seqlen) for sample in batch_data] + batch += batch_data + return dataloader_collate_fn(batch) if dataloader_collate_fn else batch + + dataloader = DataLoader( + dataset=microbatch_ids, + batch_size=dataloader_batch_size, + sampler=sampler, + num_workers=dataloader_num_workers, + collate_fn=collate_fn_wrapper, + pin_memory=dataloader_pin_memory, + ) + + deepspeed_io_kwargs = dict( + dataset=microbatch_ids, + batch_size=dataloader_batch_size, + pin_memory=dataloader_pin_memory, + data_sampler=sampler, + collate_fn=collate_fn_wrapper, + num_local_io_workers=dataloader_num_workers, + ) + + return dataloader, deepspeed_io_kwargs + + +class VariableBatchSizeLR(LRScheduler): + """ an LR scheduler that scales the LR of a given scheduler's LR """ + + @property + def optimizer(self): + return self.base_lr_scheduler.optimizer + + def __init__(self, + lr_scheduler, + base_batch_size, + batch_sizes, + dataloader, + lr_scaling_method="linear", + last_epoch=-1, + verbose=False): + self.batch_sizes = batch_sizes + self.base_batch_size = base_batch_size + self.lr_scaling_method = lr_scaling_method + self.dataloader = dataloader + self.base_lr_scheduler = lr_scheduler + # the following exist in LRScheduler but not in DeepSpeed's LRScheduler so we redefine them here + self.base_lrs = self.base_lr_scheduler.get_lr() + self.last_epoch = last_epoch + self.verbose = verbose + self.step(0) # scale LR for first sample in the dataloader + + def state_dict(self): + return { + 'base_lr_scheduler': self.base_lr_scheduler.state_dict() + } | { + 'base_batch_size': self.base_batch_size, + 'lr_scaling_method': self.lr_scaling_method, + 'batch_sizes': self.batch_sizes, + 'base_lrs': self.base_lrs, + 'last_epoch': self.last_epoch, + 'verbose': self.verbose, + } + + def load_state_dict(self, state_dict): + self.base_lr_scheduler.load_state_dict(state_dict['base_lr_scheduler']) + self.base_batch_size = state_dict['base_batch_size'] + self.lr_scaling_method = state_dict['lr_scaling_method'] + self.batch_sizes = state_dict['batch_sizes'] + self.base_lrs = state_dict['base_lrs'] + self.last_epoch = state_dict['last_epoch'] + self.verbose = state_dict['verbose'] + + def get_last_lr(self): + return self.base_lr_scheduler._last_lr + + def get_lr(self): + return [group['lr'] for group in self.base_lr_scheduler.optimizer.param_groups] + + def step(self, epoch=None): + # call the base scheduler's step method to get LR for next epoch + # Note: optimizer.step precedes lr_scheduler.step(), so the stepping workflow is: + # init: lr_scheduler.step(0) --> set LR for epoch 0 + # epoch 0: optimizer.step(); lr_scheduler.step(1) --> set LR for epoch 1 + # epoch 1: optimizer.step(); lr_scheduler.step(2) --> set LR for epoch 2 + + # reset unscaled LRs (to the original scheduler's one) to be able to step the base LR scheduler + # Note: epoch==0: reset LR scheduler; epoch==None: scale LR for next epoch; + unscaled_lrs = self.base_lrs if epoch == 0 else self.get_last_lr() + for group, lr in zip(self.base_lr_scheduler.optimizer.param_groups, unscaled_lrs): + group['lr'] = lr + + self.base_lr_scheduler.step(epoch) # set unscaled lr, _step_count, last_epoch, _last_lr for new epoch + + # scale the learning rate for the the next iteration for each parameter group. + self.last_epoch = self.last_epoch + 1 if epoch is None else epoch + # batch sizes are precomputed and stored in batch_sizes se we loop around to get the next one + batch_size = self.batch_sizes[self.last_epoch % len(self.batch_sizes)] + for group in self.base_lr_scheduler.optimizer.param_groups: + group['lr'] = scale_lr(self.base_batch_size, batch_size, group['lr'], self.lr_scaling_method) + + if self.verbose: + logger.info( + f"Next batch id {self.last_epoch}. "\ + f"Reference batch_size {self.base_batch_size} and lr {unscaled_lrs}. "\ + f"Scaled batch_size {batch_size} and lr {self.get_lr()}.") + + +def lr_scheduler_for_variable_batch_size(base_batch_size, + batch_sizes, + dataloader, + lr_scheduler_or_optimizer, + lr_scaling_method='linear', + verbose=False): + """ + returns a class that provides an LR scheduler that scales the learning rate at every + iteration taking into account the batch size of that iteration. + If learning rate is constant, ie no LR scheduler, then the base LR will be taken from the + constant LR values in the optimizer param groups. Otherwise from the scheduler's LR. + + Arguments: + - `base_batch_size`: the batch size that the base LR in the optimizer or scheduler refers to; + - `lr_scaling_method`: method to use to scale LR - see `scale_lr()`; + - `lr_scheduler_or_optimizer`: one instance of `LRScheduler` or `Optimizer` to be used as base; + - `batch_sizes`: the effective batch size of each batch in the dataloader; + + Returns the new LRScheduler + """ + + class StubLRScheduler(LRScheduler): + """ a stub LR scheduler that does not change the LR, keeps it constant """ + + def get_lr(self) -> float: + return self.base_lrs + + if isinstance(lr_scheduler_or_optimizer, Optimizer): + lr_scheduler = StubLRScheduler(lr_scheduler_or_optimizer) + elif hasattr(lr_scheduler_or_optimizer, 'optimizer'): #LRScheduler or DeepSpeed 'object' schedulers + assert isinstance(lr_scheduler_or_optimizer.optimizer, Optimizer) + lr_scheduler = lr_scheduler_or_optimizer + else: + raise ValueError("Unknown type for lr_scheduler_or_optimizer: {}".format(type(lr_scheduler_or_optimizer))) + + return VariableBatchSizeLR(lr_scheduler=lr_scheduler, + base_batch_size=base_batch_size, + batch_sizes=batch_sizes, + dataloader=dataloader, + lr_scaling_method=lr_scaling_method, + verbose=verbose) + + +def get_dataloader_and_lr_scheduler_for_variable_batch_size_deepspeed(dataset, + engine, + dataset_seqlens=None, + dataset_filter_ids=None, + dataloader_collate_fn=None, + sample_padding_fn=None, + batch_seqlens_fn=None): + """ + a simplified call to get_dataloader_and_lr_scheduler_for_variable_batch_size for the deepspeed runtime. + Needs the seqlens of every sample. It will try three alternatives: + - if `dataset_seqlens` is provided by user, use that. + - otherwise, looks for the seqlen metric path (in the connfig) that contains the output of the Data Analyzer + - otherwise, use the user-provided function `batch_seqlens_fn` and call Data Analyzer to output seqlen metric + See `batch_by_seqlens()` for arguments and more documentation. + """ + data_efficiency_config = engine._config.data_efficiency_config + data_sampling_config = data_efficiency_config[DATA_SAMPLING] + batching_config = data_sampling_config[DYNAMIC_BATCHING] + assert batching_config[DYNAMIC_BATCHING_ENABLED], "Dynamic batching is not enabled in the config" + + if dataset_seqlens is None: + # In seqlen provided by user, look for the seqlen metric that was output by the Data Analyzer + # (see the main in deepspeed/runtime/data_pipeline/data_sampling/data_analyzer.py for an example) + metrics_path = batching_config[DYNAMIC_BATCHING_METRICS_PATH] + sample_to_seqlen_path = os.path.join(metrics_path, "seqlen/seqlen_sample_to_metric") + if not (os.path.exists(f"{sample_to_seqlen_path}.bin") and os.path.exists(f"{sample_to_seqlen_path}.idx")): + # if the metric files are not found, we run the DataAnalyzer to write the metric files + msg = f"Cannot find metric files for sequence length in {sample_to_seqlen_path}.idx or *.bin." + msg += " We will run data analyzer to generated them..." + logger.warning(msg) + + if batch_seqlens_fn is None: + raise ValueError("sample_seqlen_fn must be provided if dataset_seqlens is not provided") + + DistributedDataAnalyzer( + dataset=dataset, + metric_functions=[batch_seqlens_fn], + collate_fn=dataloader_collate_fn, + batch_size=2**10, # batch size for map-reduce, not training + num_workers=engine.world_size, + worker_id=engine.global_rank, + save_path=pathlib.Path(metrics_path), + metric_types=['single_value_per_sample'], + metric_names=["seqlen"], + device=engine.device, + ).run_map_reduce() + + dataset_seqlens = MMapIndexedDataset(sample_to_seqlen_path, skip_warmup=True) + assert len(dataset_seqlens) == len(dataset), \ + "Seqlens size does not match the input dataset size. If you changed the dataset, delete the metrics_path folder." + + # TODO we are copying all seqlens into memory, we should adapt the code to use an iterative streamer + # and use the other files output by DataAnalyzer that returns an ordered dictionary of seqlen to sample ids + dataset_seqlens = np.array(list(dataset_seqlens), dtype=np.int64).flatten() # from Nx1 to N + + dataloader, lr_scheduler, deepspeed_io_kwargs = get_dataloader_and_lr_scheduler_for_variable_batch_size( + dataset=dataset, + dataset_filter_ids=dataset_filter_ids, + dataset_seqlens=dataset_seqlens, + effective_batch_size=engine.train_batch_size(), + max_tokens=batching_config[DYNAMIC_BATCHING_MAX_TOKENS], + lr_scaling_method=batching_config[DYNAMIC_BATCHING_LR_SCALING_METHOD], + sequence_picking_order=batching_config[DYNAMIC_BATCHING_SEQUENCE_PICKING_ORDER], + min_batch_size=batching_config[DYNAMIC_BATCHING_MIN_BATCH_SIZE], + max_batch_size=batching_config[DYNAMIC_BATCHING_MAX_BATCH_SIZE], + dataloader_batch_size=engine.train_micro_batch_size_per_gpu(), + dataloader_rank=engine.data_parallel_group.rank(), + dataloader_num_replicas=engine.data_parallel_group.size(), + dataloader_num_workers=data_sampling_config[DATA_SAMPLING_NUM_WORKERS], + dataloader_collate_fn=dataloader_collate_fn, + dataloader_pin_memory=data_sampling_config[DATA_SAMPLING_PIN_MEMORY], + sample_padding_fn=sample_padding_fn, + lr_scheduler_or_optimizer=engine.lr_scheduler or engine.optimizer, + required_microbatches_of_same_size=isinstance(engine, PipelineEngine), + required_microbatches_of_same_seqlen=isinstance(engine, PipelineEngine), + verbose=batching_config[DYNAMIC_BATCHING_VERBOSE], + seed=data_efficiency_config[DATA_EFFICIENCY_SEED], + ) + return dataloader, lr_scheduler, deepspeed_io_kwargs + + +def get_dataloader_and_lr_scheduler_for_variable_batch_size( + dataset, + dataset_seqlens, + max_tokens, + effective_batch_size, + dataset_filter_ids=None, + lr_scaling_method="linear", + min_batch_size=1, + max_batch_size=None, + sequence_picking_order="dataloader", + dataloader_batch_size=1, + dataloader_rank=0, + dataloader_num_replicas=1, + dataloader_num_workers=0, + dataloader_collate_fn=None, + dataloader_pin_memory=False, + lr_scheduler_or_optimizer=None, + required_microbatches_of_same_size=False, + required_microbatches_of_same_seqlen=False, + sample_padding_fn=None, + verbose=False, + seed=None, +): + """ returns a dataloader and LR scheduler for the variable batch size. see `batch_by_seqlens()` for details. """ + + # effective_batch_size = train_micro_batch_size_per_gpu * gradient_accumulation_steps * number of dataloaders + microbatch_ids, batch_sizes, batch_max_seqlens = batch_by_seqlens( + seqlens=dataset_seqlens, + max_tokens=max_tokens, + sequence_ids_per_mb=dataset_filter_ids, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + sequence_picking_order=sequence_picking_order, + effective_batch_size=effective_batch_size, + required_microbatches_of_same_size=required_microbatches_of_same_size, + verbose=verbose, + seed=seed, + ) + + dataloader, deepspeed_io_kwargs = dataloader_for_variable_batch_size( + dataset=dataset, + microbatch_ids=microbatch_ids, + batch_max_seqlens=batch_max_seqlens, + dataloader_rank=dataloader_rank, + dataloader_num_replicas=dataloader_num_replicas, + dataloader_batch_size=dataloader_batch_size, + dataloader_collate_fn=dataloader_collate_fn, + dataloader_num_workers=dataloader_num_workers, + dataloader_pin_memory=dataloader_pin_memory, + required_microbatches_of_same_seqlen=required_microbatches_of_same_seqlen, + sample_padding_fn=sample_padding_fn, + ) + + lr_scheduler = lr_scheduler_for_variable_batch_size(base_batch_size=effective_batch_size, + batch_sizes=batch_sizes, + lr_scaling_method=lr_scaling_method, + lr_scheduler_or_optimizer=lr_scheduler_or_optimizer, + dataloader=dataloader, + verbose=verbose) + + return dataloader, lr_scheduler, deepspeed_io_kwargs diff --git a/deepspeed/runtime/domino/async_linear.py b/deepspeed/runtime/domino/async_linear.py new file mode 100644 index 000000000000..8e01da500409 --- /dev/null +++ b/deepspeed/runtime/domino/async_linear.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/23.08/megatron/core/tensor_parallel/layers.py + +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +from deepspeed.accelerator import get_accelerator +import deepspeed.comm as dist +from typing import Callable + +TP_group = None + + +class DominoAsyncColumnParallelLinearImpl(torch.autograd.Function): + + @staticmethod + def forward(ctx, inp, weight, bias, handle_dic, h_id): # inp: (b, s, k), weight: (m, k), bias (m) + ctx.save_for_backward(inp, weight, bias) + ctx.handle_dic = handle_dic + ctx.h_id = h_id + output = torch.matmul(inp, weight.t()) # (b, s, k) @ (k, m) -> (b, s, m) + if bias is not None: # bias (m) + output = output + bias + return output + + @staticmethod + def backward(ctx, grad_output): + inp, weight, bias = ctx.saved_tensors + grad_input = grad_weight = grad_bias = None + grad_input = torch.matmul(grad_output, weight) # (b, s, m) @ (m, k) -> (b, s, k) + handle = dist.all_reduce(grad_input, group=TP_group, async_op=True) + ctx.handle_dic[ctx.h_id] = handle + grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]) # (b*s, m) + + inp = inp.view(inp.shape[0] * inp.shape[1], inp.shape[2]) # (b*s, k) + grad_weight = torch.matmul(grad_output.t(), inp) # (m, b*s) @ (b*s, k) -> (m, k) + + if bias is not None: + grad_bias = grad_output.sum(dim=0) # (b*s, m) -> (m) + return grad_input, grad_weight, grad_bias, None, None + + +class DominoAsyncColumnParallelLinear(torch.nn.Module): + + def __init__(self, + input_size, + output_size, + _tp_group, + config, + init_method: Callable, + bias=True, + skip_bias_add=False): + super(DominoAsyncColumnParallelLinear, self).__init__() + + self.skip_bias_add = skip_bias_add + + global TP_group + if TP_group == None: + TP_group = _tp_group + + self.weight = Parameter( + torch.empty( + output_size, + input_size, + device=get_accelerator().current_device_name(), + dtype=config.params_dtype, + )) + if config.perform_initialization: + init_method(self.weight) + + if bias: + self.bias = Parameter( + torch.empty(output_size, device=get_accelerator().current_device_name(), dtype=config.params_dtype)) + + if config.perform_initialization: + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_: torch.Tensor, handle_dic, h_id): + + bias = self.bias if not self.skip_bias_add else None + + output = DominoAsyncColumnParallelLinearImpl.apply(input_, self.weight, bias, handle_dic, h_id) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinearNoComm(torch.nn.Module): + + def __init__( + self, + input_size: int, + output_size: int, + config, + init_method: Callable, + bias: bool = True, + stride: int = 1, + skip_bias_add: bool = False, + ): + super(RowParallelLinearNoComm, self).__init__() + + self.skip_bias_add = skip_bias_add + + self.weight = Parameter( + torch.empty( + output_size, + input_size, + device=get_accelerator().current_device_name(), + dtype=config.params_dtype, + )) + if config.perform_initialization: + init_method(self.weight) + if bias: + self.bias = Parameter( + torch.empty( + output_size, + device=get_accelerator().current_device_name(), + dtype=config.params_dtype, + )) + + if config.perform_initialization: + with torch.no_grad(): + self.bias.zero_() + else: + self.register_parameter('bias', None) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + output = F.linear(input_, self.weight, bias) + + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 88c5494c8147..3dfb133373b5 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -5,24 +5,10 @@ import torch import torch.nn.functional as F -from torch.nn.parameter import Parameter +import enum import deepspeed.comm as dist -from deepspeed.accelerator import get_accelerator - -def is_rank_0(): - if dist.get_rank() == 0: - return True - - -class DominoModule(torch.nn.Module): - """extensions of torch Module.""" - - def __init__(self, ): - super(DominoModule, self).__init__() - - -import enum +from .async_linear import DominoAsyncColumnParallelLinear, RowParallelLinearNoComm class LayerType(enum.Enum): @@ -45,10 +31,23 @@ class ModelType(enum.Enum): encoder_and_decoder = 2 -handle_dic = {} +class DominoUtil: + + BATCH_0 = "BATCH0" + BATCH_1 = "BATCH1" -def no_oper(input_, dic_, h_id): + HANDLE_DIC = {"BATCH0": None, "BATCH1": None} + + +class DominoModule(torch.nn.Module): + """extensions of torch Module.""" + + def __init__(self, ): + super(DominoModule, self).__init__() + + +def _Wait_bwd_comm(input_, dic_, h_id): return NoOper.apply(input_, dic_, h_id) @@ -71,55 +70,27 @@ def backward(ctx, grad_output): return grad_output, None, None -def copy_to_tensor_model_parallel_region_a(mpu, input_, dic_, h_id): - return _CopyToModelParallelRegionA.apply(mpu, input_, dic_, h_id) - - -class _CopyToModelParallelRegionA(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def symbolic(graph, mpu, input_, handle_dic, h_id): - return input_ - - @staticmethod - def forward(ctx, mpu, input_, handle_dic, h_id): - ctx.mpu = mpu - ctx.handle_dic = handle_dic - ctx.h_id = h_id - return input_ - - @staticmethod - def backward(ctx, grad_output): - # Bypass the function if we are using only 1 GPU. - if ctx.mpu.get_tensor_model_parallel_world_size() == 1: - return grad_output - - # Async All-reduce. - handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) - ctx.handle_dic[ctx.h_id] = handle - return None, grad_output, None, None - - class CoreAttention(DominoModule): - def __init__(self, config, layer_number, mpu, attn_mask_type=AttnMaskType.causal): + def __init__(self, config, tp_world_size, attn_mask_type=AttnMaskType.causal): super(CoreAttention, self).__init__() - self.layer_number = max(1, layer_number) - self.att_dropout_p = config.attention_dropout - self.is_causal = True + self.attn_mask_type = attn_mask_type + projection_size = config.kv_channels * config.num_attention_heads - world_size = mpu.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = projection_size // world_size + + # Per attention head and per partition values. + assert projection_size % tp_world_size == 0, f"projection size {projection_size} should be multiple of TP world size {tp_world_size}" + self.hidden_size_per_partition = projection_size // tp_world_size + self.attention_dropout_rate = config.attention_dropout def forward(self, query_layer, key_layer, value_layer, attention_mask): - # attn_mask is None when is_causal=True + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attn_mask=None, - dropout_p=self.att_dropout_p, + dropout_p=self.attention_dropout_rate, is_causal=True, scale=None) @@ -136,19 +107,20 @@ def forward(self, query_layer, key_layer, value_layer, attention_mask): class ShardedAttention(DominoModule): """Sharded self-attention layer class. - Only support self attention and causal attention mask + Only support self attention and causal attention mask for now. """ def __init__(self, config, - layer_number, mpu, - ColumnParallelLinear, - RowParallelLinearNoComm, apply_rotary_pos_emb, + layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.causal): super(ShardedAttention, self).__init__() + + assert attention_type == AttnType.self_attn, "Only support self_attn for now!" + self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type @@ -158,56 +130,54 @@ def __init__(self, query_projection_size = config.kv_channels * config.num_attention_heads kv_projection_size = config.kv_channels * config.num_attention_heads - # Per attention head and per partition values. - world_size = mpu.get_tensor_model_parallel_world_size() + tp_world_size = mpu.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads // world_size + self.num_attention_heads_per_partition = config.num_attention_heads // tp_world_size + + qkv_projection_per_partition = (query_projection_size + 2 * kv_projection_size) // tp_world_size - self.query_key_value = ColumnParallelLinear(config.hidden_size, - query_projection_size + 2 * kv_projection_size, - config=config, - init_method=config.init_method, - bias=config.add_bias_linear, - gather_output=False) + self.query_key_value = DominoAsyncColumnParallelLinear(config.hidden_size, + qkv_projection_per_partition, + mpu.get_tensor_model_parallel_group(), + config=config, + init_method=config.init_method, + bias=config.add_bias_linear) - self.core_attention = CoreAttention(config, self.layer_number, mpu, self.attn_mask_type) + self.core_attention = CoreAttention(config, tp_world_size, self.attn_mask_type) - self.dense = RowParallelLinearNoComm(query_projection_size, + query_projection_size_per_partition = query_projection_size // tp_world_size + + # Output. + self.dense = RowParallelLinearNoComm(query_projection_size_per_partition, config.hidden_size, config=config, init_method=config.output_layer_init_method, bias=config.add_bias_linear, - input_is_parallel=True, skip_bias_add=True) - def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): - # hidden_states: [s, b, h] + def forward(self, hidden_states, attention_mask, micro_batch_num, rotary_pos_emb=None): + # hidden_states: [sq, b, h] - # Query, Key, and Value - # Attention heads [s, b, h] --> [s, b, np * 3 * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) + mixed_x_layer, _ = self.query_key_value(hidden_states, DominoUtil.HANDLE_DIC, micro_batch_num) - # [s, b, np * 3 * hn] --> [s, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - # [s, b, np, 3 * hn] -> [b, np, s, 3*hn] mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() - # [s, b, np, 3 * hn] --> [s, b, np, hn], [s, b, np, hn], [s, b, np, hn] (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [ self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, self.hidden_size_per_attention_head ], dim=3) - # [s, b, np, np * hn] -> [s, b, np, hn] + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) - # apply rotary embedding if rotary_pos_emb is not None: if isinstance(rotary_pos_emb, tuple): rotary_pos_emb = rotary_pos_emb @@ -219,11 +189,63 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - # Output. [s, b, h] output, bias = self.dense(context_layer) - return output, bias + def domino_core_attention_forward(self, mixed_x_layer, attention_mask, rotary_pos_emb=None): + # hidden_states: [sq, b, h] + + # To illustrate the difference between intra-layer overlap and inter-layer overlap + # mixed_x_layer, _ = self.query_key_value(hidden_states, handle_dic, micro_batch_num) + + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous() + + (query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [ + self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head + ], + dim=3) + + query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, + self.hidden_size_per_attention_head) + + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = ((rotary_pos_emb, ) * 2) + q_pos_emb, k_pos_emb = rotary_pos_emb + query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb) + key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb) + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # output, bias = self.dense(context_layer) + # return output, bias + + return context_layer + + +class bias_dropout_add(torch.nn.Module): + + def __init__(self, prob: float): + super(bias_dropout_add, self).__init__() + self.dropout = torch.nn.Dropout(prob) + + def forward(self, x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + if bias is not None: + x = x + bias + out = self.dropout(x) + out = out + residual + return out + class DominoTransformerLayer(DominoModule): """A domino single transformer layer. @@ -232,222 +254,158 @@ class DominoTransformerLayer(DominoModule): def __init__(self, config, - layer_number, mpu, - fused_layer_norm, - _initialize_affine_weight_gpu, - ColumnParallelLinear, - RowParallelLinearNoComm, apply_rotary_pos_emb, - bias_dropout_add_fused_train, - bias_dropout_add_fused_inference, - skip_bias_add=True, + layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.causal, - drop_path_rate=0., - output_bias=None): - super(DominoTransformerLayer, self).__init__() - - if not dist.is_initialized(): - dist.init_distributed() - assert dist.is_initialized(), "deepspeed.comm is not initialized!" + drop_path_rate=0.): - self.llama_model = config.llama_model + super(DominoTransformerLayer, self).__init__() self.layer_number = layer_number self.layer_type = layer_type - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - self.bias_dropout_add_fused_train = bias_dropout_add_fused_train - self.bias_dropout_add_fused_inference = bias_dropout_add_fused_inference - self.mpu = mpu - self.output_bias = output_bias - # Layernorm on the input data. - self.input_layernorm = fused_layer_norm(config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=config.no_persist_layer_norm) + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.llama_model = False + + self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon) # Self attention. self.self_attention = ShardedAttention(config, - layer_number, mpu, - ColumnParallelLinear, - RowParallelLinearNoComm, apply_rotary_pos_emb, + layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type) self.hidden_dropout = config.hidden_dropout - # Layernorm on the attention output - self.post_attention_layernorm = fused_layer_norm(config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=config.no_persist_layer_norm) - - # ------------ init mlp start ------------ - init_method = config.init_method - output_layer_init_method = config.output_layer_init_method - self.add_bias = config.add_bias_linear - self.skip_bias_add = skip_bias_add + self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + # MLP ffn_hidden_size = config.ffn_hidden_size if config.gated_linear_unit: ffn_hidden_size *= 2 + self.output_size_c = config.ffn_hidden_size self.input_size_c = config.hidden_size self.input_size_r = config.ffn_hidden_size self.output_size_r = self.input_size_c - world_size = mpu.get_tensor_model_parallel_world_size() - self.output_size_per_partition = self.output_size_c // world_size - self.input_size_per_partition = self.input_size_r // world_size - - # Initialize weight. - self.weight_c = Parameter( - torch.empty(self.output_size_per_partition, - self.input_size_c, - device=get_accelerator().current_device_name(), - dtype=config.params_dtype)) - self.weight_r = Parameter( - torch.empty(self.output_size_r, - self.input_size_per_partition, - device=get_accelerator().current_device_name(), - dtype=config.params_dtype)) - - if config.perform_initialization: - _initialize_affine_weight_gpu(self.weight_c, init_method, partition_dim=0, stride=1) - - _initialize_affine_weight_gpu(self.weight_r, output_layer_init_method, partition_dim=1, stride=1) - - if self.add_bias: - self.bias_c = Parameter( - torch.empty(self.output_size_per_partition, - device=get_accelerator().current_device_name(), - dtype=config.params_dtype)) - self.bias_r = Parameter( - torch.empty(self.output_size_r, - device=get_accelerator().current_device_name(), - dtype=config.params_dtype)) - if config.perform_initialization: - with torch.no_grad(): - self.bias_c.zero_() - self.bias_r.zero_() - else: - self.register_parameter('bias_c', None) - self.register_parameter('bias_r', None) + tp_world_size = mpu.get_tensor_model_parallel_world_size() + self.TP_group = mpu.get_tensor_model_parallel_group() + self.output_size_per_partition = self.output_size_c // tp_world_size + self.input_size_per_partition = self.input_size_r // tp_world_size - if config.swiglu: + self.linear_fc1 = DominoAsyncColumnParallelLinear(self.input_size_c, + self.output_size_per_partition, + mpu.get_tensor_model_parallel_group(), + config=config, + init_method=config.init_method, + bias=config.add_bias_linear) - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] + self.mlp_activation_func = F.gelu - self.mlp_activation_func = swiglu - else: - self.mlp_activation_func = F.gelu - # ------------ init mlp end ------------ + self.linear_fc2 = RowParallelLinearNoComm(self.input_size_per_partition, + self.output_size_r, + config=config, + init_method=config.output_layer_init_method, + bias=config.add_bias_linear, + skip_bias_add=True) + + self.bias_dropout_add_func = bias_dropout_add(self.hidden_dropout) def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): - # hidden_states: [s, b, h] + hidden_states0, hidden_states1 = hidden_states layernorm_output0 = self.input_layernorm(hidden_states0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + # Micro batch 0: attention + attention_output0, attention_bias0 = self.self_attention(layernorm_output0, + attention_mask, + DominoUtil.BATCH_0, + rotary_pos_emb=rotary_pos_emb) + + fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True) + # End of Micro batch 0: attention + + # Micro batch 1: attention layernorm_output1 = self.input_layernorm(hidden_states1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + + attention_output1, attention_bias1 = self.self_attention(layernorm_output1, + attention_mask, + DominoUtil.BATCH_1, + rotary_pos_emb=rotary_pos_emb) + fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True) - if not self.llama_model: - rotary_pos_emb = None - - attention_output0, attention_bias0 = \ - self.self_attention( - layernorm_output0, - attention_mask, - rotary_pos_emb=rotary_pos_emb) - handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) - - attention_output1, attention_bias1 = \ - self.self_attention( - layernorm_output1, - attention_mask, - rotary_pos_emb=rotary_pos_emb) - handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) - handle0.wait() - - # Residual0 connection. + # Micro batch 0: Residual connection. + fwd_handle0.wait() if self.apply_residual_connection_post_layernorm: residual0 = layernorm_output0 else: residual0 = hidden_states0 - if self.training: - bias_dropout_add_func = self.bias_dropout_add_fused_train - else: - bias_dropout_add_func = self.bias_dropout_add_fused_inference - if attention_bias0 is not None: - attention_bias0 = attention_bias0.expand_as(residual0) - layernorm_input0 = bias_dropout_add_func(attention_output0, attention_bias0, residual0, self.hidden_dropout) + layernorm_input0 = self.bias_dropout_add_func(attention_output0, attention_bias0, residual0) layernorm_output0 = self.post_attention_layernorm(layernorm_input0) - layernorm_output0 = no_oper(layernorm_output0, handle_dic, f'{self.layer_number}_0') + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + if self.apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = layernorm_input0 + # End of Micro batch 0: Residual connection. - # Residual1 connection. + # ------------ MLP ------------ + # Micro batch 0: MLP + output0, _ = self.linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + output0 = self.mlp_activation_func(output0) + + # Micro batch 1: Residual connection. + fwd_handle1.wait() if self.apply_residual_connection_post_layernorm: residual1 = layernorm_output1 else: residual1 = hidden_states1 - if attention_bias1 is not None: - attention_bias1 = attention_bias1.expand_as(residual1) - layernorm_input1 = bias_dropout_add_func(attention_output1, attention_bias1, residual1, self.hidden_dropout) + layernorm_input1 = self.bias_dropout_add_func(attention_output1, attention_bias1, residual1) layernorm_output1 = self.post_attention_layernorm(layernorm_input1) - layernorm_output1 = no_oper(layernorm_output1, handle_dic, f'{self.layer_number}_1') - - # ------------ explicit mlp start ------------ - bias_c = self.bias_c if not self.skip_bias_add else None + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) - input0 = copy_to_tensor_model_parallel_region_a(self.mpu, layernorm_output0, handle_dic, - f'{self.layer_number}_0') - # Batch0 Matrix multiply. - output0 = torch.matmul(input0, self.weight_c.t()) - if bias_c is not None: - output0 = output0 + bias_c - output0 = self.mlp_activation_func(output0) - output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + if self.apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = layernorm_input1 + # End of Micro batch 1: Residual connection. - handle1.wait() + hidden_states0, last_mlp_bias = self.linear_fc2(output0) + fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True) + # End of Micro batch 0: MLP - input1 = copy_to_tensor_model_parallel_region_a(self.mpu, layernorm_output1, handle_dic, - f'{self.layer_number}_1') - # Batch1 Matrix multiply. - output1 = torch.matmul(input1, self.weight_c.t()) + # Micro batch 1: MLP + output1, _ = self.linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) output1 = self.mlp_activation_func(output1) - if bias_c is not None: - output1 = output1 + bias_c - output1 = torch.matmul(output1, self.weight_r.t()) - dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) - handle2.wait() + hidden_states1, last_mlp_bias = self.linear_fc2(output1) - output0 = output0 + self.bias_r if self.bias_r is not None else output0 - output1 = output1 + self.bias_r if self.bias_r is not None else output1 - output_bias = self.output_bias - mlp_output0, mlp_output1, mlp_bias0, mlp_bias1 = output0, output1, output_bias, output_bias - # ------------ explicit mlp end ------------ + fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True) + # End of Micro batch 1: MLP - if self.apply_residual_connection_post_layernorm: - residual0 = layernorm_output0 - residual1 = layernorm_output1 - else: - residual0 = layernorm_input0 - residual1 = layernorm_input1 + # ------------ End of MLP ------------ + + fwd_handle0.wait() + hidden_states0 = self.bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0) - if mlp_bias0 is not None: - mlp_bias0 = mlp_bias0.expand_as(residual0) - mlp_bias1 = mlp_bias1.expand_as(residual1) - output0 = bias_dropout_add_func(mlp_output0, mlp_bias0, residual0, self.hidden_dropout) - output1 = bias_dropout_add_func(mlp_output1, mlp_bias1, residual1, self.hidden_dropout) + fwd_handle1.wait() + hidden_states1 = self.bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1) - return output0, output1 + return hidden_states0, hidden_states1 class DominoTransformer(DominoModule): @@ -455,56 +413,185 @@ class DominoTransformer(DominoModule): def __init__(self, config, - model_type, mpu, - fused_layer_norm, - _initialize_affine_weight_gpu, - ColumnParallelLinear, - RowParallelLinearNoComm, apply_rotary_pos_emb, - bias_dropout_add_fused_train, - bias_dropout_add_fused_inference, + model_type, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.causal, + post_layer_norm=True, pre_process=True, post_process=True, - post_layer_norm=True, drop_path_rate=0.0): super(DominoTransformer, self).__init__() self.layer_type = layer_type self.model_type = model_type - self.post_process = post_process self.post_layer_norm = post_layer_norm - self.num_layers = config.num_layers + self.post_process = post_process + self.input_tensor = None self.drop_path_rate = drop_path_rate + self.TP_group = mpu.get_tensor_model_parallel_group() + + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm failed to initialize!" + + self.num_layers = config.num_layers + self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, config.num_layers)] def build_layer(layer_number): + + current_layer_type = layer_type return DominoTransformerLayer(config, - layer_number, mpu, - fused_layer_norm, - _initialize_affine_weight_gpu, - ColumnParallelLinear, - RowParallelLinearNoComm, apply_rotary_pos_emb, - bias_dropout_add_fused_train, - bias_dropout_add_fused_inference, - layer_type=layer_type, + layer_number, + layer_type=current_layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) if self.post_process and self.post_layer_norm: - self.final_layernorm = fused_layer_norm(config.hidden_size, - eps=config.layernorm_epsilon, - no_persist_layer_norm=config.no_persist_layer_norm) + self.final_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layernorm_epsilon) + + self._forward_impl = self.inter_layer_overlap_forward + if config.domino_intra_layer_overlap: + self._forward_impl = self.intra_layer_overlap_forward def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): + + return self._forward_impl(hidden_states, attention_mask, rotary_pos_emb) + + def inter_layer_overlap_forward(self, hidden_states, attention_mask, rotary_pos_emb=None): # hidden_states: [s, b, h] + hidden_states0, hidden_states1 = torch.chunk(hidden_states, chunks=2, dim=1) + + last_mlp_bias = None + fwd_handle0, fwd_handle1 = None, None + residual0, residual1 = None, None + + layernorm_output0 = self.layers[0].input_layernorm(hidden_states0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + for index in range(self.num_layers): + + # Micro batch 0: attention + attention_output0, _ = self.layers[index].self_attention.query_key_value( + layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + attention_output0 = self.layers[index].self_attention.domino_core_attention_forward( + attention_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) + + # Micro batch 1: Residual connection + if index > 0: + fwd_handle1.wait() + hidden_states1 = self.layers[index - 1].bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1) + + layernorm_output1 = self.layers[index].input_layernorm(hidden_states1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + # End of Micro batch 1: Residual connection + + attention_output0, attention_bias0 = self.layers[index].self_attention.dense(attention_output0) + + fwd_handle0 = dist.all_reduce(attention_output0, group=self.TP_group, async_op=True) + # End of Micro batch 0: attention + + # Micro batch 1: attention + attention_output1, _ = self.layers[index].self_attention.query_key_value( + layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + attention_output1 = self.layers[index].self_attention.domino_core_attention_forward( + attention_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) + + # Micro batch 0: Residual connection. + fwd_handle0.wait() + if self.layers[index].apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = hidden_states0 + + layernorm_input0 = self.layers[index].bias_dropout_add_func(attention_output0, attention_bias0, residual0) + + layernorm_output0 = self.layers[index].post_attention_layernorm(layernorm_input0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + + if self.layers[index].apply_residual_connection_post_layernorm: + residual0 = layernorm_output0 + else: + residual0 = layernorm_input0 + # End of Micro batch 0: Residual connection. + + attention_output1, attention_bias1 = self.layers[index].self_attention.dense(attention_output1) + fwd_handle1 = dist.all_reduce(attention_output1, group=self.TP_group, async_op=True) + # End of Micro batch 1: attention + + # ------------ MLP ------------ + # Micro batch 0: MLP + output0, _ = self.layers[index].linear_fc1(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + output0 = self.layers[index].mlp_activation_func(output0) + + # Micro batch 1: Residual connection. + fwd_handle1.wait() + if self.layers[index].apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = hidden_states1 + + layernorm_input1 = self.layers[index].bias_dropout_add_func(attention_output1, attention_bias1, residual1) + + layernorm_output1 = self.layers[index].post_attention_layernorm(layernorm_input1) + layernorm_output1 = _Wait_bwd_comm(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + + if self.layers[index].apply_residual_connection_post_layernorm: + residual1 = layernorm_output1 + else: + residual1 = layernorm_input1 + # End of Micro batch 1: Residual connection. + + hidden_states0, last_mlp_bias = self.layers[index].linear_fc2(output0) + fwd_handle0 = dist.all_reduce(hidden_states0, group=self.TP_group, async_op=True) + # End of Micro batch 0: MLP + + # Micro batch 1: MLP + output1, _ = self.layers[index].linear_fc1(layernorm_output1, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_1) + output1 = self.layers[index].mlp_activation_func(output1) + + # Micro batch 0: Residual connection. + fwd_handle0.wait() + hidden_states0 = self.layers[index].bias_dropout_add_func(hidden_states0, last_mlp_bias, residual0) + + if index < self.num_layers - 1: + layernorm_output0 = self.layers[index + 1].input_layernorm(hidden_states0) + layernorm_output0 = _Wait_bwd_comm(layernorm_output0, DominoUtil.HANDLE_DIC, DominoUtil.BATCH_0) + # End of Micro batch 0: Residual connection. + + hidden_states1, last_mlp_bias = self.layers[index].linear_fc2(output1) + + fwd_handle1 = dist.all_reduce(hidden_states1, group=self.TP_group, async_op=True) + # End of Micro batch 1: MLP + + # ------------ End of MLP ------------ + + if self.post_process and self.post_layer_norm: + hidden_states0 = self.final_layernorm(hidden_states0) + + index = self.num_layers - 1 + + fwd_handle1.wait() + hidden_states1 = self.layers[index].bias_dropout_add_func(hidden_states1, last_mlp_bias, residual1) + + if self.post_process and self.post_layer_norm: + hidden_states1 = self.final_layernorm(hidden_states1) + + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + + return hidden_states + + def intra_layer_overlap_forward(self, hidden_states, attention_mask, rotary_pos_emb=None): + + hidden_states = torch.chunk(hidden_states, chunks=2, dim=1) + for index in range(self.num_layers): layer = self.layers[index] hidden_states = layer(hidden_states, attention_mask, rotary_pos_emb) @@ -514,4 +601,5 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): hidden_states0 = self.final_layernorm(hidden_states0) hidden_states1 = self.final_layernorm(hidden_states1) - return hidden_states0, hidden_states1 + hidden_states = torch.cat([hidden_states0, hidden_states1], dim=1) + return hidden_states diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index df6d286494de..428fc0baf43a 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -37,8 +37,7 @@ from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.linear.optimized_linear import LoRAOptimizedLinear -from deepspeed.module_inject.layers import GatherReplacedLayerParams - +from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ @@ -108,6 +107,12 @@ from deepspeed.runtime.config import DtypeEnum +from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue +from deepspeed.compile.backend import register_compile_pass, opt_passes +from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states +from deepspeed.compile.init_z1 import init_z1 +from deepspeed.compile.init_z3 import init_z3 + MEMORY_OPT_ALLREDUCE_SIZE = 500000000 DeepSpeedOptimizerCallable = \ @@ -248,7 +253,7 @@ def __init__(self, self._configure_with_arguments(args, mpu) self._do_sanity_check() if self.autotp_size() > 1: - self._configure_tensor_parallel_states(model) + self._configure_tensor_parallel(model, self.tensor_parallel_config()) see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) if mpu is not None: if self.elasticity_enabled(): @@ -272,6 +277,10 @@ def __init__(self, # Configure distributed model self._configure_distributed_model(model) + if not self.is_deepcompile_enabled(): + self.module_forward_pre_hook = self._create_module_forward_pre_hook() + self.module_forward_post_hook = self._create_module_forward_post_hook() + # needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict self.param_names = {param: name for name, param in model.named_parameters()} @@ -375,6 +384,12 @@ def __init__(self, self.unflatten = _unflatten_dense_tensors self._is_compiled = False + if is_deepcompile_supported(): + # Predefined compile passes + self.register_compile_pass(zero3_compile.NAME, zero3_compile.add_z3_gather_release) + self.register_compile_pass(prefetch.NAME, prefetch.schedule_prefetch) + self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather) + self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states) def _optimized_linear_offload_setup(self): self.optimized_linear_base_weight_sharding = False @@ -413,6 +428,10 @@ def _optimized_linear_offload_setup(self): else: p.ds_offload = False + def _configure_tensor_parallel(self, model, tp_config): + self._configure_tensor_parallel_states(model) + configure_tensor_parallel_runtime(tp_config) + def _configure_tensor_parallel_states(self, model): """ Configures the tensor parallel states for the model. @@ -420,7 +439,6 @@ def _configure_tensor_parallel_states(self, model): and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks. """ self._set_client_model(model) - # sanity check # currently, the compatibility between 'autotp' and 'zero > 1' has not been validated assert self.zero_optimization_stage( @@ -481,6 +499,8 @@ def broadcast_and_check(args, bcast_rank, bcast_group): def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() + if self.is_deepcompile_enabled(): + get_deepcompile_handle().cleanup() debug_clear_module_and_param_names() def _get_model_parameters(self): @@ -899,6 +919,9 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def tensor_parallel_config(self): + return self._config.tensor_parallel_config + def autotp_size(self): return self._config.tensor_parallel_config.autotp_size @@ -1875,7 +1898,6 @@ def deepspeed_io(self, GLOBAL_RANK: self.global_rank, DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS] } - return DeepSpeedDataLoader(dataset=dataset, batch_size=batch_size, pin_memory=pin_memory, @@ -1922,17 +1944,24 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): return scaled_loss - @instrument_w_nvtx - def forward(self, *inputs, **kwargs): - r"""Execute forward propagation - Arguments: - *inputs: Variable length input list - **kwargs: variable length keyword arguments - """ + def _create_module_forward_pre_hook(self): - if self.autotuning_profile_model_info(): - ma = get_ma_status() - else: + def _module_forward_pre_hook(module, inputs, kwargs): + return self._forward_prologue(inputs, kwargs) + + return self.module.register_forward_pre_hook(_module_forward_pre_hook, prepend=False, with_kwargs=True) + + def _create_module_forward_post_hook(self): + + def _module_forward_post_hook(module, input, output): + self._forward_epilogue() + + return self.module.register_forward_hook(_module_forward_post_hook) + + def _forward_prologue(self, inputs, kwargs): + return_modified = False + + if not self.autotuning_profile_model_info(): see_memory_usage("Engine before forward", force=self.memory_breakdown()) flops_profiler_active = (self.flops_profiler_enabled() @@ -1951,41 +1980,47 @@ def forward(self, *inputs, **kwargs): self.eigenvalue_enabled(), None, ) + return_modified = True if flops_profiler_active: self.flops_profiler.start_profile(ignore_list=None) - if self.module.training: - if self.progressive_layer_drop: - kwargs.update(self.progressive_layer_drop.get_state()) + if kwargs is not None: + if self.module.training: + if self.progressive_layer_drop: + kwargs.update(self.progressive_layer_drop.get_state()) - if self.__class__.__name__ != "PipelineEngine": - # TODO: The above if condition is a HACK since for PipelineEngine - # it's difficult to inject argument in forward pass. - if self.module.training and self.curriculum_enabled_legacy(): - self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) - if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": - kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) + if self.__class__.__name__ != "PipelineEngine": + # TODO: The above if condition is a HACK since for PipelineEngine + # it's difficult to inject argument in forward pass. + if self.module.training and self.curriculum_enabled_legacy(): + self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) + if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": + kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) + return_modified = True if self.module.training and self.random_ltd_enabled(): self.random_ltd_scheduler.update_seq(self.global_steps) + if self.training_dataloader is None: + self.tput_timer.start() + + self._start_timers(self.engine_timers.forward_timers) + if self.zero_optimization_partition_weights(): # Enable automated discovery of external parameters by indicating that # we are in a forward pass. for module in self.module.modules(): module._parameters._in_forward = True - self._start_timers(self.engine_timers.forward_timers) - - if self.training_dataloader is None: - self.tput_timer.start() - if self.fp16_auto_cast(): inputs = self._cast_inputs_half(inputs) + return_modified = True - loss = self.module(*inputs, **kwargs) + if return_modified: + return inputs, kwargs + def _forward_epilogue(self): if self.zero_optimization_partition_weights(): # Disable automated discovery of external parameters for module in self.module.modules(): @@ -1993,16 +2028,37 @@ def forward(self, *inputs, **kwargs): self._stop_timers(self.engine_timers.forward_timers) + flops_profiler_active = (self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) + if flops_profiler_active: self.flops_profiler.stop_profile() + if not self.autotuning_profile_model_info(): + see_memory_usage("Engine after forward", force=self.memory_breakdown()) + + @instrument_w_nvtx + def forward(self, *inputs, **kwargs): + r"""Execute forward propagation + Arguments: + *inputs: Variable length input list + **kwargs: variable length keyword arguments + """ + if self.autotuning_profile_model_info(): + ma = get_ma_status() + + if self.is_deepcompile_enabled() and hasattr(self, "launch_compile_passes"): + # We can't have this in forward prologue as the compiler compiles hooks including the forward prologue. + self.launch_compile_passes(self.global_steps) + + loss = self.module(*inputs, **kwargs) + if self.autotuning_profile_model_info(): activation_mem = get_ma_status() - ma self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path()) exit() - else: - see_memory_usage("Engine after forward", force=self.memory_breakdown()) + return loss def _cast_inputs_half(self, inputs): @@ -2061,43 +2117,14 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): grads = None self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) - @contextmanager - def no_sync(self): - r""" - Context manager to disable gradient reduction during backward pass. - This context manager has the following effects on other DeepSpeed features. - 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. - 2. It is illegal to call engine.step() within the context manager. - 3. Tracking of gradient accumulation steps is disabled. - """ - assert not self.zero_optimization_partition_gradients(), \ - f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" - - assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported" - - self.inside_no_sync_ctxt = True - try: - yield - finally: - self.inside_no_sync_ctxt = False - - @instrument_w_nvtx - def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True): - r"""Execute backward pass on the loss - Arguments: - loss: Torch tensor on which to execute backward propagation - retain_graph: bool, default: false - forward on user defined choice of retain_graph - """ - + def _backward_prologue(self, loss, scale_wrt_gas=True): see_memory_usage("Engine before backward", force=self.memory_breakdown()) - if self.scale_wrt_gas is not None: scale_wrt_gas = self.scale_wrt_gas - do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt - # scale loss w.r.t. gradient accumulation if reduction is not disabled + do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt and not self.is_deepcompile_enabled( + ) if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas: loss = self._scale_loss_by_gas(loss.float()) @@ -2114,13 +2141,22 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T )] self.monitor.write_events(self.summary_events) - self._start_timers(self.engine_timers.backward_timers) + if self.is_deepcompile_enabled(): + deepcompile_backward_prologue(self.is_gradient_accumulation_boundary()) - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use backward" + return loss - self._start_timers(self.engine_timers.backward_inner_timers) + def _backward_epilogue(self): + self._start_timers(self.engine_timers.backward_reduce_timers) + if self.enable_backward_allreduce and not self.inside_no_sync_ctxt: + # Traditional code path that allreduces the module parameter grads + self.allreduce_gradients() + self._stop_timers(self.engine_timers.backward_reduce_timers) + see_memory_usage("Engine after backward", force=self.memory_breakdown()) + + def _do_optimizer_backward(self, loss, retain_graph): + self._start_timers(self.engine_timers.backward_inner_timers) if self.zero_optimization(): self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() self.optimizer.backward(loss, retain_graph=retain_graph) @@ -2136,30 +2172,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T else: self.optimizer.backward(loss, retain_graph=retain_graph) elif self.bfloat16_enabled(): - self.optimizer.backward(loss) + self.optimizer.backward(loss, retain_graph=retain_graph) else: if self.eigenvalue_enabled(): loss.backward(create_graph=True, retain_graph=True) else: loss.backward(retain_graph=retain_graph) - self._stop_timers(self.engine_timers.backward_inner_timers) - self._start_timers(self.engine_timers.backward_reduce_timers) - - if do_gradient_reduction: - # Traditional code path that allreduces the module parameter grads - self.allreduce_gradients() + @contextmanager + def no_sync(self): + r""" + Context manager to disable gradient reduction during backward pass. + This context manager has the following effects on other DeepSpeed features: + 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. + 2. It is illegal to call engine.step() within the context manager. + 3. Tracking of gradient accumulation steps is disabled. + """ + assert not self.zero_optimization_partition_gradients(), \ + f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" - self._stop_timers(self.engine_timers.backward_reduce_timers) + assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported" - self._stop_timers(self.engine_timers.backward_timers) + self.inside_no_sync_ctxt = True + try: + yield + finally: + self.inside_no_sync_ctxt = False - if release_loss: - # loss.data = None - pass + @instrument_w_nvtx + def backward(self, loss, retain_graph=False, scale_wrt_gas=True): + r"""Execute backward pass on the loss + Arguments: + loss: Torch tensor on which to execute backward propagation + retain_graph: bool, default: false + forward on user defined choice of retain_graph + """ + assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ + "must provide optimizer during init in order to use backward" - see_memory_usage("Engine after backward", force=self.memory_breakdown()) + self._start_timers(self.engine_timers.backward_timers) + loss = self._backward_prologue(loss, scale_wrt_gas) + self._do_optimizer_backward(loss, retain_graph) + self._backward_epilogue() + self._stop_timers(self.engine_timers.backward_timers) return loss @@ -3014,7 +3070,7 @@ def _load_checkpoint(self, optim_checkpoint = None if load_module_only: deepspeed_states = ['module'] - if self.optimizer is not None: + if self.optimizer is not None and hasattr(self.optimizer, 'refresh_fp32_params'): self.optimizer.refresh_fp32_params() else: has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled() @@ -3817,7 +3873,7 @@ def empty_partition_cache(self): gc.collect() get_accelerator().empty_cache() - def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None: + def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}, schedule=None) -> None: """Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile(). """ @@ -3833,10 +3889,53 @@ def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwarg if 'backend' in compile_kwargs: logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.") + print(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") + + if self.is_deepcompile_enabled(): + assert self.zero_optimization_stage() == ZeroStageEnum.optimizer_states \ + or self.zero_optimization_stage() == ZeroStageEnum.weights \ + , "Currently DeepCompile supports stage 1 or 3 only." + + assert not isinstance(self.optimizer, + DeepSpeedZeRoOffload), "Currently DeepCompile is not supported without an optimizer." + + if schedule is not None: + + def passes_name_to_fn(passes): + for p in passes: + assert callable(p) or p in opt_passes, f"Unknown pass {p}" + return [p if callable(p) else opt_passes[p] for p in passes] + + schedule = [(step, passes_name_to_fn(passes)) for step, passes in schedule] + + assert backend in ['inductor', 'eager'], f"Backend {backend} is not supported for DeepCompile." + + compile_config = self._config.compile_config + if (("zero_optimization" in self.config and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"]) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu"): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + backend = init_z1(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) + # create new dict to avoid modifying original dict self.module.compile(**{**compile_kwargs, 'backend': backend}) + self._is_compiled = True + def get_compile_time(self): + from deepspeed.compile.backend import opt_pass_times + return opt_pass_times + + def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None: + register_compile_pass(pass_name, pass_fn) + + def is_deepcompile_enabled(self): + return self._config.compile_config.deepcompile + @property def is_compiled(self) -> bool: return self._is_compiled @@ -3862,7 +3961,9 @@ def offload_states(self, param_offload_config = self.zero_offload_param() assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." - assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters." + assert not isinstance( + self.optimizer, + DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." if device == OffloadDeviceEnum.none: logger.warning("No device specified for offloading states.") @@ -3881,4 +3982,9 @@ def reload_states(self, non_blocking: bool = False) -> None: """ assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers back is supported only for ZeRO stage 3." + + assert not isinstance( + self.optimizer, + DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." + self.optimizer.reload_states(non_blocking=non_blocking) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index deb44c2e71eb..3068247796ef 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -393,7 +393,7 @@ def train_batch(self, data_iter=None): self.timers(TRAIN_BATCH_TIMER).stop() - if self.global_steps % self.steps_per_print() == 0: + if self.steps_per_print() is not None and self.global_steps % self.steps_per_print() == 0: if self.global_rank == 0: elapsed = self.timers(TRAIN_BATCH_TIMER).elapsed(reset=True) / 1000.0 iter_time = elapsed / self.steps_per_print() @@ -413,7 +413,8 @@ def train_batch(self, data_iter=None): self.global_samples)] self.monitor.write_events(self.summary_events) - if self.wall_clock_breakdown() and self.global_steps % self.steps_per_print() == 0: + if self.steps_per_print() is not None and self.wall_clock_breakdown( + ) and self.global_steps % self.steps_per_print() == 0: self.timers.log([ PIPE_SEND_OUTPUT_TIMER, PIPE_SEND_GRAD_TIMER, diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 49fa2807c355..2bc0c37bffb7 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -443,18 +443,26 @@ def _partition_layers(self, method='uniform'): self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1]) + @staticmethod + def _recursive_getattr(module: torch.nn.Module, attr_name: str) -> torch.Tensor: + '''Allow getting an attribute like "linear.weight"''' + weight = module + for item in attr_name.split("."): + weight = getattr(weight, item) + return weight + def allreduce_tied_weight_gradients(self): '''All reduce the gradients of the tied weights between tied stages''' for key, comm in self.tied_comms.items(): for attr_name in comm['weight_attr']: - weight = getattr(self.tied_modules[key], attr_name) + weight = self._recursive_getattr(self.tied_modules[key], attr_name) dist.all_reduce(weight.grad, group=comm['group']) def get_tied_weights_and_groups(self): weight_group_list = [] for key, comm in self.tied_comms.items(): for attr_name in comm['weight_attr']: - weight = getattr(self.tied_modules[key], attr_name) + weight = self._recursive_getattr(self.tied_modules[key], attr_name) weight_group_list.append((weight, comm['group'])) return weight_group_list @@ -462,7 +470,7 @@ def _synchronize_tied_weights(self): for key, comm in self.tied_comms.items(): for attr_name in comm['weight_attr']: dist.broadcast( - getattr(comm['module'], attr_name), + self._recursive_getattr(comm['module'], attr_name), src=min(comm['ranks']), group=comm['group'], ) @@ -475,7 +483,10 @@ def _index_tied_modules(self): specs = self._layer_specs tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec)) - for key in tie_keys: + # Since Python 3.7, "Dictionary order is guaranteed to be insertion order." + # Sort tie_keys here so that orders of self.tied_comms.items() are consistent + # among ranks. + for key in sorted(tie_keys): # Find the layers that the tied module appears in tied_layers = [] for idx, layer in enumerate(specs): diff --git a/deepspeed/runtime/tensor_parallel/config.py b/deepspeed/runtime/tensor_parallel/config.py index 1300bf9323cd..957984e9f8b3 100644 --- a/deepspeed/runtime/tensor_parallel/config.py +++ b/deepspeed/runtime/tensor_parallel/config.py @@ -47,6 +47,9 @@ class TPTrainingConfig(DeepSpeedConfigModel): In automatic tensor-parallelism training, 'tensor_parallel_size' When set to 0, indicates that it is disabled. """ + tp_overlap_comm: bool = False + """ Whether to overlap communication with computation. Currently, only allreduce supports overlap. """ + tensor_parallel: TPConfig = Field({}, alias="tp") """ Configuration for tensor parallelism used to split the model across several diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 9fd7a65a53ba..16a14776f884 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -823,6 +823,14 @@ def get_only_unique_item(items): return unique_item +def mask_nan_or_inf_with_val_inplace(input, device=None, val=-1.): + norm_is_inf = input.isinf() + norm_is_nan = input.isnan() + inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + err = torch.tensor(-1.0, device=device, dtype=torch.float) + input.masked_fill_(inf_or_nan, err) + + def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors. @@ -897,8 +905,7 @@ def _norm_tensors(tensor_list, _compute_buffer, _norm_type): dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group) total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type) - inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan()) - total_norm.masked_fill_(inf_or_nan, -1) + mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device) return total_norm diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 2706d4474515..a39c86d65881 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -73,6 +73,8 @@ def __getitem__(self, key): def _inject_parameters(module, cls): for module in module.modules(): + module._original_parameters = module._parameters + if cls == ZeROOrderedDict: new_param = cls(parent_module=module) else: @@ -80,6 +82,7 @@ def _inject_parameters(module, cls): for key, param in module._parameters.items(): new_param[key] = param + module._parameters = new_param @@ -232,6 +235,8 @@ def _remove_module_hooks(self): for hook in self.backward_hooks: hook.remove() + self.fwd_pre_hook.remove() + print_rank_0(f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', force=False) @@ -244,7 +249,7 @@ def _start_of_forward_hook(module, *args): self.get_param_coordinator().reset_step() - self.module.register_forward_pre_hook(_start_of_forward_hook) + self.fwd_pre_hook = self.module.register_forward_pre_hook(_start_of_forward_hook) #likely one of them should be enough but just to be safe self._register_deepspeed_module(self.module) @@ -287,7 +292,7 @@ def _register_deepspeed_module(self, module, count=[0]): count[0] = count[0] + 1 self._register_deepspeed_module(child, count=count) - @instrument_w_nvtx + @torch.compiler.disable def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module) @@ -365,9 +370,9 @@ def _run_before_forward_function(input): return _apply_forward_and_backward_to_tensors_only(module, _run_before_forward_function, _run_after_backward_hook, inputs) + @torch.compiler.disable def _post_backward_module_hook(module, inputs): - if not hasattr(module, "ds_grads_remaining"): - module.ds_grads_remaining = 0 + module.ds_grads_remaining = 0 return apply_to_tensors_only(module.post_bwd_fn.apply, inputs, diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index db03a4b86134..e6bb14870fb6 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -556,8 +556,9 @@ def _init_subclass(cls, **kwargs): print_rank_0( "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", force=False) - self.linear_bk = torch.nn.functional.linear - torch.nn.functional.linear = zero3_linear_wrap + if not hasattr(InsertPostInitMethodToModuleSubClasses, "linear_bk"): + InsertPostInitMethodToModuleSubClasses.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = zero3_linear_wrap if self.quantized_initialization: print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False) @@ -1899,7 +1900,6 @@ def _allgather_params(self, param_list, hierarchy=0): tensor_size = partition_size * self.num_partitions flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device) - flat_tensor.requires_grad = False partitions = [] for i in range(self.num_partitions): start = partition_size * i diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 3417080b1bea..437b8981eecc 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -66,7 +66,7 @@ class PartitionedParameterCoordinator: FORWARD_PREFETCH_SUBMIT = 'forward_prefetch_submit' BACKWARD_FETCH_SUBMIT = 'backward_fetch_submit' BACKWARD_FETCH_WAIT = 'backward_fetch_wait' - BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_wait' + BACKWARD_PREFETCH_SUBMIT = 'backward_prefetch_submit' FORWARD_ALL_GATHER = 'forward_all_gather' BACKWARD_ALL_GATHER = 'backward_all_gather' """Handles partitioning and gathering of parameters.""" @@ -261,6 +261,7 @@ def reset_step(self) -> None: self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) self.__step_id = 0 + self.__n_available_params = 0 self.__profiler.reset_events() def _dump_params(self, tag, sub_module, params, step_id=None): @@ -451,7 +452,7 @@ def release_and_reset_all(self, module: Module) -> None: # there's a hook execution issue param.ds_active_sub_modules.clear() self.__release_param(param) - self.__n_available_params = 0 + for param in iter_params(module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: raise RuntimeError(f"{param.ds_summary()} expected to be released") diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ee97b6278d9e..4296cf46c1bb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -19,7 +19,7 @@ from deepspeed.utils.torch import register_grad_hook from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce, all_to_all_loco_quant_reduce -from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item +from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, mask_nan_or_inf_with_val_inplace from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum @@ -448,7 +448,7 @@ def destroy(self): for hook in self._leaf_module_hooks: hook.remove() print_rank_0("Removed grad acc hooks", force=False) - del self.__ipg_bucket_flat_buffer + self._release_ipg_buffers() def initialize_ds_offload( self, @@ -967,7 +967,7 @@ def _create_fp16_sub_groups(self, params_group): def _release_ipg_buffers(self): if self.contiguous_gradients: - self.ipg_buffer = None + self.__ipg_bucket_flat_buffer = None def _optimizer_step(self, sub_group_id): param_group_id = self.sub_group_to_group_id[sub_group_id] @@ -1453,12 +1453,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): total_norm = total_norm_cuda[0]**(1. / norm_type) - norm_is_inf = total_norm.isinf() - norm_is_nan = total_norm.isnan() - inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - - err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) - total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm + mask_nan_or_inf_with_val_inplace(total_norm, device=total_norm.device) return total_norm.cpu() @@ -1516,7 +1511,8 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L # free the gradient if not get_accelerator().is_synchronized_device(): - param.grad.record_stream(get_accelerator().current_stream()) + if param.grad is not None: + param.grad.record_stream(get_accelerator().current_stream()) param.grad = None if self.offload_optimizer and self.swap_optimizer: @@ -1815,7 +1811,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): inf_or_nan = norm_is_nan.logical_or(norm_is_inf) err = torch.tensor(-1.0, device=self.device, dtype=torch.float) - total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm + total_norm = torch.where(inf_or_nan, err, total_norm) return total_norm @@ -2919,7 +2915,8 @@ def needs_offload(target): # contiguous bucket if needs_offload(OffloadStateTypeEnum.contiguous_grad_buffer): - if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer"): + if hasattr(self, "_DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer" + ) and self.__ipg_bucket_flat_buffer is not None: # Record properties like shape, strides, etc. as a meta tensor self.grad_buffer_meta = self.__ipg_bucket_flat_buffer.to("meta") self.__ipg_bucket_flat_buffer = None diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 2bece09bffc4..861f7d23c9c2 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -12,11 +12,12 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, - align_dense_tensors, all_gather_dp_groups) + align_dense_tensors, all_gather_dp_groups, mask_nan_or_inf_with_val_inplace) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.utils import logger +from deepspeed.utils.torch import register_grad_hook from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank from deepspeed.moe.utils import is_moe_param from deepspeed.git_version_info import version @@ -522,11 +523,17 @@ def __init__(self, # resets the data structure value for the next backward propagation self.reset_partition_gradient_structures() - # creates backward hooks for gradient partitioning + # creates backward hooks for the following special handling of gradients + # 1. upcasting for fp32 gradient accumulation + # 2. gradient partitioning + # 3. overlapping backward and reduction self._grad_acc_hooks = [] - if self.partition_gradients or self.overlap_comm: - self.create_reduce_and_remove_grad_hooks() + if (self.partition_gradients or self.overlap_comm or self.use_grad_accum_attribute + or self.contiguous_gradients): + self.create_gradient_handling_hooks() + + self.ready_for_gradients = False self.custom_loss_scaler = False self.external_loss_scale = None @@ -678,6 +685,7 @@ def _release_ipg_buffers(self): self.ipg_buffer = None self.grads_in_partition = None self.grads_in_partition_offset = 0 + self.ready_for_gradients = False def initialize_optimizer_states(self): @@ -874,16 +882,18 @@ def increment_value(dictionary, key): def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() + def _fill_param_grad_accum_attribute(self, param): + if param.grad is not None: + if param.grad_accum is None: + param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) + else: + param.grad_accum.add_(param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) + param.grad = None + def fill_grad_accum_attribute(self): for group in self.bit16_groups: for param in group: - if param.grad is not None: - if param.grad_accum is None: - param.grad_accum = param.grad.to(self.gradient_accumulation_dtype) - else: - param.grad_accum.add_( - param.grad.to(self.gradient_accumulation_dtype).view(param.grad_accum.shape)) - param.grad = None + self._fill_param_grad_accum_attribute(param) def get_gradient_for_reduction(self, param): if self.use_grad_accum_attribute: @@ -901,21 +911,17 @@ def clear_grad_attribute(self, param): else: param.grad = None - def create_reduce_and_remove_grad_hooks(self): - self.grad_accs = [] + def create_gradient_handling_hooks(self): for i, param_group in enumerate(self.bit16_groups): for param in param_group: if param.requires_grad: def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) + def grad_handling_hook(*notneeded): + self.process_gradients(param, i) - self._grad_acc_hooks.append(grad_acc.register_hook(reduce_partition_and_remove_grads)) - self.grad_accs.append(grad_acc) + self._grad_acc_hooks.append(register_grad_hook(param, grad_handling_hook)) wrapper(param, i) @@ -1294,7 +1300,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): param_id = self.get_param_id(p) # as some model have trainable parameters but skipped in training, - # their backward hooks in self.create_reduce_and_remove_grad_hooks() will not run, + # their backward hooks in self.create_gradient_handling_hooks() will not run, # so they have no norm_for_param_grads if param_id in self.norm_for_param_grads: param_norm = self.norm_for_param_grads[param_id] @@ -1421,6 +1427,13 @@ def reduce_ipg_grads(self): self.elements_in_ipg_bucket = 0 ##################################################################### + def process_gradients(self, param, i): + self.backward_prologue() + if self.use_grad_accum_attribute: + self._fill_param_grad_accum_attribute(param) + if self.partition_gradients or self.overlap_comm: + self.reduce_ready_partitions_and_remove_grads(param, i) + def reduce_ready_partitions_and_remove_grads(self, param, i): if self.partition_gradients or self.is_gradient_accumulation_boundary: self.reduce_independent_p_g_buckets_and_remove_grads(param, i) @@ -1709,12 +1722,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): total_norm = total_norm.pow(1. / norm_type) - norm_is_inf = total_norm.isinf() - norm_is_nan = total_norm.isnan() - inf_or_nan = norm_is_nan.logical_or(norm_is_inf) + mask_nan_or_inf_with_val_inplace(total_norm, device=self.device) - err = torch.tensor(-1.0, device=self.device, dtype=torch.float) - total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm # creates a flat fused tensor from the tensor list starting at the first_offset @@ -1949,9 +1958,7 @@ def update_lp_params(self): zip(self.parallel_partitioned_bit16_groups, self.single_partition_of_fp32_groups)): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) bit16_partitions[partition_id].data.copy_(fp32_partition.data) - # print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True) - # if i == 0: - # print_rank_0(f'{fp32_partition[:10]=}', force=True) + all_gather_dp_groups(groups_flat=self.bit16_groups_flat, partitioned_param_groups=self.parallel_partitioned_bit16_groups, dp_process_group=self.real_dp_process_group, @@ -2035,6 +2042,30 @@ def _has_inf_or_nan(x, j=None): inf_or_nan = nan.logical_or(inf) return inf_or_nan.float().max() + def backward_prologue(self): + if not self.ready_for_gradients: + self.micro_step_id += 1 + if self.contiguous_gradients and self.ipg_buffer is None: + self.ipg_buffer = [] + buf_0 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_0) + + # Use double buffers to avoid data access conflict when overlap_comm is enabled. + if self.overlap_comm: + buf_1 = torch.empty(int(self.reduce_bucket_size), + dtype=self.dtype, + device=get_accelerator().current_device_name()) + self.ipg_buffer.append(buf_1) + self.ipg_index = 0 + self.ready_for_gradients = True + + def backward_epilogue(self): + # Only for Stage 1, Mode 2 + if self.use_grad_accum_attribute: + self.fill_grad_accum_attribute() + def backward(self, loss, retain_graph=False): """ :attr:`backward` performs the following steps: @@ -2043,32 +2074,13 @@ def backward(self, loss, retain_graph=False): 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ - self.micro_step_id += 1 - - if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_0) - - # Use double buffers to avoid data access conflict when overlap_comm is enabled. - if self.overlap_comm: - buf_1 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 - + self.backward_prologue() if self.custom_loss_scaler: scaled_loss = self.external_loss_scale * loss - scaled_loss.backward() + scaled_loss.backward(retain_graph=retain_graph) else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - # Only for Stage 1, Mode 2 - if self.use_grad_accum_attribute: - self.fill_grad_accum_attribute() + self.backward_epilogue() def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients) diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index ba5e596e0d6d..dc47da072e5c 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -338,11 +338,11 @@ def __init__( if sp_stream is not None: self.overlap_handles = {} self.sp_overlap_comm = True - self.dafult_stream = get_accelerator().default_stream() + self.default_stream = get_accelerator().default_stream() def layer_sync(self, layer): if self.sp_overlap_comm and hasattr(layer, 'done_event'): - self.dafult_stream.wait_event(layer.done_event) + self.default_stream.wait_event(layer.done_event) def forward(self, query: Tensor, @@ -374,7 +374,7 @@ def bwd_hook(layer_type): def pre_hook_fun(grad): type = 'd' + layer_type self.overlap_handles[type + '_work'].wait() - self.sp_stream.wait_stream(self.dafult_stream) + self.sp_stream.wait_stream(self.default_stream) all2all_output = self.overlap_handles[type + '_grad'] grad = list(grad) grad[0] = self.overlap_handles[type + '_post_all2all_func'](all2all_output) @@ -389,7 +389,7 @@ def pre_hook_fun(grad): key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'k') if self.sp_overlap_comm: - self.dafult_stream.wait_stream(self.sp_stream) + self.default_stream.wait_stream(self.sp_stream) value_layer = _SeqAllToAll.apply(self.spg, value, self.scatter_idx, self.gather_idx, batch_dim_idx, None, self.overlap_handles, 'v') diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index 6dc750035061..48dbed4d161b 100755 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -589,7 +589,7 @@ def _get_data_parallel_rank(): def _get_sequence_parallel_world_size(): - """Return world size for the model parallel group.""" + """Return world size for the sequence parallel group.""" global mpu if mesh_device is not None: return dist.get_world_size(mesh_device.get_group(mesh_dim="sequence_parallel")) @@ -599,7 +599,7 @@ def _get_sequence_parallel_world_size(): def _get_sequence_parallel_rank(): - """Return my rank for the data parallel group.""" + """Return my rank for the sequence parallel group.""" global mpu if mpu is not None and hasattr(mpu, 'get_sequence_parallel_rank'): return mpu.get_sequence_parallel_rank() diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index e3a7029aae50..b721f7619d9d 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -1,5 +1,5 @@ Training API -============ +############ :func:`deepspeed.initialize` returns a *training engine* in its first argument of type :class:`DeepSpeedEngine`. This engine is used to progress training: @@ -39,3 +39,55 @@ Model Saving Additionally when a DeepSpeed checkpoint is created, a script ``zero_to_fp32.py`` is added there which can be used to reconstruct fp32 master weights into a single pytorch ``state_dict`` file. + + +Training Multiple Models +------------------------ +DeepSpeed supports training multiple models, which is a useful feature in `scenarios `_ such as knowledge distillation and post-training RLHF. +The core approach is to create individual DeepSpeedEngines for each model. + + +Training Independent Models +=========================== + +The following code snippet illustrates independently training multiple models on the same dataset. + +.. code-block:: python + + model_engines = [engine for engine, _, _, _ in [deepspeed.initialize(m, ...,) for m in models]] + for batch in data_loader: + losses = [engine(batch) for engine in model_engines] + for engine, loss in zip(model_engines, losses): + engine.backward(loss) + + +The above is similar to typical DeepSpeed usage except for the creation of multiple DeepSpeedEngines (one for each model). + + +Jointly Training Models With Shared Loss +======================================== + +The following code snippet illustrates jointly training multiple models on a shared loss value. + +.. code-block:: python + + model_engines = [engine for engine, _, _, _ in [deepspeed.initialize(m, ...,) for m in models]] + for batch in data_loader: + losses = [engine(batch[0], batch[1]) for engine in model_engines] + loss = sum(l / (i + 1) for i, l in enumerate(losses)) + loss.backward() + + for engine in model_engines: + engine._backward_epilogue() + + for engine in model_engines: + engine.step() + + for engine in model_engines: + engine.optimizer.zero_grad() + +Besides the use of multiple DeepSpeedEngines, the above differs from typical usage in two key ways: + +#. The **backward** call is made using the common loss value rather on individual model engines. + +#. **_backward_epilogue** is called on model engine, after the **loss.backward()**. diff --git a/docs/index.md b/docs/index.md index 6ce7e6251cbf..9c3a8b309ed0 100755 --- a/docs/index.md +++ b/docs/index.md @@ -7,24 +7,33 @@ title: "Latest News" --- DeepSpeed empowers ChatGPT-like model training with a single click, offering 15x speedup over SOTA RLHF systems with unprecedented cost reduction at all scales; [learn how](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-chat). + + +* [2025/04] [DeepCompile: Unlocking Compiler Optimization for Distributed Training](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepcompile/README.md) + +* [2025/03] [DeepSpeed AutoTP: Automatic Tensor Parallel Training of Hugging Face models](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/huggingface-tp/README.md) + +* [2024/12] [Ulysses-Offload: Democratizing Long Context LLM Training ](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/ulysses-offload/README.md) + * [2024/12] [DeepSpeed Domino: Communication-Free LLM Training Engine](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-domino/README.md) * [2024/08] [DeepSpeed on Windows](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/windows/08-2024/README.md)[[日本語](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/windows/08-2024/japanese/README.md)] [[中文](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/windows/08-2024/chinese/README.md)] -* [2024/08] [DeepNVMe: Improving DL Applications through I/O Optimizations](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md)[[日本語](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-gds/japanese/README.md)] [[中文](https://github.com/deepspeedai/DeepSpeed/blob/master/blogs/deepspeed-gds/chinese/README.md)] -* [2024/07] [DeepSpeed Universal Checkpointing: Efficient and Flexible Checkpointing for Large Scale Distributed Training](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ucp/README.md)[[日本語](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ucp/japanese/README.md)] -* [2024/03] [DeepSpeed-FP6: The Power of FP6-Centric Serving for Large Language Models](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md) [[English](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README.md)] [[中文](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-fp6/03-05-2024/README-Chinese.md)] -
+ More news + diff --git a/op_builder/builder.py b/op_builder/builder.py index 9b721e110fcc..f31870a1e4ce 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -590,6 +590,7 @@ def jit_load(self, verbose=True): extra_cflags=cxx_args, extra_cuda_cflags=nvcc_args, extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), + with_cuda=True if (isinstance(self, CUDAOpBuilder) and not self.build_for_cpu) else None, verbose=verbose) build_duration = time.time() - start_build diff --git a/op_builder/dc.py b/op_builder/dc.py new file mode 100644 index 000000000000..d05210b8a2b4 --- /dev/null +++ b/op_builder/dc.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .builder import TorchCPUOpBuilder + + +class DeepCompileBuilder(TorchCPUOpBuilder): + BUILD_VAR = "DS_BUILD_DEEP_COMPILE" + NAME = "dc" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def sources(self): + return [ + 'csrc/compile/deepcompile.cpp', 'csrc/compile/init.cpp', 'csrc/compile/z1.cpp', 'csrc/compile/z3.cpp', + 'csrc/compile/util.cpp' + ] + + def libraries_args(self): + args = super().libraries_args() + return args + + def include_paths(self): + import os + import torch + if self.build_for_cpu: + CUDA_INCLUDE = [] + elif not self.is_rocm_pytorch(): + CUDA_INCLUDE = [os.path.join(torch.utils.cpp_extension.CUDA_HOME, "include")] + else: + CUDA_INCLUDE = [ + os.path.join(torch.utils.cpp_extension.ROCM_HOME, "include"), + ] + return ['csrc/includes'] + CUDA_INCLUDE diff --git a/op_builder/fp_quantizer.py b/op_builder/fp_quantizer.py index df4d967ea09a..2b962ac2c1fe 100644 --- a/op_builder/fp_quantizer.py +++ b/op_builder/fp_quantizer.py @@ -54,7 +54,7 @@ def is_compatible(self, verbose=False): return False # triton 2.3.{0,1} and 3.0.0 are ok. - allowed_versions = ("2.3", "3.0") + allowed_versions = ("2.3", "3.0", "3.1", "3.2") if pkg_version: allowed = (pkg_version.parse(v) for v in allowed_versions) installed_triton = pkg_version.parse(triton.__version__) diff --git a/requirements/requirements-deepcompile.txt b/requirements/requirements-deepcompile.txt new file mode 100644 index 000000000000..9a635b910d93 --- /dev/null +++ b/requirements/requirements-deepcompile.txt @@ -0,0 +1 @@ +scipy diff --git a/scripts/check-extraindexurl.py b/scripts/check-extraindexurl.py index 01b506dc939d..017939af95ac 100755 --- a/scripts/check-extraindexurl.py +++ b/scripts/check-extraindexurl.py @@ -27,8 +27,12 @@ def err(s: str) -> None: # - we can reasonably assume it's available on all machines # - unlike plain grep, which is slower and has different flags on MacOS versus # Linux, git grep is always the same. +excluded_file = ".github/workflows/xpu-max1100.yml" res = subprocess.run( - ["git", "grep", "-Hn", "--no-index", "-e", r"--extra-index-url", *sys.argv[1:]], + [ + "git", "grep", "-Hn", "--no-index", "-e", r"--extra-index-url", "--", f":(exclude){excluded_file}", + *sys.argv[1:] + ], capture_output=True, ) if res.returncode == 0: diff --git a/setup.py b/setup.py index 0ad54bb99403..7ae9c7a9c421 100755 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ def get_env_if_set(key, default: typing.Any = ""): 'inf': fetch_requirements('requirements/requirements-inf.txt'), 'sd': fetch_requirements('requirements/requirements-sd.txt'), 'triton': fetch_requirements('requirements/requirements-triton.txt'), + 'deepcompile': fetch_requirements('requirements/requirements-deepcompile.txt'), } # Only install pynvml on nvidia gpus. diff --git a/tests/unit/inference/quantization/test_intX_quantization.py b/tests/unit/inference/quantization/test_intX_quantization.py index 77b51fcd5814..8169912ae487 100644 --- a/tests/unit/inference/quantization/test_intX_quantization.py +++ b/tests/unit/inference/quantization/test_intX_quantization.py @@ -17,6 +17,7 @@ import pytest from collections import OrderedDict from typing import Dict +from deepspeed.ops.aio import AsyncIOBuilder device = get_accelerator().device_name() if get_accelerator().is_available() else 'cpu' @@ -57,6 +58,9 @@ def zero3_post_init_quantization_test_helper(cpu_offload: bool, nvme_offload: bo import deepspeed from transformers.integrations.deepspeed import HfDeepSpeedConfig + if nvme_offload and not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: GB = 1 << 30 @@ -174,6 +178,9 @@ def zero3_quantized_initialization_test_helper(cpu_offload: bool, nvme_offload: import deepspeed from transformers.integrations.deepspeed import HfDeepSpeedConfig + if nvme_offload and not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + def get_zero3_ds_config(hf_config: OPTConfig, cpu_offload: bool, nvme_offload: bool, bits: int) -> Dict: GB = 1 << 30 diff --git a/tests/unit/launcher/test_user_args.py b/tests/unit/launcher/test_user_args.py index b86be4dfe74c..fd1489803812 100644 --- a/tests/unit/launcher/test_user_args.py +++ b/tests/unit/launcher/test_user_args.py @@ -6,7 +6,20 @@ import pytest import subprocess +from types import SimpleNamespace + from deepspeed.accelerator import get_accelerator +from deepspeed.launcher.multinode_runner import MultiNodeRunner + + +class DummyRunner(MultiNodeRunner): + + def backend_exists(self): + return True + + def get_cmd(self, environment, active_resources): + return [] + if not get_accelerator().is_available(): pytest.skip("only supported in accelerator environments.", allow_module_level=True) @@ -38,6 +51,12 @@ def cmd(user_script_fp, prompt, multi_node): return cmd +@pytest.fixture +def dummy_runner(): + args = SimpleNamespace(user_args=[], user_script="dummy_script.py") + return DummyRunner(args, "dummy_world_info") + + @pytest.mark.parametrize("prompt", [ '''"I am 6' tall"''', """'I am 72" tall'""", """'"translate English to Romanian: "'""", '''I'm going to tell them "DeepSpeed is the best"''' @@ -64,3 +83,20 @@ def test_bash_string_args(tmpdir, user_script_fp): p = subprocess.Popen(["bash", bash_fp], stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = p.communicate() assert "ARG PARSE SUCCESS" in out.decode("utf-8"), f"User args not parsed correctly: {err.decode('utf-8')}" + + +def test_add_export_with_special_characters(dummy_runner): + """ + Values with special characters (e.g., 64(x2)) must be quoted to avoid bash syntax errors. + """ + dummy_runner.add_export("SLURM_JOB_CPUS_PER_NODE", "64(x2)") + assert dummy_runner.exports["SLURM_JOB_CPUS_PER_NODE"] == "\"64(x2)\"" + + +def test_add_export_no_special_characters(dummy_runner): + """ + Values without special characters should remain unquoted (e.g., PYTHONPATH). + This avoids issues where unnecessary quotes break module imports. + """ + dummy_runner.add_export("PYTHONPATH", "/usr/local/lib/python3.9/site-packages") + assert dummy_runner.exports["PYTHONPATH"] == "/usr/local/lib/python3.9/site-packages" diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index 7680b28ce6b5..db9c454761b5 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -163,11 +163,12 @@ def process_linear_layer(hidden_dim, input): @pytest.mark.sequential @pytest.mark.parametrize("tp_size", [2, 4]) +@pytest.mark.parametrize("tp_overlap_comm", [True, False]) class TestTpLayerFwdBwd(DistributedTest): world_size = 4 reuse_dist_env = True - def testRowParallel(self, tp_size: int): + def testRowParallel(self, tp_size: int, tp_overlap_comm: bool): skip_on_device() hidden_dim = 128 batch_size_per_device = 1 @@ -182,7 +183,8 @@ def testRowParallel(self, tp_size: int): } }, "tensor_parallel": { - "autotp_size": tp_size + "autotp_size": tp_size, + "tp_overlap_comm": tp_overlap_comm }, "zero_optimization": { "stage": 0, @@ -214,9 +216,9 @@ def testRowParallel(self, tp_size: int): torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()] assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3) - assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3) + assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-2) - def testColumnParallel(self, tp_size: int): + def testColumnParallel(self, tp_size: int, tp_overlap_comm: bool): skip_on_device() hidden_dim = 128 batch_size_per_device = 1 @@ -231,7 +233,8 @@ def testColumnParallel(self, tp_size: int): } }, "tensor_parallel": { - "autotp_size": tp_size + "autotp_size": tp_size, + "tp_overlap_comm": tp_overlap_comm }, "zero_optimization": { "stage": 0, @@ -266,7 +269,7 @@ def testColumnParallel(self, tp_size: int): assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3) assert torch.allclose(cur_device_out.to(get_accelerator().current_device()).contiguous(), out.contiguous(), - atol=1e-3) + atol=1e-2) @pytest.mark.sequential diff --git a/tests/unit/runtime/compile/test_compile_zero.py b/tests/unit/runtime/compile/test_compile_zero.py index ca80eef8b31e..ca5b7e74b64c 100644 --- a/tests/unit/runtime/compile/test_compile_zero.py +++ b/tests/unit/runtime/compile/test_compile_zero.py @@ -13,6 +13,8 @@ from unit.runtime.compile.util import compare_loss from unit.common import DistributedTest from unit.util import bf16_required_version_check, skip_on_arch +import deepspeed +from deepspeed.ops.aio import AsyncIOBuilder pytestmark = pytest.mark.skipif(not required_torch_version(min_version=2.1), reason="Compile tests requires Pytorch version 2.1 or above") @@ -36,6 +38,8 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): pytest.skip("CPU does not support this test yet") if offload_device == OffloadDeviceEnum.nvme: + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') if zero_stage != 3: pytest.skip(f"Nvme offload not supported for zero stage {zero_stage}") @@ -66,3 +70,49 @@ def test_compile_zero(self, tmpdir, zero_stage, dtype, offload_device): config_dict["bf16"] = {"enabled": True} compare_loss(self, config_dict, dtype) + + +class TestDeepCompile(DistributedTest): + world_size = 2 + non_daemonic_procs = True + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) + @pytest.mark.parametrize('zero_stage', [1, 3]) + @pytest.mark.parametrize('deepcompile', [True]) # deepcompile==False is included in test_compile_zero + def test(self, zero_stage, dtype, deepcompile): + if not required_torch_version(min_version=2.6): + pytest.skip("DeepCompile requires PyTorch >= v2.6") + + if dtype == torch.bfloat16: + skip_on_arch(min_arch=8) + if dtype == torch.bfloat16 and not bf16_required_version_check(): + pytest.skip( + "DeepSpeed BFloat16 tests need NCCL >= 2.10.3, CUDA >=11.0, and HW support for BFloat16 to run correctly" + ) + if get_accelerator().device_name() == "cpu": + pytest.skip("CPU does not support this test yet") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "compile": { + "deepcompile": deepcompile + } + } + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif dtype == torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + # Need warmup steps + compare_loss(self, config_dict, dtype, iteration=10) diff --git a/tests/unit/runtime/compile/util.py b/tests/unit/runtime/compile/util.py index d53886a81429..f30895d68bb3 100644 --- a/tests/unit/runtime/compile/util.py +++ b/tests/unit/runtime/compile/util.py @@ -70,8 +70,7 @@ def wrapper(*args: Any, **kwargs: Any): @enable_determinism(123) -def compare_loss(self, config, dtype): - iteration = 5 +def compare_loss(self, config, dtype, iteration=5): hidden_dim = 10 RTOL = 5e-1 ATOL = 1e-2 @@ -116,9 +115,12 @@ def compare_loss(self, config, dtype): baseline_engine.backward(baseline_loss) target_engine.backward(target_loss) - baseline_optimizer.step() - target_optimizer.step() + baseline_engine.step() + target_engine.step() with GatheredParameters(target_engine.parameters()): for p1, p2 in zip(baseline_engine.parameters(), target_engine.parameters()): assert torch.allclose(p1.to(dtype), p2, rtol=RTOL, atol=ATOL) + + baseline_engine.destroy() + target_engine.destroy() diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index dba15a969459..d19cdf146294 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -357,6 +357,8 @@ def test(self, zero_stage, use_cpu_offload): model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @pytest.mark.parametrize("use_cpu_offload", [True, False]) @@ -402,6 +404,8 @@ def test(self, zero_stage, use_cpu_offload, hidden_dim=4): model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @pytest.mark.parametrize("use_cpu_offload", [True, False]) @@ -436,6 +440,7 @@ def test(self, zero_stage, use_cpu_offload): model=model, optimizer=optimizer, model_parameters=model.parameters()) + model.destroy() @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @@ -486,6 +491,8 @@ def test(self, zero_stage, use_cpu_offload): model.backward(loss) model.step() + model.destroy() + @amp_available class TestAmp(DistributedTest): @@ -615,6 +622,7 @@ def test(self, zero_stage, optimizer_constructor): model = SimpleModel(hidden_dim) client_optimizer = optimizer_constructor(params=model.parameters()) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=client_optimizer) + model.destroy() class TestZero2ReduceScatterOff(DistributedTest): @@ -727,6 +735,8 @@ def test(self): model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize('stage', [1, 2, 3]) class TestZeroEmptyGrad(DistributedTest): @@ -755,3 +765,5 @@ def test(self, stage): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + model.destroy() diff --git a/tests/unit/runtime/test_data_efficiency.py b/tests/unit/runtime/test_data_efficiency.py index 87fb49aad830..a52ca2982b9a 100644 --- a/tests/unit/runtime/test_data_efficiency.py +++ b/tests/unit/runtime/test_data_efficiency.py @@ -50,12 +50,16 @@ def get_model_parallel_group(self): return self.tp_group +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) class TestDataEfficiency(DistributedTest): world_size = 2 - def test_curriculum_learning(self): + def test_curriculum_learning(self, dtype): if get_accelerator().device_name() == "cpu": pytest.skip("CPU accelerator does not support this test yet") + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"This test does not support {dtype=}.") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -96,9 +100,10 @@ def test_curriculum_learning(self): } } } - if get_accelerator().is_fp16_supported(): - config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16} - elif get_accelerator().is_bf16_supported(): + + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 8} + else: config_dict["bf16"] = {"enabled": True} def data_post_process(data, data_sampler_state_dict): @@ -107,7 +112,7 @@ def data_post_process(data, data_sampler_state_dict): hidden_dim = 10 model = SimpleModel(hidden_dim) - dataset = random_dataset(20, hidden_dim, torch.device('cpu')) + dataset = random_dataset(20, hidden_dim, torch.device('cpu'), dtype=dtype) model, _, data_loader, _ = deepspeed.initialize(config=config_dict, model=model, training_data=dataset, @@ -126,12 +131,16 @@ def data_post_process(data, data_sampler_state_dict): break +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) class TestLegacyCurriculumScheduler(DistributedTest): world_size = 2 - def test_fixed_discrete(self): + def test_fixed_discrete(self, dtype): if get_accelerator().device_name() == "cpu": pytest.skip("CPU accelerator does not support this test yet") + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"This test does not support {dtype=}.") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -155,16 +164,20 @@ def test_fixed_discrete(self): } } } - if get_accelerator().is_fp16_supported(): - config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16} - elif get_accelerator().is_bf16_supported(): + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 8} + else: config_dict["bf16"] = {"enabled": True} hidden_dim = 10 ground_truths = {1: 1, 2: 1, 3: 2, 4: 2, 5: 3, 6: 3, 7: 4, 8: 4} model = Curriculum_SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) for n, batch in enumerate(data_loader): loss, seqlen = model(batch[0], batch[1]) model.backward(loss) @@ -172,11 +185,14 @@ def test_fixed_discrete(self): true_seqlen = 5 if n + 1 in ground_truths: true_seqlen = ground_truths[n + 1] - assert seqlen == true_seqlen, f"Incorrect curriculum schedule" + assert seqlen == true_seqlen, f"Incorrect curriculum schedule {n=}, {seqlen=}, {true_seqlen=}" - def test_fixed_linear(self): + def test_fixed_linear(self, dtype): if get_accelerator().device_name() == "cpu": pytest.skip("CPU accelerator does not support this test yet") + if not dtype in get_accelerator().supported_dtypes(): + pytest.skip(f"This test does not support {dtype=}.") + config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -200,16 +216,20 @@ def test_fixed_linear(self): } } } - if get_accelerator().is_fp16_supported(): - config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 16} - elif get_accelerator().is_bf16_supported(): + if dtype == torch.float16: + config_dict["fp16"] = {"enabled": True, "loss_scale": 0, "initial_scale_power": 8} + else: config_dict["bf16"] = {"enabled": True} hidden_dim = 10 ground_truths = {1: 2, 2: 4, 3: 4, 4: 6, 5: 6, 6: 8, 7: 8, 8: 10, 9: 10, 10: 10} model = Curriculum_SimpleModel(hidden_dim) model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=20, hidden_dim=hidden_dim, device=model.device) + data_loader = random_dataloader(model=model, + total_samples=20, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) for n, batch in enumerate(data_loader): loss, seqlen = model(batch[0], batch[1]) model.backward(loss) diff --git a/tests/unit/runtime/test_multiple_models.py b/tests/unit/runtime/test_multiple_models.py new file mode 100644 index 000000000000..ba9aab69700e --- /dev/null +++ b/tests/unit/runtime/test_multiple_models.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import deepspeed +import deepspeed.comm as dist +import torch +from unit.common import DistributedTest +from unit.simple_model import SimpleModel, random_dataloader + + +def create_model(config_dict): + hidden_dim = 64 + model = SimpleModel(hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + return model + + +def train_shared_loss(num_models, config_dict, dtype): + hidden_dim = 64 + + models = [create_model(config_dict) for _ in range(num_models)] + data_loader = random_dataloader(model=models[0], + total_samples=4, + hidden_dim=hidden_dim, + device=models[0].device, + dtype=dtype) + dist.barrier() + for _, batch in enumerate(data_loader): + losses = [m.module(batch[0], batch[1]) for m in models] + loss = sum(l / (i + 1) for i, l in enumerate(losses)) + loss.backward() + + for m in models: + m._backward_epilogue() + + for m in models: + m.step() + + for m in models: + m.optimizer.zero_grad() + + for m in models: + m.destroy() + + +def train_independent_loss(num_models, config_dict, dtype): + hidden_dim = 64 + + models = [create_model(config_dict) for _ in range(num_models)] + data_loader = random_dataloader(model=models[0], + total_samples=4, + hidden_dim=hidden_dim, + device=models[0].device, + dtype=dtype) + dist.barrier() + for _, batch in enumerate(data_loader): + losses = [m.module(batch[0], batch[1]) for m in models] + for m, loss in zip(models, losses): + m.backward(loss) + m.step() + + for m in models: + m.destroy() + + +@pytest.mark.parametrize('num_models', [1, 2, 3]) +class TestMultipleModels(DistributedTest): + world_size = 2 + reuse_dist_env = True + + @pytest.mark.parametrize('shared_loss', [False, True]) + @pytest.mark.parametrize('zero_stage', [1, 2, 3]) + @pytest.mark.parametrize('fp32_grad_accum', [False, True]) + @pytest.mark.parametrize('contiguous_gradients', [False, True]) + @pytest.mark.parametrize('overlap_comm', [False, True]) + def test_zero_optimizer(self, num_models, shared_loss, zero_stage, fp32_grad_accum, contiguous_gradients, + overlap_comm): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": zero_stage, + "contiguous_gradients": contiguous_gradients, + "overlap_comm": overlap_comm, + }, + "fp16": { + "initial_scale_power": 8, + "enabled": True + }, + } + if fp32_grad_accum: + config_dict["data_types"] = {"grad_accum_dtype": "fp32"} + + if shared_loss: + train_shared_loss(num_models=num_models, config_dict=config_dict, dtype=torch.float16) + else: + train_independent_loss(num_models=num_models, config_dict=config_dict, dtype=torch.float16) + + @pytest.mark.parametrize('shared_loss', [False, True]) + def test_bf16_optimizer(self, num_models, shared_loss): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "zero_optimization": { + "stage": 1, + }, + "bf16": { + "enabled": True + }, + "data_types": { + "grad_accum_dtype": "fp32" + } + } + + if shared_loss: + train_shared_loss(num_models=num_models, config_dict=config_dict, dtype=torch.bfloat16) + else: + train_independent_loss(num_models=num_models, config_dict=config_dict, dtype=torch.bfloat16) diff --git a/tests/unit/runtime/zero/test_nvme_checkpointing.py b/tests/unit/runtime/zero/test_nvme_checkpointing.py index 01a75aa64b4e..5b0c9d2a0d34 100644 --- a/tests/unit/runtime/zero/test_nvme_checkpointing.py +++ b/tests/unit/runtime/zero/test_nvme_checkpointing.py @@ -18,6 +18,7 @@ from deepspeed.accelerator import get_accelerator +@pytest.mark.sequential class TestNVMeCheckpointing(DistributedTest): world_size = 1 diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 2ae2755086f8..73580a01c514 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -83,6 +83,7 @@ def test(self, zero_stage): data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) run_unbalanced_gradients(model, data_loader) + model.destroy() # testing the fix https://github.com/deepspeedai/DeepSpeed/pull/1227 @@ -143,6 +144,8 @@ def forward(self, x, y): model.backward(loss) model.step() + model.destroy() + # testing the fix https://github.com/deepspeedai/DeepSpeed/pull/1227 # also reproduces the https://github.com/deepspeedai/DeepSpeed/pull/1372 @@ -243,6 +246,8 @@ def forward(self, x, y): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float()) + model.destroy() + def test_2_param_groups(self, tmpdir, zero_stage, freeze_params): # TODO: # - need to test with multiple param groups @@ -348,6 +353,8 @@ def forward(self, x, y): # float() workaround for torch<1.6 assert torch.allclose(orig_state_dict[name].float(), fp32_state_dict[name].float()) + model.destroy() + @pytest.mark.parametrize("allgather_bucket_size", [1000, 1001]) class TestIncorectAllgatherBucketSize(DistributedTest): @@ -821,6 +828,8 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0) + ds_engine.destroy() + @pytest.mark.parametrize("init_context_manager", [True, False]) @pytest.mark.parametrize("reduce_scatter", [True, False]) @@ -893,6 +902,8 @@ def forward(self, x: Tensor) -> Tensor: assert torch.allclose(weight_gradient, expected_weight_gradient) + ds_engine.destroy() + @pytest.mark.parametrize("init_context_manager", [True, False]) class TestZero3ParamPartitioningManyParams(DistributedTest): @@ -977,6 +988,8 @@ def forward(self, x: Tensor) -> Tensor: for layer_num, activation in enumerate(weight_gradients): pass + ds_engine.destroy() + class TestZero3InitForParentWeightInitialization(DistributedTest): world_size = 4 @@ -1197,6 +1210,8 @@ def create_tensor(vals): ds_engine.optimizer.step() _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + ds_engine.destroy() + class TestParamPartitioningSkipInit(DistributedTest): world_size = 2 @@ -1274,6 +1289,8 @@ def forward(self, x, y): model.backward(loss) model.step() + model.destroy() + class TestZeroOffloadStage1(DistributedTest): world_size = 2 @@ -1311,6 +1328,8 @@ def test(self): model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("return_type", [tuple, list, dict]) class TestZero3DictFwd(DistributedTest): @@ -1373,6 +1392,8 @@ def forward(self, x, y): model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("zero_stage", [1, 2, 3]) class TestZeroAdamOptimizerStepCount(DistributedTest): @@ -1439,6 +1460,8 @@ def test(self, zero_stage): assert all(step == step_counts[0] for step in step_counts) assert model.global_steps == step_counts[0] + model.destroy() + @pytest.mark.parametrize("zero_stage", [1, 2, 3]) class TestZeroFrozenWeights(DistributedTest): @@ -1497,6 +1520,8 @@ def forward(self, x, y): model.backward(loss) model.step() + model.destroy() + @pytest.mark.parametrize("force_ds_optim", [True, False]) class TestZeroOffloadOptim(DistributedTest): @@ -1577,6 +1602,8 @@ def test_training_partition_cache(self, training): model.empty_partition_cache() assert sum([p.numel() for p in model.parameters()]) == 0 + model.destroy() + @pytest.mark.parametrize("use_client_optimizer", [True, False]) @pytest.mark.parametrize("empty_weight_group", [True, False]) @@ -1629,6 +1656,8 @@ def test_empty_param_groups(self, dtype, use_client_optimizer, empty_weight_grou config=config_dict, ) + model.destroy() + class TestZero3SwitchModes(DistributedTest): world_size = 2 @@ -1674,6 +1703,8 @@ def test(self, prefetch_ratio, zero_stage=3): for batch in data_loader: loss = model(batch[0], batch[1]) + model.destroy() + # Avoid overwriting client module id # https://github.com/deepspeedai/DeepSpeed/issues/6772 @@ -1707,3 +1738,4 @@ def forward(self, x): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) post_init_m_id = model.id assert pre_init_m_id == post_init_m_id + model.destroy() diff --git a/tests/unit/runtime/zero/test_zero_multiple_run.py b/tests/unit/runtime/zero/test_zero_multiple_run.py deleted file mode 100644 index d4eb3a578cc9..000000000000 --- a/tests/unit/runtime/zero/test_zero_multiple_run.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import deepspeed -import torch -from unit.common import DistributedTest, preferred_dtype -from unit.simple_model import SimpleModel, random_dataloader - - -class TestZ3MultipleModelCall(DistributedTest): - world_size = 1 - - def test_z3_multiple_model_call(self): - config_dict = { - "train_micro_batch_size_per_gpu": 1, - "gradient_accumulation_steps": 1, - "steps_per_print": 1, - "zero_optimization": { - "stage": 3 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - } - if preferred_dtype() is torch.float16: - config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} - elif preferred_dtype() is torch.bfloat16: - config_dict["bf16"] = {"enabled": True} - hidden_dim, nlayers = 2048, 3 - model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) - model_engine, _, _, _ = deepspeed.initialize(config=config_dict, - model=model, - model_parameters=model.parameters()) - data_loader = iter( - random_dataloader(model=model_engine, total_samples=10, hidden_dim=hidden_dim, device=model_engine.device)) - - for n, batch in enumerate(data_loader): - loss1 = model_engine(batch[0], batch[1]) - with torch.no_grad(): - loss2 = model_engine(batch[0], batch[1]) - loss = loss1 + loss2 - model_engine.backward(loss) - for name, submodule in model_engine.module.linears._modules.items(): - assert hasattr(submodule, "ds_grads_remaining"), \ - f"linears.{name} does not have variable ds_grads_remaining" - assert submodule.ds_grads_remaining == 0, \ - f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})" - model_engine.step() diff --git a/version.txt b/version.txt index 19270385eaf7..ce62dc55bf66 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.16.5 +0.16.9