diff --git a/README.md b/README.md index c378b40b4..662aef02e 100644 --- a/README.md +++ b/README.md @@ -255,6 +255,8 @@ loss.backward() | Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| DeepSeekv2 | `liger_kernel.transformers.apply_liger_kernel_to_deepseek_v2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | + ## Low-level APIs diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index cbf330cc2..2fb1450e0 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -7,6 +7,7 @@ from liger_kernel.transformers.layer_norm import LigerLayerNorm # noqa: F401 from liger_kernel.transformers.monkey_patch import _apply_liger_kernel # noqa: F401 from liger_kernel.transformers.monkey_patch import _apply_liger_kernel_to_instance # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_deepseek_v2 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401 diff --git a/src/liger_kernel/transformers/model/deepseekv2.py b/src/liger_kernel/transformers/model/deepseekv2.py new file mode 100644 index 000000000..573e36f48 --- /dev/null +++ b/src/liger_kernel/transformers/model/deepseekv2.py @@ -0,0 +1,176 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + +# This docstring is ported from the DeepSeek V2 model source code. +# Source: https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +@add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Copy paste deepseekv2 forward but replace torch cross entropy with liger fused linear cross entropy + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and labels is not None: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index eafce145e..2cf5f9227 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -13,6 +13,7 @@ from liger_kernel.transformers.functional import liger_cross_entropy from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm +from liger_kernel.transformers.model.deepseekv2 import lce_forward as deepseek_v2_lce_forward from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward @@ -735,6 +736,80 @@ def apply_liger_kernel_to_phi3( _patch_rms_norm_module(decoder_layer.post_attention_layernorm) +def apply_liger_kernel_to_deepseek_v2( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementations in DeepSeekv2 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized, enhancing memory efficiency. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if already loaded. Default is None. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + import sys + + # Ensure the model is a DeepSeek model + if "deepseek" not in model.__class__.__module__: + raise ValueError("The provided model is not a DeepSeek model") + + modeling_mod = sys.modules[model.__class__.__module__] + + if rope: + pass + # modeling_mod.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm + if swiglu: + modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward + if cross_entropy: + if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION): + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + else: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + modeling_mod.DeepseekV2ForCausalLM.forward = deepseek_v2_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for i, decoder_layer in enumerate(base_model.layers): + if swiglu: + if i == 0: + _bind_method_to_module(decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward) + else: + for expert in decoder_layer.mlp.experts: + _bind_method_to_module(expert, "forward", LigerSwiGLUMLP.forward) + if rms_norm: + _patch_rms_norm_module(decoder_layer.self_attn.kv_a_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py MODEL_TYPE_TO_APPLY_LIGER_FN = { "gemma": apply_liger_kernel_to_gemma, @@ -747,6 +822,7 @@ def apply_liger_kernel_to_phi3( "qwen2": apply_liger_kernel_to_qwen2, "qwen2_vl": apply_liger_kernel_to_qwen2_vl, "phi3": apply_liger_kernel_to_phi3, + "deepseek_v2": apply_liger_kernel_to_deepseek_v2, } diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 8566558e7..3f9818821 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -3,6 +3,8 @@ from datasets import load_from_disk from torch.utils.data import DataLoader +from transformers import AutoConfig +from transformers import AutoModelForCausalLM from transformers.models.gemma import GemmaConfig from transformers.models.gemma import GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config @@ -18,6 +20,7 @@ from transformers.models.qwen2 import Qwen2Config from transformers.models.qwen2 import Qwen2ForCausalLM +from liger_kernel.transformers import apply_liger_kernel_to_deepseek_v2 from liger_kernel.transformers import apply_liger_kernel_to_gemma from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_llama @@ -29,7 +32,9 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from test.utils import DEFAULT_DATASET_PATH from test.utils import MiniModelConfig +from test.utils import RemoteMiniModelConfig from test.utils import assert_verbose_allclose +from test.utils import revert_liger_kernel_to_deepseek_v2 from test.utils import revert_liger_kernel_to_gemma from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_llama @@ -292,6 +297,31 @@ attn_implementation="eager", ), ), + "remote_mini_deepseek_v2": RemoteMiniModelConfig( + remote_model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Base", + liger_kernel_patch_func=apply_liger_kernel_to_deepseek_v2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_deepseek_v2, + mini_model_config={ + "attention_dropout": 0.0, + "bos_token_id": 1, # 100000 + "eos_token_id": 2, # 100001 + "hidden_act": "silu", + "hidden_size": 896, # 2048 + "initializer_range": 0.02, + "intermediate_size": 4864, # 10944 + "max_position_embeddings": 4096, # 163840 + "num_attention_heads": 8, # 16 + "num_hidden_layers": 4, # 27 + "num_key_value_heads": None, # defaults to num_attention_heads + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "sliding_window": None, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 32064, # 102400 + "attn_implementation": "eager", + }, + ), } if MLLAMA_AVAILABLE: @@ -387,9 +417,16 @@ def create_model(model_name="mini_llama3"): Create a mini version model The commented values are the original values """ - model_config = MINI_MODEL_SETUPS[model_name].mini_model_config - model_class = MINI_MODEL_SETUPS[model_name].model_class - return model_class(model_config) + if model_name[:6] == "remote": + config = AutoConfig.from_pretrained(MINI_MODEL_SETUPS[model_name].remote_model_path, trust_remote_code=True) + config.update(MINI_MODEL_SETUPS[model_name].mini_model_config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + MINI_MODEL_SETUPS[model_name].remote_model_module = model.__class__.__module__ + return model + else: + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) def run_mini_model( @@ -410,6 +447,11 @@ def run_mini_model( if "mllama" in model_name: revert_kwargs["model_type"] = "causal_lm" + if model_name[:6] == "remote": + revert_kwargs["remote_model_module"] = MINI_MODEL_SETUPS[model_name].remote_model_module + + model = create_model(model_name).to(dtype).to(device) + if with_liger is True: kwargs = { "rope": True, @@ -427,13 +469,11 @@ def run_mini_model( kwargs["fused_linear_cross_entropy"] = True kwargs["cross_entropy"] = False - + kwargs["model"] = model MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to(device) - train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) loader_iter = iter(loader) @@ -646,6 +686,20 @@ def run_mini_model( # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" # ), # ), + ("remote_mini_deepseek_v2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "remote_mini_deepseek_v2", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), ], ) def test_mini_model( diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index 9abed2bd9..43a11ddbf 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -3,6 +3,8 @@ from datasets import load_from_disk from torch.utils.data import DataLoader +from transformers import AutoConfig +from transformers import AutoModelForCausalLM from transformers.models.gemma import GemmaConfig from transformers.models.gemma import GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config @@ -18,6 +20,7 @@ from transformers.models.qwen2 import Qwen2Config from transformers.models.qwen2 import Qwen2ForCausalLM +from liger_kernel.transformers import apply_liger_kernel_to_deepseek_v2 from liger_kernel.transformers import apply_liger_kernel_to_gemma from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_llama @@ -29,7 +32,9 @@ from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl from test.utils import DEFAULT_DATASET_PATH from test.utils import MiniModelConfig +from test.utils import RemoteMiniModelConfig from test.utils import assert_verbose_allclose +from test.utils import revert_liger_kernel_to_deepseek_v2 from test.utils import revert_liger_kernel_to_gemma from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_llama @@ -292,6 +297,31 @@ attn_implementation="eager", ), ), + "remote_mini_deepseek_v2": RemoteMiniModelConfig( + remote_model_path="deepseek-ai/DeepSeek-Coder-V2-Lite-Base", + liger_kernel_patch_func=apply_liger_kernel_to_deepseek_v2, + liger_kernel_patch_revert_func=revert_liger_kernel_to_deepseek_v2, + mini_model_config={ + "attention_dropout": 0.0, + "bos_token_id": 1, # 100000 + "eos_token_id": 2, # 100001 + "hidden_act": "silu", + "hidden_size": 896, # 2048 + "initializer_range": 0.02, + "intermediate_size": 4864, # 10944 + "max_position_embeddings": 4096, # 163840 + "num_attention_heads": 8, # 16 + "num_hidden_layers": 4, # 27 + "num_key_value_heads": None, # defaults to num_attention_heads + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "sliding_window": None, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 32064, # 102400 + "attn_implementation": "eager", + }, + ), } if MLLAMA_AVAILABLE: @@ -387,9 +417,16 @@ def create_model(model_name="mini_llama3"): Create a mini version model The commented values are the original values """ - model_config = MINI_MODEL_SETUPS[model_name].mini_model_config - model_class = MINI_MODEL_SETUPS[model_name].model_class - return model_class(model_config) + if model_name[:6] == "remote": + config = AutoConfig.from_pretrained(MINI_MODEL_SETUPS[model_name].remote_model_path, trust_remote_code=True) + config.update(MINI_MODEL_SETUPS[model_name].mini_model_config) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + MINI_MODEL_SETUPS[model_name].remote_model_module = model.__class__.__module__ + return model + else: + model_config = MINI_MODEL_SETUPS[model_name].mini_model_config + model_class = MINI_MODEL_SETUPS[model_name].model_class + return model_class(model_config) def run_mini_model( @@ -410,6 +447,11 @@ def run_mini_model( if "mllama" in model_name: revert_kwargs["model_type"] = "causal_lm" + if model_name[:6] == "remote": + revert_kwargs["remote_model_module"] = MINI_MODEL_SETUPS[model_name].remote_model_module + + model = create_model(model_name).to(dtype).to(device) + if with_liger is True: kwargs = { "rope": True, @@ -427,12 +469,11 @@ def run_mini_model( kwargs["fused_linear_cross_entropy"] = False kwargs["cross_entropy"] = True - + kwargs["model"] = model MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) loader_iter = iter(loader) @@ -645,6 +686,20 @@ def run_mini_model( # not supports_bfloat16(), reason="bfloat16 not supported on this GPU" # ), # ), + ("remote_mini_deepseek_v2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5), + pytest.param( + "remote_mini_deepseek_v2", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + ), ], ) def test_mini_model( diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 811cd74cc..5a35b8c31 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -1,4 +1,5 @@ import inspect +import sys from inspect import signature from unittest.mock import MagicMock @@ -9,6 +10,7 @@ import torch import transformers +from transformers import AutoConfig from transformers import AutoModelForCausalLM from transformers import PretrainedConfig from transformers import PreTrainedModel @@ -47,6 +49,7 @@ def is_qwen2_vl_available(): def test_import_from_root(): try: from liger_kernel.transformers import AutoLigerKernelForCausalLM # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_deepseek_v2 # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_gemma # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_gemma2 # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_llama # noqa: F401 @@ -695,3 +698,48 @@ def test_apply_liger_kernel_to_instance_for_phi3(): print(dummy_model_instance) except Exception as e: pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") + + +def test_apply_liger_kernel_to_deepseek_v2(): + config = AutoConfig.from_pretrained("deepseek-ai/DeepSeek-Coder-V2-Lite-Base", trust_remote_code=True) + + config.torch_dtype = torch.bfloat16 + config.rms_norm_eps = 1e-5 + config.hidden_size = 32 + config.intermediate_size = 64 + config.hidden_act = "silu" + config.num_hidden_layers = 2 + + dummy_model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + modeling_mod_name = dummy_model.__class__.__module__ + + with patch.dict(sys.modules, {modeling_mod_name: Mock()}): + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource(dummy_model.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + for i, layer in enumerate(dummy_model.model.layers): + if i == 0: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + else: + for expert in layer.mlp.experts: + assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource(dummy_model.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + for i, layer in enumerate(dummy_model.model.layers): + if i == 0: + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + else: + for expert in layer.mlp.experts: + assert inspect.getsource(expert.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.self_attn.kv_a_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + + try: + print(dummy_model) + except Exception as e: + pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") diff --git a/test/utils.py b/test/utils.py index 31294cc09..22bf61c58 100644 --- a/test/utils.py +++ b/test/utils.py @@ -134,6 +134,15 @@ class MiniModelConfig: mini_model_config: PretrainedConfig +@dataclass +class RemoteMiniModelConfig: + remote_model_path: str + liger_kernel_patch_func: callable + liger_kernel_patch_revert_func: callable + mini_model_config: Dict + remote_model_module: str = None + + def simple_collate_fn(data: List[Dict[str, Any]]): """A basic collate function to use for DataLoader""" @@ -379,6 +388,23 @@ def revert_liger_kernel_to_phi3(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_deepseek_v2(model_config: RemoteMiniModelConfig, remote_model_module: str): + """ + Revert all Liger kernel patches applied to deepseek v2 model. + """ + import sys + + if remote_model_module: + parent_mod_path = remote_model_module.rpartition(".")[0] + + if parent_mod_path not in sys.modules: + __import__(parent_mod_path) + + importlib.reload(sys.modules[remote_model_module]) + + print("Liger kernel patches have been reverted.") + + class HFAlignmentLoss: def __init__( self,