Skip to content

Commit cca75c5

Browse files
sherlockwucopybara-github
authored andcommitted
Add configurables for norm/rope/activation/scale/residual connection.
PiperOrigin-RevId: 648971168
1 parent 7e4b204 commit cca75c5

9 files changed

+85
-34
lines changed

backprop/backward-inl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
355355
static constexpr size_t kLayers = TConfig::kLayers;
356356
const float kEmbScaling = EmbeddingScaling<TConfig>();
357357
static_assert(!TConfig::kAbsolutePE);
358-
static_assert(!TConfig::kPostNormScale);
358+
static_assert(TConfig::kPostNorm == PostNormType::None);
359359
static_assert(TConfig::kKVHeads == 1);
360360

361361
HWY_DASSERT(prompt.context_size > 0);

backprop/backward_scalar_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ struct TestConfig : ConfigCapNoSSM {
388388
FixedLayerConfig<2>(LayerAttentionType::kGemma);
389389
static constexpr int kLayers = kLayerConfig.size();
390390
static constexpr bool kAbsolutePE = false;
391-
static constexpr bool kPostNormScale = false;
391+
static constexpr PostNormType kPostNorm = PostNormType::None;
392392

393393
static constexpr int kKVHeads = 1;
394394
static constexpr int kGemmaLayers = kLayers;

backprop/backward_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ struct TestConfig : public ConfigCapNoSSM {
193193
FixedLayerConfig<2>(LayerAttentionType::kGemma);
194194
static constexpr int kLayers = kLayerConfig.size();
195195
static constexpr bool kAbsolutePE = false;
196-
static constexpr bool kPostNormScale = false;
196+
static constexpr PostNormType kPostNorm = PostNormType::None;
197197

198198
static constexpr int kKVHeads = 1;
199199
static constexpr int kGemmaLayers = kLayers;

backprop/forward-inl.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
233233
static constexpr size_t kLayers = TConfig::kLayers;
234234
const float kEmbScaling = EmbeddingScaling<TConfig>();
235235
static_assert(!TConfig::kAbsolutePE);
236-
static_assert(!TConfig::kPostNormScale);
236+
static_assert(TConfig::kPostNorm == PostNormType::None);
237237
static_assert(TConfig::kKVHeads == 1);
238238

239239
HWY_DASSERT(context_size > 0);

gemma/configs.h

+40-6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ enum class LayerAttentionType {
5252
kGriffinRecurrentBlock,
5353
};
5454

55+
// Post attention and ffw normalization type.
56+
enum class PostNormType {
57+
None,
58+
Scale,
59+
};
60+
61+
// Post qk projection operation type.
62+
enum class PostQKType {
63+
Rope,
64+
};
65+
66+
// FFW activation function.
67+
enum class ActivationType {
68+
Gelu,
69+
};
70+
71+
// Attention query scale.
72+
enum class QueryScaleType {
73+
Sqrt,
74+
};
75+
76+
// Residual connection type.
77+
enum class ResidualType {
78+
Add,
79+
};
80+
5581
template <size_t kNum>
5682
constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
5783
LayerAttentionType type) {
@@ -107,6 +133,11 @@ struct ConfigNoSSM {
107133
static constexpr bool kUseLocalAttention = false;
108134
static constexpr bool kInterleaveQKV = true;
109135
static constexpr int kNumTensorScales = 0;
136+
137+
static constexpr PostQKType kPostQK = PostQKType::Rope;
138+
static constexpr ActivationType kActivation = ActivationType::Gelu;
139+
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
140+
static constexpr ResidualType kResidual = ResidualType::Add;
110141
};
111142

112143
struct ConfigNoCapNoSSM : ConfigNoSSM {
@@ -143,7 +174,7 @@ struct ConfigGemma27B : public ConfigCapNoSSM {
143174
static constexpr int kQKVDim = 128; // query size == key size == value size
144175
static constexpr int kTopK = gcpp::kTopK;
145176
static constexpr bool kAbsolutePE = false;
146-
static constexpr bool kPostNormScale = true;
177+
static constexpr PostNormType kPostNorm = PostNormType::Scale;
147178
};
148179

149180
template <typename TWeight>
@@ -169,7 +200,7 @@ struct ConfigGemma9B : public ConfigCapNoSSM {
169200
static constexpr int kQKVDim = 256; // query size == key size == value size
170201
static constexpr int kTopK = gcpp::kTopK;
171202
static constexpr bool kAbsolutePE = false;
172-
static constexpr bool kPostNormScale = true;
203+
static constexpr PostNormType kPostNorm = PostNormType::Scale;
173204
};
174205

175206
template <typename TWeight>
@@ -191,7 +222,7 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
191222
static constexpr int kQKVDim = 256; // query size == key size == value size
192223
static constexpr int kTopK = gcpp::kTopK;
193224
static constexpr bool kAbsolutePE = false;
194-
static constexpr bool kPostNormScale = false;
225+
static constexpr PostNormType kPostNorm = PostNormType::None;
195226
};
196227

197228
template <typename TWeight>
@@ -213,7 +244,7 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
213244
static constexpr int kQKVDim = 256; // query size == key size == value size
214245
static constexpr int kTopK = gcpp::kTopK;
215246
static constexpr bool kAbsolutePE = false;
216-
static constexpr bool kPostNormScale = false;
247+
static constexpr PostNormType kPostNorm = PostNormType::None;
217248
};
218249

219250
template <typename TWeight>
@@ -235,7 +266,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
235266
static constexpr int kQKVDim = 16; // query size == key size == value size
236267
static constexpr int kTopK = gcpp::kTopK;
237268
static constexpr bool kAbsolutePE = false;
238-
static constexpr bool kPostNormScale = false;
269+
static constexpr PostNormType kPostNorm = PostNormType::None;
239270

240271
static constexpr float kAttCap = 0.0f;
241272
// This is required for optimize_test to pass.
@@ -294,7 +325,7 @@ struct ConfigGriffin2B {
294325
static constexpr int kQKVDim = 256; // query size == key size == value size
295326
static constexpr int kTopK = gcpp::kTopK;
296327
static constexpr bool kAbsolutePE = false;
297-
static constexpr bool kPostNormScale = false;
328+
static constexpr PostNormType kPostNorm = PostNormType::None;
298329

299330
// No SoftCap.
300331
static constexpr float kAttCap = 0.0f;
@@ -308,6 +339,9 @@ struct ConfigGriffin2B {
308339
static constexpr bool kUseLocalAttention = true;
309340
static constexpr bool kInterleaveQKV = false;
310341
static constexpr int kNumTensorScales = 140;
342+
static constexpr PostQKType kPostQK = PostQKType::Rope;
343+
static constexpr QueryScaleType kQueryScale = QueryScaleType::Sqrt;
344+
static constexpr ResidualType kResidual = ResidualType::Add;
311345
};
312346

313347
} // namespace gcpp

gemma/gemma.cc

+28-15
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,9 @@ HWY_NOINLINE void Attention(
295295
constexpr size_t kHeads = TConfig::kHeads;
296296
constexpr size_t kKVHeads = TConfig::kKVHeads;
297297
constexpr size_t kSeqLen = TConfig::kSeqLen;
298-
GEMMA_CONSTEXPR_SQRT const float kQueryScale =
298+
GEMMA_CONSTEXPR_SQRT float kQueryScale =
299299
1.0f / Sqrt(static_cast<float>(kQKVDim));
300+
300301
constexpr bool kIsMHA = TActivations::kIsMHA; // Multi-Head Attention
301302
const size_t batch_start = batch_and_query_start / num_queries;
302303
const size_t num_tokens_and_queries = num_tokens * num_queries;
@@ -350,7 +351,9 @@ HWY_NOINLINE void Attention(
350351
// Skip past the Q part of `q`, and copy KV to `kv`.
351352
memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float));
352353
}
353-
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
354+
if (TConfig::kPostQK == PostQKType::Rope) {
355+
Rope(kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
356+
}
354357
});
355358

