@@ -1762,6 +1762,78 @@ func.func @torch.aten.to.dtype$no_fold$unk_dtype(%arg0: !torch.tensor) -> !torch
1762
1762
return %0 : !torch.tensor
1763
1763
}
1764
1764
1765
+ // CHECK-LABEL: @torch.aten.to.dtype$fold_splat(
1766
+ func.func @torch.aten.to.dtype$fold_splat () -> (!torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >) {
1767
+ %false = torch.constant.bool false
1768
+ %none = torch.constant.none
1769
+
1770
+ // int32 splat → float32
1771
+ %int_splat = torch.vtensor.literal (dense <42 > : tensor <2 x3 xsi32 >) : !torch.vtensor <[2 ,3 ],si32 >
1772
+ %int6 = torch.constant.int 6 // torch.float32
1773
+ // CHECK: %[[R1:.*]] = torch.vtensor.literal({{.*}} : tensor<2x3xf32>) : !torch.vtensor<[2,3],f32>
1774
+ %result1 = torch.aten.to.dtype %int_splat , %int6 , %false , %false , %none
1775
+ : !torch.vtensor <[2 ,3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1776
+ -> !torch.vtensor <[2 ,3 ],f32 >
1777
+
1778
+ // float32 splat → int32 (rmTowardZero)
1779
+ %float_splat = torch.vtensor.literal (dense <3.14159 > : tensor <4 x4 xf32 >) : !torch.vtensor <[4 ,4 ],f32 >
1780
+ %int3 = torch.constant.int 3 // torch.int32
1781
+ // CHECK: %[[R2:.*]] = torch.vtensor.literal(dense<3> : tensor<4x4xsi32>) : !torch.vtensor<[4,4],si32>
1782
+ %result2 = torch.aten.to.dtype %float_splat , %int3 , %false , %false , %none
1783
+ : !torch.vtensor <[4 ,4 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1784
+ -> !torch.vtensor <[4 ,4 ],si32 >
1785
+
1786
+ // int64 splat (max int32) → int32 (trunc)
1787
+ %int64_splat = torch.vtensor.literal (dense <2147483647 > : tensor <10 xsi64 >) : !torch.vtensor <[10 ],si64 >
1788
+ // CHECK: %[[R3:.*]] = torch.vtensor.literal(dense<2147483647> : tensor<10xsi32>) : !torch.vtensor<[10],si32>
1789
+ %result3 = torch.aten.to.dtype %int64_splat , %int3 , %false , %false , %none
1790
+ : !torch.vtensor <[10 ],si64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1791
+ -> !torch.vtensor <[10 ],si32 >
1792
+
1793
+ // float32 splat → float64
1794
+ %float32_splat = torch.vtensor.literal (dense <2.71828 > : tensor <5 x5 xf32 >) : !torch.vtensor <[5 ,5 ],f32 >
1795
+ %int7 = torch.constant.int 7 // torch.float64
1796
+ // CHECK: %[[R4:.*]] = torch.vtensor.literal({{.*}} : tensor<5x5xf64>) : !torch.vtensor<[5,5],f64>
1797
+ %result4 = torch.aten.to.dtype %float32_splat , %int7 , %false , %false , %none
1798
+ : !torch.vtensor <[5 ,5 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1799
+ -> !torch.vtensor <[5 ,5 ],f64 >
1800
+
1801
+ // float64 splat → float16
1802
+ %float64_splat = torch.vtensor.literal (dense <1.23456 > : tensor <3 x3 xf64 >) : !torch.vtensor <[3 ,3 ],f64 >
1803
+ %int5 = torch.constant.int 5 // torch.float16
1804
+ // CHECK: %[[R5:.*]] = torch.vtensor.literal({{.*}} : tensor<3x3xf16>) : !torch.vtensor<[3,3],f16>
1805
+ %result5 = torch.aten.to.dtype %float64_splat , %int5 , %false , %false , %none
1806
+ : !torch.vtensor <[3 ,3 ],f64 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1807
+ -> !torch.vtensor <[3 ,3 ],f16 >
1808
+
1809
+ // float32 splat → bfloat16
1810
+ %float32_bf16 = torch.vtensor.literal (dense <-0.5 > : tensor <2 x2 xf32 >) : !torch.vtensor <[2 ,2 ],f32 >
1811
+ %int15 = torch.constant.int 15 // torch.bfloat16
1812
+ // CHECK: %[[R6:.*]] = torch.vtensor.literal({{.*}} : tensor<2x2xbf16>) : !torch.vtensor<[2,2],bf16>
1813
+ %result6 = torch.aten.to.dtype %float32_bf16 , %int15 , %false , %false , %none
1814
+ : !torch.vtensor <[2 ,2 ],f32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1815
+ -> !torch.vtensor <[2 ,2 ],bf16 >
1816
+
1817
+ // int32 splat → int64 (sign-extend)
1818
+ %int32_ext = torch.vtensor.literal (dense <-1000 > : tensor <4 xsi32 >) : !torch.vtensor <[4 ],si32 >
1819
+ %int4 = torch.constant.int 4 // torch.int64
1820
+ // CHECK: %[[R7:.*]] = torch.vtensor.literal(dense<-1000> : tensor<4xsi64>) : !torch.vtensor<[4],si64>
1821
+ %result7 = torch.aten.to.dtype %int32_ext , %int4 , %false , %false , %none
1822
+ : !torch.vtensor <[4 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1823
+ -> !torch.vtensor <[4 ],si64 >
1824
+
1825
+ // int32 splat → int16 (trunc)
1826
+ %int32_trunc = torch.vtensor.literal (dense <32000 > : tensor <3 xsi32 >) : !torch.vtensor <[3 ],si32 >
1827
+ %int2 = torch.constant.int 2 // torch.int16
1828
+ // CHECK: %[[R8:.*]] = torch.vtensor.literal(dense<32000> : tensor<3xsi16>) : !torch.vtensor<[3],si16>
1829
+ %result8 = torch.aten.to.dtype %int32_trunc , %int2 , %false , %false , %none
1830
+ : !torch.vtensor <[3 ],si32 >, !torch.int , !torch.bool , !torch.bool , !torch.none
1831
+ -> !torch.vtensor <[3 ],si16 >
1832
+
1833
+ return %result1 , %result2 , %result3 , %result4 , %result5 , %result6 , %result7 , %result8
1834
+ : !torch.vtensor <[2 ,3 ],f32 >, !torch.vtensor <[4 ,4 ],si32 >, !torch.vtensor <[10 ],si32 >, !torch.vtensor <[5 ,5 ],f64 >, !torch.vtensor <[3 ,3 ],f16 >, !torch.vtensor <[2 ,2 ],bf16 >, !torch.vtensor <[4 ],si64 >, !torch.vtensor <[3 ],si16 >
1835
+ }
1836
+
1765
1837
// CHECK-LABEL: func.func @torch.aten.to.other$basic(
1766
1838
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
1767
1839
// CHECK: %[[NONE:.*]] = torch.constant.none
0 commit comments