diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 8183f355795a9..c3c2c8d18b77c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> { "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by " "the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } @@ -1163,7 +1175,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> { Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", "bool", /*default=*/"true", "Emulate narrower scalar types with 32-bit ones if not supported by" - " the target"> + " the target">, + Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types", + "bool", /*default=*/"true", + "Emulate unsupported float types by representing them with integer " + "types of same bit width"> ]; } diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 3d22ec918f4c5..03ae54a8ae30a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -39,6 +39,10 @@ struct SPIRVConversionOptions { /// The number of bits to store a boolean value. unsigned boolNumBits{8}; + /// Whether to emulate unsupported floats with integer types of same bit + /// width. + bool emulateUnsupportedFloatTypes{true}; + /// How sub-byte values are storaged in memory. SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d43e6816641cb..265293b83f84c 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector elements; if (isa(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa(srcType)) { auto srcAttr = cast(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast(dstType), rewriter); if (!dstAttr) return failure(); @@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4df4912..56b6181018153 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f659afb10..c0439a4033eac 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386ea80124..8cd650e649008 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 35ec0190b5a61..8f4c4cc027798 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } + // Handle 8-bit floats. + if (options.emulateUnsupportedFloatTypes && isa(type)) { + auto bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 8) + return bitWidth / 8; + return std::nullopt; + } + if (auto complexType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) @@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, type.getSignedness()); } +/// Converts 8-bit float types to integer types with the same bit width. +/// Returns a nullptr for unsupported 8-bit float types. +static Type convert8BitFloatType(const SPIRVConversionOptions &options, + FloatType type) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa(type)) + return IntegerType::get(type.getContext(), type.getWidth()); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n"); + return nullptr; +} + +/// Returns a type with the same shape but with any 8-bit float element type +/// converted to the same bit width integer type. This is a noop when the +/// element type is not the 8-bit float type or emulation flag is set to false. +static ShapedType +convertShaped8BitFloatType(ShapedType type, + const SPIRVConversionOptions &options) { + if (!options.emulateUnsupportedFloatTypes) + return type; + Type srcElementType = type.getElementType(); + Type convertedElementType = nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa(srcElementType)) + convertedElementType = IntegerType::get( + type.getContext(), srcElementType.getIntOrFloatBitWidth()); + + if (!convertedElementType) + return type; + + return type.clone(convertedElementType); +} + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional storageClass = {}) { type = cast(convertIndexElementType(type, options)); + type = cast(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { // If this is not a spec allowed scalar type, try to handle sub-byte integer @@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, } type = cast(convertIndexElementType(type, options)); + type = cast(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() @@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } else if (auto indexType = dyn_cast(elementType)) { type = cast(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); + } else if (auto floatType = dyn_cast(elementType)) { + // Hnadle 8 bit float types. + type = cast(convertShaped8BitFloatType(type, options)); + arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() @@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](FloatType floatType) -> std::optional { if (auto scalarType = dyn_cast(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); + if (floatType.getWidth() == 8) + return convert8BitFloatType(this->options, floatType); return Type(); }); diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 1abe0fd2ec468..6e2352e706acc 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -559,6 +559,23 @@ func.func @constant() { return } +// CHECK-LABEL: @constant_8bit_float +func.func @constant_8bit_float() { + // CHECK: spirv.Constant 56 : i8 + %cst = arith.constant 1.0 : f8E4M3 + // CHECK: spirv.Constant 56 : i8 + %cst_i8 = arith.bitcast %cst : f8E4M3 to i8 + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3> + // CHECK: spirv.Constant dense<56> : vector<4xi8> + %cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2> + // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8> + %cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8> + return +} + // CHECK-LABEL: @constant_16bit func.func @constant_16bit() { // CHECK: spirv.Constant 4 : i16 diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 1737f4a906bf8..0c77c88334572 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \ // RUN: FileCheck %s --check-prefix=NOEMU +// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \ +// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT //===----------------------------------------------------------------------===// // Integer types @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return } func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } } // end module + + +// ----- + +// Check that 8-bit float types are emulated as i8. +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + + // CHECK: spirv.func @float8_to_integer8 + // CHECK-SAME: (%arg0: i8 + // CHECK-SAME: %arg1: i8 + // CHECK-SAME: %arg2: i8 + // CHECK-SAME: %arg3: i8 + // CHECK-SAME: %arg4: i8 + // CHECK-SAME: %arg5: i8 + // CHECK-SAME: %arg6: i8 + // CHECK-SAME: %arg7: i8 + // CHECK-SAME: %arg8: vector<4xi8> + // CHECK-SAME: %arg9: !spirv.ptr [0])>, StorageBuffer> + // CHECK-SAME: %arg10: !spirv.array<4 x i8> + // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8 + // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2 + // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3 + // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN + // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ + // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4 + // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU + // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ> + // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class> + // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2> + // UNSUPPORTED_FLOAT-SAME: ) { + + func.func @float8_to_integer8( + %arg0: f8E5M2, // CHECK-NOT: f8E5M2 + %arg1: f8E4M3, // CHECK-NOT: f8E4M3 + %arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN + %arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ + %arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ + %arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ + %arg6: f8E3M4, // CHECK-NOT: f8E3M4 + %arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU + %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ> + %arg9: memref<8xf8E4M3, #spirv.storage_class>, // CHECK-NOT: memref + %arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor + ) { + // CHECK: spirv.Return + return + } +}