diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index fbe06951..82b58fbc 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -42,6 +42,10 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "mixtral" +class OlmoeGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): + name: typing.ClassVar[str] = "olmoe" + + @config_class() class GPTArchitectureConfig(LanguageModelArchitectureConfig): _abstract = False @@ -98,6 +102,7 @@ class GPTModelConfig(FastLLMModelConfig): LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + OlmoeGPTHuggingfaceCheckpointFormat, ) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index eea98b8b..5333426b 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -28,6 +28,7 @@ LlamaGPTHuggingfaceCheckpointFormat, MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, + OlmoeGPTHuggingfaceCheckpointFormat, Starcoder2GPTHuggingfaceCheckpointFormat, ) from fast_llm.models.gpt.model import GPTModel @@ -103,7 +104,7 @@ class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: GPTModel _model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig """ - Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral) + Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral, olmoe) """ @abc.abstractmethod @@ -336,6 +337,42 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): ] +class OlmoeHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler): + format: typing.ClassVar[type[CheckpointFormat]] = OlmoeGPTHuggingfaceCheckpointFormat + + @classmethod + def _create_config_converters(cls) -> list[ParamConverter]: + return super()._create_config_converters() + [ + ConstantExportParamConverter(None, "architectures", ["OlmoeForCausalLM"]), + ConstantImportParamConverter(("transformer", "expert_routing_type"), None, RoutingType.topk), + ParamConverter(("transformer", "num_experts"), "num_experts"), + ParamConverter(("transformer", "num_experts_per_token"), "num_experts_per_tok"), + # TODO: change this once fast-llm supports normalized topk probs + ConstantExportParamConverter(None, "norm_topk_prob", True), + # TODO: change this once fast-llm supports qk normalization + ConstantExportParamConverter(None, "qk_norm", False), + ] + + def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): + num_experts = self._model.base_model_config.transformer.num_experts + return [ + WeightConverter(f"{fast_llm_prefix}.mlp.router.weight", f"{hf_prefix}.mlp.gate.weight"), + SplitWeightConverter( + f"{fast_llm_prefix}.mlp.layer_1.weight", + tuple( + f"{hf_prefix}.mlp.experts.{i}.{w}.weight" + for i in range(num_experts) + for w in ("gate_proj", "up_proj") + ), + ), + MLPLayer2Converter( + f"{fast_llm_prefix}.mlp.layer_2.weight", + tuple(f"{hf_prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)), + self._model.base_model_config, + ), + ] + + class AutoGPTHuggingfaceCheckpointHandler( AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC ): @@ -345,4 +382,5 @@ class AutoGPTHuggingfaceCheckpointHandler( LlamaGPTHuggingfaceCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler, MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler, MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler, + OlmoeGPTHuggingfaceCheckpointFormat.name: OlmoeHuggingfaceCheckpointHandler, }