Skip to content

[mlir][spirv] Add 8-bit float type emulation #148811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

mshahneo
Copy link
Contributor

8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.

Copy link

github-actions bot commented Jul 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

mshahneo added 5 commits July 15, 2025 09:38
SPIR-V does not support any 8-bit floats.
Threfore, 8-bit floats are emulated as 8-bit integers.
Handles scalar and vector.
This approach minimizes the code modification.
@mshahneo mshahneo force-pushed the spirv_data_emulation branch from ab237ea to 3755267 Compare July 15, 2025 10:02
@mshahneo mshahneo changed the title [SPIRV] 8-bit float type emulation [MLIR][SPIRV] 8-bit float type emulation Jul 15, 2025
@mshahneo mshahneo changed the title [MLIR][SPIRV] 8-bit float type emulation [mlir][spirv] 8-bit float type emulation Jul 15, 2025
@mshahneo mshahneo changed the title [mlir][spirv] 8-bit float type emulation [mlir][spirv] Add 8-bit float type emulation Jul 15, 2025
@mshahneo mshahneo marked this pull request as ready for review July 15, 2025 20:03
@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.


Full diff: https://github.com/llvm/llvm-project/pull/148811.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+15-3)
  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (+4)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+31-4)
  • (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp (+1)
  • (modified) mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp (+1)
  • (modified) mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp (+1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+56)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+11)
  • (modified) mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir (+54)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..0eb9720351027 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by "
            "the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -404,7 +407,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -488,7 +494,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -1151,7 +1160,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec918f4c5..03ae54a8ae30a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
   /// The number of bits to store a boolean value.
   unsigned boolNumBits{8};
 
+  /// Whether to emulate unsupported floats with integer types of same bit
+  /// width.
+  bool emulateUnsupportedFloatTypes{true};
+
   /// How sub-byte values are storaged in memory.
   SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
 
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..f42f779a69d33 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
   return builder.getF32FloatAttr(dstVal.convertToFloat());
 }
 
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+                                        ConversionPatternRewriter &rewriter) {
+  APFloat floatVal = floatAttr.getValue();
+  APInt intVal = floatVal.bitcastToAPInt();
+  return rewriter.getIntegerAttr(dstType, intVal);
+}
+
 /// Returns true if the given `type` is a boolean scalar or vector type.
 static bool isBoolScalarOrVector(Type type) {
   assert(type && "Not a valid type");
@@ -296,8 +304,18 @@ struct ConstantCompositeOpPattern final
       SmallVector<Attribute, 8> elements;
       if (isa<FloatType>(srcElemType)) {
         for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
-          FloatAttr dstAttr =
-              convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+          Attribute dstAttr = nullptr;
+          // Handle 8-bit float conversion to 8-bit integer.
+          auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+          if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+              srcElemType.getIntOrFloatBitWidth() == 8 &&
+              isa<IntegerType>(dstElemType)) {
+            dstAttr =
+                getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+          } else {
+            dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+                                       rewriter);
+          }
           if (!dstAttr)
             return failure();
           elements.push_back(dstAttr);
