Skip to content

Commit edc6cc0

Browse files
committed
fix correctness issue
1 parent 0457df5 commit edc6cc0

File tree

2 files changed

+47
-40
lines changed

2 files changed

+47
-40
lines changed

lib/gc/Transforms/FlashAttentionConversion.cpp

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,12 @@ struct MHAToFlashAttention
162162
rewriter.create<linalg::FillOp>(loc, minusInf, maxSlice).getResult(0);
163163
Value sumSliceFilled =
164164
rewriter.create<linalg::FillOp>(loc, zero, sumSlice).getResult(0);
165+
Value collapsedOSliceFilled =
166+
rewriter.create<linalg::FillOp>(loc, zero, collapsedOSlice)
167+
.getResult(0);
165168
// create the innermost for loop for columnBlock
166169
SmallVector<Value> innermostDestinationTensors{
167-
collapsedOSlice, maxSliceFilled, sumSliceFilled};
170+
collapsedOSliceFilled, maxSliceFilled, sumSliceFilled};
168171
auto columnBlockLoop = rewriter.create<scf::ForOp>(
169172
loc,
170173
getValueOrCreateConstantIndexOp(
@@ -241,9 +244,9 @@ struct MHAToFlashAttention
241244
ValueRange args) {
242245
Value constant = nestedBuilder.create<arith::ConstantOp>(
243246
loc, nestedBuilder.getFloatAttr(dtype, rsqrtHead));
244-
Value added = nestedBuilder.create<arith::MulFOp>(
247+
Value scaled = nestedBuilder.create<arith::MulFOp>(
245248
loc, args[0], constant);
246-
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
249+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, scaled);
247250
})
248251
.getResult(0);
249252
Value add = rewriter
@@ -338,22 +341,32 @@ struct MHAToFlashAttention
338341
ValueRange{PSlice, collapsedVSlice},
339342
ValueRange{matmulVOutFilled})
340343
.getResult(0);
341-
Value expMaxDiffRecip =
342-
rewriter
343-
.create<linalg::ReciprocalOp>(loc, reducedShapeOut.getType(),
344-
ValueRange{expMaxDiff},
345-
ValueRange{reducedShapeOut})
346-
.getResult(0);
347-
Value expMaxDiffRecipBroadcasted =
344+
Value expMaxDiffBroadcasted =
348345
rewriter
349-
.create<linalg::BroadcastOp>(loc, expMaxDiffRecip, VShapeOut,
346+
.create<linalg::BroadcastOp>(loc, expMaxDiff, VShapeOut,
350347
SmallVector<int64_t>{1})
351348
.getResults()[0];
349+
Value expMaxDiffBroadcastedEps =
350+
rewriter
351+
.create<linalg::GenericOp>(
352+
loc, VShapeOut.getType(), ValueRange{expMaxDiffBroadcasted},
353+
ValueRange{VShapeOut}, indexingMaps,
354+
SmallVector<utils::IteratorType>(2,
355+
utils::IteratorType::parallel),
356+
[&](OpBuilder &nestedBuilder, Location nestedLoc,
357+
ValueRange args) {
358+
Value eps = nestedBuilder.create<arith::ConstantOp>(
359+
loc, nestedBuilder.getFloatAttr(dtype, 1e-9));
360+
Value added =
361+
nestedBuilder.create<arith::AddFOp>(loc, args[0], eps);
362+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
363+
})
364+
.getResult(0);
352365
Value rescaledOSlice =
353366
rewriter
354-
.create<linalg::MulOp>(
367+
.create<linalg::DivOp>(
355368
loc, VShapeOut.getType(),
356-
ValueRange{prevOSlice, expMaxDiffRecipBroadcasted},
369+
ValueRange{prevOSlice, expMaxDiffBroadcastedEps},
357370
ValueRange{VShapeOut})
358371
.getResult(0);
359372
Value newOSlice =
@@ -372,25 +385,19 @@ struct MHAToFlashAttention
372385
sumSliceFinal = innermostLoopResults[2];
373386
Value sliceShapeOut =
374387
rewriter.create<tensor::EmptyOp>(loc, reducedShape, dtype);
375-
Value sumSliceFinalRecip =
376-
rewriter
377-
.create<linalg::ReciprocalOp>(loc, sliceShapeOut.getType(),
378-
ValueRange{sumSliceFinal},
379-
ValueRange{sliceShapeOut})
380-
.getResult(0);
381388
Value broadcastedSliceShapeOut =
382389
rewriter.create<tensor::EmptyOp>(loc, VShape, dtype);
383-
Value sumSliceFinalRecipBroadcasted =
390+
Value sumSliceFinalBroadcasted =
384391
rewriter
385-
.create<linalg::BroadcastOp>(loc, sumSliceFinalRecip,
392+
.create<linalg::BroadcastOp>(loc, sumSliceFinal,
386393
broadcastedSliceShapeOut,
387394
SmallVector<int64_t>{1})
388395
.getResults()[0];
389396
Value rescaledOSliceFinal =
390397
rewriter
391-
.create<linalg::MulOp>(
398+
.create<linalg::DivOp>(
392399
loc, broadcastedSliceShapeOut.getType(),
393-
ValueRange{sumSliceFinalRecipBroadcasted, OSliceFinal},
400+
ValueRange{OSliceFinal, sumSliceFinalBroadcasted},
394401
ValueRange{broadcastedSliceShapeOut})
395402
.getResult(0);
396403
SmallVector<OpFoldResult> outputOffsets;

test/gc/Transform/flashAttention.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
11
// RUN: gc-opt --split-input-file --flash-attention-conversion --gc-cpu-pipeline %s | gc-cpu-runner -e main -entry-point-result=void
2+
// | FileCheck --allow-empty
23

3-
func.func @flash_attention(%arg0: tensor<1x16x384x64xf32>, %arg1: tensor<1x16x384x64xf32>, %arg2: tensor<1x16x384x64xf32>, %arg3: tensor<1x16x384x384xf32>) -> tensor<1x16x384x64xf32> {
4-
%0 = tensor.empty() : tensor<1x16x384x64xf32>
5-
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) outs(%0 : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
6-
return %1 : tensor<1x16x384x64xf32>
4+
func.func @flash_attention(%arg0: tensor<4x4x384x64xf32>, %arg1: tensor<4x4x384x64xf32>, %arg2: tensor<4x4x384x64xf32>, %arg3: tensor<4x4x384x384xf32>) -> tensor<4x4x384x64xf32> {
5+
%0 = tensor.empty() : tensor<4x4x384x64xf32>
6+
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x384xf32>) outs(%0 : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
7+
return %1 : tensor<4x4x384x64xf32>
78
}
89

910
func.func @main() {
10-
%cst = arith.constant 1.000000e+00 : f32
11+
%cst = arith.constant 4.000000e+00 : f32
1112

12-
%QKVShape = tensor.empty() : tensor<1x16x384x64xf32>
13-
%maskShape = tensor.empty() : tensor<1x16x384x384xf32>
13+
%QKVShape = tensor.empty() : tensor<4x4x384x64xf32>
14+
%maskShape = tensor.empty() : tensor<4x4x384x384xf32>
1415

15-
%Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
16-
%K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
17-
%V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
18-
%mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<1x16x384x384xf32>) -> tensor<1x16x384x384xf32>
16+
%Q = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
17+
%K = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
18+
%V = linalg.fill ins(%cst : f32) outs(%QKVShape : tensor<4x4x384x64xf32>) -> tensor<4x4x384x64xf32>
19+
%mask = linalg.fill ins(%cst : f32) outs(%maskShape : tensor<4x4x384x384xf32>) -> tensor<4x4x384x384xf32>
1920

2021
%out = func.call @flash_attention(%Q, %K, %V, %mask) :
21-
(tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>)
22-
-> (tensor<1x16x384x64xf32>)
22+
(tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x64xf32>, tensor<4x4x384x384xf32>)
23+
-> (tensor<4x4x384x64xf32>)
2324

2425
%idx = arith.constant 0 : index
25-
%val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<1x16x384x64xf32>
26-
cpuruntime.printf "output[0, 0, 0]: %f\n" %val : f32
26+
%val = tensor.extract %out[%idx, %idx, %idx, %idx] : tensor<4x4x384x64xf32>
27+
cpuruntime.printf "output[0, 0, 0, 0]: %f\n" %val : f32
2728

2829
return
2930
}
30-
// CHECK: output[0, 0, 0]: 1.0
31-
31+
// CHECK: output[0, 0, 0]: 4.0

0 commit comments

Comments
 (0)