@@ -162,9 +162,12 @@ struct MHAToFlashAttention
162
162
rewriter.create <linalg::FillOp>(loc, minusInf, maxSlice).getResult (0 );
163
163
Value sumSliceFilled =
164
164
rewriter.create <linalg::FillOp>(loc, zero, sumSlice).getResult (0 );
165
+ Value collapsedOSliceFilled =
166
+ rewriter.create <linalg::FillOp>(loc, zero, collapsedOSlice)
167
+ .getResult (0 );
165
168
// create the innermost for loop for columnBlock
166
169
SmallVector<Value> innermostDestinationTensors{
167
- collapsedOSlice , maxSliceFilled, sumSliceFilled};
170
+ collapsedOSliceFilled , maxSliceFilled, sumSliceFilled};
168
171
auto columnBlockLoop = rewriter.create <scf::ForOp>(
169
172
loc,
170
173
getValueOrCreateConstantIndexOp (
@@ -241,9 +244,9 @@ struct MHAToFlashAttention
241
244
ValueRange args) {
242
245
Value constant = nestedBuilder.create <arith::ConstantOp>(
243
246
loc, nestedBuilder.getFloatAttr (dtype, rsqrtHead));
244
- Value added = nestedBuilder.create <arith::MulFOp>(
247
+ Value scaled = nestedBuilder.create <arith::MulFOp>(
245
248
loc, args[0 ], constant);
246
- nestedBuilder.create <linalg::YieldOp>(nestedLoc, added );
249
+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, scaled );
247
250
})
248
251
.getResult (0 );
249
252
Value add = rewriter
@@ -338,22 +341,32 @@ struct MHAToFlashAttention
338
341
ValueRange{PSlice, collapsedVSlice},
339
342
ValueRange{matmulVOutFilled})
340
343
.getResult (0 );
341
- Value expMaxDiffRecip =
342
- rewriter
343
- .create <linalg::ReciprocalOp>(loc, reducedShapeOut.getType (),
344
- ValueRange{expMaxDiff},
345
- ValueRange{reducedShapeOut})
346
- .getResult (0 );
347
- Value expMaxDiffRecipBroadcasted =
344
+ Value expMaxDiffBroadcasted =
348
345
rewriter
349
- .create <linalg::BroadcastOp>(loc, expMaxDiffRecip , VShapeOut,
346
+ .create <linalg::BroadcastOp>(loc, expMaxDiff , VShapeOut,
350
347
SmallVector<int64_t >{1 })
351
348
.getResults ()[0 ];
349
+ Value expMaxDiffBroadcastedEps =
350
+ rewriter
351
+ .create <linalg::GenericOp>(
352
+ loc, VShapeOut.getType (), ValueRange{expMaxDiffBroadcasted},
353
+ ValueRange{VShapeOut}, indexingMaps,
354
+ SmallVector<utils::IteratorType>(2 ,
355
+ utils::IteratorType::parallel),
356
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
357
+ ValueRange args) {
358
+ Value eps = nestedBuilder.create <arith::ConstantOp>(
359
+ loc, nestedBuilder.getFloatAttr (dtype, 1e-9 ));
360
+ Value added =
361
+ nestedBuilder.create <arith::AddFOp>(loc, args[0 ], eps);
362
+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, added);
363
+ })
364
+ .getResult (0 );
352
365
Value rescaledOSlice =
353
366
rewriter
354
- .create <linalg::MulOp >(
367
+ .create <linalg::DivOp >(
355
368
loc, VShapeOut.getType (),
356
- ValueRange{prevOSlice, expMaxDiffRecipBroadcasted },
369
+ ValueRange{prevOSlice, expMaxDiffBroadcastedEps },
357
370
ValueRange{VShapeOut})
358
371
.getResult (0 );
359
372
Value newOSlice =
@@ -372,25 +385,19 @@ struct MHAToFlashAttention
372
385
sumSliceFinal = innermostLoopResults[2 ];
373
386
Value sliceShapeOut =
374
387
rewriter.create <tensor::EmptyOp>(loc, reducedShape, dtype);
375
- Value sumSliceFinalRecip =
376
- rewriter
377
- .create <linalg::ReciprocalOp>(loc, sliceShapeOut.getType (),
378
- ValueRange{sumSliceFinal},
379
- ValueRange{sliceShapeOut})
380
- .getResult (0 );
381
388
Value broadcastedSliceShapeOut =
382
389
rewriter.create <tensor::EmptyOp>(loc, VShape, dtype);
383
- Value sumSliceFinalRecipBroadcasted =
390
+ Value sumSliceFinalBroadcasted =
384
391
rewriter
385
- .create <linalg::BroadcastOp>(loc, sumSliceFinalRecip ,
392
+ .create <linalg::BroadcastOp>(loc, sumSliceFinal ,
386
393
broadcastedSliceShapeOut,
387
394
SmallVector<int64_t >{1 })
388
395
.getResults ()[0 ];
389
396
Value rescaledOSliceFinal =
390
397
rewriter
391
- .create <linalg::MulOp >(
398
+ .create <linalg::DivOp >(
392
399
loc, broadcastedSliceShapeOut.getType (),
393
- ValueRange{sumSliceFinalRecipBroadcasted, OSliceFinal },
400
+ ValueRange{OSliceFinal, sumSliceFinalBroadcasted },
394
401
ValueRange{broadcastedSliceShapeOut})
395
402
.getResult (0 );
396
403
SmallVector<OpFoldResult> outputOffsets;
0 commit comments