@@ -361,11 +379,19 @@ struct ConstantScalarOpPattern final
     // Floating-point types.
     if (isa<FloatType>(srcType)) {
       auto srcAttr = cast<FloatAttr>(cstAttr);
-      auto dstAttr = srcAttr;
+      Attribute dstAttr = srcAttr;
 
       // Floating-point types not supported in the target environment are all
       // converted to float type.
-      if (srcType != dstType) {
+      auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+      if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+          srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+          dstType.getIntOrFloatBitWidth() == 8) {
+        // If the source is an 8-bit float, convert it to a 8-bit integer.
+        dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+        if (!dstAttr)
+          return failure();
+      } else if (srcType != dstType) {
         dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
         if (!dstAttr)
           return failure();
@@ -1351,6 +1377,7 @@ struct ConvertArithToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+    options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4df4912..56b6181018153 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f659afb10..c0439a4033eac 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386ea80124..8cd650e649008 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+    options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..6b2580b6541f2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
 // SPIR-V dialect. Keeping it local till the use case arises.
 static std::optional<int64_t>
 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+
   if (isa<spirv::ScalarType>(type)) {
     auto bitWidth = type.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
     return bitWidth / 8;
   }
 
+  // Handle 8-bit floats.
+  if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+    auto bitWidth = type.getIntOrFloatBitWidth();
+    if (bitWidth == 8)
+      return bitWidth / 8;
+    else
+      return std::nullopt;
+  }
+
   if (auto complexType = dyn_cast<ComplexType>(type)) {
     auto elementSize = getTypeNumBytes(options, complexType.getElementType());
     if (!elementSize)
@@ -318,6 +328,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
                           type.getSignedness());
 }
 
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+                                 FloatType type) {
+  if (!options.emulateUnsupportedFloatTypes)
+    return nullptr;
+  // F8 types are converted to integer types with the same bit width.
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float8E8M0FNUType>(type))
+    return IntegerType::get(type.getContext(), type.getWidth());
+  LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+  return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+                           const SPIRVConversionOptions &options) {
+  if (!options.emulateUnsupportedFloatTypes)
+    return type;
+  auto srcElementType = type.getElementType();
+  Type convertedElementType = nullptr;
+  // F8 types are converted to integer types with the same bit width.
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float8E8M0FNUType>(srcElementType))
+    convertedElementType = IntegerType::get(
+        type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+  if (!convertedElementType)
+    return type;
+
+  return type.clone(convertedElementType);
+}
+
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.
@@ -337,6 +385,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
                   const SPIRVConversionOptions &options, VectorType type,
                   std::optional<spirv::StorageClass> storageClass = {}) {
   type = cast<VectorType>(convertIndexElementType(type, options));
+  type = cast<VectorType>(convertShaped8BitFloatType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     // If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +482,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
   }
 
   type = cast<TensorType>(convertIndexElementType(type, options));
+  type = cast<TensorType>(convertShaped8BitFloatType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +646,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
     type = cast<MemRefType>(convertIndexElementType(type, options));
     arrayElemType = type.getElementType();
+  } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+    // Hnadle 8 bit float types.
+    type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+    arrayElemType = type.getElementType();
   } else {
     LLVM_DEBUG(
         llvm::dbgs()
@@ -1439,6 +1493,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   addConversion([this](FloatType floatType) -> std::optional<Type> {
     if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
       return convertScalarType(this->targetEnv, this->options, scalarType);
+    if (floatType.getWidth() == 8)
+      return convert8BitFloatType(this->options, floatType);
     return Type();
   });
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd2ec468..751e727534efe 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,17 @@ func.func @constant() {
   return
 }
 
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+  // CHECK: spirv.Constant 56 : i8
+  %cst = arith.constant 1.0 : f8E4M3
+  // CHECK: spirv.Constant dense<56> : vector<4xi8>
+  %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+  // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+  %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+  return
+}
+
 // CHECK-LABEL: @constant_16bit
 func.func @constant_16bit() {
   // CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a906bf8..0c77c88334572 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
 // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
 // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
 // RUN:   FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN:   FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
 
 //===----------------------------------------------------------------------===//
 // Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
 func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
 
 } // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+  // CHECK: spirv.func @float8_to_integer8
+  // CHECK-SAME: (%arg0: i8
+  // CHECK-SAME: %arg1: i8
+  // CHECK-SAME: %arg2: i8
+  // CHECK-SAME: %arg3: i8
+  // CHECK-SAME: %arg4: i8
+  // CHECK-SAME: %arg5: i8
+  // CHECK-SAME: %arg6: i8
+  // CHECK-SAME: %arg7: i8
+  // CHECK-SAME: %arg8: vector<4xi8>
+  // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+  // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+  // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+  // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+  // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+  // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+  // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+  // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+  // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+  // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+  // UNSUPPORTED_FLOAT-SAME: ) {
+
+  func.func @float8_to_integer8(
+    %arg0: f8E5M2,                   // CHECK-NOT: f8E5M2
+    %arg1: f8E4M3,                   // CHECK-NOT: f8E4M3
+    %arg2: f8E4M3FN,                // CHECK-NOT: f8E4M3FN
+    %arg3: f8E5M2FNUZ,              // CHECK-NOT: f8E5M2FNUZ
+    %arg4: f8E4M3FNUZ,              // CHECK-NOT: f8E4M3FNUZ
+    %arg5: f8E4M3B11FNUZ,           // CHECK-NOT: f8E4M3B11FNUZ
+    %arg6: f8E3M4,                  // CHECK-NOT: f8E3M4
+    %arg7: f8E8M0FNU,               // CHECK-NOT: f8E8M0FNU
+    %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+    %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+    %arg10: tensor<4xf8E5M2>        // CHECK-NOT: tensor
+  ) {
+    // CHECK: spirv.Return
+    return
+  }
+}

@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-mlir

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.