356359
static_assert((kHeads % kKVHeads) == 0,
@@ -373,7 +376,10 @@ HWY_NOINLINE void Attention(
373376
activations.att.data() + head * kSeqLen
374377
+ batch_and_query_idx * kHeads * kSeqLen;
375378

376-
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
379+
if (TConfig::kPostQK == PostQKType::Rope) {
380+
Rope(q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos);
381+
}
382+
377383
MulByConst(kQueryScale, q, kQKVDim);
378384

379385
// Compute Q dot K scores
@@ -465,10 +471,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
465471
namespace hn = hwy::HWY_NAMESPACE;
466472
using DF = hn::ScalableTag<float>;
467473
using VF = hn::Vec<DF>;
468-
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
469-
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
470-
return hn::Mul(mul, Gelu(df, v));
471-
});
474+
if (TConfig::kActivation == ActivationType::Gelu) {
475+
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
476+
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
477+
return hn::Mul(mul, Gelu(df, v));
478+
});
479+
}
472480

473481
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
474482
layer_weights->linear_w.data(),
@@ -560,29 +568,34 @@ HWY_NOINLINE void TransformerLayer(
560568
layer_weights, kv_caches, pool);
561569
}
562570
}
563-
if (TConfig::kPostNormScale) {
571+
572+
if (TConfig::kPostNorm == PostNormType::Scale) {
564573
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
565574
num_tokens_and_queries,
566575
layer_weights->post_attention_norm_scale.data(),
567576
activations.att_post2.data(), kModelDim);
568577
}
569-
AddFromBatched<kBatchSize * kQueryBatchSize>(num_tokens_and_queries,
570-
activations.att_post2.data(),
571-
activations.x.data(), kModelDim);
578+
if (TConfig::kResidual == ResidualType::Add) {
579+
AddFromBatched<kBatchSize * kQueryBatchSize>(
580+
num_tokens_and_queries, activations.att_post2.data(),
581+
activations.x.data(), kModelDim);
582+
}
572583
RMSNormBatched<kBatchSize * kQueryBatchSize>(
573584
num_tokens_and_queries, activations.x.data(),
574585
layer_weights->pre_ffw_norm_scale.data(),
575586
activations.bf_pre_ffw_rms_out.data(), kModelDim);
576587
FFW<TConfig, kBatchSize * kQueryBatchSize>(
577588
activations, num_tokens_and_queries, layer_weights, pool);
578-
if (TConfig::kPostNormScale) {
589+
if (TConfig::kPostNorm == PostNormType::Scale) {
579590
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize>(
580591
num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(),
581592
activations.ffw_out.data(), kModelDim);
582593
}
583-
AddFromBatched<kBatchSize * kQueryBatchSize>(
584-
num_tokens_and_queries, activations.ffw_out.data(),
585-
activations.x.data(), kModelDim);
594+
if (TConfig::kResidual == ResidualType::Add) {
595+
AddFromBatched<kBatchSize * kQueryBatchSize>(
596+
num_tokens_and_queries, activations.ffw_out.data(),
597+
activations.x.data(), kModelDim);
598+
}
586599
}
587600

