Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StarDoc model training #5

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions examples/stardoc_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
training:
train_iters: 1000
num_workers: 2
logs:
interval: 10
checkpoint:
interval: 1000
keep: 10
export:
interval: 1000
validation:
iterations: null
test_iters: 0
pretrained:
path: ".../stardoc_checkpoint"
format: huggingface
batch:
sequence_length: 8192
micro_batch_size: 1
batch_size: 8
data:
split: [0.9, 0.1, 0]
path: ".../stardoc_data_config.json"
tokenizer:
format: TokenzierFromFile
path: ".../Mistral-7B-v0.3/tokenizer.json"
special_tokens:
eos_token: "</s>"
bos_token: "<s>"
pad_token: "[control_8]"
image_placeholder_token: "[control_9]"
optimizer:
learning_rate:
base: 1.0e-05
decay_style: constant
warmup_iterations: 0
weight_decay: 0.1
beta_1: 0.9
beta_2: 0.95
model:
base_model:
transformer:
normalization:
type: rms_norm
epsilon: 1.0e-05
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 14336
num_attention_heads: 32
head_groups: 8
add_linear_biases: false
use_rotary_embeddings: true
gated: true
activation_type: silu
triton_rotary: true
kv_channels: 128
rotary_embedding_scale: -9.210340371976184
window_size: 4096
init_method_std: 0.009021
attention_dropout: 0.0
hidden_dropout: 0.0
multimodal_model:
image_encoder_hidden_size: 1024
num_image_tokens: 256
max_num_images: 10
image_resolution: 448
image_encoder_type: clip
vocab_size: 32000
tie_word_embeddings: false
multi_stage:
zero_stage: 3
distributed:
training_dtype: bf16
distributed_timeout: 3600
seed: 984059

run:
experiment_dir: stardoc
124 changes: 124 additions & 0 deletions examples/train_stardoc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Required or optional environment variables
# export PROJECT_DIR=
# export PROJECT_NAME=
# export PROJECT_VERSION=
# export DATA_PATH=
# export PRETRAINED_STARDOC_PATH=
# export TOKENIZER_PATH=

# export HF_HOME=
# export HF_TOKEN=

export CMD_ARGS="fast-llm train stardoc"

export MODEL_ARGS_PRETRAINED="\
--pretrained_checkpoint_type=huggingface \
--pretrained_checkpoint_path=$PRETRAINED_STARDOC_PATH \
--use_pretrained_config=1 \
"

export MODEL_ARGS_ARCHITECTURE="\
--num_layers=32 \
--hidden_size=4096 \
--vocab_size=32000 \
--num_attention_heads=32 \
--head_groups=8 \
--add_linear_biases=0 \
--ffn_hidden_size=14336 \
--kv_channels=128 \
--use_rotary_embeddings=1 \
--rotary_embedding_scale=-9.210340371976184 \
--gated=1 \
--activation_type=silu \
--normalization_type=rms_norm \
--tie_word_embeddings=0 \
--window_size=8192 \
"

export MULTIMODAL_ARGS="\
--image_encoder_hidden_size=1024 \
--num_image_tokens=256 \
--max_num_images=10 \
--image_encoder_type=clip \
"

export DATA_ARGS="\
--split=9998,2,0 \
--dataset_type=stardoc \
--dataset_source=multimodal \
--data_path=$DATA_PATH \
--tokenizer_type=PreTrainedTokenizer \
--tokenizer_path=$TOKENIZER_PATH \
"

export TRAINING_ARGS="\
--batch_size=8 \
--sequence_length=8192 \
--train_iters=500000 \
--weight_decay=0.1 \
--adam_beta1=0.9 \
--adam_beta2=0.95 \
--clip_grad=1.0 \
--lr=0.0001 \
--lr_warmup_iters=1000 \
--lr_decay_style=cosine \
--lr_decay_iters=500000 \
--min_lr=0.000003 \
"

export PERFORMANCE_ARGS="\
--micro_batch_size=1 \
--training_dtype=bf16 \
--zero_stage=3 \
--num_workers=8 \
"