Full diff: https://github.com/llvm/llvm-project/pull/148811.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+15-3)
  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (+4)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+31-4)
  • (modified) mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp (+1)
  • (modified) mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp (+1)
  • (modified) mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp (+1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+56)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+11)
  • (modified) mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir (+54)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 50c67da91a4af..0eb9720351027 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by "
            "the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -404,7 +407,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -488,7 +494,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
@@ -1151,7 +1160,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
     Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
            "bool", /*default=*/"true",
            "Emulate narrower scalar types with 32-bit ones if not supported by"
-           " the target">
+           " the target">,
+    Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
+    "bool", /*default=*/"true",
+    "Emulate unsupported float types by emulating them with integer types of same bit width">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 3d22ec918f4c5..03ae54a8ae30a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
   /// The number of bits to store a boolean value.
   unsigned boolNumBits{8};
 
+  /// Whether to emulate unsupported floats with integer types of same bit
+  /// width.
+  bool emulateUnsupportedFloatTypes{true};
+
   /// How sub-byte values are storaged in memory.
   SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
 
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 434d7df853a5e..f42f779a69d33 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
   return builder.getF32FloatAttr(dstVal.convertToFloat());
 }
 
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+                                        ConversionPatternRewriter &rewriter) {
+  APFloat floatVal = floatAttr.getValue();
+  APInt intVal = floatVal.bitcastToAPInt();
+  return rewriter.getIntegerAttr(dstType, intVal);
+}
+
 /// Returns true if the given `type` is a boolean scalar or vector type.
 static bool isBoolScalarOrVector(Type type) {
   assert(type && "Not a valid type");
@@ -296,8 +304,18 @@ struct ConstantCompositeOpPattern final
       SmallVector<Attribute, 8> elements;
       if (isa<FloatType>(srcElemType)) {
         for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
-          FloatAttr dstAttr =
-              convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+          Attribute dstAttr = nullptr;
+          // Handle 8-bit float conversion to 8-bit integer.
+          auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+          if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+              srcElemType.getIntOrFloatBitWidth() == 8 &&
+              isa<IntegerType>(dstElemType)) {
+            dstAttr =
+                getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+          } else {
+            dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+                                       rewriter);
+          }
           if (!dstAttr)
             return failure();
           elements.push_back(dstAttr);
@@ -361,11 +379,19 @@ struct ConstantScalarOpPattern final
     // Floating-point types.
     if (isa<FloatType>(srcType)) {
       auto srcAttr = cast<FloatAttr>(cstAttr);
-      auto dstAttr = srcAttr;
+      Attribute dstAttr = srcAttr;
 
       // Floating-point types not supported in the target environment are all
       // converted to float type.
-      if (srcType != dstType) {
+      auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+      if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+          srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+          dstType.getIntOrFloatBitWidth() == 8) {
+        // If the source is an 8-bit float, convert it to a 8-bit integer.
+        dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+        if (!dstAttr)
+          return failure();
+      } else if (srcType != dstType) {
         dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
         if (!dstAttr)
           return failure();
@@ -1351,6 +1377,7 @@ struct ConvertArithToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+    options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4df4912..56b6181018153 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   // TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f659afb10..c0439a4033eac 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
 
   SPIRVConversionOptions options;
   options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+  options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
   SPIRVTypeConverter typeConverter(targetAttr, options);
 
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386ea80124..8cd650e649008 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
 
     SPIRVConversionOptions options;
     options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+    options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     RewritePatternSet patterns(context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index f70b3325f8725..6b2580b6541f2 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
 // SPIR-V dialect. Keeping it local till the use case arises.
 static std::optional<int64_t>
 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
+
   if (isa<spirv::ScalarType>(type)) {
     auto bitWidth = type.getIntOrFloatBitWidth();
     // According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
     return bitWidth / 8;
   }
 
+  // Handle 8-bit floats.
+  if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+    auto bitWidth = type.getIntOrFloatBitWidth();
+    if (bitWidth == 8)
+      return bitWidth / 8;
+    else
+      return std::nullopt;
+  }
+
   if (auto complexType = dyn_cast<ComplexType>(type)) {
     auto elementSize = getTypeNumBytes(options, complexType.getElementType());
     if (!elementSize)
@@ -318,6 +328,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
                           type.getSignedness());
 }
 
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+                                 FloatType type) {
+  if (!options.emulateUnsupportedFloatTypes)
+    return nullptr;
+  // F8 types are converted to integer types with the same bit width.
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float8E8M0FNUType>(type))
+    return IntegerType::get(type.getContext(), type.getWidth());
+  LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
+  return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+                           const SPIRVConversionOptions &options) {
+  if (!options.emulateUnsupportedFloatTypes)
+    return type;
+  auto srcElementType = type.getElementType();
+  Type convertedElementType = nullptr;
+  // F8 types are converted to integer types with the same bit width.
+  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+          Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+          Float8E8M0FNUType>(srcElementType))
+    convertedElementType = IntegerType::get(
+        type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+  if (!convertedElementType)
+    return type;
+
+  return type.clone(convertedElementType);
+}
+
 /// Returns a type with the same shape but with any index element type converted
 /// to the matching integer type. This is a noop when the element type is not
 /// the index type.
