@@ -33,7 +33,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
33
33
if (failed(verifyCompatibleShape(type1, type2))) return false;
34
34
return tensorsHaveSameElType(type1.cast<ShapedType>(),
35
35
type2.cast<ShapedType>(), ignoreFpPrecision);
36
- @@ -785,6 +769,17 @@
36
+ @@ -785,6 +769,19 @@
37
37
return success();
38
38
}
39
39
@@ -42,16 +42,18 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
42
42
+ Optional<ArrayAttr> maybeArrayAttr) {
43
43
+ if (!maybeArrayAttr.has_value()) return success();
44
44
+ auto arrayAttr = maybeArrayAttr.value();
45
- + return !arrayAttr || arrayAttr.size() == 2 || arrayAttr.empty()
45
+ + if (!arrayAttr) return success();
46
+ + return arrayAttr.size() <= 2
46
47
+ ? success()
47
- + : emitOptionalError(
48
- + loc, "expects precision config to be null or of size 2.");
48
+ + : emitOptionalError(loc,
49
+ + "expects precision config to be empty or have "
50
+ + "<= 2 elements.");
49
51
+ }
50
52
+
51
53
// Verifies the following properties:
52
54
// P1. The input, kernel, and output spatial-dimentions are valid.
53
55
// P2. Given,
54
- @@ -805,14 +800 ,16 @@
56
+ @@ -805,14 +802 ,16 @@
55
57
// dim(lhs, f) / fgc = dim(rhs, i)
56
58
// * dim(rhs, o) (or dim(output, f')) % bgc == 0 and
57
59
// dim(rhs, o) (or dim(output, f')) % fgc == 0
@@ -71,7 +73,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
71
73
// P1.
72
74
if (failed(isSpatialDimensionsValid(
73
75
lhs, inputBatchDimension, inputFeatureDimension,
74
- @@ -892,18 +889 ,11 @@
76
+ @@ -892,18 +891 ,11 @@
75
77
"batch_group_count. Got batch_group_count = ",
76
78
batchGroupCount, ".");
77
79
@@ -95,15 +97,15 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
95
97
}
96
98
97
99
LogicalResult inferDotShape(RankedTensorType lhs, RankedTensorType rhs,
98
- @@ -2804,7 +2794 ,6 @@
100
+ @@ -2804,7 +2796 ,6 @@
99
101
* P3. Verify and collect the window atributes.
100
102
* P4. Verify precision_config attribute.
101
103
* P5. Verify the return shape.
102
104
- * TODO(b/232574102): Verify the element-type of return-value.
103
105
*/
104
106
LogicalResult verifyConvolutionOp(
105
107
Optional<Location> location, Value lhs, Value rhs,
106
- @@ -2839,11 +2828 ,11 @@
108
+ @@ -2839,11 +2830 ,11 @@
107
109
108
110
// P2.
109
111
if (failed(verifyConvolutionAttributes(
@@ -117,7 +119,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
117
119
return failure();
118
120
119
121
if ((size_t)numDims != inputSpatialDimensions.size() + 2)
120
- @@ -2878,11 +2867 ,7 @@
122
+ @@ -2878,11 +2869 ,7 @@
121
123
*rhsDilationOrErr, *windowReversalOrErr, location);
122
124
if (failed(windowOrErr)) return failure();
123
125
@@ -130,7 +132,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
130
132
SmallVector<int64_t> outputDimensions(lhsType.getShape().size(),
131
133
ShapedType::kDynamic);
132
134
133
- @@ -2920,7 +2905 ,8 @@
135
+ @@ -2920,7 +2907 ,8 @@
134
136
}
135
137
136
138
LogicalResult verifyDotOp(Optional<Location> location, Value lhs, Value rhs,
@@ -140,7 +142,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
140
142
auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
141
143
auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
142
144
auto resultType = result.getType().dyn_cast<RankedTensorType>();
143
- @@ -2937,7 +2923 ,7 @@
145
+ @@ -2937,7 +2925 ,7 @@
144
146
location, "inferred shape '", dimSizesToString(inferredShape), "' ",
145
147
"is incompatible with return type of operation ", resultType, "");
146
148
@@ -239,19 +241,64 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/d
239
241
diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir
240
242
--- stablehlo/stablehlo/tests/ops_stablehlo.mlir
241
243
+++ stablehlo/stablehlo/tests/ops_stablehlo.mlir
242
- @@ -1485,14 +1485,6 @@
243
- // CHECK-LABEL: func @dot_precision_config
244
- func.func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
245
- %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision HIGH>, #stablehlo<precision HIGHEST>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
244
+ @@ -1490,14 +1490,6 @@
245
+
246
+ // -----
247
+
248
+ - func.func @dot_precision_invalid_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
249
+ - // expected-error@+1 {{expects precision config to be null or of size 2.}}
250
+ - %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision HIGH>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
246
251
- func.return %0: tensor<2x2xi32>
247
252
- }
248
253
-
249
254
- // -----
250
255
-
251
- - func.func @dot_precision_invalid_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
252
- - // expected-error@+1 {{expects precision config to be null or of size 2.}}
253
- - %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision HIGH>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
254
- func.return %0: tensor<2x2xi32>
256
+ func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
257
+ // expected-error@+1 {{'precision_config' failed to satisfy constraint}}
258
+ %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #stablehlo<precision HIGHEST>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
259
+ @@ -3332,8 +3324,7 @@
260
+
261
+ // -----
262
+
263
+ - func.func @dot_general_invalid_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
264
+ - // expected-error@+1 {{expects precision config to be null or of size 2}}
265
+ + func.func @dot_general_one_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
266
+ %0 = "stablehlo.dot_general"(%arg0, %arg1) {
267
+ dot_dimension_numbers = #stablehlo.dot<
268
+ lhs_batching_dimensions = [0],
269
+ @@ -3342,6 +3333,22 @@
270
+ rhs_contracting_dimensions = [1]
271
+ >,
272
+ precision_config = [#stablehlo<precision DEFAULT>]
273
+ + } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>
274
+ + func.return %0 : tensor<2x4x5xf32>
275
+ + }
276
+ +
277
+ + // -----
278
+ +
279
+ + func.func @dot_general_three_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
280
+ + // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}}
281
+ + %0 = "stablehlo.dot_general"(%arg0, %arg1) {
282
+ + dot_dimension_numbers = #stablehlo.dot<
283
+ + lhs_batching_dimensions = [0],
284
+ + rhs_batching_dimensions = [0],
285
+ + lhs_contracting_dimensions = [1],
286
+ + rhs_contracting_dimensions = [1]
287
+ + >,
288
+ + precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
289
+ } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>
290
+ func.return %0 : tensor<2x4x5xf32>
255
291
}
292
+ diff --ruN a/stablehlo/stablehlo/tests/verify_conv.mlir b/stablehlo/stablehlo/tests/verify_conv.mlir
293
+ --- stablehlo/stablehlo/tests/verify_conv.mlir
294
+ +++ stablehlo/stablehlo/tests/verify_conv.mlir
295
+ @@ -963,7 +963,7 @@
256
296
297
+ func.func @conv_invalid_precision_config(%arg0: tensor<3x2xf16>,
298
+ %arg1: tensor<2x2xf16>) -> tuple<tensor<3x2xf16>> {
299
+ - // expected-error@+1{{expects precision config to be null or of size 2.}}
300
+ + // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}}
301
+ %0 = stablehlo.convolution(%arg0, %arg1)
302
+ dim_numbers = [b, f]x[i, o]->[b, f],
303
+ window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [],
257
304
0 commit comments