Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion examples/BuddyNext/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ next-positional-encoding-aot:
./next-positional-encoding.out || true

next-norm-aot-omp:
@${MLIR_OPT} ./next-norm.mlir \
@${MLIR_OPT} ./next-norm-opt.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
${BUDDY_OPT} \
-eliminate-empty-tensors \
Expand Down Expand Up @@ -371,6 +371,48 @@ next-norm-aot-omp:
-o next-norm.out
./next-norm.out || true

next-norm-opt-aot:
@${MLIR_OPT} ./next-norm-opt.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
${BUDDY_OPT} \
-eliminate-empty-tensors \
-empty-tensor-to-alloc-tensor \
-convert-elementwise-to-linalg \
-one-shot-bufferize="bufferize-function-boundaries" \
-expand-strided-metadata \
-ownership-based-buffer-deallocation \
-buffer-deallocation-simplification \
-bufferization-lower-deallocations \
-matmul-parallel-vectorization-optimize \
-batchmatmul-optimize \
-convert-linalg-to-affine-loops \
-affine-loop-fusion \
-affine-parallelize \
-convert-vector-to-scf \
-lower-affine \
-func-bufferize-dynamic-offset \
-cse \
-memref-expand \
-arith-expand \
-convert-vector-to-llvm \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-cf-to-llvm \
-convert-openmp-to-llvm \
-convert-arith-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_TRANSLATE} -mlir-to-llvmir | \
${CLANG} -x ir - \
${MARCH_FLAG} -O3 \
-L${MLIR_LIB} -lmlir_runner_utils -lmlir_c_runner_utils -lomp -lm \
-Wl,-rpath,${MLIR_LIB} \
-o next-norm.out
./next-norm.out || true

next-norm-aot:
@${MLIR_OPT} ./next-norm.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
Expand Down
117 changes: 117 additions & 0 deletions examples/BuddyNext/next-norm-opt.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// RUN: buddy-opt %s \
// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \
// RUN: | buddy-opt \
// RUN: -eliminate-empty-tensors \
// RUN: -empty-tensor-to-alloc-tensor \
// RUN: -convert-elementwise-to-linalg \
// RUN: -one-shot-bufferize="unknown-type-conversion=identity-layout-map function-boundary-type-conversion=identity-layout-map bufferize-function-boundaries" \
// RUN: -expand-strided-metadata \
// RUN: -ownership-based-buffer-deallocation \
// RUN: -buffer-deallocation-simplification \
// RUN: -bufferization-lower-deallocations \
// RUN: -convert-bufferization-to-memref \
// RUN: -matmul-parallel-vectorization-optimize \
// RUN: -batchmatmul-optimize \
// RUN: -convert-linalg-to-affine-loops \
// RUN: -affine-loop-fusion \
// RUN: -affine-parallelize \
// RUN: -convert-vector-to-scf \
// RUN: -lower-affine \
// RUN: -convert-scf-to-openmp \
// RUN: -func-bufferize-dynamic-offset \
// RUN: -cse \
// RUN: -memref-expand \
// RUN: -arith-expand \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-scf-to-cf \
// RUN: -convert-cf-to-llvm \
// RUN: -convert-openmp-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-math-to-llvm \
// RUN: -convert-math-to-libm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libomp%shlibext \
// RUN: | FileCheck %s

func.func private @rtclock() -> f64
func.func private @printMemrefF32(%ptr : tensor<*xf32>)

#map = affine_map<(d0, d1, d2) -> (d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func.func @kernel(%arg0: tensor<1x1024x1536xf32>, %arg1: tensor<1536xf32>) -> tensor<1x1024x1536xf32> {
%t_start = call @rtclock() : () -> f64

%eps = arith.constant 9.99999997E-7 : f32
%zero = arith.constant 0.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%c1024 = arith.constant 1024 : index
%c1536 = arith.constant 1536 : index
%dim = arith.constant 1.536000e+03 : f32

%x_memref = bufferization.to_memref %arg0 : tensor<1x1024x1536xf32> to memref<1x1024x1536xf32>
%g_memref = bufferization.to_memref %arg1 : tensor<1536xf32> to memref<1536xf32>
%y_memref = memref.alloc() : memref<1x1024x1536xf32>

scf.for %b = %c0 to %c1024 step %c1 {
%acc = scf.for %i = %c0 to %c1536 step %c16 iter_args(%acc_iter = %zero) -> (f32) {
%x_vec = vector.load %x_memref[%c0, %b, %i] : memref<1x1024x1536xf32>, vector<16xf32>
%x_sq_vec = arith.mulf %x_vec, %x_vec : vector<16xf32>
%partial = vector.reduction <add>, %x_sq_vec : vector<16xf32> into f32
%acc_new = arith.addf %acc_iter, %partial : f32
scf.yield %acc_new : f32
}
%mean = arith.divf %acc, %dim : f32
%m_eps = arith.addf %mean, %eps : f32
%inv_rms = math.rsqrt %m_eps : f32
%inv_vec = vector.splat %inv_rms : vector<16xf32>
scf.for %i = %c0 to %c1536 step %c16 {
%x_vec = vector.load %x_memref[%c0, %b, %i] : memref<1x1024x1536xf32>, vector<16xf32>
%g_vec = vector.load %g_memref[%i] : memref<1536xf32>, vector<16xf32>
%x_norm_vec = arith.mulf %x_vec, %inv_vec : vector<16xf32>
%y_vec = arith.mulf %x_norm_vec, %g_vec : vector<16xf32>
vector.store %y_vec, %y_memref[%c0, %b, %i] : memref<1x1024x1536xf32>, vector<16xf32>
}
}

