55
66import torch
77
8- from torch . nn import CrossEntropyLoss
8+ from transformers . modeling_outputs import BaseModelOutputWithPast
99from transformers .modeling_outputs import CausalLMOutputWithPast
10- from transformers .utils .deprecation import deprecate_kwarg
1110
12- from liger_kernel .transformers .fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
1311from liger_kernel .transformers .model .loss_utils import LigerForCausalLMLoss
1412
1513
16- def lce_forward_deprecated (
17- self ,
18- input_ids : torch .LongTensor = None ,
19- attention_mask : Optional [torch .Tensor ] = None ,
20- position_ids : Optional [torch .LongTensor ] = None ,
21- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
22- inputs_embeds : Optional [torch .FloatTensor ] = None ,
23- labels : Optional [torch .LongTensor ] = None ,
24- use_cache : Optional [bool ] = None ,
25- output_attentions : Optional [bool ] = None ,
26- output_hidden_states : Optional [bool ] = None ,
27- return_dict : Optional [bool ] = None ,
28- cache_position : Optional [torch .LongTensor ] = None ,
29- skip_logits : Optional [bool ] = None ,
30- ) -> Union [Tuple , CausalLMOutputWithPast ]:
31- r"""
32- Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy
33-
34-
35- Args:
36- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40-
41- Returns:
42-
43- Example:
44-
45- ```python
46- >>> from transformers import AutoTokenizer, Phi3ForCausalLM
47-
48- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
49- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
50-
51- >>> prompt = "This is an example script ."
52- >>> inputs = tokenizer(prompt, return_tensors="pt")
53-
54- >>> # Generate
55- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
56- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
57- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
58- ```"""
59-
60- output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
61- output_hidden_states = (
62- output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
63- )
64- return_dict = return_dict if return_dict is not None else self .config .use_return_dict
65-
66- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
67- outputs = self .model (
68- input_ids = input_ids ,
69- attention_mask = attention_mask ,
70- position_ids = position_ids ,
71- past_key_values = past_key_values ,
72- inputs_embeds = inputs_embeds ,
73- use_cache = use_cache ,
74- output_attentions = output_attentions ,
75- output_hidden_states = output_hidden_states ,
76- return_dict = return_dict ,
77- )
78-
79- hidden_states = outputs [0 ]
80-
81- loss = None
82- logits = None
83-
84- if skip_logits and labels is None :
85- raise ValueError ("skip_logits is True, but labels is None" )
86-
87- if skip_logits is None :
88- # By default, if in training mode, don't materialize logits
89- skip_logits = self .training and labels is not None
90-
91- if skip_logits :
92- shift_hidden_states = hidden_states [..., :- 1 , :].contiguous ()
93- shift_labels = labels [..., 1 :].contiguous ()
94-
95- # flatten tokens
96- shift_hidden_states = shift_hidden_states .view (- 1 , self .config .hidden_size )
97- shift_labels = shift_labels .view (- 1 )
98-
99- lce = LigerFusedLinearCrossEntropyLoss ()
100- loss = lce (self .lm_head .weight , shift_hidden_states , shift_labels )
101- else :
102- logits = self .lm_head (hidden_states )
103-
104- loss = None
105- if labels is not None :
106- # Upcast to float if we need to compute the loss to avoid potential precision issues
107- logits = logits .float ()
108- # Shift so that tokens < n predict n
109- shift_logits = logits [..., :- 1 , :].contiguous ()
110- shift_labels = labels [..., 1 :].contiguous ()
111- # Flatten the tokens
112- loss_fct = CrossEntropyLoss ()
113- shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
114- shift_labels = shift_labels .view (- 1 )
115- # Enable model parallelism
116- shift_labels = shift_labels .to (shift_logits .device )
117- loss = loss_fct (shift_logits , shift_labels )
118-
119- if not return_dict :
120- output = (logits ,) + outputs [1 :]
121- return (loss ,) + output if loss is not None else output
122-
123- return CausalLMOutputWithPast (
124- loss = loss ,
125- logits = logits ,
126- past_key_values = outputs .past_key_values ,
127- hidden_states = outputs .hidden_states ,
128- attentions = outputs .attentions ,
129- )
130-
131-
132- @deprecate_kwarg ("num_logits_to_keep" , version = "4.50" , new_name = "logits_to_keep" )
13314def lce_forward (
13415 self ,
13516 input_ids : torch .LongTensor = None ,
@@ -148,36 +29,21 @@ def lce_forward(
14829 ** kwargs ,
14930) -> Union [Tuple , CausalLMOutputWithPast ]:
15031 r"""
151- Args:
152- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
153- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
154- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
155- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
156-
157- logits_to_keep (`int` or `torch.Tensor`, *optional*):
158- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
159- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
160- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
161- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
162- This is useful when using packed tensor format (single dimension for batch and sequence length).
163-
164- Returns:
165-
16632 Example:
16733
16834 ```python
16935 >>> from transformers import AutoTokenizer, Phi3ForCausalLM
17036
171- >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct ")
172- >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct ")
37+ >>> model = Phi3ForCausalLM.from_pretrained("meta-phi3/Phi3-2-7b-hf ")
38+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi3/Phi3-2-7b-hf ")
17339
174- >>> prompt = "This is an example script . "
40+ >>> prompt = "Hey, are you conscious? Can you talk to me? "
17541 >>> inputs = tokenizer(prompt, return_tensors="pt")
17642
17743 >>> # Generate
17844 >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
17945 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
180- 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
46+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
18147 ```"""
18248
18349 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -186,21 +52,18 @@ def lce_forward(
18652 )
18753 return_dict = return_dict if return_dict is not None else self .config .use_return_dict
18854
189- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
190- outputs = self .model (
55+ outputs : BaseModelOutputWithPast = self .model (
19156 input_ids = input_ids ,
19257 attention_mask = attention_mask ,
19358 position_ids = position_ids ,
19459 past_key_values = past_key_values ,
19560 inputs_embeds = inputs_embeds ,
19661 use_cache = use_cache ,
197- output_attentions = output_attentions ,
198- output_hidden_states = output_hidden_states ,
199- return_dict = return_dict ,
62+ cache_position = cache_position ,
20063 ** kwargs ,
20164 )
20265
203- hidden_states = outputs [ 0 ]
66+ hidden_states = outputs . last_hidden_state
20467 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
20568 slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
20669 kept_hidden_states = hidden_states [:, slice_indices , :]
0 commit comments