@@ -337,6 +385,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
                   const SPIRVConversionOptions &options, VectorType type,
                   std::optional<spirv::StorageClass> storageClass = {}) {
   type = cast<VectorType>(convertIndexElementType(type, options));
+  type = cast<VectorType>(convertShaped8BitFloatType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     // If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +482,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
   }
 
   type = cast<TensorType>(convertIndexElementType(type, options));
+  type = cast<TensorType>(convertShaped8BitFloatType(type, options));
   auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
   if (!scalarType) {
     LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +646,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
   } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
     type = cast<MemRefType>(convertIndexElementType(type, options));
     arrayElemType = type.getElementType();
+  } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+    // Hnadle 8 bit float types.
+    type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+    arrayElemType = type.getElementType();
   } else {
     LLVM_DEBUG(
         llvm::dbgs()
@@ -1439,6 +1493,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
   addConversion([this](FloatType floatType) -> std::optional<Type> {
     if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
       return convertScalarType(this->targetEnv, this->options, scalarType);
+    if (floatType.getWidth() == 8)
+      return convert8BitFloatType(this->options, floatType);
     return Type();
   });
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 1abe0fd2ec468..751e727534efe 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -559,6 +559,17 @@ func.func @constant() {
   return
 }
 
+// CHECK-LABEL: @constant_8bit_float
+func.func @constant_8bit_float() {
+  // CHECK: spirv.Constant 56 : i8
+  %cst = arith.constant 1.0 : f8E4M3
+  // CHECK: spirv.Constant dense<56> : vector<4xi8>
+  %cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
+  // CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
+  %cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
+  return
+}
+
 // CHECK-LABEL: @constant_16bit
 func.func @constant_16bit() {
   // CHECK: spirv.Constant 4 : i16
diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
index 1737f4a906bf8..0c77c88334572 100644
--- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
+++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
@@ -1,6 +1,8 @@
 // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
 // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
 // RUN:   FileCheck %s --check-prefix=NOEMU
+// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
+// RUN:   FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
 
 //===----------------------------------------------------------------------===//
 // Integer types
@@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
 func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
 
 } // end module
+
+
+// -----
+
+// Check that 8-bit float types are emulated as i8.
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
+} {
+
+  // CHECK: spirv.func @float8_to_integer8
+  // CHECK-SAME: (%arg0: i8
+  // CHECK-SAME: %arg1: i8
+  // CHECK-SAME: %arg2: i8
+  // CHECK-SAME: %arg3: i8
+  // CHECK-SAME: %arg4: i8
+  // CHECK-SAME: %arg5: i8
+  // CHECK-SAME: %arg6: i8
+  // CHECK-SAME: %arg7: i8
+  // CHECK-SAME: %arg8: vector<4xi8>
+  // CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
+  // CHECK-SAME: %arg10: !spirv.array<4 x i8>
+  // UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
+  // UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
+  // UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
+  // UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
+  // UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
+  // UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
+  // UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
+  // UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
+  // UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
+  // UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
+  // UNSUPPORTED_FLOAT-SAME: ) {
+
+  func.func @float8_to_integer8(
+    %arg0: f8E5M2,                   // CHECK-NOT: f8E5M2
+    %arg1: f8E4M3,                   // CHECK-NOT: f8E4M3
+    %arg2: f8E4M3FN,                // CHECK-NOT: f8E4M3FN
+    %arg3: f8E5M2FNUZ,              // CHECK-NOT: f8E5M2FNUZ
+    %arg4: f8E4M3FNUZ,              // CHECK-NOT: f8E4M3FNUZ
+    %arg5: f8E4M3B11FNUZ,           // CHECK-NOT: f8E4M3B11FNUZ
+    %arg6: f8E3M4,                  // CHECK-NOT: f8E3M4
+    %arg7: f8E8M0FNU,               // CHECK-NOT: f8E8M0FNU
+    %arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
+    %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
+    %arg10: tensor<4xf8E5M2>        // CHECK-NOT: tensor
+  ) {
+    // CHECK: spirv.Return
+    return
+  }
+}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants