Skip to content

Commit

Permalink
Integrate LLVM at llvm/llvm-project@aa65f93b71de
Browse files Browse the repository at this point in the history
  • Loading branch information
abhigunj committed Jan 29, 2025
1 parent 8993ef7 commit 5ce85df
Show file tree
Hide file tree
Showing 48 changed files with 4,014 additions and 58 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ gentbl_cc_library(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/dialect/VhloAttrs.td",
td_file = "stablehlo/dialect/VhloEnums.td",
deps = [
":vhlo_ops_td_files",
],
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "e2402615a5a76d46a433dfcc1de10b38a1263c9d"
LLVM_COMMIT = "aa65f93b71dee8cacb22be1957673c8be6a3ec24"

LLVM_SHA256 = "9c22349e1d38555b2f223e49951655f60c04c0c3467e0150aaf6c9f50484cc9f"
LLVM_SHA256 = "0a6046edb6a9834d5b912ec0e705dec91d39ee1b7b2fbb5930955d83d2090ff5"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
e2402615a5a76d46a433dfcc1de10b38a1263c9d
aa65f93b71dee8cacb22be1957673c8be6a3ec24
23 changes: 23 additions & 0 deletions stablehlo/dialect/AssemblyFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,29 @@ ParseResult parseCustomCallTarget(AsmParser& parser, StringAttr& target) {
return parser.parseSymbolName(target);
}

void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol,
int64_t ulps, Attribute mode) {
odsPrinter << "<";
if (!atol.isZero()) {
odsPrinter << "atol = ";
odsPrinter.printFloat(atol);
odsPrinter << ", ";
}
if (!rtol.isZero()) {
odsPrinter << "rtol = ";
odsPrinter.printFloat(rtol);
odsPrinter << ", ";
}
if (ulps != 0) {
odsPrinter << "ulps = ";
odsPrinter << ulps;
odsPrinter << ", ";
}
odsPrinter << "mode = ";
odsPrinter.printAttribute(mode);
odsPrinter << ">";
}

void printTypeExtensions(BoundedAttrInterface attr, DialectAsmPrinter& os) {
os << "bounds<";
llvm::interleaveComma(attr.getBounds(), os,
Expand Down
59 changes: 59 additions & 0 deletions stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,65 @@ ParseResult parseDotDimensionNumbers(AsmParser& parser, AttrTy& target) {
return success();
}

// ResultAccuracyAttr - Custom printing and parsing for ResultAccuracyAttr.
//
// ResultAccuractAttr ::= `<` OptAtolAccuracy OptRtolAccuracy
// OptUlpAccuracy ModeAccuracy `>`
// OptAtolAccuracy ::= `atol` `=` APFloat `, ` | eps
// OptRtolAccuracy ::= `rtol` `=` APFloat `, ` | eps
// OptUlpAccuracy ::= `ulps` `=` int64_t `, ` | eps
// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr
void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol,
int64_t ulps, Attribute mode);

template <typename AttrTy, typename ModeTy>
Attribute parseResultAccuracyAttr(AsmParser& parser, Type type) {
APFloat resultAtol = APFloat::getZero(APFloat::IEEEdouble());
APFloat resultRtol = APFloat::getZero(APFloat::IEEEdouble());
int64_t resultUlps = 0;

// Parse literal '<'
if (parser.parseLess()) return {};

// OptAtolAccuracy
if (succeeded(parser.parseOptionalKeyword("atol"))) {
double value;
if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma())
return {};
resultAtol = APFloat(value);
}

// OptRtolAccuracy
if (succeeded(parser.parseOptionalKeyword("rtol"))) {
double value;
if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma())
return {};
resultRtol = APFloat(value);
}

// OptUlpAccuracy
if (succeeded(parser.parseOptionalKeyword("ulps"))) {
int64_t value;
if (parser.parseEqual() || parser.parseInteger(value) ||
parser.parseComma())
return {};
resultUlps = value;
}

// ModeAccuracy
ModeTy modeAttr;
if (parser.parseKeyword("mode") || parser.parseEqual() ||
parser.parseAttribute(modeAttr)) {
return {};
}

// Parse literal '>'
if (parser.parseGreater()) return {};
return parser.getChecked<AttrTy>(
parser.getCurrentLocation(), parser.getContext(), resultAtol, resultRtol,
resultUlps, modeAttr);
}

} // namespace hlo
} // namespace mlir

Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,22 @@ bool isValidQuantizedDimension(Type type) {
numScales == rankedType.getDimSize(quantDim));
}

