Skip to content

Commit 3bf4e4b

Browse files
committed
[Torch] Fold aten.to.dtype on splat constants.
This commit teaches `AtenToDtypeOp::fold` to constant-fold dtype conversions when the operand is a splat `DenseElementsAttr`. Folding is done according to torch's rounding behavior, i.e. * Bool: 0 and -0.0 → false; nonzero/NaN/±Inf → true. * Float → Int: round toward zero. * Int → Float: sign-aware, rmNearestTiesToEven. * Float ↔ Float: use builtin `mlir::FloatType::getFloatSemantics()`. * Int ↔ Int: use `zextOrTrunc` / `sextOrTrunc` based on source signedness. Folding is only performed when `non_blocking == false`, `copy == false`, and `memory_format` is None.
1 parent 4f572c5 commit 3bf4e4b

File tree

3 files changed

+160
-21
lines changed

3 files changed

+160
-21
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -892,26 +892,99 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
892892
// The non_blocking arg must be `False`.
893893
if (!matchPattern(getNonBlocking(), m_TorchConstantBool(&nonBlocking)) ||
894894
nonBlocking)
895-
return nullptr;
895+
return {};
896896
// The copy arg must be `False`.
897897
if (!matchPattern(getCopy(), m_TorchConstantBool(&copyArg)) || copyArg)
898-
return nullptr;
898+
return {};
899899
// The memory_format arg must be `none`.
900900
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
901-
return nullptr;
901+
return {};
902902

903903
auto inputType = cast<BaseTensorType>(getSelf().getType());
904904
auto resType = cast<BaseTensorType>(getType());
905-
// If the types aren't equal, then we can't fold.
906-
if (inputType != resType)
907-
return nullptr;
905+
906+
// Fold when both the input tensor and result are of the same type.
908907
// If the type does not have a statically known dtype, then we cannot fold.
909908
// For example, folding `tensor<*,unk>` to `tensor<*,unk>` would be wrong,
910909
// since the `unk` could be dynamically different for the operand and result.
911-
if (!inputType.hasDtype())
912-
return nullptr;
913-
// Fold when both the input tensor and result are of the same type.
914-
return getOperand(0);
910+
if (inputType == resType && inputType.hasDtype())
911+
return getOperand(0);
912+
913+
// Fold conversion of splat values.
914+
auto elems = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
915+
if (!elems || !elems.isSplat())
916+
return {};
917+
918+
auto outVTy = dyn_cast<ValueTensorType>(getType());
919+
if (!outVTy)
920+
return {};
921+
922+
auto outShaped = outVTy.toBuiltinTensor();
923+
if (!outShaped.hasStaticShape())
924+
return {};
925+
926+
Type srcEltTy = inputType.getDtype();
927+
Type dstEltTy = outVTy.getDtype();
928+
929+
// Handle integer destination.
930+
if (auto dstI = dyn_cast<IntegerType>(dstEltTy)) {
931+
// any -> bool(i1).
932+
if (dstI.isSignlessInteger(1)) {
933+
bool truthy = false;
934+
if (isa<mlir::FloatType>(srcEltTy)) {
935+
const APFloat &floatVal = elems.getSplatValue<APFloat>();
936+
truthy = !floatVal.isZero();
937+
} else {
938+
const APInt &intVal = elems.getSplatValue<APInt>();
939+
truthy = !intVal.isZero();
940+
}
941+
return DenseElementsAttr::get(outShaped, APInt(/*numBits=*/1, truthy));
942+
}
943+
// float -> intN
944+
if (auto srcF = dyn_cast<mlir::FloatType>(srcEltTy)) {
945+
APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger());
946+
bool isExact = false;
947+
APFloat f = elems.getSplatValue<APFloat>();
948+
APFloat::opStatus st =
949+
f.convertToInteger(result, APFloat::rmTowardZero, &isExact);
950+
if (st == APFloat::opOK || st == APFloat::opInexact)
951+
return DenseElementsAttr::get(outShaped, APInt(result));
952+
return {}; // NaN/Inf/out-of-range: preserve runtime semantics.
953+
}
954+
// intM -> intN
955+
const APInt &v = elems.getSplatValue<APInt>();
956+
APInt casted = cast<IntegerType>(srcEltTy).isUnsignedInteger()
957+
? v.zextOrTrunc(dstI.getWidth())
958+
: v.sextOrTrunc(dstI.getWidth());
959+
return DenseElementsAttr::get(outShaped, casted);
960+
}
961+
962+
// Handle float destination.
963+
if (auto dstF = dyn_cast<mlir::FloatType>(dstEltTy)) {
964+
const llvm::fltSemantics &dstSem = dstF.getFloatSemantics();
965+
966+
// int -> float
967+
if (auto srcI = dyn_cast<IntegerType>(srcEltTy)) {
968+
APFloat f(dstSem);
969+
APFloat::opStatus st = f.convertFromAPInt(
970+
elems.getSplatValue<APInt>(),
971+
/*isSigned=*/!srcI.isUnsignedInteger(), APFloat::rmNearestTiesToEven);
972+
if (st == APFloat::opOK || st == APFloat::opInexact)
973+
return DenseElementsAttr::get(outShaped, f);
974+
return {};
975+
}
976+
977+
// floatX -> floatY
978+
APFloat f = elems.getSplatValue<APFloat>();
979+
bool losesInfo = false;
980+
APFloat::opStatus st =
981+
f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo);
982+
if (st == APFloat::opOK || st == APFloat::opInexact)
983+
return DenseElementsAttr::get(outShaped, f);
984+
return {};
985+
}
986+
987+
return {};
915988
}
916989

