Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds mixing loader for FSL datasets #70

Merged
merged 61 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a8b5ec6
WIP: Generate a mixture dataset
undfined Oct 18, 2024
637fee9
WIP: Adds dry run
undfined Oct 18, 2024
346135c
Test cleanup
undfined Oct 18, 2024
53def38
WIP: Make it fast
undfined Oct 19, 2024
8649cc8
WIP: Simple benchmark
undfined Oct 19, 2024
e3d7011
WIP: Refactor
undfined Oct 23, 2024
efe766b
Launch script
undfined Oct 23, 2024
5dff40c
temp changes to test
undfined Oct 23, 2024
2703538
deps for now
undfined Oct 23, 2024
3ee3278
Try with session
undfined Oct 23, 2024
cd1c6d2
Try internal client
undfined Oct 23, 2024
9895b23
Try boto3
undfined Oct 23, 2024
3c15f52
Fixes
undfined Oct 24, 2024
0c9355b
?
undfined Oct 24, 2024
abb362a
Cleanup + session stuff
undfined Oct 24, 2024
82a1af9
Use environ
undfined Oct 24, 2024
d0a80ba
JUST use env vars please boto
undfined Oct 24, 2024
e621f8e
No unions of containers
undfined Oct 24, 2024
0689c42
prepare first
undfined Oct 24, 2024
8ab2e99
Loader handles prepare
undfined Oct 24, 2024
dcfda67
Try recording torch exceptions
undfined Oct 24, 2024
23a0806
Don't need overrides
undfined Oct 24, 2024
8cfa282
Figure out why config/creds are missing
undfined Oct 24, 2024
fd1a508
fmt
undfined Oct 24, 2024
01a40ea
Env not ready yet
undfined Oct 24, 2024
8bde2b3
print beaker user
undfined Oct 24, 2024
ae208f6
uncomment eval file
undfined Oct 24, 2024
ce9d06f
replicate CommonComponents setup
undfined Oct 24, 2024
d1eb4df
Some class init stuff
undfined Oct 24, 2024
dbce279
Some more config logging
undfined Oct 24, 2024
5daa274
Conflict in CHANGELOG
undfined Oct 24, 2024
980e05a
checks cleanup
undfined Oct 24, 2024
f27bd73
Fixes for duplicate paths in mixture
undfined Oct 25, 2024
c69b228
In case there a ton of files
undfined Oct 25, 2024
18efafd
Maybe fix trainer launch
undfined Oct 25, 2024
5ceed46
Match other example
undfined Oct 25, 2024
4c7513e
More tests
undfined Oct 25, 2024
8401580
Try diff gpus
undfined Oct 25, 2024
0d77422
keep fsdp
undfined Oct 25, 2024
d22ed10
checks
undfined Oct 25, 2024
68a4d28
Less tokens
undfined Oct 25, 2024
c35514c
Exclude ai2/allennlp-elanding-a100-40g temp
undfined Oct 25, 2024
02cb49b
Merge branch 'main' of github.com:allenai/OLMo-core into undfined/mix…
undfined Oct 28, 2024
c453e65
Feedback
undfined Oct 28, 2024
a288d9e
Drop examples
undfined Oct 28, 2024
9c49f25
A bit more cleanup
undfined Oct 28, 2024
89504bc
Outdated changelog
undfined Oct 28, 2024
3aa5c35
Unused deps
undfined Oct 28, 2024
8f729dd
One more dep
undfined Oct 28, 2024
a848195
uncomment test assertions
undfined Oct 28, 2024
5322bf1
Drop todo
undfined Oct 28, 2024
5c22665
0 is an invalid token
undfined Oct 28, 2024
87e9168
More feedback
undfined Oct 29, 2024
fe50a32
Randomly sample instances when segmenting
undfined Oct 29, 2024
cffcba3
Memray + limit marker
undfined Oct 30, 2024
293be02
Add dep
undfined Oct 30, 2024
e45d2c3
Lint
undfined Oct 30, 2024
ffe7660
Bigger array is more informative
undfined Oct 30, 2024
19db2a9
Merge branch 'main' into undfined/mixing-loader
undfined Oct 30, 2024
1fdb995
Feedback
undfined Oct 30, 2024
74d4c5a
Merge branch 'main' into undfined/mixing-loader
undfined Nov 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ jobs:
constraints:
cluster:
- ai2/allennlp-cirrascale
- ai2/allennlp-elanding-a100-40g
# - ai2/allennlp-elanding-a100-40g
- ai2/pluto-cirrascale
- ai2/jupiter-cirrascale-2
envVars:
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `SourceMixtureDataset` for composing a training mixture based on ratios of source datasets.
- Added `NumpyFSLDatasetMixture` for constructing a `NumpyDatasetBase` from a `SourceMixtureDataset`. Note this is only supported for FSL datasets.
- Added tests for `SourceMixture*` and `NumpyFSLDatasetMixture`.
- Added example launch script for training a model using a `NumpyFSLDatasetMixture`.

### Changed
- Moved some types into `olmo_core.data.types` to avoid some circular dependencies.

