Skip to content

Commit e6de786

Browse files
Tcc0403lancertsshimizust
authored
fix(phi3): update monkey patch for Phi3ForCausalLM (#837)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Liger backward support for transformers has been updated to 4.49.0, so we can remove most reprecated functions that were made for `SUPPORTED_TRANSFORMER_VERSION=4.46,1` (we should probably bump it for next release too). This PR merges most patching logics into one that should be able to work with both 4.49.0 and the latest version (4.54.0) <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <[email protected]> Co-authored-by: Shao Tang <[email protected]> Co-authored-by: Steven Shimizu <[email protected]>
1 parent 10b6cde commit e6de786

File tree

2 files changed

+14
-163
lines changed

2 files changed

+14
-163
lines changed

src/liger_kernel/transformers/model/phi3.py

Lines changed: 8 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -5,131 +5,12 @@
55

66
import torch
77

8-
from torch.nn import CrossEntropyLoss
8+
from transformers.modeling_outputs import BaseModelOutputWithPast
99
from transformers.modeling_outputs import CausalLMOutputWithPast
10-
from transformers.utils.deprecation import deprecate_kwarg
1110

12-
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1311
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1412

1513

16-
def lce_forward_deprecated(
17-
self,
18-
input_ids: torch.LongTensor = None,
19-
attention_mask: Optional[torch.Tensor] = None,
20-
position_ids: Optional[torch.LongTensor] = None,
21-
past_key_values: Optional[List[torch.FloatTensor]] = None,
22-
inputs_embeds: Optional[torch.FloatTensor] = None,
23-
labels: Optional[torch.LongTensor] = None,
24-
use_cache: Optional[bool] = None,
25-
output_attentions: Optional[bool] = None,
26-
output_hidden_states: Optional[bool] = None,
27-
return_dict: Optional[bool] = None,
28-
cache_position: Optional[torch.LongTensor] = None,
29-
skip_logits: Optional[bool] = None,
30-
) -> Union[Tuple, CausalLMOutputWithPast]:
31-
r"""
32-
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
33-
34-
35-
Args:
36-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40-
41-
Returns:
42-
43-
Example:
44-
45-
```python
46-
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
47-
48-
>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
49-
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
50-
51-
>>> prompt = "This is an example script ."
52-
>>> inputs = tokenizer(prompt, return_tensors="pt")
53-
54-
>>> # Generate
55-
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
56-
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
57-
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
58-
```"""
59-
60-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
61-
output_hidden_states = (
62-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
63-
)
64-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
65-
66-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67-
outputs = self.model(
68-
input_ids=input_ids,
69-
attention_mask=attention_mask,
70-
position_ids=position_ids,
71-
past_key_values=past_key_values,
72-
inputs_embeds=inputs_embeds,
73-
use_cache=use_cache,
74-
output_attentions=output_attentions,
75-
output_hidden_states=output_hidden_states,
76-
return_dict=return_dict,
77-
)
78-
79-
hidden_states = outputs[0]
80-
81-
loss = None
82-
logits = None
83-
84-
if skip_logits and labels is None:
85-
raise ValueError("skip_logits is True, but labels is None")
86-
87-
if skip_logits is None:
88-
# By default, if in training mode, don't materialize logits
89-
skip_logits = self.training and labels is not None
90-
91-
if skip_logits:
92-
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
93-
shift_labels = labels[..., 1:].contiguous()
94-
95-
# flatten tokens
96-
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
97-
shift_labels = shift_labels.view(-1)
98-
99-
lce = LigerFusedLinearCrossEntropyLoss()
100-
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
101-
else:
102-
logits = self.lm_head(hidden_states)
103-
104-
loss = None
105-
if labels is not None:
106-
# Upcast to float if we need to compute the loss to avoid potential precision issues
107-
logits = logits.float()
108-
# Shift so that tokens < n predict n
109-
shift_logits = logits[..., :-1, :].contiguous()
110-
shift_labels = labels[..., 1:].contiguous()
111-
# Flatten the tokens
112-
loss_fct = CrossEntropyLoss()
113-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
114-
shift_labels = shift_labels.view(-1)
115-
# Enable model parallelism
116-
shift_labels = shift_labels.to(shift_logits.device)
117-
loss = loss_fct(shift_logits, shift_labels)
118-
119-
if not return_dict:
120-
output = (logits,) + outputs[1:]
121-
return (loss,) + output if loss is not None else output
122-
123-
return CausalLMOutputWithPast(
124-
loss=loss,
125-
logits=logits,
126-
past_key_values=outputs.past_key_values,
127-
hidden_states=outputs.hidden_states,
128-
attentions=outputs.attentions,
129-
)
130-
131-
132-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
13314
def lce_forward(
13415
self,
13516
input_ids: torch.LongTensor = None,
@@ -148,36 +29,21 @@ def lce_forward(
14829
**kwargs,
14930
) -> Union[Tuple, CausalLMOutputWithPast]:
15031
r"""
151-
Args:
152-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
153-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
154-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
155-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
156-
157-
logits_to_keep (`int` or `torch.Tensor`, *optional*):
158-
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
159-
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
160-
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
161-
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
162-
This is useful when using packed tensor format (single dimension for batch and sequence length).
163-
164-
Returns:
165-
16632
Example:
16733
16834
```python
16935
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
17036
171-
>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
172-
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
37+
>>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf")
38+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf")
17339
174-
>>> prompt = "This is an example script ."
40+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
17541
>>> inputs = tokenizer(prompt, return_tensors="pt")
17642
17743
>>> # Generate
17844
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
17945
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
180-
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
46+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
18147
```"""
18248

18349
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -186,21 +52,18 @@ def lce_forward(
18652
)
18753
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
18854

189-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
190-
outputs = self.model(
55+
outputs: BaseModelOutputWithPast = self.model(
19156
input_ids=input_ids,
19257
attention_mask=attention_mask,
19358
position_ids=position_ids,
19459
past_key_values=past_key_values,
19560
inputs_embeds=inputs_embeds,
19661
use_cache=use_cache,
197-
output_attentions=output_attentions,
198-
output_hidden_states=output_hidden_states,
199-
return_dict=return_dict,
62+
cache_position=cache_position,
20063
**kwargs,
20164
)
20265

203-
hidden_states = outputs[0]
66+
hidden_states = outputs.last_hidden_state
20467
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
20568
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
20669
kept_hidden_states = hidden_states[:, slice_indices, :]

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
2727
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
2828
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
29-
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
3029
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
3130
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
3231
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
@@ -1677,25 +1676,14 @@ def apply_liger_kernel_to_phi3(
16771676
if swiglu:
16781677
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
16791678
if cross_entropy:
1680-
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1681-
from transformers.loss.loss_utils import nn
1679+
from transformers.loss.loss_utils import nn
16821680

1683-
nn.functional.cross_entropy = liger_cross_entropy
1684-
else:
1685-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1686-
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
1681+
nn.functional.cross_entropy = liger_cross_entropy
16871682
if fused_linear_cross_entropy:
1688-
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
1689-
if model is not None:
1690-
model.forward = MethodType(phi3_lce_forward, model)
1691-
else:
1692-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
1693-
else: # if version < 4.46.1
1694-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
1695-
if model is not None:
1696-
model.forward = MethodType(phi3_lce_forward_deprecated, model)
1697-
else:
1698-
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward_deprecated
1683+
if model is not None:
1684+
model.forward = MethodType(phi3_lce_forward, model)
1685+
else:
1686+
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
16991687

17001688
if model is not None:
17011689
# The model instance already exists, so we need to additionally patch the

0 commit comments

Comments
 (0)