Skip to content

Commit d45fd33

Browse files
committed
[Frontend] Achieve norm optimization by scf dialect and intergrate into E2E pass
1 parent 5cb2360 commit d45fd33

File tree

6 files changed

+410
-3
lines changed

6 files changed

+410
-3
lines changed

examples/BuddyNext/makefile

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ next-positional-encoding-aot:
329329
./next-positional-encoding.out || true
330330

331331
next-norm-aot-omp:
332-
@${MLIR_OPT} ./next-norm.mlir \
332+
@${MLIR_OPT} ./next-norm-opt.mlir \
333333
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
334334
${BUDDY_OPT} \
335335
-eliminate-empty-tensors \
@@ -371,6 +371,48 @@ next-norm-aot-omp:
371371
-o next-norm.out
372372
./next-norm.out || true
373373

374+
next-norm-opt-aot:
375+
@${MLIR_OPT} ./next-norm-opt.mlir \
376+
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
377+
${BUDDY_OPT} \
378+
-eliminate-empty-tensors \
379+
-empty-tensor-to-alloc-tensor \
380+
-convert-elementwise-to-linalg \
381+
-one-shot-bufferize="bufferize-function-boundaries" \
382+
-expand-strided-metadata \
383+
-ownership-based-buffer-deallocation \
384+
-buffer-deallocation-simplification \
385+
-bufferization-lower-deallocations \
386+
-matmul-parallel-vectorization-optimize \
387+
-batchmatmul-optimize \
388+
-convert-linalg-to-affine-loops \
389+
-affine-loop-fusion \
390+
-affine-parallelize \
391+
-convert-vector-to-scf \
392+
-lower-affine \
393+
-func-bufferize-dynamic-offset \
394+
-cse \
395+
-memref-expand \
396+
-arith-expand \
397+
-convert-vector-to-llvm \
398+
-convert-arith-to-llvm \
399+
-finalize-memref-to-llvm \
400+
-convert-scf-to-cf \
401+
-convert-cf-to-llvm \
402+
-convert-openmp-to-llvm \
403+
-convert-arith-to-llvm \
404+
-convert-math-to-llvm \
405+
-convert-math-to-libm \
406+
-convert-func-to-llvm \
407+
-reconcile-unrealized-casts | \
408+
${MLIR_TRANSLATE} -mlir-to-llvmir | \
409+
${CLANG} -x ir - \
410+
${MARCH_FLAG} -O3 \
411+
-L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils -lomp -lm \
412+
-Wl,-rpath,${MLIR_LIB} \
413+
-o next-norm.out
414+
./next-norm.out || true
415+
374416
next-norm-aot:
375417
@${MLIR_OPT} ./next-norm.mlir \
376418
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// RUN: buddy-opt %s \
2+
// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \
3+
// RUN: | buddy-opt \
4+
// RUN: -eliminate-empty-tensors \
5+
// RUN: -empty-tensor-to-alloc-tensor \
6+
// RUN: -convert-elementwise-to-linalg \
7+
// RUN: -one-shot-bufferize="unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" \
8+
// RUN: -expand-strided-metadata \
9+
// RUN: -ownership-based-buffer-deallocation \
10+
// RUN: -buffer-deallocation-simplification \
11+
// RUN: -bufferization-lower-deallocations \
12+
// RUN: -convert-bufferization-to-memref \
13+
// RUN: -matmul-parallel-vectorization-optimize \
14+
// RUN: -batchmatmul-optimize \
15+
// RUN: -convert-linalg-to-affine-loops \
16+
// RUN: -affine-loop-fusion \
17+
// RUN: -affine-parallelize \
18+
// RUN: -convert-vector-to-scf \
19+
// RUN: -lower-affine \
20+
// RUN: -convert-scf-to-openmp \
21+
// RUN: -func-bufferize-dynamic-offset \
22+
// RUN: -cse \
23+
// RUN: -memref-expand \
24+
// RUN: -arith-expand \
25+
// RUN: -convert-vector-to-llvm \
26+
// RUN: -convert-arith-to-llvm \
27+
// RUN: -finalize-memref-to-llvm \
28+
// RUN: -convert-scf-to-cf \
29+
// RUN: -convert-cf-to-llvm \
30+
// RUN: -convert-openmp-to-llvm \
31+
// RUN: -convert-arith-to-llvm \
32+
// RUN: -convert-math-to-llvm \
33+
// RUN: -convert-math-to-libm \
34+
// RUN: -convert-func-to-llvm \
35+
// RUN: -reconcile-unrealized-casts \
36+
// RUN: | mlir-runner -e main -entry-point-result=void \
37+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
38+
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
39+
// RUN: -shared-libs=%mlir_runner_utils_dir/libomp%shlibext \
40+
// RUN: | FileCheck %s
41+
42+
func.func private @rtclock() -> f64
43+
func.func private @printMemrefF32(%ptr : tensor<*xf32>)
44+
45+
#map = affine_map<(d0, d1, d2) -> (d1)>
46+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
47+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
48+
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
49+
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
50+
51+
func.func @kernel(%arg0: tensor<1x1024x1536xf32>, %arg1: tensor<1536xf32>) -> tensor<1x1024x1536xf32> {
52+
%t_start = call @rtclock() : () -> f64
53+
54+
%eps = arith.constant 9.99999997E-7 : f32
55+
%zero = arith.constant 0.0 : f32
56+
%c0 = arith.constant 0 : index
57+
%c1 = arith.constant 1 : index
58+
%c1024 = arith.constant 1024 : index
59+
%c1536 = arith.constant 1536 : index
60+
%dim = arith.constant 1.536000e+03 : f32
61+
62+
%x_memref = bufferization.to_memref %arg0 : tensor<1x1024x1536xf32> to memref<1x1024x1536xf32>
63+
%g_memref = bufferization.to_memref %arg1 : tensor<1536xf32> to memref<1536xf32>
64+
%y_memref = memref.alloc() : memref<1x1024x1536xf32>
65+
66+
scf.for %b = %c0 to %c1024 step %c1 {
67+
%acc= scf.parallel (%i) = (%c0) to (%c1536) step (%c1) init(%zero) -> (f32) {
68+
%x = memref.load %x_memref[%c0, %b, %i] : memref<1x1024x1536xf32>
69+
%x_sq = arith.mulf %x, %x : f32
70+
scf.reduce(%x_sq : f32) {
71+
^bb0(%lhs : f32, %rhs: f32):
72+
%res = arith.addf %lhs, %rhs : f32
73+
scf.reduce.return %res : f32
74+
}
75+
}
76+
%mean = arith.divf %acc, %dim : f32
77+
%m_eps = arith.addf %mean, %eps : f32
78+
%inv_rms = math.rsqrt %m_eps : f32
79+
scf.for %i = %c0 to %c1536 step %c1 {
80+
%x = memref.load %x_memref[%c0, %b, %i] : memref<1x1024x1536xf32>
81+
%g = memref.load %g_memref[%i] : memref<1536xf32>
82+
%x_norm = arith.mulf %x, %inv_rms : f32
83+
%y = arith.mulf %x_norm, %g : f32
84+
memref.store %y, %y_memref[%c0, %b, %i] : memref<1x1024x1536xf32>
85+
}
86+
}
87+
88+
%out = bufferization.to_tensor %y_memref restrict : memref<1x1024x1536xf32> to tensor<1x1024x1536xf32>
89+
90+
%t_end = call @rtclock() : () -> f64
91+
%time = arith.subf %t_end, %t_start : f64
92+
93+
// Print timings.
94+
vector.print %time : f64
95+
// CHECK: {{[0-9]+\.[0-9]+}}
96+
97+
return %out : tensor<1x1024x1536xf32>
98+
}
99+
100+
func.func @main() {
101+
102+
%cst_3 = arith.constant 3.0 : f32
103+
%empty_0 = tensor.empty() : tensor<1x1024x1536xf32>
104+
%c0 = linalg.fill ins(%cst_3 : f32) outs(%empty_0 : tensor<1x1024x1536xf32>) -> tensor<1x1024x1536xf32>
105+
106+
%cst_2 = arith.constant 2.0 : f32
107+
%empty_1 = tensor.empty() : tensor<1536xf32>
108+
%c1 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<1536xf32>) -> tensor<1536xf32>
109+
110+
%res = call @kernel(%c0, %c1) : (tensor<1x1024x1536xf32>, tensor<1536xf32>) -> tensor<1x1024x1536xf32>
111+
112+
%tensor_unranked = tensor.cast %res : tensor<1x1024x1536xf32> to tensor<*xf32>
113+
// Print results.
114+
// call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> ()
115+
116+
return
117+
}

