Skip to content

Commit fd1ad8b

Browse files
committed
remove unused import
1 parent 95b3cc8 commit fd1ad8b

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

tools/example/simple_test.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,23 @@
2929

3030
import ml_dtypes
3131
import torch
32-
from bench import py_timeit_bench
33-
from enhanced_np_to_memref import ranked_memref_to_numpy
3432
from utils import get_mlir_args
3533

34+
# an example of simple validation
3635
if __name__ == "__main__":
3736
with ir.Context() as ctx:
37+
ctx.enable_multithreading(False)
3838
module = ir.Module.parse(
3939
"""
40-
module {
41-
func.func @main_entry(%arg0:tensor<10x10xbf16>, %arg1:tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} {
42-
%0 = onednn_graph.matmul %arg0, %arg1: (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16>
43-
return %0:tensor<10x10xbf16>
44-
}
45-
}
40+
module {
41+
func.func @main_entry(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} {
42+
%cst = arith.constant 0.000000e+00 : bf16
43+
%0 = tensor.empty() : tensor<10x10xbf16>
44+
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<10x10xbf16>) -> tensor<10x10xbf16>
45+
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<10x10xbf16>, tensor<10x10xbf16>) outs(%1 : tensor<10x10xbf16>) -> tensor<10x10xbf16>
46+
return %2 : tensor<10x10xbf16>
47+
}
48+
}
4649
"""
4750
)
4851
torch_arg0 = torch.full((10, 10), 1.0, dtype=torch.bfloat16)
@@ -59,7 +62,7 @@
5962

6063
# just run
6164
compiler = GraphCompiler(passes)
62-
engine = compiler.compile_and_jit(module)
65+
engine = compiler.compile_and_jit(module, ir_printing=True)
6366
engine.invoke(entry, *mlir_args)
6467

6568
print(gc_res)

0 commit comments

Comments
 (0)