bool hasSingleBoundedDimension(Type type) {
RankedTensorType rankedType = dyn_cast<RankedTensorType>(type);
auto boundedAttr =
dyn_cast_or_null<BoundedAttrInterface>(rankedType.getEncoding());
if (!boundedAttr) return false;

// Count if bounded attr size is not kDynamic
int64_t numBoundedDims = llvm::count_if(
boundedAttr.getBounds(),
[](int64_t bound) { return !ShapedType::isDynamic(bound); });
// Also check that there are only bounded dims and no unbounded dims.
int64_t numDynamicDims = llvm::count_if(
rankedType.getShape(),
[](int64_t bound) { return ShapedType::isDynamic(bound); });
return numBoundedDims == 1 && numDynamicDims == 1;
}

} // namespace hlo
} // namespace mlir
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ bool isValidStablehloQuantizedElementType(Type elementType);
// mentioned in the StableHLO specification.
bool isValidQuantizedDimension(Type type);

// Returns true if the given type has a single bounded dimension.
bool hasSingleBoundedDimension(Type type);

// TODO(zhouxin) Move type inference related methods to TypeInference.cpp

std::pair<int64_t, int64_t> inferConcatenatedDimAndBound(int64_t leftSize,
Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def I32RankedTensor : RankedTensorOf<[I32]>;

def UI32RankedTensor : RankedTensorOf<[UI32]>;

//===----------------------------------------------------------------------===//
// HLO type constraints.
//===----------------------------------------------------------------------===//

// Note: Bounded dynamisms is largely unspecced and this feature needs more
// thoguht as it is adopted to modern frameworks. The current support is
// designed to allow existing TF programs to be representable in StableHLO and
// is subject to change as a formal design for boudned dynamism is developed.
def HLO_HasSingleBoundedDimensionPred
: CPred<"mlir::hlo::hasSingleBoundedDimension($_self)">;

def HLO_HasStaticOrSingleBoundedShapePred
: Or<[HasStaticShapePred, HLO_HasSingleBoundedDimensionPred]>;

//===----------------------------------------------------------------------===//
// HLO type definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -267,6 +281,9 @@ def HLO_StaticShapeTensor : StaticShapeTensorOf<[
def HLO_StaticShapeTensorOrPerAxisQuantizedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HasStaticShapePred], "statically shaped tensor">;

def HLO_StaticShapeTensorPerAxisQuantizedTensorOrBoundedTensor : RankedTensorOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt, HLO_PerAxisQuantizedInt],
[IsValidQuantizedDimension, HLO_HasStaticOrSingleBoundedShapePred], "statically shaped or single bounded dimension tensor">;

def HLO_StaticShapeTensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_StaticShapeTensor, HLO_StaticShapeTensorOrPerAxisQuantizedTensor, HLO_Token]>;

def HLO_StaticShapeIntOrFpTensor : StaticShapeTensorOf<[HLO_Int, HLO_Float]>;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ mlir_tablegen(VhloEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS VhloOps.td)
mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs)
set(LLVM_TARGET_DEFINITIONS VhloAttrs.td)
set(LLVM_TARGET_DEFINITIONS VhloEnums.td)
mlir_tablegen(VhloAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(VhloAttrInterfaces.cpp.inc -gen-attr-interface-defs)
set(LLVM_TARGET_DEFINITIONS VhloTypes.td)
Expand Down
15 changes: 15 additions & 0 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

include "mlir/IR/OpBase.td"
include "mlir/IR/TensorEncoding.td"
include "stablehlo/dialect/StablehloTypes.td"

def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "parseDimSizes($_parser)";
Expand Down Expand Up @@ -209,4 +210,18 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
let hasCustomAssemblyFormat = 1;
}