frontend/Python/graph/operation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,10 @@ class AsStridedOp(Op):
661661
def __init__(self) -> None:
662662
super().__init__()
663663
self._op_type = OpType.ElementwiseType
664+
665+
666+
class NormOp(Op):
667+
def __init__(self) -> None:
668+
super().__init__()
669+
self._op_type = OpType.ElementwiseType
670+

frontend/Python/graph/transform/fuse_ops.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from .. import DeviceType
2424
from torch.fx.immutable_collections import immutable_list
2525

26-
classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp}
26+
classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp,
27+
"norm_fusion": NormOp
28+
}
2729

2830
# TODO: classify op type for op fusion
2931
# OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType]
@@ -56,6 +58,39 @@ def classic_fuse_check(graph: Graph):
5658
transpose_matmul_fusion(
5759
graph, op, pattern[0], pattern[1], pattern[2]
5860
)
61+
# === LayerNorm pattern ===
62+
if isinstance(op, PowOp):
63+
# check LayerNorm pattern: pow -> mean -> add -> rsqrt -> mul -> mul
64+
if not op._children:
65+
continue
66+
mean_node = graph.node_table.get(op._children[0], None)
67+
if not isinstance(mean_node, MeanOp):
68+
continue
69+
70+
if not mean_node._children:
71+
continue
72+
add_node = graph.node_table.get(mean_node._children[0], None)
73+
if not isinstance(add_node, AddOp):
74+
continue
75+
76+
if not add_node._children:
77+
continue
78+
rsqrt_node = graph.node_table.get(add_node._children[0], None)
79+
if not isinstance(rsqrt_node, RsqrtOp):
80+
continue
81+
82+
if not rsqrt_node._children:
83+
continue
84+
mul_node = graph.node_table.get(rsqrt_node._children[0], None)
85+
if not isinstance(mul_node, MulOp):
86+
continue
87+
88+
if not mul_node._children:
89+
continue
90+
mul_2_node = graph.node_table.get(mul_node._children[0], None)
91+
if not isinstance(mul_2_node, MulOp):
92+
continue
93+
norm_fusion(graph, op, mean_node, add_node, rsqrt_node, mul_node, mul_2_node, "norm_fusion")
5994