%out = bufferization.to_tensor %y_memref restrict : memref<1x1024x1536xf32> to tensor<1x1024x1536xf32>

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64

// Print timings.
vector.print %time : f64
// CHECK: {{[0-9]+\.[0-9]+}}

return %out : tensor<1x1024x1536xf32>
}

func.func @main() {

%cst_3 = arith.constant 3.0 : f32
%empty_0 = tensor.empty() : tensor<1x1024x1536xf32>
%c0 = linalg.fill ins(%cst_3 : f32) outs(%empty_0 : tensor<1x1024x1536xf32>) -> tensor<1x1024x1536xf32>

%cst_2 = arith.constant 2.0 : f32
%empty_1 = tensor.empty() : tensor<1536xf32>
%c1 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<1536xf32>) -> tensor<1536xf32>

%res = call @kernel(%c0, %c1) : (tensor<1x1024x1536xf32>, tensor<1536xf32>) -> tensor<1x1024x1536xf32>

%tensor_unranked = tensor.cast %res : tensor<1x1024x1536xf32> to tensor<*xf32>
// Print results.
// call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> ()

return
}
7 changes: 7 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,10 @@ class AsStridedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class NormOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType

84 changes: 83 additions & 1 deletion frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from .. import DeviceType
from torch.fx.immutable_collections import immutable_list

classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp}
classicfuse_register = {"transpose_matmul_fusion": TransposeMatmulFusedOp,
"norm_fusion": NormOp
}

# TODO: classify op type for op fusion
# OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType]
Expand Down Expand Up @@ -56,6 +58,39 @@ def classic_fuse_check(graph: Graph):
transpose_matmul_fusion(
graph, op, pattern[0], pattern[1], pattern[2]
)
# === LayerNorm pattern ===
if isinstance(op, PowOp):
# check LayerNorm pattern: pow -> mean -> add -> rsqrt -> mul -> mul
if not op._children:
continue
mean_node = graph.node_table.get(op._children[0], None)
if not isinstance(mean_node, MeanOp):
continue

if not mean_node._children:
continue
add_node = graph.node_table.get(mean_node._children[0], None)
if not isinstance(add_node, AddOp):
continue

if not add_node._children:
continue
rsqrt_node = graph.node_table.get(add_node._children[0], None)
if not isinstance(rsqrt_node, RsqrtOp):
continue

if not rsqrt_node._children:
continue
mul_node = graph.node_table.get(rsqrt_node._children[0], None)
if not isinstance(mul_node, MulOp):
continue

if not mul_node._children:
continue
mul_2_node = graph.node_table.get(mul_node._children[0], None)
if not isinstance(mul_2_node, MulOp):
continue
norm_fusion(graph, op, mean_node, add_node, rsqrt_node, mul_node, mul_2_node, "norm_fusion")


def transpose_matmul_fusion(
Expand Down Expand Up @@ -91,6 +126,53 @@ def transpose_matmul_fusion(
graph.delete_node(target, targets_parent)


def norm_fusion(
graph: Graph,
pow_node: Op,
mean_node: Op,
add_node: Op,
rsqrt_node: Op,
mul_node: Op,
mul_2_node: Op,
pattern: str,
):
"""
Fuse LayerNorm subgraph (Pow -> Mean -> Add -> Rsqrt -> Mul -> Mul)
into one LayerNormFusedOp.
"""
fused_cls = classicfuse_register.get(pattern)

fused_op = fused_cls()

fused_op.name = "NormOp"

graph.displace_node(mul_2_node, fused_op)

fused_op.args.pop(fused_op.args.index(mul_node.name))

fused_op._parents.pop(fused_op._parents.index(mul_node.name))
for parent_name in pow_node._parents:
fused_op._parents.append(parent_name)
fused_op.args.append(parent_name)

mul_node._children.clear()

if graph.check_delete_node(mul_node):
graph.delete_node(mul_node,[graph.node_table.get(mul_node._parents[0], None),graph.node_table.get(mul_node._parents[1], None)])

if graph.check_delete_node(rsqrt_node):
graph.delete_node(rsqrt_node,[add_node])

if graph.check_delete_node(add_node):
graph.delete_node(add_node,[mean_node])

if graph.check_delete_node(mean_node):
graph.delete_node(mean_node,[pow_node])

if graph.check_delete_node(pow_node):
graph.delete_node(pow_node,[graph.node_table.get(mul_node._parents[0], None)])


def apply_classic_fusion(graph: Graph):
"""
Function to fuse some typical operations into one operation and fuse
Expand Down
Loading