Skip to content

Commit 6437c35

Browse files
authored
[OneDNN Graph Dialect] Use Broadcast Trait and organize data types (#81)
1 parent 51ac048 commit 6437c35

File tree

4 files changed

+46
-70
lines changed

4 files changed

+46
-70
lines changed

include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/OpDefinition.h"
1717
#include "mlir/Interfaces/InferTypeOpInterface.h"
1818
#include "mlir/Interfaces/SideEffectInterfaces.h"
19+
#include "mlir/Dialect/Traits.h"
1920

2021
#define GET_OP_CLASSES
2122
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h.inc"

include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.td

+13-12
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,20 @@ class OneDNNGraph_Op<string mnemonic, list<Trait> traits = []> :
2525
Op<OneDNNGraphDialect, mnemonic, traits>;
2626

2727
class OneDNNGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
28-
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType]> {
29-
let arguments = (ins OneDNNGraph_LogicalTensor:$input_0,
30-
OneDNNGraph_LogicalTensor:$input_1);
31-
let results = (outs OneDNNGraph_LogicalTensor:$result);
28+
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType,
29+
ResultsBroadcastableShape]> {
30+
let arguments = (ins OneDNNGraph_FloatTensor:$input_0,
31+
OneDNNGraph_FloatTensor:$input_1);
32+
let results = (outs OneDNNGraph_FloatTensor:$result);
3233

3334
let assemblyFormat =
3435
"operands attr-dict `:` functional-type(operands, results)";
3536
}
3637

3738
class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
3839
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
39-
let arguments = (ins OneDNNGraph_LogicalTensor:$operand);
40-
let results = (outs OneDNNGraph_LogicalTensor:$result);
40+
let arguments = (ins OneDNNGraph_FloatTensor:$operand);
41+
let results = (outs OneDNNGraph_FloatTensor:$result);
4142

4243
let assemblyFormat =
4344
"operands attr-dict `:` functional-type(operands, results)";
@@ -51,15 +52,15 @@ def OneDNNGraph_MatMulOp :
5152
OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
5253
let summary = "Generalized matrix multiplication";
5354
let description = [{
54-
`https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html`
55+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_matmul.html`
5556
}];
5657

57-
let arguments = (ins OneDNNGraph_LogicalTensor:$input_a,
58-
OneDNNGraph_LogicalTensor:$input_b,
58+
let arguments = (ins OneDNNGraph_FloatTensor:$input_a,
59+
OneDNNGraph_FloatTensor:$input_b,
5960
Optional<OneDNNGraph_LogicalTensor>:$bias,
6061
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
6162
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
62-
let results = (outs OneDNNGraph_LogicalTensor:$result);
63+
let results = (outs OneDNNGraph_FloatTensor:$result);
6364

6465
let assemblyFormat =
6566
"operands attr-dict `:` functional-type(operands, results)";
@@ -68,14 +69,14 @@ def OneDNNGraph_MatMulOp :
6869
def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> {
6970
let summary = "element-wise relu";
7071
let description = [{
71-
`https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html`
72+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_relu.html`
7273
}];
7374
}
7475

7576
def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> {
7677
let summary = "element-wise addition with multi-directional broadcast";
7778
let description = [{
78-
`https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html`
79+
`https://oneapi-src.github.io/oneDNN/dev_guide_op_add.html`
7980
}];
8081
}
8182

include/gc/Dialect/OneDNNGraph/OneDNNGraphTypes.td

+18-6
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,26 @@ include "OneDNNGraphDialect.td"
1717
// OneDNNGraph type definitions
1818
//===----------------------------------------------------------------------===//
1919

20+
//===----------------------------------------------------------------------===//
21+
// Floating-point types.
22+
//===----------------------------------------------------------------------===//
23+
def OneDNNGraph_Float : AnyTypeOf<[F32,
24+
F16,
25+
BF16]>;
26+
27+
//===----------------------------------------------------------------------===//
28+
// Integer types.
29+
//===----------------------------------------------------------------------===//
30+
31+
def OneDNNGraph_Int : AnyTypeOf<[SI<8>,
32+
UI<8>]>;
33+
2034
def OneDNNGraph_DataType : AnyTypeOf<[
21-
F16,
22-
BF16,
23-
F32,
24-
SI<32>,
25-
SI<8>,
26-
UI<8>]>;
35+
OneDNNGraph_Float,
36+
OneDNNGraph_Int
37+
]>;
2738

