diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index da5c6b40..f4e7628c 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -24,7 +24,7 @@ def update_config(config, **kwargs): def get_model_config(model_variant): if model_variant == "llama2_70b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( emb_dim=8192, multiple_of=4096, nheads=64, @@ -33,7 +33,7 @@ def get_model_config(model_variant): hidden_grow_factor=28672 / 8192, ) elif model_variant == "llama2_34b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( emb_dim=8192, nheads=64, kvheads=8, @@ -43,19 +43,19 @@ def get_model_config(model_variant): rope_theta=1000000.0, ) elif model_variant == "llama2_13b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( emb_dim=5120, nheads=40, nlayers=40, hidden_grow_factor=13824 / 5120, ) elif model_variant == "llama2_7b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( hidden_grow_factor=11008 / 4096, kvheads=32, ) elif model_variant == "llama2_1.4b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( emb_dim=2048, nheads=16, nlayers=24, @@ -63,7 +63,7 @@ def get_model_config(model_variant): kvheads=4, ) elif model_variant == "llama3_8b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=4096, nheads=32, @@ -74,7 +74,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_8b_4k": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=4096, nheads=32, @@ -85,7 +85,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_1.8b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, nheads=16, @@ -96,7 +96,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_1.8b_4k": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=2048, nheads=16, @@ -107,7 +107,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_3.2b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=3072, nheads=24, @@ -118,7 +118,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_3.2b_4k": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=3072, nheads=24, @@ -129,7 +129,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_70b": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=8192, nheads=64, @@ -140,7 +140,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_70b_4k": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=8192, nheads=64, @@ -151,7 +151,7 @@ def get_model_config(model_variant): rope_theta=500000.0, ) elif model_variant == "llama3_194m_4k": - llama_config = LLaMAConfig( + model_config = LLaMAConfig( src_vocab_size=128256, emb_dim=1024, nheads=8, @@ -159,7 +159,31 @@ def get_model_config(model_variant): max_expected_seq_len=4096, rope_theta=500000.0, ) + elif model_variant == "mamba_9.8b": + model_config = { + "d_model": 4096, + "d_intermediate": 14336, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } else: raise ValueError(f"model variant {model_variant} not supported.") - return llama_config + return model_config diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 2faeffb7..4b811d6d 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -2,6 +2,7 @@ from fms_fsdp.utils.dataset_utils import ( ArrowHandler, + AutoHandler, BufferDataset, CheckpointDataset, ParquetHandler, @@ -16,6 +17,7 @@ _handler_map = { "arrow": ArrowHandler, "hf_parquet": ParquetHandler, + "auto": AutoHandler, } @@ -84,10 +86,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): assert ( cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" - if cfg.file_type == "hf_parquet": - filehandler = ParquetHandler(cfg.tokenizer_path, cfg.col_name) + if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": + filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) else: - filehandler = _handler_map[cfg.file_type](cfg.col_name) + filehandler = _handler_map[cfg.file_type] # Base reader layer data = StreamingDocDataset( cfg.data_path, diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index d1d442d7..aedc5862 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -357,11 +357,11 @@ def length(self, path: str): def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): doc = reader.get_batch(index)[self.col_name] - if len(doc) > 0: - if doc[0].as_py() in drop_tokens: - doc = doc.slice(1, len(doc) - 1) - if doc[-1].as_py() in drop_tokens: - doc = doc.slice(0, len(doc) - 1) + if len(doc) > 0 and doc[0].as_py() in drop_tokens: + doc = doc.slice(1, len(doc) - 1) + # Recheck len for edge case where doc=[eos] + if len(doc) > 0 and doc[-1].as_py() in drop_tokens: + doc = doc.slice(0, len(doc) - 1) return doc def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List: @@ -384,24 +384,79 @@ def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] def open(self, path: str): - return pq.read_pandas(path, columns=[self.col_name])[self.col_name] + return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[ + self.col_name + ] def length(self, path: str): - return pq.read_pandas(path, columns=[]).num_rows + return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): doc = self.tokenizer(str(reader[index]))["input_ids"] - if len(doc) > 0: - if doc[0] in drop_tokens: - doc = doc[1:] - if doc[-1] in drop_tokens: - doc = doc[:-1] + if len(doc) > 0 and doc[0] in drop_tokens: + doc = doc[1:] + # Recheck len for edge case where doc=[eos] + if len(doc) > 0 and doc[-1] in drop_tokens: + doc = doc[:-1] return doc def slice(self, doc: List, index: int, n_pull: int) -> List: return doc[index : index + n_pull] +class AutoHandler(_ShardFileHandler): + def __init__(self, tokenizer_path: str, col_name: str = "text"): + self.PHandler = ParquetHandler(tokenizer_path, col_name) + self.AHandler = ArrowHandler() + self.current = _ShardFileHandler() + + def is_legal(self, filepath: str): + return ( + "parquet" in os.path.splitext(filepath)[1] + or "arrow" in os.path.splitext(filepath)[1] + ) + + def open(self, path: str): + """ + Open the file, to be indexed via self.get() method. + Avoid reading entire multi-Gb files when possible! + """ + if "arrow" in os.path.splitext(path)[1]: + self.current = self.AHandler + else: + self.current = self.PHandler + return self.current.open(path) + + def length(self, path: str): + """ + Calculate the number of documents in the given file. + Avoid reading entire multi-Gb files when possible! + """ + if "arrow" in os.path.splitext(path)[1]: + return self.AHandler.length(path) + else: + return self.PHandler.length(path) + + def get(self, reader, index: int, drop_tokens: Set): + """ + Given the output of self.open() and an index, return the document at that index. + Then, remove the first and/or last items if they appear in drop_tokens. + Try to avoid reading entire documents at a time in case of long documents, + but this is less important than avoiding reading entire files as above. + Output must support len(). + """ + return self.current.get(reader, index, drop_tokens) + + def slice(self, doc, index: int, n_pull: int) -> List: + """ + Given a long document, retrieve n_pull consecutive items starting from index. + Again, try to be memory-efficient when doing so, but efficiency in self.get() + and self.open() is far more important. + Must return a python list. + """ + return self.current.slice(doc, index, n_pull) + + #### ------------------------- PIPELINE LAYERS ------------------------- #### diff --git a/fms_to_hf.py b/fms_to_hf_llama.py similarity index 99% rename from fms_to_hf.py rename to fms_to_hf_llama.py index d03582a3..76042eee 100644 --- a/fms_to_hf.py +++ b/fms_to_hf_llama.py @@ -1,6 +1,6 @@ import fire import torch -from fms.models.hf import to_hf_api +from fms.models.hf.utils import to_hf_api from fms.models.llama import LLaMA from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict from transformers import LlamaConfig, LlamaForCausalLM diff --git a/fms_to_hf_mamba.py b/fms_to_hf_mamba.py new file mode 100644 index 00000000..a3fdfc87 --- /dev/null +++ b/fms_to_hf_mamba.py @@ -0,0 +1,37 @@ +import fire +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict + +from fms_fsdp.utils.config_utils import get_model_config + + +def main(model_variant, load_path, save_path, tokenizer_name_or_path): + print("Initializing model...") + config_data = get_model_config(model_variant) + mamba_config = MambaConfig(**config_data) + model = MambaLMHeadModel(mamba_config) + + print(f"Reading state dict from {load_path}") + state_dict = {"model_state": model.state_dict()} + load_state_dict( + state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True + ) + + print("Loading state dict into the model...") + model.load_state_dict(state_dict["model_state"]) + + print("Saving model to HF-compatible format...") + model.save_pretrained(save_path) + + print("Copying tokenizer...") + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + tokenizer.save_pretrained(save_path) + + print(f"Model saving at {save_path}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/main_training.py b/main_training_llama.py similarity index 100% rename from main_training.py rename to main_training_llama.py diff --git a/main_training_mamba.py b/main_training_mamba.py new file mode 100644 index 00000000..3619ea25 --- /dev/null +++ b/main_training_mamba.py @@ -0,0 +1,177 @@ +import math +import os +from pathlib import Path + +import fire +import torch +import torch.optim as optim +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from mamba_ssm.modules.block import Block +from torch import distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim.lr_scheduler import LambdaLR + +from fms_fsdp import config +from fms_fsdp.utils.checkpointing_utils import Checkpointer +from fms_fsdp.utils.config_utils import get_model_config, update_config +from fms_fsdp.utils.dataloader_utils import get_data_loader, get_dummy_loader +from fms_fsdp.utils.train_utils import ( + get_policies, + get_profiler, + setup, + setup_environ_flags, + train, +) + + +def main(**kwargs): + # get configs + cfg = config.train_config() + update_config(cfg, **kwargs) + + # ensure reproducibility + torch.cuda.manual_seed(cfg.seed) + torch.manual_seed(cfg.seed) + + # torchrun specific + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if rank == 0: + print(f"--> running with these configs {cfg}") + + # some setups + setup() + torch.cuda.set_device(local_rank) + torch.cuda.empty_cache() + setup_environ_flags() + os.environ["TRITON_CACHE_DIR"] = os.path.join( + Path.home(), ".triton", "cache", str(local_rank) + ) + + # get policy + block = Block + ( + mixed_precision_policy, + wrapping_policy, + sharding_strategy_policy, + apply_selective_ac, + param_init_fn, + ) = get_policies(cfg, rank, block) + + # get model + config_data = get_model_config(cfg.model_variant) + mamba_config = MambaConfig(**config_data) + model = MambaLMHeadModel(mamba_config) + + if rank == 0: + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n--> model has {total_params / 1e6} Million params\n") + + # get data loader + if rank == 0: + print("Constructing datasets...") + if not cfg.use_dummy_dataset: + train_loader = get_data_loader(cfg, rank, world_size) + else: + train_loader = get_dummy_loader(cfg, rank, world_size) + if rank == 0: + print("Datasets constructed!") + + # FSDP + model = FSDP( + model, + auto_wrap_policy=wrapping_policy, + mixed_precision=mixed_precision_policy, + sharding_strategy=sharding_strategy_policy, + use_orig_params=cfg.use_torch_compile, + device_id=torch.cuda.current_device(), + limit_all_gathers=True, + param_init_fn=param_init_fn, + ) + + # fsdp activation checkpointing + if cfg.fsdp_activation_checkpointing: + if rank == 0: + print(f"--> applying FSDP activation checkpointing...") + apply_selective_ac(model, p=cfg.selective_checkpointing) + + # torch compile + if cfg.use_torch_compile: + if rank == 0: + print(f"--> enabling torch compile...") + # the default accumulated_cache_size_limit=64 is not enough for 70b model, so we make it 128 here + torch._dynamo.config.accumulated_cache_size_limit = 128 + model = torch.compile(model) + + # Optimizer + optimizer = optim.AdamW( + model.parameters(), lr=cfg.learning_rate, betas=(0.9, 0.95), weight_decay=0.1 + ) + + # optionally load from checkpoint (when continue pretraining) + checkpointer = Checkpointer( + cfg.ckpt_save_path, 1000, cfg.sharding_strategy, rank, local_rank + ) + model, optimizer, _, start_step, tokens_seen, is_resuming = checkpointer.load( + model, + optimizer, + None, + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path, + strict=False, + ) + if not is_resuming: + start_step = 0 + # Override loaded optim hyperparams with the current values + for g in optimizer.param_groups: + g["initial_lr"] = cfg.learning_rate + + # LR schedule + # linear decay for annealing + if cfg.training_stage == "annealing": + schedule = lambda x: 1 - x / cfg.num_steps + else: + # cosine decay + warmup_interval = min(2000, cfg.num_steps // 20) + schedule = lambda x: min( + 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + 0.1 + + 0.5 + * (1 - 0.1) + * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + ) + + scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) + + # profiler + profiler = get_profiler(cfg, rank) + + # Train + if rank == 0: + print(f"Training for {cfg.num_steps} steps") + train( + cfg, + model, + local_rank, + rank, + train_loader, + optimizer, + scheduler, + profiler, + checkpointer, + start_step, + tokens_seen, + ) + + checkpointer.save_single_file(cfg.num_steps, model) + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/speculator/train_speculator_utils.py b/speculator/train_speculator_utils.py index 0a265a63..0e437051 100644 --- a/speculator/train_speculator_utils.py +++ b/speculator/train_speculator_utils.py @@ -10,11 +10,11 @@ from fms.models import register_model from fms.models.gpt_bigcode import GPTBigCode from fms.models.gpt_bigcode import _20b_config as _gpt_bigcode_20b_config -from fms.models.gpt_bigcode import _hf_sd_to_fms_sd as _gptbigcode_hf_sd_to_fms_sd +from fms.models.gpt_bigcode import _hf_to_fms_names as _gptbigcode_hf_sd_to_fms_sd from fms.models.llama import LLaMA -from fms.models.llama import _hf_sd_to_fms_sd as _llama_hf_sd_to_fms_sd +from fms.models.llama import _hf_to_fms_names as _llama_hf_sd_to_fms_sd from fms.models.mixtral import Mixtral, MixtralConfig -from fms.models.mixtral import _hf_sd_to_fms_sd as _mixtral_hf_sd_to_fms_sd +from fms.models.mixtral import _hf_to_fms_names as _mixtral_hf_sd_to_fms_sd from fms.utils import serialization, tokenizers from fms.utils.generation import _make_cache_contiguous from torch.nn import CrossEntropyLoss @@ -554,7 +554,10 @@ def factory(**kwargs): register_model( "embedgpt_bigcode", "20b", _gpt_bigcode_factory_factory(_gpt_bigcode_20b_config) ) -serialization.register_adapter("embedgpt_bigcode", "hf", _gptbigcode_hf_sd_to_fms_sd) +serialization.register_adapter_step( + "embedgpt_bigcode", "hf_to_fms", _gptbigcode_hf_sd_to_fms_sd +) +serialization.register_adapter("embedgpt_bigcode", "hf", ["hf_to_fms"]) register_model( "embedllama", "7b", _llama_factory_factory(get_model_config("llama2_7b")) @@ -562,7 +565,11 @@ def factory(**kwargs): register_model( "embedllama", "8b", _llama_factory_factory(get_model_config("llama3_8b")) ) -serialization.register_adapter("embedllama", "hf", _llama_hf_sd_to_fms_sd) +serialization.register_adapter_step("embedllama", "hf_to_fms", _llama_hf_sd_to_fms_sd) +serialization.register_adapter("embedllama", "hf", ["hf_to_fms"]) register_model("embedmixtral", "8x7b", _mixtral_factory_factory(MixtralConfig())) -serialization.register_adapter("embedmixtral", "hf", _mixtral_hf_sd_to_fms_sd) +serialization.register_adapter_step( + "embedmixtral", "hf_to_fms", _mixtral_hf_sd_to_fms_sd +) +serialization.register_adapter("embedmixtral", "hf", ["hf_to_fms"])