## [v1.5.0](https://github.com/allenai/OLMo-core/releases/tag/v1.5.0) - 2024-10-23

### Added
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ dependencies = [
"omegaconf",
"safetensors",
"importlib_resources",
"s3fs", # REMOVE THIS IN FAVOR OF SOMETHING CONSISTENT ELSEWHERE
undfined marked this conversation as resolved.
Show resolved Hide resolved
"tabulate",
"tqdm",
epwalsh marked this conversation as resolved.
Show resolved Hide resolved
]

[project.urls]
Expand Down
242 changes: 242 additions & 0 deletions src/examples/train_with_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
"""
undfined marked this conversation as resolved.
Show resolved Hide resolved
Example of how to train a transformer language model with a source mixture config.

Launch this with torchrun:

torchrun --nproc-per-node=4 src/examples/train_with_mixture.py run_name
"""

import sys
from dataclasses import dataclass
from typing import cast

import s3fs
from torch.distributed.elastic.multiprocessing.errors import record

from olmo_core.config import Config, DType
from olmo_core.data import (
NumpyDataLoaderConfig,
NumpyDatasetConfig,
NumpyDatasetType,
NumpyFSLDatasetMixtureConfig,
TokenizerConfig,
)
from olmo_core.data.source_mixture import (
SourceMixtureConfig,
SourceMixtureDatasetConfig,
)
from olmo_core.data.types import NumpyDatasetDType
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig, TransformerDataParallelConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup, OptimGroupOverride
from olmo_core.train import (
Duration,
TrainerConfig,
prepare_training_environment,
teardown_training_environment,
)
from olmo_core.train.callbacks import (
CheckpointerCallback,
CometCallback,
ConfigSaverCallback,
GPUMemoryMonitorCallback,
GradClipperCallback,
LMEvaluatorCallbackConfig,
ProfilerCallback,
SchedulerCallback,
SequenceLengthSchedulerCallback,
WandBCallback,
)
from olmo_core.utils import get_default_device, seed_all


@dataclass
class ExperimentConfig(Config):
model: TransformerConfig
optim: AdamWConfig
dataset: NumpyFSLDatasetMixtureConfig
data_loader: NumpyDataLoaderConfig
trainer: TrainerConfig
init_seed: int = 12536


def build_config(run_name: str) -> ExperimentConfig:
tokenizer_config = TokenizerConfig.dolma2()

model_config = TransformerConfig.llama2_271M(
# a little bigger than actual vocab size to make it a multiple of 128
vocab_size=tokenizer_config.padded_vocab_size(),
compile=True,
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp,
param_dtype=DType.bfloat16,
reduce_dtype=DType.float32,
),
)

optim_config = AdamWConfig(
lr=1e-3,
group_overrides=[
OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0))
],
)

s3 = s3fs.S3FileSystem()

# DCLM docs + rewrites
baseline = s3.glob(
"s3://ai2-llm/preprocessed/dclm/samples/src-100b/**/allenai/dolma2-tokenizer/*.npy"
)
rewrites = s3.glob(
"s3://ai2-llm/preprocessed/dclm/samples/rewrite-100b/**/allenai/dolma2-tokenizer/*.npy"
)

sequence_length = 1024
source_config = SourceMixtureDatasetConfig(
max_tokens=int(10e7), # 100M tokens
sequence_length=sequence_length,
source_configs=[
SourceMixtureConfig(
paths=[f"s3://{path}" for path in baseline],
source_name="baseline",
max_repetition_ratio=1.0, # 1.0 is default but here to illustrate options
target_ratio=0.8,
),
SourceMixtureConfig(
source_name="rewrites",
paths=[f"s3://{path}" for path in rewrites],
target_ratio=0.2,
),
],
processes=10,
dtype=NumpyDatasetDType.uint32,
seed=42,
)

dataset_config = NumpyFSLDatasetMixtureConfig(
source_mixture_config=source_config,
sequence_length=sequence_length,
max_target_sequence_length=8192,
tokenizer=TokenizerConfig.dolma2(),
work_dir="/tmp/dataset-cache",
bust_index_cache=True,
)

data_loader_config = NumpyDataLoaderConfig(
global_batch_size=256 * sequence_length,
seed=0,
num_workers=4,
)