6095

6196
def transpose_matmul_fusion(
@@ -91,6 +126,53 @@ def transpose_matmul_fusion(
91126
graph.delete_node(target, targets_parent)
92127

93128

129+
def norm_fusion(
130+
graph: Graph,
131+
pow_node: Op,
132+
mean_node: Op,
133+
add_node: Op,
134+
rsqrt_node: Op,
135+
mul_node: Op,
136+
mul_2_node: Op,
137+
pattern: str,
138+
):
139+
"""
140+
Fuse LayerNorm subgraph (Pow -> Mean -> Add -> Rsqrt -> Mul -> Mul)
141+
into one LayerNormFusedOp.
142+
"""
143+
fused_cls = classicfuse_register.get(pattern)
144+
145+
fused_op = fused_cls()
146+
147+
fused_op.name = "NormOp"
148+
149+
graph.displace_node(mul_2_node, fused_op)
150+
151+
fused_op.args.pop(fused_op.args.index(mul_node.name))
152+
153+
fused_op._parents.pop(fused_op._parents.index(mul_node.name))
154+
for parent_name in pow_node._parents:
155+
fused_op._parents.append(parent_name)
156+
fused_op.args.append(parent_name)
157+
158+
mul_node._children.clear()
159+
160+
if graph.check_delete_node(mul_node):
161+
graph.delete_node(mul_node,[graph.node_table.get(mul_node._parents[0], None),graph.node_table.get(mul_node._parents[1], None)])
162+
163+
if graph.check_delete_node(rsqrt_node):
164+
graph.delete_node(rsqrt_node,[add_node])
165+
166+
if graph.check_delete_node(add_node):
167+
graph.delete_node(add_node,[mean_node])
168+
169+
if graph.check_delete_node(mean_node):
170+
graph.delete_node(mean_node,[pow_node])
171+
172+
if graph.check_delete_node(pow_node):
173+
graph.delete_node(pow_node,[graph.node_table.get(mul_node._parents[0], None)])
174+
175+
94176
def apply_classic_fusion(graph: Graph):
95177
"""
96178
Function to fuse some typical operations into one operation and fuse

0 commit comments

Comments
 (0)