Skip to content

Commit 27e1a5e

Browse files
committed
fix: Add initialization for maxpool_2d and conv2d.
1 parent a7ff76f commit 27e1a5e

File tree

4 files changed

+33
-28
lines changed

4 files changed

+33
-28
lines changed

examples/BuddyLeNet/buddy-lenet-import.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"--output-dir",
3838
type=str,
3939
default="./",
40-
help="Directory to save output files."
40+
help="Directory to save output files.",
4141
)
4242
args = parser.parse_args()
4343

@@ -54,8 +54,7 @@
5454

5555
# Initialize Dynamo Compiler with specific configurations as an importer.
5656
dynamo_compiler = DynamoCompiler(
57-
primary_registry=linalg.ops_registry,
58-
verbose=True
57+
primary_registry=linalg.ops_registry, verbose=True
5958
)
6059

6160
data = torch.randn([1, 1, 28, 28])

examples/BuddyLeNet/buddy-lenet-main.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
#include <cstdlib>
2222
#include <filesystem>
2323
#include <fstream>
24-
#include <limits>
2524
#include <string>
26-
#include <utility>
2725
#include <vector>
2826

2927
constexpr size_t ParamsSize = 44426;

frontend/Python/graph/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ class DeviceType(Enum):
9595

9696
CPU = "cpu"
9797
GPU = "gpu"
98-
UNKNOWN = "unknow"
98+
UNKNOWN = "unknown"

frontend/Python/ops/linalg.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
#
1919
# ===---------------------------------------------------------------------------
2020

21-
from typing import Dict, Tuple, List
21+
from typing import Dict, Tuple
2222

2323
import mlir.ir as ir
2424
from mlir.dialects import tosa, linalg, arith, tensor, math
25-
import copy, array, sys
25+
import copy, array
2626
import numpy
27-
import functools
2827

2928
from ..graph import *
3029
from ..graph.graph import TensorDType
@@ -2032,13 +2031,8 @@ def gt_op(node: GtOp, symbol_table):
20322031
- symbol_table: A mapping of tensor names to their corresponding MLIR objects.
20332032
20342033
Returns:
2035-
<<<<<<< HEAD
20362034
- cmp_op: A comparison operation result indicating where the input tensor's elements
20372035
are greater than the scalar.
2038-
=======
2039-
- cmp_op: A comparison operation result indicating where
2040-
the input tensor's elements are greater than the scalar.
2041-
>>>>>>> 93de87f (feat: Add conv2d implemented in linalg)
20422036
"""
20432037
input_tensor = symbol_table.get((str(node.args[0]), 0), node.args[0])
20442038
input_dtype = ir.RankedTensorType(input_tensor.type).element_type
@@ -2566,7 +2560,10 @@ def convolution2d_op(
25662560
out_shape = node.tensor_meta["shape"]
25672561
strides_attr = ir._denseI64ArrayAttr(strides, None)
25682562
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
2569-
conv2d_result = tensor.EmptyOp(out_shape, result_element_type)
2563+
conv2d_result = tensor.EmptyOp(out_shape, result_element_type).result
2564+
f32 = ir.F32Type.get()
2565+
zero = arith.ConstantOp(value=ir.FloatAttr.get(f32, 0.0), result=f32).result
2566+
conv2d_result = linalg.fill(zero, outs=[conv2d_result])
25702567
conv2d_nchw_op = linalg.conv_2d_nchw_fchw(
25712568
input_val,
25722569
filter_val,
@@ -2591,7 +2588,6 @@ def convolution2d_op(
25912588
def maxpool2d_op(
25922589
node: Conv2dOp, symbol_table: Dict[Tuple[str, int], ir.Operation]
25932590
):
2594-
# print(node.kwargs, node.args)
25952591
input_ = node.args[0]
25962592
kernel_size = node.args[1]
25972593
strides = node.args[2]
@@ -2602,22 +2598,34 @@ def maxpool2d_op(
26022598
input_value = symbol_table.get((str(input_), 0))
26032599
kernel_size_value = tensor.EmptyOp(kernel_size, result_element_type)
26042600

2605-
if len(node.args) > 3:
2606-
dilations = node.args[4]
2607-
else:
2608-
dilations = [1, 1]
2609-
26102601
strides_attr = ir._denseI64ArrayAttr(strides, None)
2611-
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
26122602

26132603
result = tensor.EmptyOp(result_shape, result_element_type)
2614-
op = linalg.pooling_nchw_max(
2615-
input_value,
2616-
kernel_size_value,
2617-
outs=[result],
2618-
strides=strides_attr,
2619-
dilations=dilations_attr,
2604+
f32 = ir.F32Type.get()
2605+
2606+
# FIXME: fix this magic value!
2607+
largest = arith.ConstantOp(
2608+
value=ir.FloatAttr.get(f32, numpy.finfo(numpy.float32).min), result=f32
26202609
)
2610+
result = linalg.fill(largest, outs=[result])
2611+
2612+
if len(node.args) > 3:
2613+
dilations = node.args[3]
2614+
dilations_attr = ir._denseI64ArrayAttr(dilations, None)
2615+
op = linalg.pooling_nchw_max(
2616+
input_value,
2617+
kernel_size_value,
2618+
outs=[result],
2619+
strides=strides_attr,
2620+
dilations=dilations_attr,
2621+
)
2622+
else:
2623+
op = linalg.pooling_nchw_max(
2624+
input_value,
2625+
kernel_size_value,
2626+
outs=[result],
2627+
strides=strides_attr,
2628+
)
26212629

26222630
return op
26232631

0 commit comments

Comments
 (0)