2839
def OneDNNGraph_LogicalTensor : TensorOf<[OneDNNGraph_DataType]>;
40+
def OneDNNGraph_FloatTensor : TensorOf<[OneDNNGraph_Float]>;
2941

3042
#endif // ONEDNNGRAPH_TYPES

lib/gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp

+14-52
Original file line numberDiff line numberDiff line change
@@ -17,59 +17,22 @@
1717
namespace mlir {
1818
namespace onednn_graph {
1919

20-
// https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md
21-
template <typename ShapeRange>
22-
static LogicalResult inferBroadcastShape(
23-
ShapeRange operands, SmallVector<int64_t> &outShape,
24-
const std::function<ShapeAdaptor(ShapeRange, size_t)> &getShapeIdx) {
25-
int64_t outRank = 0;
26-
for (size_t i = 0; i < operands.size(); i++) {
27-
auto shape = getShapeIdx(operands, i);
28-
if (!shape.hasRank()) {
29-
return failure();
30-
}
31-
outRank = std::max(outRank, shape.getRank());
32-
}
33-
// Start with all 1 dim
34-
outShape.clear();
35-
outShape.resize(outRank, 1);
36-
// Scan each shape for match dims
37-
for (size_t i = 0; i < operands.size(); i++) {
38-
auto shape = getShapeIdx(operands, i);
39-
auto diff = outShape.size() - shape.getRank();
40-
for (int64_t j = 0; j < shape.getRank(); j++) {
41-
auto dim1 = outShape[diff + j];
42-
auto dim2 = shape.getDimSize(j);
43-
auto resolvedDim = dim1;
44-
45-
if (dim1 == 1) {
46-
resolvedDim = dim2;
47-
} else if (dim2 == 1) {
48-
resolvedDim = dim1;
49-
} else if (dim1 != dim2) {
50-
return failure();
51-
}
52-
outShape[diff + j] = resolvedDim;
53-
}
54-
}
55-
return success();
56-
}
57-
5820
LogicalResult onednn_graph::AddOp::inferReturnTypeComponents(
5921
MLIRContext *context, ::std::optional<Location> location,
6022
ValueShapeRange operands, DictionaryAttr attributes,
6123
OpaqueProperties properties, RegionRange regions,
6224
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
6325
llvm::SmallVector<int64_t> outShape;
64-
auto resultTy = dyn_cast<TensorType>(operands.front().getType());
65-
auto getShapeIdx = [](ValueShapeRange operands, size_t i) {
66-
return operands.getShape(i);
26+
auto resultTy = dyn_cast<ShapedType>(operands.front().getType());
27+
auto getShapeIdx = [&operands](size_t i) {
28+
return operands.getTypes()[i].dyn_cast<ShapedType>().getShape();
6729
};
68-
auto ret =
69-
inferBroadcastShape<ValueShapeRange>(operands, outShape, getShapeIdx);
30+
31+
auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1),
32+
outShape);
7033
inferredReturnShapes.push_back(
7134
ShapedTypeComponents(outShape, resultTy.getElementType()));
72-
return ret;
35+
return LogicalResult::success(ret);
7336
}
7437

7538
LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
@@ -158,22 +121,21 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
158121
// Not supported
159122
return failure();
160123
}
161-
auto getShapeIdx = [](ArrayRef<ShapeAdaptor> operands, size_t i) {
162-
return operands[i];
163-
};
164124
// final shape
165125
auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType());
166126
inferredReturnShapes.push_back(retShape);
167127
// check for bias broadcasting
168128
if (adaptor.getBias()) {
169-
ShapeAdaptor biasShape(adaptor.getBias().getType());
170-
ShapeAdaptor matShape(retShape);
129+
auto biasType = adaptor.getBias().getType();
130+
ShapeAdaptor biasShape(biasType);
131+
171132
bool biasRankMatch = biasShape.getRank() == 1 ||
172133
biasShape.getRank() == (int64_t)outShape.size();
173-
SmallVector<int64_t> bcastShape;
134+
SmallVector<int64_t> resultShape;
174135
if (!biasRankMatch ||
175-
failed(inferBroadcastShape<ArrayRef<ShapeAdaptor>>(
176-
{matShape, biasShape}, bcastShape, getShapeIdx))) {
136+
!OpTrait::util::getBroadcastedShape(
137+
retShape.getDims(), biasType.dyn_cast<ShapedType>().getShape(),
138+
resultShape)) {
177139
return failure();
178140
}
179141
}

0 commit comments

Comments
 (0)