-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdecoder_wrapper.py
42 lines (33 loc) · 1.44 KB
/
decoder_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
from src.model.configurations.base_model import BaseModel
from src.utils import Batch, BatchTest
class DecoderWrapper(BaseModel):
"""This class serves as a GPT-2 wrapper for commit message completion task.
Args:
tokenizer: tokenizer for target sequences (messages)
decoder_name_or_path: name or path for pretrained GPT-2 checkpoint
"""
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
decoder_name_or_path: str,
**kwargs,
):
super().__init__()
self._tokenizer = tokenizer
self.model = AutoModelForCausalLM.from_pretrained(decoder_name_or_path)
self.model.resize_token_embeddings(len(self._tokenizer)) # type: ignore[arg-type]
def forward(self, batch: Batch):
return self.model(
input_ids=batch.decoder_input_ids, attention_mask=batch.decoder_attention_mask, labels=batch.labels
)
def generate(self, batch: BatchTest, **generation_kwargs):
return self.model.generate(
input_ids=batch.decoder_input_ids,
attention_mask=batch.decoder_attention_mask,
**generation_kwargs,
)
def num_parameters(self, exclude_embeddings: bool):
return self.model.num_parameters(exclude_embeddings=exclude_embeddings)
def save_pretrained(self, path: str) -> None:
self.model.save_pretrained(path)