588601
template <class TConfig, size_t kBatchSize, size_t kQueryBatchSize>

gemma/weights.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct CompressedLayer {
5050
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
5151
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
5252
static constexpr bool kFFBiases = TConfig::kFFBiases;
53-
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
53+
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
5454
static constexpr size_t kAOBiasDim =
5555
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
5656
static constexpr size_t kGriffinDim =
@@ -86,9 +86,10 @@ struct CompressedLayer {
8686
// We don't yet have an RMSNorm that accepts all Weight.
8787
ArrayT<WeightF32OrBF16, kModelDim> pre_attention_norm_scale;
8888
ArrayT<WeightF32OrBF16, kModelDim> pre_ffw_norm_scale;
89-
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0>
89+
ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
9090
post_attention_norm_scale;
91-
ArrayT<WeightF32OrBF16, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
91+
ArrayT<WeightF32OrBF16, kPostNorm == PostNormType::Scale ? kModelDim : 0>
92+
post_ffw_norm_scale;
9293

9394
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
9495
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
@@ -267,7 +268,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
267268
GEMMA_CALL_FUNC("gr_a", griffin.a);
268269
}
269270
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
270-
if (TConfig::kPostNormScale) {
271+
if (TConfig::kPostNorm == PostNormType::Scale) {
271272
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
272273
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
273274
}
@@ -331,7 +332,7 @@ void ForEachTensor(RawWeightsPtr raw_weights,
331332
GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \
332333
GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \
333334
GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \
334-
if (TConfig::kPostNormScale) { \
335+
if (TConfig::kPostNorm == PostNormType::Scale) { \
335336
GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \
336337
GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \
337338
} \

gemma/weights_raw.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <random>
2626

2727
#include "gemma/common.h"
28+
#include "gemma/configs.h"
2829
#include "hwy/aligned_allocator.h"
2930
#include "hwy/base.h"
3031
#include "hwy/contrib/thread_pool/thread_pool.h"
@@ -46,7 +47,7 @@ struct Layer {
4647
static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim;
4748
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
4849
static constexpr bool kFFBiases = TConfig::kFFBiases;
49-
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
50+
static constexpr PostNormType kPostNorm = TConfig::kPostNorm;
5051
static constexpr size_t kAOBiasDim =
5152
TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0;
5253
static constexpr size_t kGriffinDim =
@@ -78,8 +79,10 @@ struct Layer {
7879
std::array<T, kModelDim * kFFHiddenDim> linear_w;
7980
std::array<T, kModelDim> pre_attention_norm_scale;
8081
std::array<T, kModelDim> pre_ffw_norm_scale;
81-
std::array<T, kPostNormScale ? kModelDim : 0> post_attention_norm_scale;
82-
std::array<T, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
82+
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
83+
post_attention_norm_scale;
84+
std::array<T, kPostNorm == PostNormType::Scale ? kModelDim : 0>
85+
post_ffw_norm_scale;
8386

8487
std::array<T, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
8588
std::array<T, kFFBiases ? kModelDim : 0> ffw_output_biases;

util/compress_weights.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ struct LoadRawWeightsT {
159159
SCALE_WEIGHTS(linear_w);
160160
READ_WEIGHTS(pre_attention_norm_scale);
161161
READ_WEIGHTS(pre_ffw_norm_scale);
162-
if (TConfig::kPostNormScale) {
162+
if (TConfig::kPostNorm == PostNormType::Scale) {
163163
READ_WEIGHTS(post_attention_norm_scale);
164164
READ_WEIGHTS(post_ffw_norm_scale);
165165
}

0 commit comments

Comments
 (0)