def StableHLO_ResultAccuracyAttr : AttrDef<StableHLO_Dialect, "ResultAccuracy"> {
let mnemonic = "result_accuracy";
let summary = "The requested accuracy for transcendental unary ops.";
let parameters = (ins
"APFloat":$atol,
"APFloat":$rtol,
"int64_t":$ulps,
StableHLO_ResultAccuracyModeAttr:$mode
);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS
86 changes: 79 additions & 7 deletions stablehlo/dialect/StablehloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <memory>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -180,6 +181,18 @@ enum AttributeCode {
/// allowImpreciseAccumulation : svarint
/// }
kDotAlgorithmAttr = 15,

// ResultAccuracyModeAttr {
// mode: varint (encoded enum)
// }
kResultAccuracyModeAttr = 16,

// ResultAccuracyAttr {
// atol: APFloat
// rtol: APFloat
// ulps: svarint
// }
kResultAccuracyAttr = 17,
};

/// This enum contains marker codes used to indicate which type is
Expand Down Expand Up @@ -241,6 +254,10 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {
OutputOperandAliasAttr readOutputOperandAliasAttr(
DialectBytecodeReader &reader) const;
PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const;
ResultAccuracyAttr readResultAccuracyAttr(
DialectBytecodeReader &reader) const;
ResultAccuracyModeAttr readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const;
RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const;
RngDistributionAttr readRngDistributionAttr(
DialectBytecodeReader &reader) const;
Expand All @@ -264,6 +281,8 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {
DialectBytecodeWriter &writer) const;
void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const;
void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const;
void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const;
void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const;
void write(ScatterDimensionNumbersAttr attr,
Expand Down Expand Up @@ -327,6 +346,10 @@ Attribute StablehloBytecodeInterface::readAttribute(
return readOutputOperandAliasAttr(reader);
case stablehlo_encoding::kPrecisionAttr:
return readPrecisionAttr(reader);
case stablehlo_encoding::kResultAccuracyAttr:
return readResultAccuracyAttr(reader);
case stablehlo_encoding::kResultAccuracyModeAttr:
return readResultAccuracyModeAttr(reader);
case stablehlo_encoding::kRngAlgorithmAttr:
return readRngAlgorithmAttr(reader);
case stablehlo_encoding::kRngDistributionAttr:
Expand All @@ -352,13 +375,13 @@ LogicalResult StablehloBytecodeInterface::writeAttribute(
.Case<ChannelHandleAttr, ComparisonDirectionAttr, ComparisonTypeAttr,
ConvDimensionNumbersAttr, DotAlgorithmAttr, DotDimensionNumbersAttr,
FftTypeAttr, GatherDimensionNumbersAttr, OutputOperandAliasAttr,
PrecisionAttr, RngAlgorithmAttr, RngDistributionAttr,
ScatterDimensionNumbersAttr, TransposeAttr, TypeExtensionsAttr>(
[&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
})
PrecisionAttr, ResultAccuracyAttr, ResultAccuracyModeAttr,
RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr,
TransposeAttr, TypeExtensionsAttr>([&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
})
.Default([&](Attribute) {
LOG_NOT_IMPLEMENTED;
return failure();
Expand Down Expand Up @@ -806,6 +829,55 @@ void StablehloBytecodeInterface::writeVersion(
}
}

//===----------------------------------------------------------------------===//
// ResultAccuracyModeAttr

ResultAccuracyModeAttr StablehloBytecodeInterface::readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
return hlo::bytecode::readEnumAttribute<ResultAccuracyModeAttr>(
reader, getContext(),
[](uint32_t val) { return symbolizeResultAccuracyMode(val); });
}

void StablehloBytecodeInterface::write(ResultAccuracyModeAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kResultAccuracyModeAttr);
hlo::bytecode::writeEnumAttribute<ResultAccuracyMode>(attr, writer);
}

//===----------------------------------------------------------------------===//
// ResultAccuracyAttr

ResultAccuracyAttr StablehloBytecodeInterface::readResultAccuracyAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
FailureOr<APFloat> atol;
FailureOr<APFloat> rtol;
int64_t ulps;
ResultAccuracyModeAttr mode;
if (failed(atol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(rtol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(ulps)) ||
failed(reader.readAttribute(mode))) {
mlir::emitWarning(mlir::UnknownLoc::get(getContext()))
<< "failed to read APFloat for atol";
return ResultAccuracyAttr();
}
return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode);
}

void StablehloBytecodeInterface::write(ResultAccuracyAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kResultAccuracyAttr);
writer.writeAPFloatWithKnownSemantics(attr.getAtol());
writer.writeAPFloatWithKnownSemantics(attr.getRtol());
writer.writeSignedVarInt(attr.getUlps());
writer.writeAttribute(attr.getMode());
}

} // namespace

void addBytecodeInterface(StablehloDialect *dialect) {
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/StablehloEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def StableHLO_PrecisionAttr : EnumAttr<StableHLO_Dialect, StableHLO_Precision, "
def StableHLO_PrecisionConfigAttr:
TypedArrayAttrBase<StableHLO_PrecisionAttr, "Precision Config attribute">;

//===----------------------------------------------------------------------===//
// Result Accuracy enum definitions.
//===----------------------------------------------------------------------===//

def STABLEHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
def STABLEHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>;
def STABLEHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>;

def StableHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode",
"XLA result accuracy mode.",
[
STABLEHLO_RESULT_ACCURACY_DEFAULT,
STABLEHLO_RESULT_ACCURACY_HIGHEST,
STABLEHLO_RESULT_ACCURACY_TOLERANCE
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::stablehlo";
}

def StableHLO_ResultAccuracyModeAttr : EnumAttr<StableHLO_Dialect, StableHLO_ResultAccuracyMode, "result_accuracy_mode"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 5ce85df

Please sign in to comment.