Skip to content

Commit 78063c0

Browse files
fix(server): remove position_ids from galactica forward (#82)
closes #80
1 parent 17bc841 commit 78063c0

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

server/text_generation/models/galactica.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.distributed
44

5-
from typing import List, Optional, Type
5+
from typing import List, Optional, Type, Tuple
66

77
from accelerate import init_empty_weights
88
from safetensors import safe_open
@@ -145,6 +145,20 @@ def decode(self, generated_ids: List[int]) -> str:
145145
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
146146
)
147147

148+
def forward(
149+
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
150+
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
151+
"""Overwrite forward to ignore position_ids"""
152+
153+
# Model Forward
154+
outputs = self.model.forward(
155+
input_ids=input_ids,
156+
attention_mask=attention_mask,
157+
past_key_values=past_key_values,
158+
use_cache=True,
159+
)
160+
return outputs.logits, outputs.past_key_values
161+
148162

149163
class GalacticaSharded(Galactica):
150164
def __init__(
@@ -322,7 +336,6 @@ def forward(
322336
outputs = self.model.forward(
323337
input_ids=input_ids,
324338
attention_mask=attention_mask,
325-
position_ids=position_ids,
326339
past_key_values=past_key_values,
327340
use_cache=True,
328341
)

0 commit comments

Comments
 (0)