-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmamba_config.py
executable file
·129 lines (106 loc) · 3.99 KB
/
mamba_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from dataclasses import dataclass, field
from typing import Callable
from typing import List
import torch
import torch.nn.functional as F
from utils import init_method_normal, scaled_init_method_normal
@dataclass
class MambaConfig():
# model architecture
base_model_type: str = "mamba"
num_layers: int = 0
hidden_size: int = 0
state_size: int = 0
mamba_headdim: int = 64
mamba_ngroups: int = 1
expansion_factor: int = 2
conv_dimension: int = 0
conv_bias: bool = True
bias: bool = True
use_fast_path: bool = True
dt_rank: str = "auto"
dt_min: float = 0.001
dt_max: float = 0.1
dt_init: str = "random"
dt_scale: float = 1.0
dt_init_floor: float = 1e-4
use_module_layernorm: bool = True
rms_norm: bool = False
fused_add_norm: bool = False
residual_in_fp32: bool = False
hidden_dropout: float = 0.0
ffn_hidden_size: int = None
gated_linear_unit: bool = False
kv_channels: int = None
kv_mem_channels: int = None
num_attention_heads: int = 0
num_query_groups: int = None
num_mem_query_groups: int = None
attention_dropout: float = 0.1
num_mem_heads: int = 0
use_mem_mlp: bool = False
window_size: int = None
gateconv_expansion_factor: int = 2
layer_mapping: List[str] = field(default_factory=lambda: [""])
vocab_size: int = 0
device: str = "cuda"
use_mem_rope: bool = False
use_shared_block_lora: bool = False
lora_rank: int = 16
num_mem_blocks: int = 1
use_low_rank_mamba_proj: bool = False
num_shared_mamba_proj: int = 1
mamba_lora_rank: int = 1
use_shared_attention_lora: bool = False
rope_theta: int = 10000
fp32_residual_connection: bool = False
layernorm_epsilon: float = 1e-5
layernorm_zero_centered_gamma: bool = False
add_bias_linear: bool = False
activation_func: Callable = F.gelu
num_moe_experts: int = None
# initialization
init_method: Callable = None
output_layer_init_method: Callable = None
init_method_std: float = 0.02
# mixed-precision
apply_query_key_layer_scaling: bool = True
attention_softmax_in_fp32: bool = True
gated_linear_unit: bool = False
bias_gelu_fusion: bool = False
masked_softmax_fusion: bool = False
persist_layer_norm: bool = False
bias_dropout_fusion: bool = False
# activation recomputation
recompute_granularity: str = None
recompute_method: str = None
recompute_num_layers: int = None
distribute_saved_activations: bool = None
def __post_init__(self):
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.ffn_hidden_size is None:
self.ffn_hidden_size = 4 * self.hidden_size
if self.kv_channels is None and self.num_attention_heads is not None:
self.kv_channels = self.hidden_size // self.num_attention_heads
if self.kv_mem_channels is None and self.num_mem_heads > 0:
self.kv_mem_channels = self.hidden_size // self.num_mem_heads
if self.num_query_groups is None:
self.num_query_groups = self.num_attention_heads
if self.num_mem_query_groups is None:
self.num_mem_query_groups = self.num_mem_heads
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
if self.bias_gelu_fusion:
if not self.add_bias_linear:
raise ValueError(
"When bias_gelu_fusion is True, add_bias_linear must also be True."
)
if self.activation_func != F.gelu:
raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
if self.init_method is None:
self.init_method = init_method_normal(self.init_method_std)
if self.output_layer_init_method is None:
self.output_layer_init_method = scaled_init_method_normal(
self.init_method_std, self.num_layers
)