trainer_config = (
TrainerConfig(
save_folder=f"/tmp/{run_name}",
rank_microbatch_size=16 * sequence_length,
save_overwrite=True,
metrics_collect_interval=5,
cancel_check_interval=5,
)
.with_callback("lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=100)))
.with_callback(
"seq_len_scheduler",
SequenceLengthSchedulerCallback(
min_sequence_length=128, warmup_steps=100, enabled=False
),
)
.with_callback("gpu_monitor", GPUMemoryMonitorCallback())
.with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0))
.with_callback(
"checkpointer",
CheckpointerCallback(
save_interval=1000,
ephemeral_save_interval=100,
save_async=True,
),
)
.with_callback(
"comet",
CometCallback(
name=run_name,
cancel_check_interval=10,
enabled=False, # change to true to enable
),
)
.with_callback(
"wandb",
WandBCallback(
name=run_name,
cancel_check_interval=10,
enabled=False, # change to true to enable
),
)
.with_callback("config_saver", ConfigSaverCallback())
.with_callback("profiler", ProfilerCallback(enabled=False))
.with_callback(
"evaluator",
LMEvaluatorCallbackConfig(
eval_dataset=NumpyDatasetConfig(
paths=[
"s3://ai2-llm/eval-data/perplexity/v3_small_dolma2-tokenizer/c4_en/val/part-0-00000.npy"
],
metadata=[{"label": "c4-validation"}],
name=NumpyDatasetType.padded_fsl,
sequence_length=sequence_length,
tokenizer=tokenizer_config,
work_dir="/tmp/dataset-cache",
),
eval_interval=250,
eval_duration=Duration.steps(10),
),
)
)

return ExperimentConfig(
model=model_config,
optim=optim_config,
dataset=dataset_config,
data_loader=data_loader_config,
trainer=trainer_config,
)


@record
def main(run_name: str):
config = build_config(run_name)

# Set RNG states on all devices.
seed_all(config.init_seed)

# Build components.
model = config.model.build(
init_device="meta",
device=get_default_device(),
dp_mesh=init_hybrid_shard_mesh(num_replicas=2),
)
optim = config.optim.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(dataset)
trainer = config.trainer.build(model, optim, data_loader)

# Save config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
cast(CometCallback, trainer.callbacks["comet"]).config = config_dict
cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict
cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict

# Train.
trainer.fit()


if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} run_name")
sys.exit(1)

run_name = sys.argv[1]

prepare_training_environment()
try:
main(run_name)
finally:
teardown_training_environment()
66 changes: 66 additions & 0 deletions src/examples/train_with_mixture_launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
An example of how to launch the training script on Beaker.
Run this with:

python src/examples/train_with_mixture_launch.py run_name [OVERRIDES...]
"""

import sys

from beaker import Beaker

from olmo_core.launch.beaker import BeakerEnvSecret, BeakerLaunchConfig
from olmo_core.utils import generate_uuid, prepare_cli_environment


def build_config(run_name: str) -> BeakerLaunchConfig:
beaker_user = (Beaker.from_env().account.whoami().name).upper()
return BeakerLaunchConfig(
name=f"olmo-core-test-{generate_uuid()[:8]}",
budget="ai2/oe-training",
cmd=["src/examples/train_with_mixture.py", run_name],
task_name="train",
workspace="ai2/OLMo-core",
description="Testing OLMo-core launch utilities",
clusters=["ai2/allennlp-cirrascale"],
env_secrets=[
BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"),
BeakerEnvSecret(name="WANDB_API_KEY", secret=f"{beaker_user}_WANDB_API_KEY"),
BeakerEnvSecret(name="COMET_API_KEY", secret=f"{beaker_user}_COMET_API_KEY"),
BeakerEnvSecret(name="AWS_CONFIG", secret=f"{beaker_user}_AWS_CONFIG"),
BeakerEnvSecret(name="AWS_CREDENTIALS", secret=f"{beaker_user}_AWS_CREDENTIALS"),
BeakerEnvSecret(name="R2_ENDPOINT_URL", secret="R2_ENDPOINT_URL"),
BeakerEnvSecret(name="WEKA_ENDPOINT_URL", secret="WEKA_ENDPOINT_URL"),
],
setup_steps=[
# Clone repo.
'git clone "$REPO_URL" .',
'git checkout "$GIT_REF"',
"git submodule update --init --recursive",
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip freeze",
# Move AWS credentials from env to relevant files
"mkdir -p ~/.aws",
"printenv AWS_CONFIG > ~/.aws/config",
"printenv AWS_CREDENTIALS > ~/.aws/credentials",
],
num_nodes=1,
num_gpus=4,
shared_filesystem=True,
nfs=False,
allow_dirty=True,
)


if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {sys.argv[0]} run_name [OVERRIDES...]")
sys.exit(1)

run_name = sys.argv[1]

prepare_cli_environment()

build_config(run_name).launch(follow=True)
5 changes: 3 additions & 2 deletions src/olmo_core/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from .numpy_dataset import (
NumpyDatasetBase,
NumpyDatasetConfig,
NumpyDatasetDType,
NumpyDatasetType,
NumpyFSLDataset,
NumpyFSLDatasetMixtureConfig,
NumpyPaddedFSLDataset,
NumpyVSLDataset,
VSLCurriculum,
Expand All @@ -38,10 +37,12 @@
VSLNaturalCurriculum,
)
from .tokenizer import TokenizerConfig, TokenizerName
from .types import NumpyDatasetDType, NumpyDatasetType

__all__ = [
"NumpyDatasetBase",
"NumpyFSLDataset",
"NumpyFSLDatasetMixtureConfig",
"NumpyPaddedFSLDataset",
"NumpyVSLDataset",
"VSLCurriculum",
Expand Down
Loading
Loading