@@ -2560,13 +2560,10 @@ def convolution2d_op(
2560
2560
dilations = node .args [5 ]
2561
2561
2562
2562
input_val = symbol_table .get ((str (input_ ), 0 ))
2563
- input_shape = list (ir .RankedTensorType (input_val .type ).shape )
2564
2563
filter_val = symbol_table .get ((str (filter_ ), 0 ))
2565
- filter_shape = list (ir .RankedTensorType (filter_val .type ).shape )
2566
2564
dtype = node .tensor_meta ["dtype" ]
2567
2565
result_element_type = mlir_element_type_get (dtype )
2568
2566
out_shape = node .tensor_meta ["shape" ]
2569
- result_tensor_type = ir .RankedTensorType .get (out_shape , result_element_type )
2570
2567
strides_attr = ir ._denseI64ArrayAttr (strides , None )
2571
2568
dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2572
2569
conv2d_result = tensor .EmptyOp (out_shape , result_element_type )
@@ -2582,7 +2579,6 @@ def convolution2d_op(
2582
2579
if len (node ._parents ) > 2 :
2583
2580
bias_tensor = symbol_table .get ((str (bias ), 0 ))
2584
2581
init = tensor .EmptyOp (out_shape , result_element_type )
2585
- print (f"{ bias_tensor } , { init } , { out_shape } " )
2586
2582
broadcasted = linalg .broadcast (
2587
2583
bias_tensor , outs = [init ], dimensions = [0 , 2 , 3 ]
2588
2584
)
@@ -2592,6 +2588,40 @@ def convolution2d_op(
2592
2588
return op_to_return
2593
2589
2594
2590
2591
+ def maxpool2d_op (
2592
+ node : Conv2dOp , symbol_table : Dict [Tuple [str , int ], ir .Operation ]
2593
+ ):
2594
+ # print(node.kwargs, node.args)
2595
+ input_ = node .args [0 ]
2596
+ kernel_size = node .args [1 ]
2597
+ strides = node .args [2 ]
2598
+ dtype = node .tensor_meta ["dtype" ]
2599
+ result_element_type = mlir_element_type_get (dtype )
2600
+ result_shape = node .tensor_meta ["shape" ]
2601
+
2602
+ input_value = symbol_table .get ((str (input_ ), 0 ))
2603
+ kernel_size_value = tensor .EmptyOp (kernel_size , result_element_type )
2604
+
2605
+ if len (node .args ) > 3 :
2606
+ dilations = node .args [4 ]
2607
+ else :
2608
+ dilations = [1 , 1 ]
2609
+
2610
+ strides_attr = ir ._denseI64ArrayAttr (strides , None )
2611
+ dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2612
+
2613
+ result = tensor .EmptyOp (result_shape , result_element_type )
2614
+ op = linalg .pooling_nchw_max (
2615
+ input_value ,
2616
+ kernel_size_value ,
2617
+ outs = [result ],
2618
+ strides = strides_attr ,
2619
+ dilations = dilations_attr ,
2620
+ )
2621
+
2622
+ return op
2623
+
2624
+
2595
2625
ops_registry = {
2596
2626
"MatmulOp" : matmul_op ,
2597
2627
"TransposeMatmulFusedOp" : matmul_transpose_b_op ,
@@ -2635,4 +2665,5 @@ def convolution2d_op(
2635
2665
"CopyOp" : copy_op ,
2636
2666
"SliceScatterOp" : slice_scatter_op ,
2637
2667
"Conv2dOp" : convolution2d_op ,
2668
+ "MaxPool2dOp" : maxpool2d_op ,
2638
2669
}
0 commit comments