Skip to content

Commit

Permalink
fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 14, 2024
1 parent a189795 commit b949489
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
NumpyDatasetType,
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
Expand Down Expand Up @@ -58,7 +58,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
model_config = TransformerConfig.llama2_271M(
vocab_size=tokenizer_config.padded_vocab_size(), # a little bigger than actual vocab size to make it a multiple of 128
compile=True,
dp_config=DataParallelConfig(
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
)
Expand Down

0 comments on commit b949489

Please sign in to comment.