Skip to content

Commit 5cde869

Browse files
committed
Add test case & make arith-to-spirv use emulation flag.
1 parent b1f826c commit 5cde869

File tree

4 files changed

+74
-4
lines changed

4 files changed

+74
-4
lines changed

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,9 @@ struct ConstantCompositeOpPattern final
306306
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
307307
Attribute dstAttr = nullptr;
308308
// Handle 8-bit float conversion to 8-bit integer.
309-
if (srcElemType.getIntOrFloatBitWidth() == 8 &&
309+
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
310+
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
311+
srcElemType.getIntOrFloatBitWidth() == 8 &&
310312
isa<IntegerType>(dstElemType)) {
311313
dstAttr =
312314
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
@@ -381,7 +383,9 @@ struct ConstantScalarOpPattern final
381383

382384
// Floating-point types not supported in the target environment are all
383385
// converted to float type.
384-
if (srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
386+
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
387+
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
388+
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
385389
dstType.getIntOrFloatBitWidth() == 8) {
386390
// If the source is an 8-bit float, convert it to a 8-bit integer.
387391
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
@@ -1373,6 +1377,7 @@ struct ConvertArithToSPIRVPass
13731377

13741378
SPIRVConversionOptions options;
13751379
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
1380+
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
13761381
SPIRVTypeConverter typeConverter(targetAttr, options);
13771382

13781383
// Use UnrealizedConversionCast as the bridge so that we don't need to pull

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,12 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
345345

346346
/// Returns a type with the same shape but with any 8-bit float element type
347347
/// converted to the same bit width integer type. This is a noop when the
348-
/// element type is not the 8-bit float type.
348+
/// element type is not the 8-bit float type or emulation flag is set to false.
349349
static ShapedType
350350
convertShaped8BitFloatType(ShapedType type,
351351
const SPIRVConversionOptions &options) {
352352
if (!options.emulateUnsupportedFloatTypes)
353-
return nullptr;
353+
return type;
354354
auto srcElementType = type.getElementType();
355355
Type convertedElementType = nullptr;
356356
// F8 types are converted to integer types with the same bit width.

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,17 @@ func.func @constant() {
559559
return
560560
}
561561

562+
// CHECK-LABEL: @constant_8bit_float
563+
func.func @constant_8bit_float() {
564+
// CHECK: spirv.Constant 56 : i8
565+
%cst = arith.constant 1.0 : f8E4M3
566+
// CHECK: spirv.Constant dense<56> : vector<4xi8>
567+
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
568+
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
569+
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
570+
return
571+
}
572+
562573
// CHECK-LABEL: @constant_16bit
563574
func.func @constant_16bit() {
564575
// CHECK: spirv.Constant 4 : i16

mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
22
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
33
// RUN: FileCheck %s --check-prefix=NOEMU
4+
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
5+
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
46

57
//===----------------------------------------------------------------------===//
68
// Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
944946
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
945947

946948
} // end module
949+
950+
951+
// -----
952+
953+
// Check that 8-bit float types are emulated as i8.
954+
module attributes {
955+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
956+
} {
957+
958+
// CHECK: spirv.func @float8_to_integer8
959+
// CHECK-SAME: (%arg0: i8
960+
// CHECK-SAME: %arg1: i8
961+
// CHECK-SAME: %arg2: i8
962+
// CHECK-SAME: %arg3: i8
963+
// CHECK-SAME: %arg4: i8
964+
// CHECK-SAME: %arg5: i8
965+
// CHECK-SAME: %arg6: i8
966+
// CHECK-SAME: %arg7: i8
967+
// CHECK-SAME: %arg8: vector<4xi8>
968+
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
969+
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
970+
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
971+
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
972+
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
973+
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
974+
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
975+
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
976+
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
977+
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
978+
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
979+
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
980+
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
981+
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
982+
// UNSUPPORTED_FLOAT-SAME: ) {
983+
984+
func.func @float8_to_integer8(
985+
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
986+
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
987+
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
988+
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
989+
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
990+
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
991+
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
992+
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
993+
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
994+
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
995+
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
996+
) {
997+
// CHECK: spirv.Return
998+
return
999+
}
1000+
}

0 commit comments

Comments
 (0)