Skip to content

Commit 7e4b204

Browse files
sherlockwucopybara-github
authored andcommitted
Add sliding window attention for Gemma 2.
PiperOrigin-RevId: 648778253
1 parent 09a7e75 commit 7e4b204

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

gemma/configs.h

+32-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
6262
return config;
6363
}
6464

65+
template <size_t kNum>
66+
constexpr std::array<size_t, kNum> FixedAttentionWindowSizes(
67+
size_t window_size) {
68+
std::array<size_t, kNum> window_size_configs = {};
69+
for (size_t& l : window_size_configs) {
70+
l = window_size;
71+
}
72+
return window_size_configs;
73+
}
74+
6575
template <size_t kNumLayers>
6676
constexpr size_t NumLayersOfTypeBefore(
6777
const std::array<LayerAttentionType, kNumLayers>& layers,
@@ -114,10 +124,16 @@ template <typename TWeight>
114124
struct ConfigGemma27B : public ConfigCapNoSSM {
115125
using Weight = TWeight; // make accessible where we only have a TConfig
116126

117-
static constexpr int kSeqLen = gcpp::kSeqLen;
127+
static constexpr int kSeqLen = 8192;
118128
static constexpr int kVocabSize = 256000;
119129
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
120130
FixedLayerConfig<46>(LayerAttentionType::kGemma);
131+
static constexpr std::array<size_t, 46> kAttentionWindowSizes = {
132+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
133+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
134+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
135+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
136+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen};
121137
static constexpr int kLayers = kLayerConfig.size();
122138
static constexpr int kGemmaLayers = kLayers;
123139
static constexpr int kModelDim = 4608;
@@ -134,10 +150,16 @@ template <typename TWeight>
134150
struct ConfigGemma9B : public ConfigCapNoSSM {
135151
using Weight = TWeight; // make accessible where we only have a TConfig
136152

137-
static constexpr int kSeqLen = gcpp::kSeqLen;
153+
static constexpr int kSeqLen = 8192;
138154
static constexpr int kVocabSize = 256000;
139155
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
140156
FixedLayerConfig<42>(LayerAttentionType::kGemma);
157+
static constexpr std::array<size_t, 42> kAttentionWindowSizes = {
158+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
159+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
160+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
161+
4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen, 4096, kSeqLen,
162+
4096, kSeqLen};
141163
static constexpr int kLayers = kLayerConfig.size();
142164
static constexpr int kGemmaLayers = kLayers;
143165
static constexpr int kModelDim = 3584;
@@ -158,6 +180,8 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
158180
static constexpr int kVocabSize = 256000;
159181
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
160182
FixedLayerConfig<28>(LayerAttentionType::kGemma);
183+
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
184+
FixedAttentionWindowSizes<28>(kSeqLen);
161185
static constexpr int kLayers = kLayerConfig.size();
162186
static constexpr int kGemmaLayers = kLayers;
163187
static constexpr int kModelDim = 3072;
@@ -178,6 +202,8 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
178202
static constexpr int kVocabSize = 256000;
179203
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
180204
FixedLayerConfig<18>(LayerAttentionType::kGemma);
205+
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
206+
FixedAttentionWindowSizes<18>(kSeqLen);
181207
static constexpr int kLayers = kLayerConfig.size();
182208
static constexpr int kGemmaLayers = kLayers;
183209
static constexpr int kModelDim = 2048;
@@ -198,6 +224,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
198224
static constexpr int kVocabSize = 64;
199225
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
200226
FixedLayerConfig<3>(LayerAttentionType::kGemma);
227+
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
228+
FixedAttentionWindowSizes<3>(kSeqLen);
201229
static constexpr int kLayers = kLayerConfig.size();
202230
static constexpr int kGemmaLayers = kLayers;
203231
static constexpr int kModelDim = 128;
@@ -250,6 +278,8 @@ struct ConfigGriffin2B {
250278
LayerAttentionType::kGriffinRecurrentBlock,
251279
LayerAttentionType::kGriffinRecurrentBlock,
252280
};
281+
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
282+
FixedAttentionWindowSizes<26>(kSeqLen);
253283
static constexpr int kLayers = kLayerConfig.size();
254284
static constexpr int kGemmaLayers =
255285
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);

gemma/gemma.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ HWY_NOINLINE void Attention(
377377
MulByConst(kQueryScale, q, kQKVDim);
378378

379379
// Compute Q dot K scores
380-
const size_t start_pos = pos - std::min(kSeqLen - 1, pos);
380+
const size_t start_pos =
381+
pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos);
381382
for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) {
382383
const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize);
383384
const size_t kv_offset =

0 commit comments

Comments
 (0)