Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 14, 2024
1 parent a1ba50e commit a189795
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/scripts/train/OLMo-13B.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging

from olmo_core.config import DType
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.internal.experiment import CommonComponents, main
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback
Expand All @@ -19,7 +19,7 @@ def build_model_config(common: CommonComponents) -> TransformerConfig:
return TransformerConfig.olmo_13B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=True,
dp_config=DataParallelConfig(
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)
Expand Down
6 changes: 3 additions & 3 deletions src/scripts/train/OLMo-1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"""

from olmo_core.config import DType
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.internal.experiment import CommonComponents, main
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback
Expand All @@ -15,7 +15,7 @@ def build_model_config(common: CommonComponents) -> TransformerConfig:
return TransformerConfig.olmo_1B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=True,
dp_config=DataParallelConfig(
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)
Expand Down
6 changes: 3 additions & 3 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import logging

from olmo_core.config import DType
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.internal.experiment import CommonComponents, main
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback
Expand All @@ -19,7 +19,7 @@ def build_model_config(common: CommonComponents) -> TransformerConfig:
return TransformerConfig.olmo_7B(
vocab_size=common.tokenizer.padded_vocab_size(),
compile=True,
dp_config=DataParallelConfig(
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)
Expand Down
16 changes: 12 additions & 4 deletions src/scripts/train/OLMoE-1B-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
"""

from olmo_core.config import DType
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.internal.experiment import CommonComponents, main
from olmo_core.nn.moe import MoEActivationFn, MoEConfig, MoEMLPImplementation, MoEType
from olmo_core.nn.transformer import TransformerBlockType, TransformerConfig
from olmo_core.nn.transformer import (
TransformerBlockType,
TransformerConfig,
TransformerDataParallelConfig,
TransformerDataParallelWrappingStrategy,
)
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import (
Expand All @@ -26,8 +31,11 @@ def build_model_config(common: CommonComponents) -> TransformerConfig:
compile=True,
fused_ops=False,
block_name=TransformerBlockType.moe_reordered_norm,
dp_config=DataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp,
param_dtype=DType.bfloat16,
reduce_dtype=DType.float32,
wrapping_strategy=TransformerDataParallelWrappingStrategy.full,
),
)
model_config.block.feed_forward = None
Expand Down

0 comments on commit a189795

Please sign in to comment.