Skip to content

[mlir][spirv] Add 8-bit float type emulation #148811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by emulating them with integer types of same bit width">
];
}

Expand Down Expand Up @@ -404,7 +407,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by emulating them with integer types of same bit width">
];
}

Expand Down Expand Up @@ -488,7 +494,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by emulating them with integer types of same bit width">
];
}

Expand Down Expand Up @@ -1151,7 +1160,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by"
" the target">
" the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by emulating them with integer types of same bit width">
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value.
unsigned boolNumBits{8};

/// Whether to emulate unsupported floats with integer types of same bit
/// width.
bool emulateUnsupportedFloatTypes{true};

/// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};

Expand Down
35 changes: 31 additions & 4 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}

// Get IntegerAttr from FloatAttr.
IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
ConversionPatternRewriter &rewriter) {
APFloat floatVal = floatAttr.getValue();
APInt intVal = floatVal.bitcastToAPInt();
return rewriter.getIntegerAttr(dstType, intVal);
}

/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
Expand Down Expand Up @@ -296,8 +304,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
FloatAttr dstAttr =
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
Attribute dstAttr = nullptr;
// Handle 8-bit float conversion to 8-bit integer.
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
srcElemType.getIntOrFloatBitWidth() == 8 &&
isa<IntegerType>(dstElemType)) {
dstAttr =
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
} else {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
rewriter);
}
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
Expand Down Expand Up @@ -361,11 +379,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
auto dstAttr = srcAttr;
Attribute dstAttr = srcAttr;

// Floating-point types not supported in the target environment are all
// converted to float type.
if (srcType != dstType) {
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
dstType.getIntOrFloatBitWidth() == 8) {
// If the source is an 8-bit float, convert it to a 8-bit integer.
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
if (!dstAttr)
return failure();
} else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
Expand Down Expand Up @@ -1351,6 +1377,7 @@ struct ConvertArithToSPIRVPass

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

// Use UnrealizedConversionCast as the bridge so that we don't need to pull
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

// TODO: We should also take care of block argument type conversion.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

RewritePatternSet patterns(context);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass

SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);

RewritePatternSet patterns(context);
Expand Down
56 changes: 56 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
// SPIR-V dialect. Keeping it local till the use case arises.
static std::optional<int64_t>
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {

if (isa<spirv::ScalarType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
// According to the SPIR-V spec:
Expand All @@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}

// Handle 8-bit floats.
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
if (bitWidth == 8)
return bitWidth / 8;
else
return std::nullopt;
}

if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
Expand Down Expand Up @@ -318,6 +328,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}

/// Converts 8-bit float types to integer types with the same bit width.
/// Returns a nullptr for unsupported 8-bit float types.
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
FloatType type) {
if (!options.emulateUnsupportedFloatTypes)
return nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(type))
return IntegerType::get(type.getContext(), type.getWidth());
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
return nullptr;
}

/// Returns a type with the same shape but with any 8-bit float element type
/// converted to the same bit width integer type. This is a noop when the
/// element type is not the 8-bit float type or emulation flag is set to false.
static ShapedType
convertShaped8BitFloatType(ShapedType type,
const SPIRVConversionOptions &options) {
if (!options.emulateUnsupportedFloatTypes)
return type;
auto srcElementType = type.getElementType();
Type convertedElementType = nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(srcElementType))
convertedElementType = IntegerType::get(
type.getContext(), srcElementType.getIntOrFloatBitWidth());

if (!convertedElementType)
return type;

return type.clone(convertedElementType);
}

/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
Expand All @@ -337,6 +385,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
Expand Down Expand Up @@ -433,6 +482,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}

type = cast<TensorType>(convertIndexElementType(type, options));
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
Expand Down Expand Up @@ -596,6 +646,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
// Hnadle 8 bit float types.
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
Expand Down Expand Up @@ -1439,6 +1493,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
if (floatType.getWidth() == 8)
return convert8BitFloatType(this->options, floatType);
return Type();
});

Expand Down
11 changes: 11 additions & 0 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,17 @@ func.func @constant() {
return
}

// CHECK-LABEL: @constant_8bit_float
func.func @constant_8bit_float() {
// CHECK: spirv.Constant 56 : i8
%cst = arith.constant 1.0 : f8E4M3
// CHECK: spirv.Constant dense<56> : vector<4xi8>
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
return
}

// CHECK-LABEL: @constant_16bit
func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16
Expand Down
54 changes: 54 additions & 0 deletions mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT

//===----------------------------------------------------------------------===//
// Integer types
Expand Down Expand Up @@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }

} // end module


// -----

// Check that 8-bit float types are emulated as i8.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
} {

// CHECK: spirv.func @float8_to_integer8
// CHECK-SAME: (%arg0: i8
// CHECK-SAME: %arg1: i8
// CHECK-SAME: %arg2: i8
// CHECK-SAME: %arg3: i8
// CHECK-SAME: %arg4: i8
// CHECK-SAME: %arg5: i8
// CHECK-SAME: %arg6: i8
// CHECK-SAME: %arg7: i8
// CHECK-SAME: %arg8: vector<4xi8>
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
// UNSUPPORTED_FLOAT-SAME: ) {

func.func @float8_to_integer8(
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
) {
// CHECK: spirv.Return
return
}
}
Loading