|
17 | 17 | namespace mlir {
|
18 | 18 | namespace onednn_graph {
|
19 | 19 |
|
20 |
| -// https://github.com/onnx/onnx/blob/main/docs/Broadcasting.md |
21 |
| -template <typename ShapeRange> |
22 |
| -static LogicalResult inferBroadcastShape( |
23 |
| - ShapeRange operands, SmallVector<int64_t> &outShape, |
24 |
| - const std::function<ShapeAdaptor(ShapeRange, size_t)> &getShapeIdx) { |
25 |
| - int64_t outRank = 0; |
26 |
| - for (size_t i = 0; i < operands.size(); i++) { |
27 |
| - auto shape = getShapeIdx(operands, i); |
28 |
| - if (!shape.hasRank()) { |
29 |
| - return failure(); |
30 |
| - } |
31 |
| - outRank = std::max(outRank, shape.getRank()); |
32 |
| - } |
33 |
| - // Start with all 1 dim |
34 |
| - outShape.clear(); |
35 |
| - outShape.resize(outRank, 1); |
36 |
| - // Scan each shape for match dims |
37 |
| - for (size_t i = 0; i < operands.size(); i++) { |
38 |
| - auto shape = getShapeIdx(operands, i); |
39 |
| - auto diff = outShape.size() - shape.getRank(); |
40 |
| - for (int64_t j = 0; j < shape.getRank(); j++) { |
41 |
| - auto dim1 = outShape[diff + j]; |
42 |
| - auto dim2 = shape.getDimSize(j); |
43 |
| - auto resolvedDim = dim1; |
44 |
| - |
45 |
| - if (dim1 == 1) { |
46 |
| - resolvedDim = dim2; |
47 |
| - } else if (dim2 == 1) { |
48 |
| - resolvedDim = dim1; |
49 |
| - } else if (dim1 != dim2) { |
50 |
| - return failure(); |
51 |
| - } |
52 |
| - outShape[diff + j] = resolvedDim; |
53 |
| - } |
54 |
| - } |
55 |
| - return success(); |
56 |
| -} |
57 |
| - |
58 | 20 | LogicalResult onednn_graph::AddOp::inferReturnTypeComponents(
|
59 | 21 | MLIRContext *context, ::std::optional<Location> location,
|
60 | 22 | ValueShapeRange operands, DictionaryAttr attributes,
|
61 | 23 | OpaqueProperties properties, RegionRange regions,
|
62 | 24 | SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
63 | 25 | llvm::SmallVector<int64_t> outShape;
|
64 |
| - auto resultTy = dyn_cast<TensorType>(operands.front().getType()); |
65 |
| - auto getShapeIdx = [](ValueShapeRange operands, size_t i) { |
66 |
| - return operands.getShape(i); |
| 26 | + auto resultTy = dyn_cast<ShapedType>(operands.front().getType()); |
| 27 | + auto getShapeIdx = [&operands](size_t i) { |
| 28 | + return operands.getTypes()[i].dyn_cast<ShapedType>().getShape(); |
67 | 29 | };
|
68 |
| - auto ret = |
69 |
| - inferBroadcastShape<ValueShapeRange>(operands, outShape, getShapeIdx); |
| 30 | + |
| 31 | + auto ret = OpTrait::util::getBroadcastedShape(getShapeIdx(0), getShapeIdx(1), |
| 32 | + outShape); |
70 | 33 | inferredReturnShapes.push_back(
|
71 | 34 | ShapedTypeComponents(outShape, resultTy.getElementType()));
|
72 |
| - return ret; |
| 35 | + return LogicalResult::success(ret); |
73 | 36 | }
|
74 | 37 |
|
75 | 38 | LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
|
@@ -158,22 +121,21 @@ LogicalResult onednn_graph::MatMulOp::inferReturnTypeComponents(
|
158 | 121 | // Not supported
|
159 | 122 | return failure();
|
160 | 123 | }
|
161 |
| - auto getShapeIdx = [](ArrayRef<ShapeAdaptor> operands, size_t i) { |
162 |
| - return operands[i]; |
163 |
| - }; |
164 | 124 | // final shape
|
165 | 125 | auto retShape = ShapedTypeComponents(outShape, lhsShape.getElementType());
|
166 | 126 | inferredReturnShapes.push_back(retShape);
|
167 | 127 | // check for bias broadcasting
|
168 | 128 | if (adaptor.getBias()) {
|
169 |
| - ShapeAdaptor biasShape(adaptor.getBias().getType()); |
170 |
| - ShapeAdaptor matShape(retShape); |
| 129 | + auto biasType = adaptor.getBias().getType(); |
| 130 | + ShapeAdaptor biasShape(biasType); |
| 131 | + |
171 | 132 | bool biasRankMatch = biasShape.getRank() == 1 ||
|
172 | 133 | biasShape.getRank() == (int64_t)outShape.size();
|
173 |
| - SmallVector<int64_t> bcastShape; |
| 134 | + SmallVector<int64_t> resultShape; |
174 | 135 | if (!biasRankMatch ||
|
175 |
| - failed(inferBroadcastShape<ArrayRef<ShapeAdaptor>>( |
176 |
| - {matShape, biasShape}, bcastShape, getShapeIdx))) { |
| 136 | + !OpTrait::util::getBroadcastedShape( |
| 137 | + retShape.getDims(), biasType.dyn_cast<ShapedType>().getShape(), |
| 138 | + resultShape)) { |
177 | 139 | return failure();
|
178 | 140 | }
|
179 | 141 | }
|
|
0 commit comments