Skip to content

Commit a7ff76f

Browse files
committed
feat: add maxpool_2d implementation using linalg
1 parent 1e46aa2 commit a7ff76f

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

frontend/Python/graph/type.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class TensorDType(Enum):
3535
- Bool: str
3636
Represents the boolean data type.
3737
"""
38-
38+
3939
Int8 = "int8"
4040
Int32 = "int32"
4141
Int64 = "int64"
@@ -47,7 +47,7 @@ class TensorDType(Enum):
4747

4848
class TensorMeta:
4949
"""
50-
Store tensor metadata, including shape and data type, while overlooking raw
50+
Store tensor metadata, including shape and data type, while overlooking raw
5151
data.
5252
5353
Attributes:
@@ -58,7 +58,7 @@ class TensorMeta:
5858
5959
Methods:
6060
- __init__(shape: tuple, dtype: str) -> None:
61-
Initializes a new instance of the TensorMeta class with the specified
61+
Initializes a new instance of the TensorMeta class with the specified
6262
shape and data type.
6363
6464
Example:
@@ -79,6 +79,7 @@ def __init__(self, shape, dtype) -> None:
7979
self.shape = shape
8080
self.dtype = dtype
8181

82+
8283
class DeviceType(Enum):
8384
"""
8485
Enumeration class representing different types of devices.
@@ -91,6 +92,7 @@ class DeviceType(Enum):
9192
Each attribute represents a specific device type and is associated with a
9293
string value.
9394
"""
94-
CPU = 'cpu'
95-
GPU = 'gpu'
96-
UNKNOW = 'unknow'
95+
96+
CPU = "cpu"
97+
GPU = "gpu"
98+
UNKNOWN = "unknow"

frontend/Python/ops/linalg.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2560,13 +2560,10 @@ def convolution2d_op(
25602560
dilations = node.args[5]
25612561

25622562
input_val = symbol_table.get((str(input_), 0))
2563-
input_shape = list(ir.RankedTensorType(input_val.type).shape)
25642563
filter_val = symbol_table.get((str(filter_), 0))
2565-
filter_shape = list(ir.RankedTensorType(filter_val.type).shape)
25662564
dtype = node.tensor_meta["dtype"]
25672565
result_element_type = mlir_element_type_get(dtype)
25682566
out_shape = node.tensor_meta["shape"]
2569-
result_tensor_type = ir.RankedTensorType.get(out_shape, result_element_type)
25702567
strides_attr = ir._denseI64ArrayAttr(strides, None)
25712568
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
25722569
conv2d_result = tensor.EmptyOp(out_shape, result_element_type)
@@ -2582,7 +2579,6 @@ def convolution2d_op(
25822579
if len(node._parents) > 2:
25832580
bias_tensor = symbol_table.get((str(bias), 0))
25842581
init = tensor.EmptyOp(out_shape, result_element_type)
2585-
print(f"{bias_tensor}, {init}, {out_shape}")
25862582
broadcasted = linalg.broadcast(
25872583
bias_tensor, outs=[init], dimensions=[0, 2, 3]
25882584
)
@@ -2592,6 +2588,40 @@ def convolution2d_op(
25922588
return op_to_return
25932589

25942590

2591+
def maxpool2d_op(
2592+
node: Conv2dOp, symbol_table: Dict[Tuple[str, int], ir.Operation]
2593+
):
2594+
# print(node.kwargs, node.args)
2595+
input_ = node.args[0]
2596+
kernel_size = node.args[1]
2597+
strides = node.args[2]
2598+
dtype = node.tensor_meta["dtype"]
2599+
result_element_type = mlir_element_type_get(dtype)
2600+
result_shape = node.tensor_meta["shape"]
2601+
2602+
input_value = symbol_table.get((str(input_), 0))
2603+
kernel_size_value = tensor.EmptyOp(kernel_size, result_element_type)
2604+
2605+
if len(node.args) > 3:
2606+
dilations = node.args[4]
2607+
else:
2608+
dilations = [1, 1]
2609+
2610+
strides_attr = ir._denseI64ArrayAttr(strides, None)
2611+
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
2612+
2613+
result = tensor.EmptyOp(result_shape, result_element_type)
2614+
op = linalg.pooling_nchw_max(
2615+
input_value,
2616+
kernel_size_value,
2617+
outs=[result],
2618+
strides=strides_attr,
2619+
dilations=dilations_attr,
2620+
)
2621+
2622+
return op
2623+
2624+
25952625
ops_registry = {
25962626
"MatmulOp": matmul_op,
25972627
"TransposeMatmulFusedOp": matmul_transpose_b_op,
@@ -2635,4 +2665,5 @@ def convolution2d_op(
26352665
"CopyOp": copy_op,
26362666
"SliceScatterOp": slice_scatter_op,
26372667
"Conv2dOp": convolution2d_op,
2668+
"MaxPool2dOp": maxpool2d_op,
26382669
}

0 commit comments

Comments
 (0)