917990
//===----------------------------------------------------------------------===//

test/Dialect/Torch/canonicalize.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
17621762
return %0 : !torch.tensor
17631763
}
17641764

1765+
// CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
1766+
func.func @torch.aten.to.dtype$fold_splat() -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>) {
1767+
%false = torch.constant.bool false
1768+
%none = torch.constant.none
1769+
1770+
// int32 splat → float32
1771+
%int_splat = torch.vtensor.literal(dense<42> : tensor<2x3xsi32>) : !torch.vtensor<[2,3],si32>
1772+
%int6 = torch.constant.int 6 // torch.float32
1773+
// CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
1774+
%result1 = torch.aten.to.dtype %int_splat, %int6, %false, %false, %none
1775+
: !torch.vtensor<[2,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1776+
-> !torch.vtensor<[2,3],f32>
1777+
1778+
// float32 splat → int32 (rmTowardZero)
1779+
%float_splat = torch.vtensor.literal(dense<3.14159> : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32>
1780+
%int3 = torch.constant.int 3 // torch.int32
1781+
// CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
1782+
%result2 = torch.aten.to.dtype %float_splat, %int3, %false, %false, %none
1783+
: !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1784+
-> !torch.vtensor<[4,4],si32>
1785+
1786+
// int64 splat (max int32) → int32 (trunc)
1787+
%int64_splat = torch.vtensor.literal(dense<2147483647> : tensor<10xsi64>) : !torch.vtensor<[10],si64>
1788+
// CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<2147483647> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
1789+
%result3 = torch.aten.to.dtype %int64_splat, %int3, %false, %false, %none
1790+
: !torch.vtensor<[10],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none
1791+
-> !torch.vtensor<[10],si32>
1792+
1793+
// float32 splat → float64
1794+
%float32_splat = torch.vtensor.literal(dense<2.71828> : tensor<5x5xf32>) : !torch.vtensor<[5,5],f32>
1795+
%int7 = torch.constant.int 7 // torch.float64
1796+
// CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
1797+
%result4 = torch.aten.to.dtype %float32_splat, %int7, %false, %false, %none
1798+
: !torch.vtensor<[5,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1799+
-> !torch.vtensor<[5,5],f64>
1800+
1801+
// float64 splat → float16
1802+
%float64_splat = torch.vtensor.literal(dense<1.23456> : tensor<3x3xf64>) : !torch.vtensor<[3,3],f64>
1803+
%int5 = torch.constant.int 5 // torch.float16
1804+
// CHECK: %[[R5:.*]] = torch.vtensor.literal({{.*}} : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
1805+
%result5 = torch.aten.to.dtype %float64_splat, %int5, %false, %false, %none
1806+
: !torch.vtensor<[3,3],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none
1807+
-> !torch.vtensor<[3,3],f16>
1808+
1809+
// float32 splat → bfloat16
1810+
%float32_bf16 = torch.vtensor.literal(dense<-0.5> : tensor<2x2xf32>) : !torch.vtensor<[2,2],f32>
1811+
%int15 = torch.constant.int 15 // torch.bfloat16
1812+
// CHECK: %[[R6:.*]] = torch.vtensor.literal({{.*}} : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
1813+
%result6 = torch.aten.to.dtype %float32_bf16, %int15, %false, %false, %none
1814+
: !torch.vtensor<[2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1815+
-> !torch.vtensor<[2,2],bf16>
1816+
1817+
// int32 splat → int64 (sign-extend)
1818+
%int32_ext = torch.vtensor.literal(dense<-1000> : tensor<4xsi32>) : !torch.vtensor<[4],si32>
1819+
%int4 = torch.constant.int 4 // torch.int64
1820+
// CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
1821+
%result7 = torch.aten.to.dtype %int32_ext, %int4, %false, %false, %none
1822+
: !torch.vtensor<[4],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1823+
-> !torch.vtensor<[4],si64>
1824+
1825+
// int32 splat → int16 (trunc)
1826+
%int32_trunc = torch.vtensor.literal(dense<32000> : tensor<3xsi32>) : !torch.vtensor<[3],si32>
1827+
%int2 = torch.constant.int 2 // torch.int16
1828+
// CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<32000> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
1829+
%result8 = torch.aten.to.dtype %int32_trunc, %int2, %false, %false, %none
1830+
: !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none
1831+
-> !torch.vtensor<[3],si16>
1832+
1833+
return %result1, %result2, %result3, %result4, %result5, %result6, %result7, %result8
1834+
: !torch.vtensor<[2,3],f32>, !torch.vtensor<[4,4],si32>, !torch.vtensor<[10],si32>, !torch.vtensor<[5,5],f64>, !torch.vtensor<[3,3],f16>, !torch.vtensor<[2,2],bf16>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si16>
1835+
}
1836+
17651837
// CHECK-LABEL: func.func @torch.aten.to.other$basic(
17661838
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
17671839
// CHECK: %[[NONE:.*]] = torch.constant.none

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,15 @@ func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vt
159159

160160
// CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> {
161161
// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00
162-
// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
163-
// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0> : tensor<si64>) : !torch.vtensor<[],si64>
164-
// CHECK: %[[NONE:.+]] = torch.constant.none
165-
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
166-
// CHECK: %[[INT5:.+]] = torch.constant.int 5
167-
// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1> : tensor<si64>) : !torch.vtensor<[],si64>
162+
// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1.0{{.*}}> : tensor<f16>) : !torch.vtensor<[],f16>
163+
// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0.0{{.*}}> : tensor<f16>) : !torch.vtensor<[],f16>
164+
// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1.0{{.*}}> : tensor<f16>) : !torch.vtensor<[],f16>
168165
// CHECK: %[[INT0:.+]] = torch.constant.int 0
169166
// CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16>
170167
// CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
171168
// CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1>
172-
// CHECK: %[[V6:.+]] = torch.aten.to.dtype %[[V2]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
173-
// CHECK: %[[V7:.+]] = torch.aten.to.dtype %[[V1]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
174-
// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V6]], %[[V7]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16>
175-
// CHECK: %[[V9:.+]] = torch.aten.to.dtype %[[V0]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16>
176-
// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V9]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
169+
// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V2]], %[[V1]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16>
170+
// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V0]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
177171
// CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
178172
// CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>
179173
// CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16>

0 commit comments

Comments
 (0)