Skip to content

Commit 9d7b57a

Browse files
committed
init sdpa op and flash attention pass
1 parent d69856f commit 9d7b57a

File tree

5 files changed

+145
-1
lines changed

5 files changed

+145
-1
lines changed

include/gc/Dialect/Linalgx/LinalgxStructuredOps.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
2323
include "mlir/Interfaces/SideEffectInterfaces.td"
2424
include "mlir/IR/OpAsmInterface.td"
2525

26+
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
27+
Op<LinalgxDialect, mnemonic, traits>;
28+
2629
// Base Tablegen class for Linalg ops.
2730
// Linalg ops that correspond to library calls operate on ShapedType as their
2831
// first operands. These may be optionally followed by non-view operands
@@ -312,4 +315,22 @@ def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
312315
}];
313316
}
314317

318+
def Linalgx_ScaledDotProductAttentionOp
319+
: Linalgx_Op<"scaled_dot_product_attention",
320+
[AttrSizedOperandSegments,
321+
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
322+
let summary = "Attention structure.";
323+
let description = [{
324+
Q, K, V, attention_mask.
325+
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
326+
}];
327+
let arguments = (ins
328+
Variadic<TensorOrMemref>:$inputs,
329+
Variadic<TensorOrMemref>:$outputs);
330+
let results = (outs Variadic<TensorOrMemref>:$results);
331+
let regions = (region AnyRegion:$region);
332+
333+
let hasVerifier = 1;
334+
}
335+
315336
#endif // LINALGX_STRUCTURED_OPS

include/gc/Transforms/Passes.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,18 @@ def DeepTileContractionNamedOp
3434
];
3535
}
3636

37-
def GCCPUPipeline : Pass<"gc-cpu-pipeline"> {
37+
def FlashAttentionConversion
38+
: Pass<"flash-attention-conversion", "func::FuncOp"> {
39+
let summary = "Flash Attention Conversion";
40+
let description =
41+
[{The pass converts MHA to flash attention implementation.}];
42+
let dependentDialects = [
43+
"func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
44+
"tensor::TensorDialect"
45+
];
46+
}
47+
48+
def GCCPUPipeline: Pass<"gc-cpu-pipeline"> {
3849
let summary = "All-in-one pipeline for GC for CPU";
3950
let dependentDialects = [
4051
"onednn_graph::OneDNNGraphDialect", "tensor::TensorDialect",

lib/gc/Dialect/Linalgx/LinalgxOps.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1010
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
1111
#include "mlir/IR/OpImplementation.h"
12+
#include <utility>
1213

1314
//===----------------------------------------------------------------------===//
1415
// Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -613,6 +614,58 @@ void MultiBatchMatmulOp::getEffects(
613614
getDpsInits());
614615
}
615616

617+
//===----------------------------------------------------------------------===//
618+
// ScaledDotProductAttentionOp
619+
//===----------------------------------------------------------------------===//
620+
621+
LogicalResult ScaledDotProductAttentionOp::verify() { return success(); }
622+
623+
/// Given an N-dimensional tensor x, this method converts
624+
/// softmax(x) to the following sequence of operations:
625+
///
626+
/// 1. transpose ins[1]
627+
/// 2. matmul ins[0] @ 1
628+
///
629+
FailureOr<SmallVector<Value>>
630+
ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
631+
OpBuilder::InsertionGuard guard(b);
632+
b.setInsertionPoint(*this);
633+
Location loc = getLoc();
634+
Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2],
635+
mask = getInputs()[3];
636+
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
637+
auto shape = cast<RankedTensorType>(query.getType()).getShape();
638+
639+
SmallVector<int64_t> permutation{0, 1, 3, 2};
640+
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
641+
auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype);
642+
auto transpose = b.create<linalg::TransposeOp>(
643+
/*location=*/loc,
644+
/*inputs=*/key,
645+
/*outputs=*/transposeOut,
646+
/*permutation=*/permutation);
647+
648+
SmallVector<int64_t> matmulQKShape{shape[0], shape[1], shape[2], shape[2]};
649+
auto matmulQKOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
650+
auto matmulQK = b.create<linalgx::MultiBatchMatmulOp>(
651+
/*location=*/loc, matmulQKOut.getResult().getType(),
652+
/*inputs=*/ValueRange{query, transpose->getResult(0)},
653+
/*outputs=*/ValueRange{matmulQKOut.getResult()});
654+
655+
auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
656+
auto add = b.create<linalg::AddOp>(
657+
/*location=*/loc, addOut.getResult().getType(),
658+
/*inputs=*/ValueRange{matmulQK->getResult(0), mask},
659+
/*outputs=*/ValueRange{addOut.getResult()});
660+
661+
auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
662+
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
663+
/*location=*/loc, matmulVOut.getResult().getType(),
664+
/*inputs=*/ValueRange{add->getResult(0), value},
665+
/*outputs=*/ValueRange{matmulVOut.getResult()});
666+
return SmallVector<Value>{matmulV.getResults()[0]};
667+
}
668+
616669
/////// Operations corresponding to library calls defined with Tablegen ////////
617670

618671
#define GET_OP_CLASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ add_mlir_library(GCPasses
1111
OneDNNGraphToLinalg.cpp
1212
Pipeline.cpp
1313
DeepTileContractionNamedOp.cpp
14+
FlashAttentionConversion.cpp
1415
Tiling.cpp
1516

1617
ADDITIONAL_HEADER_DIRS
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===-- FlashAttentionConversion.cpp ----------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
10+
#include "./Tiling.hpp"
11+
#include "gc/Dialect/Arith/Utils/EasyBuild.h"
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "gc/IR/EasyBuild.h"
14+
#include "gc/IR/EasyBuildSCF.h"
15+
#include "mlir/AsmParser/AsmParser.h"
16+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
20+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
22+
#include "mlir/Dialect/SCF/IR/SCF.h"
23+
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
24+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
25+
#include "mlir/IR/Builders.h"
26+
#include "mlir/IR/Operation.h"
27+
#include "mlir/IR/PatternMatch.h"
28+
#include "mlir/IR/Region.h"
29+
#include "mlir/IR/Visitors.h"
30+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
31+
#include "mlir/Interfaces/TilingInterface.h"
32+
#include "mlir/Parser/Parser.h"
33+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34+
#include <iostream>
35+
36+
#include "gc/Transforms/Passes.h"
37+
38+
#include <llvm/Support/Debug.h>
39+
40+
#include <memory>
41+
42+
namespace mlir {
43+
namespace gc {
44+
#define GEN_PASS_DEF_FLASHATTENTIONCONVERSION
45+
#include "gc/Transforms/Passes.h.inc"
46+
47+
namespace {
48+
struct FlashAttentionConversion
49+
: public impl::FlashAttentionConversionBase<FlashAttentionConversion> {
50+
public:
51+
void runOnOperation() final {
52+
return;
53+
}
54+
};
55+
56+
} // namespace
57+
} // namespace gc
58+
} // namespace mlir

0 commit comments

Comments
 (0)