Skip to content

Commit

Permalink
Add support for MoE models with megablocks (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Oct 21, 2024
1 parent 6e32043 commit 4d3b231
Show file tree
Hide file tree
Showing 35 changed files with 1,281 additions and 205 deletions.
35 changes: 24 additions & 11 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ on:
pull_request:
branches:
- main
paths:
- 'Makefile'
- 'pyproject.toml'
- 'src/olmo_core/version.py'
- 'src/Dockerfile'
- '.github/workflows/docker.yml'
push:
branches:
- main
Expand All @@ -16,15 +22,11 @@ on:

jobs:
beaker:
name: Beaker image (${{ matrix.version }})
runs-on: ubuntu-latest
timeout-minutes: 20
name: Beaker images
runs-on: ubuntu-latest-m
timeout-minutes: 60
env:
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
strategy:
fail-fast: false
matrix:
version: [nightly, stable]
steps:
- uses: actions/checkout@v3

Expand All @@ -35,17 +37,28 @@ jobs:
run: |
echo "BEAKER_WORKSPACE=$(make get-beaker-workspace)" >> $GITHUB_ENV
- name: Build
- name: Build stable image
run: |
make ${{ matrix.version }}-image
make stable-image
- name: Build nightly image
run: |
make nightly-image
- uses: allenai/setup-beaker@v2
if: env.BEAKER_TOKEN != ''
with:
token: ${{ env.BEAKER_TOKEN }}
workspace: ${{ env.BEAKER_WORKSPACE }}

- name: Push
- name: Push stable image
if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/')
run: |
rm -rf /opt/hostedtoolcache # clear up some disk space
make beaker-image-stable
- name: Push nightly image
if: env.BEAKER_TOKEN != '' && startsWith(github.ref, 'refs/tags/')
run: |
make beaker-image-${{ matrix.version }}
rm -rf /opt/hostedtoolcache # clear up some disk space
make beaker-image-nightly
25 changes: 21 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,27 @@ jobs:
matrix:
task:
- name: Test (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/checkpoint*'
image: olmo-core
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
--ignore-glob='src/test/distributed/checkpoint*' \
--ignore-glob='src/test/nn/moe*' \
src/test/
- name: Test checkpoint (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint*
image: olmo-core
gpus: 2
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/distributed/checkpoint*
- name: Test MoE (GPU)
image: olmo-core-nightly
gpus: 1
run: |
pytest -v --color=yes --durations=3 -m gpu \
src/test/nn/moe*
steps:
- uses: actions/checkout@v3

Expand Down Expand Up @@ -142,7 +159,7 @@ jobs:
- name: Get full image name
if: env.BEAKER_TOKEN != ''
run:
echo "BEAKER_IMAGE=$(make get-full-beaker-image-name)" >> $GITHUB_ENV
echo "BEAKER_IMAGE=$(make get-full-beaker-image-name IMAGE_NAME=${{ matrix.task.image }})" >> $GITHUB_ENV

- name: GPU Tests
uses: allenai/[email protected]
Expand All @@ -160,7 +177,7 @@ jobs:
priority: low
preemptible: true
resources:
gpuCount: 2
gpuCount: ${{ matrix.task.gpus }}
constraints:
cluster:
- ai2/allennlp-cirrascale
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added Google Cloud support for `list_directory()` and `clear_directory()`.
- Added `CometCallback` for logging training runs to Comet.ml.
- Added `DataMixBase` class, to allow extending to new data mix groups.
- Added support for MoE-based models.
- Added method `DataLoaderBase.get_mock_batch()`.
- Trainer now starts with a dry-run of a fake batch created by `DataLoaderBase.get_mock_batch()`.
- Added `Callback.pre_backward()`, `.pre_eval_batch()`, and `.post_eval_batch()` methods.
- Added `Trainer.model_forward()`, `.get_losses()`, and `.eval_batch()` methods.

### Changed

Expand Down
27 changes: 19 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11

# NOTE: when upgrading the nightly version you also need to upgrade the torch version specification
# in 'pyproject.toml' to include that nightly version.
NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121"
TORCHAO_VERSION = "0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121"
MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://[email protected]/epwalsh/megablocks.git@epwalsh/deps"
CUDA_TOOLKIT_VERSION = 12.1.0

VERSION = $(shell python src/olmo_core/version.py)
VERSION_SHORT = $(shell python src/olmo_core/version.py short)
Expand Down Expand Up @@ -45,25 +48,33 @@ stable-image :
docker build -f src/Dockerfile \
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--target stable \
--progress plain \
-t $(IMAGE_BASENAME) .

.PHONY : beaker-image-stable
beaker-image-stable : stable-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE)
echo "Built image '$(IMAGE_BASENAME)', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME) | numfmt --to=si)"

.PHONY : nightly-image
nightly-image :
docker build -f src/Dockerfile \
--build-arg BUILDKIT_INLINE_CACHE=1 \
--build-arg BASE=$(BASE_IMAGE) \
--build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \
--build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \
--build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \
--build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \
--target nightly \
--progress plain \
-t $(IMAGE_BASENAME)-nightly .
echo "Built image '$(IMAGE_BASENAME)-nightly', size: $$(docker inspect -f '{{ .Size }}' $(IMAGE_BASENAME)-nightly | numfmt --to=si)"

.PHONY : beaker-image-stable
beaker-image-stable : stable-image
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION_SHORT) $(BEAKER_WORKSPACE)
./src/scripts/beaker/create_beaker_image.sh $(IMAGE_BASENAME) $(IMAGE_BASENAME)-v$(VERSION) $(BEAKER_WORKSPACE)

.PHONY : beaker-image-nightly
beaker-image-nightly : nightly-image
Expand All @@ -77,4 +88,4 @@ get-beaker-workspace :

.PHONY : get-full-beaker-image-name
get-full-beaker-image-name :
@./src/scripts/beaker/get_full_image_name.sh $(IMAGE_BASENAME) $(BEAKER_WORKSPACE)
@./src/scripts/beaker/get_full_image_name.sh $(IMAGE_NAME) $(BEAKER_WORKSPACE)
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ First install [PyTorch](https://pytorch.org) according to the instructions speci
pip install ai2-olmo-core
```

## API stability

Even though this library is under rapid development we are trying hard to adhere to [Semantic Versioning](https://semver.org/spec/v2.0.0.html) with every release except for features that are explicitly marked as beta features. Those features will be tagged like this in the [API docs](https://olmo-core.readthedocs.io/en/latest/):

![image](https://github.com/user-attachments/assets/c666686d-3ae6-4c88-8381-befd698d3fd0)

## Official training scripts

Official training scripts for various model sizes can be found in [`src/scripts/train/`](https://github.com/allenai/OLMo-core/tree/main/src/scripts/train).
Expand Down
6 changes: 6 additions & 0 deletions docs/source/nn/attention.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.attention``
================

.. automodule:: olmo_core.nn.attention
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/feed_forward.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.feed_forward``
===================

.. automodule:: olmo_core.nn.feed_forward
:members:
:member-order: bysource
34 changes: 5 additions & 29 deletions docs/source/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,14 @@

.. automodule:: olmo_core.nn

Attention
---------

.. automodule:: olmo_core.nn.attention
:members:
:member-order: bysource

FeedForward
-----------

.. automodule:: olmo_core.nn.feed_forward
:members:
:member-order: bysource

RoPE
----

.. automodule:: olmo_core.nn.rope
:members:
:member-order: bysource

LayerNorms
----------

.. automodule:: olmo_core.nn.layer_norm
:members:
:member-order: bysource

.. toctree::
:maxdepth: 2
:caption: Submodules
:hidden:

attention
feed_forward
functional
layer_norm
moe
rope
transformer
6 changes: 6 additions & 0 deletions docs/source/nn/layer_norm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.layer_norm``
=================

.. automodule:: olmo_core.nn.layer_norm
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/moe.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.moe``
==========

.. automodule:: olmo_core.nn.moe
:members:
:member-order: bysource
6 changes: 6 additions & 0 deletions docs/source/nn/rope.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``nn.rope``
===========

.. automodule:: olmo_core.nn.rope
:members:
:member-order: bysource
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
requires-python = ">=3.9"
license = { file = "LICENSE" }
dependencies = [
"numpy",
"numpy<2.0",
"torch>=2.4,<=2.6.0.dev20241009",
"cached-path",
"requests",
Expand Down
38 changes: 35 additions & 3 deletions src/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
# Base image comes with PyTorch, numpy, flash-attn
ARG BASE

#########################################################################
# Build image
#########################################################################

FROM ${BASE} as build

WORKDIR /app/build

# Install CUDA toolkit.
ARG CUDA_TOOLKIT_VERSION
RUN conda install -y -c nvidia cuda-toolkit==${CUDA_TOOLKIT_VERSION}

# Build megablocks and grouped-gemm.
ENV TORCH_CUDA_ARCH_LIST="8.0 9.0"
ENV GROUPED_GEMM_CUTLASS=1
ARG MEGABLOCKS_VERSION
RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \
&& rm -rf torch-*.whl numpy-*.whl triton-*.whl

#########################################################################
# Stable image
#########################################################################

FROM ${BASE} as stable

# Install torchao
# Install torchao.
ARG TORCHAO_VERSION
RUN pip install --no-cache-dir torchao==${TORCHAO_VERSION}
RUN pip install --no-cache-dir ${TORCHAO_VERSION}

# Install other dependencies, but not the source code.
# Copy and install wheels from build image.
COPY --from=build /app/build /app/build
RUN pip install --no-cache-dir /app/build/*

# Install direct dependencies, but not source code.
COPY pyproject.toml .
COPY src/olmo_core/__init__.py src/olmo_core/__init__.py
COPY src/olmo_core/version.py src/olmo_core/version.py
Expand All @@ -16,6 +44,10 @@ RUN pip install --no-cache-dir '.[all]' && \

WORKDIR /app/olmo-core

#########################################################################
# Nightly image
#########################################################################

FROM stable as nightly

ARG NIGHTLY_VERSION
Expand Down
6 changes: 3 additions & 3 deletions src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
NumpyDatasetType,
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
Expand Down Expand Up @@ -58,7 +58,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
model_config = TransformerConfig.llama2_271M(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
compile=True,
dp_config=DataParallelConfig(
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)
Expand Down
9 changes: 9 additions & 0 deletions src/olmo_core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,14 @@ class DType(StrEnum):
float32 = "float32"
bfloat16 = "bfloat16"

@classmethod
def from_pt(cls, dtype: torch.dtype) -> "DType":
if dtype == torch.float32:
return DType.float32
elif dtype == torch.bfloat16:
return DType.bfloat16
else:
raise NotImplementedError(dtype)

def as_pt(self) -> torch.dtype:
return getattr(torch, self)
22 changes: 22 additions & 0 deletions src/olmo_core/doc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import TypeVar

T = TypeVar("T")


def beta_feature(f: T) -> T:
"""
Mark a class or function as a beta feature.
"""
if f.__doc__ is None:
f.__doc__ = ""

f.__doc__ += """
.. warning::
This is a beta feature! The API is subject to change even with minor and patch releases.
If you choose to use this feature please read the `CHANGELOG <https://github.com/allenai/OLMo-core/blob/main/CHANGELOG.md>`_
before upgrading your version of this library.
"""

return f
Loading

0 comments on commit 4d3b231

Please sign in to comment.