export MONITORING_ARGS="\
--validation_iters=25 \
--validation_interval=1000 \
--log_interval=10 \
--log_offset=0 \
--checkpoint_interval=500 \
--max_checkpoints=5 \
--export_interval=25000 \
--wandb_status_interval=25000 \
--wandb_entity_name=$WANDB_ENTITY_NAME \
--wandb_project_name=$PROJECT_NAME \
--wandb_group_name=$PROJECT_VERSION \
"

export ALL_ARGS="\
$CMD_ARGS \
$MODEL_ARGS_PRETRAINED \
$MODEL_ARGS_ARCHITECTURE \
$MULTIMODAL_ARGS \
$DATA_ARGS \
$TRAINING_ARGS \
$PERFORMANCE_ARGS \
$MONITORING_ARGS \
"

export PROFILE_ARGS="\
--profile_cuda=1 \
--profile_skip=10 \
--profile_wait=95 \
--profile_warmup=2 \
--profile_cycles=3 \
--profile_export=1 \
"

run_local () { # run(name, num_gpus, base_cmd)
echo $1 $2 $3
export TORCHRUN="torchrun --nproc-per-node=$2 --nnodes=1 --no-python"
$TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1
}

run_c10d () { # run(name, num_nodes, base_cmd)
echo $1 $2 $3
export TORCHRUN="torchrun --nproc-per-node=8 --nnodes=$2 --no-python --rdzv-backend=c10d --rdzv-endpoint=$HOST_NODE_ADDR"
$TORCHRUN $3 --experiment_dir=$PROJECT_DIR/$PROJECT_NAME_$PROJECT_VERSION/$1
}

run_local stardoc_example 8 "$ALL_ARGS"
# run_c10d stardoc_example 16 "$ALL_ARGS"
# run_c10d stardoc_example 16 "$ALL_ARGS $MIXTRAL_ARGS --train_iters=50"
47 changes: 45 additions & 2 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class DatasetSource(str, enum.Enum):
file = "file"
sample = "sample"
random = "random"
multimodal = "multimodal"


class MultiprocessingContext(str, enum.Enum):
Expand Down Expand Up @@ -99,10 +100,48 @@ def _validate(self):
Assert.in_range_incl(self.rate, 0, 1)


EOD = "<|endoftext|>"
TokenizerFromFile = "TokenizerFromFile"


@config_class()
class SpecialTokensConfig(Config):
"""
Define special tokens like EOS, BOS, PAD and image_placeholder tokens
"""

bos_token: str | None = Field(
default=None,
desc="Beginning of sequence token",
hint=FieldHint.core,
)
eos_token: str | None = Field(
default="<|endoftext|>",
desc="End of sequence token",
hint=FieldHint.core,
)
pad_token: str | None = Field(
default=None,
desc="Pad token",
hint=FieldHint.core,
)
image_placeholder_token: str | None = Field(
default=None,
desc="Placeholder token for images. Used only in multi-modal models",
hint=FieldHint.core,
)

def get_special_tokens(self):
special_tokens = [
self.bos_token,
self.eos_token,
self.pad_token,
self.image_placeholder_token,
]

# Only return special tokens that are set
return [token for token in special_tokens if token is not None]


@config_class()
class TokenizerConfig(Config):
"""
Expand All @@ -114,13 +153,17 @@ class TokenizerConfig(Config):
default="TokenizerFromFile",
desc="Unused.",
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
path: str | None = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
)
special_tokens: SpecialTokensConfig = Field(
default_factory=SpecialTokensConfig,
desc="Define special tokens.",
hint=FieldHint.core,
)


@config_class
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
for name, prefix in zip(dataset_names, dataset_prefixes)
}
self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)}

def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]):
"""
Load the datasets, and prepare or load the samplings.
Expand Down
9 changes: 9 additions & 0 deletions fast_llm/data/stardoc_data_utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15

LOGDIR = "./demo_logs"

# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<|image|>"
Loading
Loading