18
18
#
19
19
# ===---------------------------------------------------------------------------
20
20
21
- from typing import Dict , Tuple , List
21
+ from typing import Dict , Tuple
22
22
23
23
import mlir .ir as ir
24
24
from mlir .dialects import tosa , linalg , arith , tensor , math
25
- import copy , array , sys
25
+ import copy , array
26
26
import numpy
27
- import functools
28
27
29
28
from ..graph import *
30
29
from ..graph .graph import TensorDType
@@ -2032,13 +2031,8 @@ def gt_op(node: GtOp, symbol_table):
2032
2031
- symbol_table: A mapping of tensor names to their corresponding MLIR objects.
2033
2032
2034
2033
Returns:
2035
- <<<<<<< HEAD
2036
2034
- cmp_op: A comparison operation result indicating where the input tensor's elements
2037
2035
are greater than the scalar.
2038
- =======
2039
- - cmp_op: A comparison operation result indicating where
2040
- the input tensor's elements are greater than the scalar.
2041
- >>>>>>> 93de87f (feat: Add conv2d implemented in linalg)
2042
2036
"""
2043
2037
input_tensor = symbol_table .get ((str (node .args [0 ]), 0 ), node .args [0 ])
2044
2038
input_dtype = ir .RankedTensorType (input_tensor .type ).element_type
@@ -2566,7 +2560,10 @@ def convolution2d_op(
2566
2560
out_shape = node .tensor_meta ["shape" ]
2567
2561
strides_attr = ir ._denseI64ArrayAttr (strides , None )
2568
2562
dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2569
- conv2d_result = tensor .EmptyOp (out_shape , result_element_type )
2563
+ conv2d_result = tensor .EmptyOp (out_shape , result_element_type ).result
2564
+ f32 = ir .F32Type .get ()
2565
+ zero = arith .ConstantOp (value = ir .FloatAttr .get (f32 , 0.0 ), result = f32 ).result
2566
+ conv2d_result = linalg .fill (zero , outs = [conv2d_result ])
2570
2567
conv2d_nchw_op = linalg .conv_2d_nchw_fchw (
2571
2568
input_val ,
2572
2569
filter_val ,
@@ -2591,7 +2588,6 @@ def convolution2d_op(
2591
2588
def maxpool2d_op (
2592
2589
node : Conv2dOp , symbol_table : Dict [Tuple [str , int ], ir .Operation ]
2593
2590
):
2594
- # print(node.kwargs, node.args)
2595
2591
input_ = node .args [0 ]
2596
2592
kernel_size = node .args [1 ]
2597
2593
strides = node .args [2 ]
@@ -2602,22 +2598,34 @@ def maxpool2d_op(
2602
2598
input_value = symbol_table .get ((str (input_ ), 0 ))
2603
2599
kernel_size_value = tensor .EmptyOp (kernel_size , result_element_type )
2604
2600
2605
- if len (node .args ) > 3 :
2606
- dilations = node .args [4 ]
2607
- else :
2608
- dilations = [1 , 1 ]
2609
-
2610
2601
strides_attr = ir ._denseI64ArrayAttr (strides , None )
2611
- dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2612
2602
2613
2603
result = tensor .EmptyOp (result_shape , result_element_type )
2614
- op = linalg .pooling_nchw_max (
2615
- input_value ,
2616
- kernel_size_value ,
2617
- outs = [result ],
2618
- strides = strides_attr ,
2619
- dilations = dilations_attr ,
2604
+ f32 = ir .F32Type .get ()
2605
+
2606
+ # FIXME: fix this magic value!
2607
+ largest = arith .ConstantOp (
2608
+ value = ir .FloatAttr .get (f32 , numpy .finfo (numpy .float32 ).min ), result = f32
2620
2609
)
2610
+ result = linalg .fill (largest , outs = [result ])
2611
+
2612
+ if len (node .args ) > 3 :
2613
+ dilations = node .args [3 ]
2614
+ dilations_attr = ir ._denseI64ArrayAttr (dilations , None )
2615
+ op = linalg .pooling_nchw_max (
2616
+ input_value ,
2617
+ kernel_size_value ,
2618
+ outs = [result ],
2619
+ strides = strides_attr ,
2620
+ dilations = dilations_attr ,
2621
+ )
2622
+ else :
2623
+ op = linalg .pooling_nchw_max (
2624
+ input_value ,
2625
+ kernel_size_value ,
2626
+ outs = [result ],
2627
+ strides = strides_attr ,
2628
+ )
2621
2629
2622
2630
return op
2623
2631
0 commit comments