Skip to content

Commit e853be9

Browse files
Eugene Burmakocopybara-github
Eugene Burmako
authored andcommitted
Also allow one item in precision_config
Recently we've added a verifier for precision_config to check that it contains either 0 or 2 elements, in accordance with the spec: https://github.com/openxla/stablehlo/blob/main/docs/spec.md. Turns out that there are some producers that create precision configs with 1 element. I've opened a ticket to better understand this (openxla/stablehlo#879), but in the meanwhile this CL proposes to relax the verifier. PiperOrigin-RevId: 499935917
1 parent f6d7860 commit e853be9

File tree

3 files changed

+85
-22
lines changed

3 files changed

+85
-22
lines changed

third_party/stablehlo/temporary.patch

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
3333
if (failed(verifyCompatibleShape(type1, type2))) return false;
3434
return tensorsHaveSameElType(type1.cast<ShapedType>(),
3535
type2.cast<ShapedType>(), ignoreFpPrecision);
36-
@@ -785,6 +769,17 @@
36+
@@ -785,6 +769,19 @@
3737
return success();
3838
}
3939

@@ -42,16 +42,18 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
4242
+ Optional<ArrayAttr> maybeArrayAttr) {
4343
+ if (!maybeArrayAttr.has_value()) return success();
4444
+ auto arrayAttr = maybeArrayAttr.value();
45-
+ return !arrayAttr || arrayAttr.size() == 2 || arrayAttr.empty()
45+
+ if (!arrayAttr) return success();
46+
+ return arrayAttr.size() <= 2
4647
+ ? success()
47-
+ : emitOptionalError(
48-
+ loc, "expects precision config to be null or of size 2.");
48+
+ : emitOptionalError(loc,
49+
+ "expects precision config to be empty or have "
50+
+ "<= 2 elements.");
4951
+}
5052
+
5153
// Verifies the following properties:
5254
// P1. The input, kernel, and output spatial-dimentions are valid.
5355
// P2. Given,
54-
@@ -805,14 +800,16 @@
56+
@@ -805,14 +802,16 @@
5557
// dim(lhs, f) / fgc = dim(rhs, i)
5658
// * dim(rhs, o) (or dim(output, f')) % bgc == 0 and
5759
// dim(rhs, o) (or dim(output, f')) % fgc == 0
@@ -71,7 +73,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
7173
// P1.
7274
if (failed(isSpatialDimensionsValid(
7375
lhs, inputBatchDimension, inputFeatureDimension,
74-
@@ -892,18 +889,11 @@
76+
@@ -892,18 +891,11 @@
7577
"batch_group_count. Got batch_group_count = ",
7678
batchGroupCount, ".");
7779

@@ -95,15 +97,15 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
9597
}
9698

9799
LogicalResult inferDotShape(RankedTensorType lhs, RankedTensorType rhs,
98-
@@ -2804,7 +2794,6 @@
100+
@@ -2804,7 +2796,6 @@
99101
* P3. Verify and collect the window atributes.
100102
* P4. Verify precision_config attribute.
101103
* P5. Verify the return shape.
102104
- * TODO(b/232574102): Verify the element-type of return-value.
103105
*/
104106
LogicalResult verifyConvolutionOp(
105107
Optional<Location> location, Value lhs, Value rhs,
106-
@@ -2839,11 +2828,11 @@
108+
@@ -2839,11 +2830,11 @@
107109

108110
// P2.
109111
if (failed(verifyConvolutionAttributes(
@@ -117,7 +119,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
117119
return failure();
118120

119121
if ((size_t)numDims != inputSpatialDimensions.size() + 2)
120-
@@ -2878,11 +2867,7 @@
122+
@@ -2878,11 +2869,7 @@
121123
*rhsDilationOrErr, *windowReversalOrErr, location);
122124
if (failed(windowOrErr)) return failure();
123125

@@ -130,7 +132,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
130132
SmallVector<int64_t> outputDimensions(lhsType.getShape().size(),
131133
ShapedType::kDynamic);
132134

133-
@@ -2920,7 +2905,8 @@
135+
@@ -2920,7 +2907,8 @@
134136
}
135137

136138
LogicalResult verifyDotOp(Optional<Location> location, Value lhs, Value rhs,
@@ -140,7 +142,7 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
140142
auto lhsType = lhs.getType().dyn_cast<RankedTensorType>();
141143
auto rhsType = rhs.getType().dyn_cast<RankedTensorType>();
142144
auto resultType = result.getType().dyn_cast<RankedTensorType>();
143-
@@ -2937,7 +2923,7 @@
145+
@@ -2937,7 +2925,7 @@
144146
location, "inferred shape '", dimSizesToString(inferredShape), "' ",
145147
"is incompatible with return type of operation ", resultType, "");
146148

@@ -239,19 +241,64 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.h b/stablehlo/stablehlo/d
239241
diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/tests/ops_stablehlo.mlir
240242
--- stablehlo/stablehlo/tests/ops_stablehlo.mlir
241243
+++ stablehlo/stablehlo/tests/ops_stablehlo.mlir
242-
@@ -1485,14 +1485,6 @@
243-
// CHECK-LABEL: func @dot_precision_config
244-
func.func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
245-
%0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision HIGH>, #stablehlo<precision HIGHEST>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
244+
@@ -1490,14 +1490,6 @@
245+
246+
// -----
247+
248+
-func.func @dot_precision_invalid_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
249+
- // expected-error@+1 {{expects precision config to be null or of size 2.}}
250+
- %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision HIGH>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
246251
- func.return %0: tensor<2x2xi32>
247252
-}
248253
-
249254
-// -----
250255
-
251-
-func.func @dot_precision_invalid_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
252-
- // expected-error@+1 {{expects precision config to be null or of size 2.}}
253-
- %0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = [#stablehlo<precision HIGH>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
254-
func.return %0: tensor<2x2xi32>
256+
func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> {
257+
// expected-error@+1 {{'precision_config' failed to satisfy constraint}}
258+
%0 = "stablehlo.dot"(%arg0, %arg1) {precision_config = ["FOO", #stablehlo<precision HIGHEST>]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
259+
@@ -3332,8 +3324,7 @@
260+
261+
// -----
262+
263+
-func.func @dot_general_invalid_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
264+
- // expected-error@+1 {{expects precision config to be null or of size 2}}
265+
+func.func @dot_general_one_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
266+
%0 = "stablehlo.dot_general"(%arg0, %arg1) {
267+
dot_dimension_numbers = #stablehlo.dot<
268+
lhs_batching_dimensions = [0],
269+
@@ -3342,6 +3333,22 @@
270+
rhs_contracting_dimensions = [1]
271+
>,
272+
precision_config = [#stablehlo<precision DEFAULT>]
273+
+ } : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>
274+
+ func.return %0 : tensor<2x4x5xf32>
275+
+}
276+
+
277+
+// -----
278+
+
279+
+func.func @dot_general_three_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
280+
+ // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}}
281+
+ %0 = "stablehlo.dot_general"(%arg0, %arg1) {
282+
+ dot_dimension_numbers = #stablehlo.dot<
283+
+ lhs_batching_dimensions = [0],
284+
+ rhs_batching_dimensions = [0],
285+
+ lhs_contracting_dimensions = [1],
286+
+ rhs_contracting_dimensions = [1]
287+
+ >,
288+
+ precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]
289+
} : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>
290+
func.return %0 : tensor<2x4x5xf32>
255291
}
292+
diff --ruN a/stablehlo/stablehlo/tests/verify_conv.mlir b/stablehlo/stablehlo/tests/verify_conv.mlir
293+
--- stablehlo/stablehlo/tests/verify_conv.mlir
294+
+++ stablehlo/stablehlo/tests/verify_conv.mlir
295+
@@ -963,7 +963,7 @@
256296

297+
func.func @conv_invalid_precision_config(%arg0: tensor<3x2xf16>,
298+
%arg1: tensor<2x2xf16>) -> tuple<tensor<3x2xf16>> {
299+
- // expected-error@+1{{expects precision config to be null or of size 2.}}
300+
+ // expected-error@+1 {{expects precision config to be empty or have <= 2 elements}}
301+
%0 = stablehlo.convolution(%arg0, %arg1)
302+
dim_numbers = [b, f]x[i, o]->[b, f],
303+
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [],
257304

xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3224,10 +3224,10 @@ func.func @dot_general(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> te
32243224
func.return %0 : tensor<2x4x6xf32>
32253225
}
32263226

3227+
32273228
// -----
32283229

3229-
func.func @dot_general_invalid_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
3230-
// expected-error@+1 {{expects precision config to be null or of size 2}}
3230+
func.func @dot_general_one_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
32313231
%0 = "mhlo.dot_general"(%arg0, %arg1) {
32323232
dot_dimension_numbers = #mhlo.dot<
32333233
lhs_batching_dimensions = [0],
@@ -3242,6 +3242,22 @@ func.func @dot_general_invalid_precision_config(%arg0: tensor<2x3x4xf32>, %arg1:
32423242

32433243
// -----
32443244

3245+
func.func @dot_general_three_element_precision_config(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x5xf32>) -> tensor<2x4x5xf32> {
3246+
// expected-error@+1 {{expects precision config to be empty or have <= 2 elements}}
3247+
%0 = "mhlo.dot_general"(%arg0, %arg1) {
3248+
dot_dimension_numbers = #mhlo.dot<
3249+
lhs_batching_dimensions = [0],
3250+
rhs_batching_dimensions = [0],
3251+
lhs_contracting_dimensions = [1],
3252+
rhs_contracting_dimensions = [1]
3253+
>,
3254+
precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
3255+
} : (tensor<2x3x4xf32>, tensor<2x3x5xf32>) -> tensor<2x4x5xf32>
3256+
func.return %0 : tensor<2x4x5xf32>
3257+
}
3258+
3259+
// -----
3260+
32453261
func.func @compatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
32463262
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
32473263
func.return %0 : tensor<?x?xf32>

xla/mlir_hlo/tests/Dialect/mhlo/verifier_conv_op.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,7 @@ func.func @convolution(%arg0: tensor<2x2x3x4xf32>, %arg1: tensor<3x5x5x3xf32>) -
938938

939939
func.func @conv_invalid_precision_config(%arg0: tensor<3x2xf16>,
940940
%arg1: tensor<2x2xf16>) -> tuple<tensor<3x2xf16>> {
941-
// expected-error@+1{{expects precision config to be null or of size 2.}}
941+
// expected-error@+1 {{expects precision config to be empty or have <= 2 elements}}
942942
%0 = mhlo.convolution(%arg0, %arg1)
943943
dim_numbers = [b, f]x[i, o]->[b, f],
944944
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [],

0 commit comments

Comments
 (0)