|
29 | 29 |
|
30 | 30 | import ml_dtypes
|
31 | 31 | import torch
|
32 |
| -from bench import py_timeit_bench |
33 |
| -from enhanced_np_to_memref import ranked_memref_to_numpy |
34 | 32 | from utils import get_mlir_args
|
35 | 33 |
|
| 34 | +# an example of simple validation |
36 | 35 | if __name__ == "__main__":
|
37 | 36 | with ir.Context() as ctx:
|
| 37 | + ctx.enable_multithreading(False) |
38 | 38 | module = ir.Module.parse(
|
39 | 39 | """
|
40 |
| - module { |
41 |
| - func.func @main_entry(%arg0:tensor<10x10xbf16>, %arg1:tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} { |
42 |
| - %0 = onednn_graph.matmul %arg0, %arg1: (tensor<10x10xbf16>, tensor<10x10xbf16>) -> tensor<10x10xbf16> |
43 |
| - return %0:tensor<10x10xbf16> |
44 |
| - } |
45 |
| - } |
| 40 | + module { |
| 41 | + func.func @main_entry(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} { |
| 42 | + %cst = arith.constant 0.000000e+00 : bf16 |
| 43 | + %0 = tensor.empty() : tensor<10x10xbf16> |
| 44 | + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<10x10xbf16>) -> tensor<10x10xbf16> |
| 45 | + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<10x10xbf16>, tensor<10x10xbf16>) outs(%1 : tensor<10x10xbf16>) -> tensor<10x10xbf16> |
| 46 | + return %2 : tensor<10x10xbf16> |
| 47 | + } |
| 48 | + } |
46 | 49 | """
|
47 | 50 | )
|
48 | 51 | torch_arg0 = torch.full((10, 10), 1.0, dtype=torch.bfloat16)
|
|
59 | 62 |
|
60 | 63 | # just run
|
61 | 64 | compiler = GraphCompiler(passes)
|
62 |
| - engine = compiler.compile_and_jit(module) |
| 65 | + engine = compiler.compile_and_jit(module, ir_printing=True) |
63 | 66 | engine.invoke(entry, *mlir_args)
|
64 | 67 |
|
65 | 68 | print(gc_res)
|
|
0 commit comments