@@ -295,8 +295,9 @@ HWY_NOINLINE void Attention(
295
295
constexpr size_t kHeads = TConfig::kHeads ;
296
296
constexpr size_t kKVHeads = TConfig::kKVHeads ;
297
297
constexpr size_t kSeqLen = TConfig::kSeqLen ;
298
- GEMMA_CONSTEXPR_SQRT const float kQueryScale =
298
+ GEMMA_CONSTEXPR_SQRT float kQueryScale =
299
299
1 .0f / Sqrt (static_cast <float >(kQKVDim ));
300
+
300
301
constexpr bool kIsMHA = TActivations::kIsMHA ; // Multi-Head Attention
301
302
const size_t batch_start = batch_and_query_start / num_queries;
302
303
const size_t num_tokens_and_queries = num_tokens * num_queries;
@@ -350,7 +351,9 @@ HWY_NOINLINE void Attention(
350
351
// Skip past the Q part of `q`, and copy KV to `kv`.
351
352
memcpy (kv, q + kQKVDim , 2 * kQKVDim * sizeof (float ));
352
353
}
353
- Rope (kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
354
+ if (TConfig::kPostQK == PostQKType::Rope) {
355
+ Rope (kv, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
356
+ }
354
357
});
355
358
356
359
static_assert ((kHeads % kKVHeads ) == 0 ,
@@ -373,7 +376,10 @@ HWY_NOINLINE void Attention(
373
376
activations.att .data () + head * kSeqLen
374
377
+ batch_and_query_idx * kHeads * kSeqLen ;
375
378
376
- Rope (q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
379
+ if (TConfig::kPostQK == PostQKType::Rope) {
380
+ Rope (q, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim , pos);
381
+ }
382
+
377
383
MulByConst (kQueryScale , q, kQKVDim );
378
384
379
385
// Compute Q dot K scores
@@ -465,10 +471,12 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
465
471
namespace hn = hwy::HWY_NAMESPACE;
466
472
using DF = hn::ScalableTag<float >;
467
473
using VF = hn::Vec<DF>;
468
- hn::Transform1 (DF (), activations.C1 .data (), kFFHiddenDim * num_tokens,
469
- activations.C2 .data (), [](DF df, VF v, VF mul) HWY_ATTR {
470
- return hn::Mul (mul, Gelu (df, v));
471
- });
474
+ if (TConfig::kActivation == ActivationType::Gelu) {
475
+ hn::Transform1 (DF (), activations.C1 .data (), kFFHiddenDim * num_tokens,
476
+ activations.C2 .data (), [](DF df, VF v, VF mul) HWY_ATTR {
477
+ return hn::Mul (mul, Gelu (df, v));
478
+ });
479
+ }
472
480
473
481
MatMul_4x4_Batch<kFFHiddenDim , kModelDim >(num_tokens, activations.C1 .data (),
474
482
layer_weights->linear_w .data (),
@@ -560,29 +568,34 @@ HWY_NOINLINE void TransformerLayer(
560
568
layer_weights, kv_caches, pool);
561
569
}
562
570
}
563
- if (TConfig::kPostNormScale ) {
571
+
572
+ if (TConfig::kPostNorm == PostNormType::Scale) {
564
573
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize >(
565
574
num_tokens_and_queries,
566
575
layer_weights->post_attention_norm_scale .data (),
567
576
activations.att_post2 .data (), kModelDim );
568
577
}
569
- AddFromBatched<kBatchSize * kQueryBatchSize >(num_tokens_and_queries,
570
- activations.att_post2 .data (),
571
- activations.x .data (), kModelDim );
578
+ if (TConfig::kResidual == ResidualType::Add) {
579
+ AddFromBatched<kBatchSize * kQueryBatchSize >(
580
+ num_tokens_and_queries, activations.att_post2 .data (),
581
+ activations.x .data (), kModelDim );
582
+ }
572
583
RMSNormBatched<kBatchSize * kQueryBatchSize >(
573
584
num_tokens_and_queries, activations.x .data (),
574
585
layer_weights->pre_ffw_norm_scale .data (),
575
586
activations.bf_pre_ffw_rms_out .data (), kModelDim );
576
587
FFW<TConfig, kBatchSize * kQueryBatchSize >(
577
588
activations, num_tokens_and_queries, layer_weights, pool);
578
- if (TConfig::kPostNormScale ) {
589
+ if (TConfig::kPostNorm == PostNormType::Scale ) {
579
590
RMSNormInplaceBatched<kBatchSize * kQueryBatchSize >(
580
591
num_tokens_and_queries, layer_weights->post_ffw_norm_scale .data (),
581
592
activations.ffw_out .data (), kModelDim );
582
593
}
583
- AddFromBatched<kBatchSize * kQueryBatchSize >(
584
- num_tokens_and_queries, activations.ffw_out .data (),
585
- activations.x .data (), kModelDim );
594
+ if (TConfig::kResidual == ResidualType::Add) {
595
+ AddFromBatched<kBatchSize * kQueryBatchSize >(
596
+ num_tokens_and_queries, activations.ffw_out .data (),
597
+ activations.x .data (), kModelDim );
598
+ }
586
599
}
587
600
588
601
template <class TConfig , size_t kBatchSize , size_t kQueryBatchSize >
0 commit comments