|
2 | 2 | import torch |
3 | 3 | import torch.distributed |
4 | 4 |
|
5 | | -from typing import List, Optional, Type |
| 5 | +from typing import List, Optional, Type, Tuple |
6 | 6 |
|
7 | 7 | from accelerate import init_empty_weights |
8 | 8 | from safetensors import safe_open |
@@ -145,6 +145,20 @@ def decode(self, generated_ids: List[int]) -> str: |
145 | 145 | generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False |
146 | 146 | ) |
147 | 147 |
|
| 148 | + def forward( |
| 149 | + self, input_ids, attention_mask, position_ids, past_key_values: Optional = None |
| 150 | + ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: |
| 151 | + """Overwrite forward to ignore position_ids""" |
| 152 | + |
| 153 | + # Model Forward |
| 154 | + outputs = self.model.forward( |
| 155 | + input_ids=input_ids, |
| 156 | + attention_mask=attention_mask, |
| 157 | + past_key_values=past_key_values, |
| 158 | + use_cache=True, |
| 159 | + ) |
| 160 | + return outputs.logits, outputs.past_key_values |
| 161 | + |
148 | 162 |
|
149 | 163 | class GalacticaSharded(Galactica): |
150 | 164 | def __init__( |
@@ -322,7 +336,6 @@ def forward( |
322 | 336 | outputs = self.model.forward( |
323 | 337 | input_ids=input_ids, |
324 | 338 | attention_mask=attention_mask, |
325 | | - position_ids=position_ids, |
326 | 339 | past_key_values=past_key_values, |
327 | 340 | use_cache=True, |
328 | 341 | ) |
|
0 commit comments