-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconfigs.py
65 lines (48 loc) · 1.88 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from dataclasses import dataclass
from .base_configs import (
BaseDecoderConfig,
BaseEncoderDecoderConfig,
BaseRACEConfig,
BaseSeq2SeqConfig,
)
@dataclass
class DistilGPT2Config(BaseDecoderConfig):
diff_tokenizer_name_or_path: str = "distilgpt2"
msg_tokenizer_name_or_path: str = "distilgpt2"
encoder_context_max_len: int = 512
decoder_context_max_len: int = 512
decoder_name_or_path: str = "distilgpt2"
@dataclass
class RandomTransformerConfig(BaseEncoderDecoderConfig):
diff_tokenizer_name_or_path: str = "raw_data/multilang/byte_level"
msg_tokenizer_name_or_path: str = "raw_data/multilang/byte_level"
encoder_context_max_len: int = 512
decoder_context_max_len: int = 256
num_layers_encoder: int = 2
encoder_model_type: str = "roberta"
num_layers_decoder: int = 2
decoder_model_type: str = "gpt2"
tie_encoder_decoder: bool = False
tie_word_embeddings: bool = False
@dataclass
class CodeT5Config(BaseSeq2SeqConfig):
diff_tokenizer_name_or_path: str = "Salesforce/codet5-base"
msg_tokenizer_name_or_path: str = "Salesforce/codet5-base"
encoder_context_max_len: int = 512
decoder_context_max_len: int = 512
name_or_path: str = "Salesforce/codet5-base"
@dataclass
class CodeReviewerConfig(BaseSeq2SeqConfig):
preprocessor_configuration: str = "codereviewer"
diff_tokenizer_name_or_path: str = "microsoft/codereviewer"
msg_tokenizer_name_or_path: str = "microsoft/codereviewer"
encoder_context_max_len: int = 512
decoder_context_max_len: int = 512
name_or_path: str = "microsoft/codereviewer"
@dataclass
class RACEConfig(BaseRACEConfig):
diff_tokenizer_name_or_path: str = "Salesforce/codet5-base"
msg_tokenizer_name_or_path: str = "Salesforce/codet5-base"
encoder_context_max_len: int = 512
decoder_context_max_len: int = 512
name_or_path: str = "Salesforce/codet5-base"