@@ -124,45 +124,43 @@ def forward(self, x, y):
124
124
print (m )
125
125
126
126
127
- # Running this test only for the latest torch version since it's generating different IR for older torch versions.
128
- if str (torch .__version__ ) >= "2.6.0" :
129
-
130
- @run
131
- # CHECK-LABEL: test_outer_with_squared_shape
132
- # CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
133
- # CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
134
- # CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
135
- # CHECK: %[[I0:.+]] = torch.constant.int 0
136
- # CHECK: %[[SIZE:.+]] = torch.aten.size.int %[[ARG0]], %[[I0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
137
- # CHECK: %[[OUTER:.+]] = torch.operator "torch.aten.outer"(%[[ARG0]], %[[ARG0]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32>
138
- # CHECK: torch.bind_symbolic_shape %[[OUTER]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32>
139
- # CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[SIZE]], %[[SIZE]] : !torch.int, !torch.int -> !torch.int
140
- # CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MUL]] : (!torch.int) -> !torch.list<int>
141
- # CHECK: %[[VIEW:.+]] = torch.aten.view %[[OUTER]], %[[LIST]] : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
142
- # CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32>
143
- # CHECK: return %[[VIEW]] : !torch.vtensor<[?],f32>
144
- def test_outer_with_squared_shape ():
145
- class OuterWithSquaredShape (torch .nn .Module ):
146
- def __init__ (self ):
147
- super ().__init__ ()
148
-
149
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
150
- return torch .outer (x , x ).flatten ()
151
-
152
- # Sample inputs
153
- x = torch .rand (10 )
154
-
155
- # Dynamic dim constraints
156
- batch = Dim ("batch" , max = 10 )
157
- dynamic_shapes = {"x" : {0 : batch }}
158
-
159
- m = fx .export_and_import (
160
- OuterWithSquaredShape (),
161
- x ,
162
- dynamic_shapes = dynamic_shapes ,
163
- import_symbolic_shape_expressions = True ,
164
- )
165
- print (m )
127
+ @run
128
+ # TODO: Enable these checks once the IR generated is same for both nightly and stable Torch version.
129
+ # C_HECK-LABEL: test_outer_with_squared_shape
130
+ # C_HECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> {
131
+ # C_HECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
132
+ # C_HECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
133
+ # C_HECK: %[[I0:.+]] = torch.constant.int 0
134
+ # C_HECK: %[[SIZE:.+]] = torch.aten.size.int %[[ARG0]], %[[I0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
135
+ # C_HECK: %[[OUTER:.+]] = torch.operator "torch.aten.outer"(%[[ARG0]], %[[ARG0]]) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32>
136
+ # C_HECK: torch.bind_symbolic_shape %[[OUTER]], [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32>
137
+ # C_HECK: %[[MUL:.+]] = torch.aten.mul.int %[[SIZE]], %[[SIZE]] : !torch.int, !torch.int -> !torch.int
138
+ # C_HECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MUL]] : (!torch.int) -> !torch.list<int>
139
+ # C_HECK: %[[VIEW:.+]] = torch.aten.view %[[OUTER]], %[[LIST]] : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
140
+ # C_HECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32>
141
+ # C_HECK: return %[[VIEW]] : !torch.vtensor<[?],f32>
142
+ def test_outer_with_squared_shape ():
143
+ class OuterWithSquaredShape (torch .nn .Module ):
144
+ def __init__ (self ):
145
+ super ().__init__ ()
146
+
147
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
148
+ return torch .outer (x , x ).flatten ()
149
+
150
+ # Sample inputs
151
+ x = torch .rand (10 )
152
+
153
+ # Dynamic dim constraints
154
+ batch = Dim ("batch" , max = 10 )
155
+ dynamic_shapes = {"x" : {0 : batch }}
156
+
157
+ m = fx .export_and_import (
158
+ OuterWithSquaredShape (),
159
+ x ,
160
+ dynamic_shapes = dynamic_shapes ,
161
+ import_symbolic_shape_expressions = True ,
162
+ )
163
+ print (m )
166
164
167
165
168
166
@run
0 commit comments