diff --git a/examples/BuddyDeepSeekR1/.gitignore b/examples/BuddyDeepSeekR1/.gitignore new file mode 100644 index 0000000000..8d6276ca46 --- /dev/null +++ b/examples/BuddyDeepSeekR1/.gitignore @@ -0,0 +1,2 @@ +*.data +*.mlir \ No newline at end of file diff --git a/examples/BuddyNext/.gitignore b/examples/BuddyNext/.gitignore index 80a243fa81..56216edde9 100644 --- a/examples/BuddyNext/.gitignore +++ b/examples/BuddyNext/.gitignore @@ -1 +1,2 @@ log.* +compare_outputs.sh diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index c7f75e2307..a49563387f 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -381,6 +381,256 @@ next-transpose-vec-manual-run: -shared-libs=${MLIR_RUNNER_UTILS} \ -shared-libs=${MLIR_C_RUNNER_UTILS} +next-transpose-vec-autoopt-run: + @${MLIR_OPT} ./log-transpose-optimized.mlir \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -lower-affine \ + -convert-arith-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-transpose-vec-auto-run: + @${BUDDY_OPT} next-transpose.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize | \ + ${BUDDY_OPT} \ + -genericOp-transpose-vectorization="vector-size=16" \ + -func-bufferize \ + -arith-bufferize \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -lower-affine \ + -convert-arith-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + + +next-reduce-sum-lower: + @${MLIR_OPT} ./next-reduce_sum1.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -o next-log1.mlir + + +next-reduce-sum-run: + @${MLIR_OPT} ./next-reduce_sum.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-reduce-sum1-run: + @${MLIR_OPT} ./next-reduce_sum1.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize \ + -func-bufferize \ + -arith-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-reduce-sum-vec-manual1-run: + @${MLIR_OPT} ./next-reduce_sum-vec-manual1.mlir \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-reduce-sum-vec-auto-run: + @${MLIR_OPT} ./next-reduce_sum.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize | \ + ${BUDDY_OPT} \ + -reduce-sum-vectorization-3d="vector-size=16" \ + -func-bufferize \ + -arith-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-reduce-sum1-vec-auto-run: + @${MLIR_OPT} ./next-reduce_sum1.mlir \ + -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize | \ + ${BUDDY_OPT} \ + -reduce-sum-vectorization-3d="vector-size=16" \ + -func-bufferize \ + -arith-bufferize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + + +next-reduce-sum-vec-manual-run: + @${MLIR_OPT} ./next-reduce_sum-vec-manual.mlir \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} \ + -shared-libs=${MLIR_C_RUNNER_UTILS} + next-embedding-lower: @${MLIR_OPT} ./next-embedding.mlir \ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \ diff --git a/examples/BuddyNext/next-reduce-sum-12x40x40.mlir b/examples/BuddyNext/next-reduce-sum-12x40x40.mlir new file mode 100644 index 0000000000..cc3ccf788d --- /dev/null +++ b/examples/BuddyNext/next-reduce-sum-12x40x40.mlir @@ -0,0 +1,69 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -arith-expand \ +// RUN: -eliminate-empty-tensors \ +// RUN: -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize \ +// RUN: | buddy-opt \ +// RUN: -reduce-sum-vectorization-3d="vector-size=16" \ +// RUN: -func-bufferize \ +// RUN: -arith-bufferize \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 +func.func private @printMemrefF32(%ptr : tensor<*xf32>) + +func.func @kernel(%t0 : tensor<12x40x40xf32>) { + %t_start = call @rtclock() : () -> f64 + + // Perform reduce_sum along axis=2 + %t1 = tosa.reduce_sum %t0 {axis = 2 : i32} : (tensor<12x40x40xf32>) -> tensor<12x40x1xf32> + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %tensor_unranked = tensor.cast %t1 : tensor<12x40x1xf32> to tensor<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [12, 40, 1] strides = [40, 1, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [120{{(, 120)*}}], + + // Print results + call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + // Print timings + vector.print %time : f64 + + return +} + +func.func @main() { + // Create a tensor filled with 3.0 + %c0 = arith.constant dense<3.0> : tensor<12x40x40xf32> + call @kernel(%c0) : (tensor<12x40x40xf32>) -> () + + return +} diff --git a/examples/BuddyNext/next-reduce-sum-1x40x1536.mlir b/examples/BuddyNext/next-reduce-sum-1x40x1536.mlir new file mode 100644 index 0000000000..b8687d1f6f --- /dev/null +++ b/examples/BuddyNext/next-reduce-sum-1x40x1536.mlir @@ -0,0 +1,69 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -arith-expand \ +// RUN: -eliminate-empty-tensors \ +// RUN: -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize \ +// RUN: | buddy-opt \ +// RUN: -reduce-sum-vectorization-3d="vector-size=32" \ +// RUN: -func-bufferize \ +// RUN: -arith-bufferize \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 +func.func private @printMemrefF32(%ptr : tensor<*xf32>) + +func.func @kernel(%t0 : tensor<1x40x1536xf32>) { + %t_start = call @rtclock() : () -> f64 + + // Perform reduce_sum along axis=2 + %t1 = tosa.reduce_sum %t0 {axis = 2 : i32} : (tensor<1x40x1536xf32>) -> tensor<1x40x1xf32> + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %tensor_unranked = tensor.cast %t1 : tensor<1x40x1xf32> to tensor<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 40, 1] strides = [40, 1, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [ + // CHECK-SAME: [4608{{(, 4608)*}}], + + // Print results + call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> () + // Print timings + vector.print %time : f64 + + return +} + +func.func @main() { + // Create a tensor filled with 3.0 + %c0 = arith.constant dense<3.0> : tensor<1x40x1536xf32> + call @kernel(%c0) : (tensor<1x40x1536xf32>) -> () + + return +} diff --git a/examples/BuddyNext/next-reduce-sum-vec-manual-12x40x40.mlir b/examples/BuddyNext/next-reduce-sum-vec-manual-12x40x40.mlir new file mode 100644 index 0000000000..706a811ee1 --- /dev/null +++ b/examples/BuddyNext/next-reduce-sum-vec-manual-12x40x40.mlir @@ -0,0 +1,116 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 +func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } + +// Create a 12x40x40 input tensor +memref.global "private" @A : memref<12x40x40xf32> = dense<3.0> + +func.func @kernel(%a : memref<12x40x40xf32>) { + %t_start = call @rtclock() : () -> f64 + + %b = memref.alloc() : memref<12x40xf32> // Output tensor + + // Initialize constants + %c0 = arith.constant 0.0 : f32 + %c16 = arith.constant 16 : index + %c12 = arith.constant 12 : index + %c40 = arith.constant 40 : index + %c0_idx = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + + // Use outer loop with step 1 and 8x8 blocking + affine.for %i0 = 0 to 12 step 1 { + affine.for %j0 = 0 to 40 step 8 { + // Use 1D parallel processing + affine.parallel (%idx) = (0) to (8) { + // Compute j1 + %j1 = arith.remui %idx, %c8 : index + + %j = affine.apply affine_map<(d0, d1) -> (d0 + d1)> (%j0, %j1) + + // Check if within valid range + %j_in_range = arith.cmpi slt, %j, %c40 : index + + // Only compute within valid range + scf.if %j_in_range { + // Initialize accumulator + %init_acc = arith.constant 0.0 : f32 + + // Vectorize along k dimension with 16 elements + %result_acc = affine.for %k = 0 to 40 step 16 iter_args(%acc = %init_acc) -> f32 { + // Prefetch next data block + %next_k = arith.addi %k, %c16 : index + %next_valid = arith.cmpi slt, %next_k, %c40 : index + scf.if %next_valid { + memref.prefetch %a[%i0, %j, %next_k], read, locality<3>, data : memref<12x40x40xf32> + } + + // Compute current block size and mask + %remaining = arith.subi %c40, %k : index + %vl = arith.minsi %remaining, %c16 : index + %mask = vector.create_mask %vl : vector<16xi1> + + // Vectorized data read + %vec = vector.transfer_read %a[%i0, %j, %k], %c0, %mask : memref<12x40x40xf32>, vector<16xf32> + + // Vector reduction sum + %block_sum = vector.reduction , %vec : vector<16xf32> into f32 + %next_acc = arith.addf %acc, %block_sum : f32 + affine.yield %next_acc : f32 + } + + // Write result + memref.store %result_acc, %b[%i0, %j] : memref<12x40xf32> + } + } + } + } + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + // Print result + %printed_b = memref.cast %b : memref<12x40xf32> to memref<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [12, 40] strides = [40, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [120{{(, 120)*}}] + + call @printMemrefF32(%printed_b) : (memref<*xf32>) -> () + + // Print timings + vector.print %time : f64 + + memref.dealloc %b : memref<12x40xf32> + return +} + +func.func @main() { + %a = memref.get_global @A : memref<12x40x40xf32> + call @kernel(%a) : (memref<12x40x40xf32>) -> () + return +} \ No newline at end of file diff --git a/examples/BuddyNext/next-reduce-sum-vec-manual-1x40x1536.mlir b/examples/BuddyNext/next-reduce-sum-vec-manual-1x40x1536.mlir new file mode 100644 index 0000000000..da6ee6b1ff --- /dev/null +++ b/examples/BuddyNext/next-reduce-sum-vec-manual-1x40x1536.mlir @@ -0,0 +1,112 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner -O0 -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +func.func private @rtclock() -> f64 +func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface } + +// Create a 1x40x1536 input tensor +memref.global "private" @A : memref<1x40x1536xf32> = dense<3.0> + +func.func @kernel(%a : memref<1x40x1536xf32>) { + %t_start = call @rtclock() : () -> f64 + + %b = memref.alloc() : memref<1x40xf32> // Output tensor + + // Initialize constants + %c0 = arith.constant 0.0 : f32 + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %c40 = arith.constant 40 : index + %c1536 = arith.constant 1536 : index + %c0_idx = arith.constant 0 : index + %c8 = arith.constant 8 : index + + // Use blocking and vectorization + affine.for %j0 = 0 to 40 step 8 { + // Process 8 elements at a time + affine.for %j1 = 0 to 8 { + %j = affine.apply affine_map<(d0, d1) -> (d0 + d1)> (%j0, %j1) + + // Check if within valid range + %j_in_range = arith.cmpi slt, %j, %c40 : index + + // Only compute within valid range + scf.if %j_in_range { + // Initialize accumulator + %init_acc = arith.constant 0.0 : f32 + + // Vectorize along k dimension with 32 elements + %result_acc = affine.for %k = 0 to 1536 step 32 iter_args(%acc = %init_acc) -> f32 { + // Prefetch next data block + %next_k = arith.addi %k, %c32 : index + %next_valid = arith.cmpi slt, %next_k, %c1536 : index + scf.if %next_valid { + memref.prefetch %a[%c0_idx, %j, %next_k], read, locality<3>, data : memref<1x40x1536xf32> + } + + // Compute current block size and mask + %remaining = arith.subi %c1536, %k : index + %vl = arith.minsi %remaining, %c32 : index + %mask = vector.create_mask %vl : vector<32xi1> + + // Vectorized data read + %vec = vector.transfer_read %a[%c0_idx, %j, %k], %c0, %mask : memref<1x40x1536xf32>, vector<32xf32> + + // Vector reduction sum + %block_sum = vector.reduction , %vec : vector<32xf32> into f32 + %next_acc = arith.addf %acc, %block_sum : f32 + affine.yield %next_acc : f32 + } + + // Write result + memref.store %result_acc, %b[%c0_idx, %j] : memref<1x40xf32> + } + } + } + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + // Print result + %printed_b = memref.cast %b : memref<1x40xf32> to memref<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [1, 40] strides = [40, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [4608{{(, 4608)*}}] + + call @printMemrefF32(%printed_b) : (memref<*xf32>) -> () + + // Print time + vector.print %time : f64 + + memref.dealloc %b : memref<1x40xf32> + return +} + +func.func @main() { + %a = memref.get_global @A : memref<1x40x1536xf32> + call @kernel(%a) : (memref<1x40x1536xf32>) -> () + return +} + diff --git a/midend/lib/CMakeLists.txt b/midend/lib/CMakeLists.txt index cae54478c3..b8b2d18fd4 100644 --- a/midend/lib/CMakeLists.txt +++ b/midend/lib/CMakeLists.txt @@ -27,6 +27,7 @@ set(LinkedLibs MatMulParallelVectorization SchedulingOnDevices TransposeOptimization + TosaVectorization ) diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index c3c2fa2ddd..1d9e9b63d4 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -16,3 +16,4 @@ add_subdirectory(LowerSche) add_subdirectory(FuncBufferize) add_subdirectory(DepthwiseConvOptimization) add_subdirectory(MLIRGPU) +add_subdirectory(TosaVectorization) diff --git a/midend/lib/Conversion/TosaVectorization/CMakeLists.txt b/midend/lib/Conversion/TosaVectorization/CMakeLists.txt new file mode 100644 index 0000000000..fead1acafb --- /dev/null +++ b/midend/lib/Conversion/TosaVectorization/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_library(TosaVectorization + ReduceSumVectorization3D.cpp + + LINK_LIBS PUBLIC + BuddyUtils +) diff --git a/midend/lib/Conversion/TosaVectorization/ReduceSumVectorization3D.cpp b/midend/lib/Conversion/TosaVectorization/ReduceSumVectorization3D.cpp new file mode 100644 index 0000000000..4fea724255 --- /dev/null +++ b/midend/lib/Conversion/TosaVectorization/ReduceSumVectorization3D.cpp @@ -0,0 +1,324 @@ +//===- ReduceSumVectorization3D.cpp ----------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the reduce sum vectorization for 3D tensors. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeRange.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace vector; +using namespace affine; + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +namespace { + +class ReduceSumVectorization3DPattern : public ConversionPattern { +public: + explicit ReduceSumVectorization3DPattern(MLIRContext *context, + int64_t affineVectorSizeParam) + : ConversionPattern(linalg::ReduceOp::getOperationName(), 1, context), + affineVectorSize(affineVectorSizeParam) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reduceOp = cast(op); + + // Check if it's a 3D to 2D reduction + if (!reduceOp.getOperand(0).getType().isa() || + !reduceOp.getOperand(1).getType().isa()) + return failure(); + + auto inputType = reduceOp.getOperand(0).getType().cast(); + auto outputType = reduceOp.getOperand(1).getType().cast(); + + // Verify dimensions + if (inputType.getRank() != 3 || outputType.getRank() != 2) + return failure(); + + // Get input and output + auto input = reduceOp.getOperand(0); + auto output = reduceOp.getOperand(1); + auto loc = op->getLoc(); + + // Get element type of input tensor + Type elementType = inputType.getElementType(); + + // Define constants + const Value index0 = + rewriter.create(loc, rewriter.getIndexAttr(0)); + const Value indexVecSize = rewriter.create( + loc, rewriter.getIndexAttr(affineVectorSize)); + // const Value c8 = + // rewriter.create(loc, rewriter.getIndexAttr(8)); + // const Value c1 = + // rewriter.create(loc, rewriter.getIndexAttr(1)); + const Value zeroFloat = rewriter.create( + loc, rewriter.getZeroAttr(elementType)); + + // Get input tensor dimensions + Value dim0 = rewriter.create(loc, input, 0); + Value dim1 = rewriter.create(loc, input, 1); + Value dim2 = rewriter.create(loc, input, 2); + + // Outer loop - first dimension + affine::buildAffineLoopNest( + rewriter, loc, {index0}, {dim0}, 1, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value i0 = ivRange.front(); + + // Middle loop - second dimension, step 8 + affine::buildAffineLoopNest( + builder, loc, {index0}, {dim1}, 8, + [&](OpBuilder &builder, Location loc, ValueRange ivRange) { + Value j0 = ivRange.front(); + + // Create parallel op to process 8 blocks + SmallVector reducedValues = + llvm::to_vector<4>(llvm::map_range( + ArrayRef{}, + [](const LoopReduction &red) { return red.value; })); + + AffineParallelOp parallelOp = + builder.create( + loc, ValueRange(reducedValues).getTypes(), ValueRange{}, + ArrayRef{ + builder.getNamedAttr("lowerBoundsGroups", + builder.getI32TensorAttr({1})), + builder.getNamedAttr("upperBoundsGroups", + builder.getI32TensorAttr({1})), + builder.getNamedAttr( + "lowerBoundsMap", + AffineMapAttr::get(AffineMap::get( + 0, 0, {builder.getAffineConstantExpr(0)}, + builder.getContext()))), + builder.getNamedAttr( + "upperBoundsMap", + AffineMapAttr::get(AffineMap::get( + 0, 0, {builder.getAffineConstantExpr(8)}, + builder.getContext()))), + builder.getNamedAttr("steps", + builder.getI64ArrayAttr({1})), + builder.getNamedAttr("reductions", + builder.getArrayAttr({}))}); + + // Create parallel block body + Block *parallelBody = new Block(); + builder.setInsertionPointToStart(parallelBody); + parallelBody->addArgument(builder.getIndexType(), loc); + Value idx = parallelBody->getArguments()[0]; + + // Calculate actual j index + Value j = builder.create(loc, j0, idx); + + // Check if j is within valid range + Value j_in_range = builder.create( + loc, arith::CmpIPredicate::slt, j, dim1); + + builder.create( + loc, j_in_range, [&](OpBuilder &builder, Location loc) { + // Initialize accumulator + Value acc = builder.create( + loc, builder.getZeroAttr(elementType)); + + // Vectorized reduction in the innermost dimension + auto lbMap = AffineMap::get( + /*dimCount=*/0, /*symbolCount=*/0, + builder.getAffineConstantExpr(0), + builder.getContext()); + auto ubMap = AffineMap::get( + /*dimCount=*/1, /*symbolCount=*/0, + builder.getAffineDimExpr(0), builder.getContext()); + + affine::AffineForOp reductionLoop = builder.create< + affine::AffineForOp>( + loc, + /*lbOperands=*/ValueRange{}, + /*lbMap=*/lbMap, + /*ubOperands=*/ValueRange{dim2}, + /*ubMap=*/ubMap, + /*step=*/affineVectorSize, + /*iterArgs=*/ValueRange{acc}, + [&](OpBuilder &builder, Location loc, Value iv, + ValueRange iterArgs) { + Value curr_acc = iterArgs[0]; + + // Prefetch next data block + Value next_k = builder.create( + loc, iv, indexVecSize); + Value next_valid = builder.create( + loc, arith::CmpIPredicate::slt, next_k, dim2); + + builder.create( + loc, next_valid, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, input, ValueRange{i0, j, next_k}, + /*isWrite=*/false, + /*locality=*/3, + /*isDataCache=*/true); + builder.create(loc); + }); + + // Calculate current block size and mask + Value remaining = + builder.create(loc, dim2, iv); + Value vl = builder.create( + loc, remaining, indexVecSize); + Value mask = builder.create( + loc, + VectorType::get({(int64_t)affineVectorSize}, + builder.getI1Type()), + ValueRange{vl}); + + // Vectorized read + auto vecType = VectorType::get( + {(int64_t)affineVectorSize}, elementType); + auto map = AffineMap::get( + /*dimCount=*/3, // 3D输入 + /*symbolCount=*/0, + {rewriter.getAffineDimExpr(2)}, // 只映射k维度 + rewriter.getContext()); + Value vec = builder.create( + loc, vecType, input, ValueRange{i0, j, iv}, map, + zeroFloat, mask, + ArrayAttr::get(builder.getContext(), + {builder.getBoolAttr(false)})); + + // Vector reduction sum + Value block_sum = + builder.create( + loc, vector::CombiningKind::ADD, vec); + + // Update accumulator + Value next_acc = builder.create( + loc, curr_acc, block_sum); + + builder.create(loc, + next_acc); + }); + + // Store result + builder.create( + loc, reductionLoop.getResult(0), output, + ValueRange{i0, j}); + + builder.create(loc); + }); + + builder.create(loc); + parallelOp.getRegion().push_back(parallelBody); + }); + }); + + // Remove original operation + rewriter.eraseOp(op); + return success(); + } + +private: + int64_t affineVectorSize; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ReduceSumVectorizationPass +//===----------------------------------------------------------------------===// + +namespace { +class ReduceSumVectorizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReduceSumVectorizationPass) + + StringRef getArgument() const final { return "reduce-sum-vectorization-3d"; } + + StringRef getDescription() const final { + return "Reduce Sum Vectorization for 3D tensors."; + } + + ReduceSumVectorizationPass() = default; + + ReduceSumVectorizationPass(const ReduceSumVectorizationPass &) {} + + explicit ReduceSumVectorizationPass(int64_t affineVectorSizeParam) { + affineVectorSize = affineVectorSizeParam; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalOp(); + RewritePatternSet patterns(context); + patterns.add(context, affineVectorSize); + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option affineVectorSize{*this, "vector-size", + llvm::cl::desc("Affine Vector size."), + llvm::cl::init(16)}; +}; +} // namespace + +namespace mlir { +namespace buddy { +void registerReduceSumVectorizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir diff --git a/midend/lib/InitAll.cpp b/midend/lib/InitAll.cpp index d6cad2bc1e..f40ffac4ef 100644 --- a/midend/lib/InitAll.cpp +++ b/midend/lib/InitAll.cpp @@ -48,6 +48,7 @@ void registerMatMulParallelVectorizationPass(); void registerMatMulVectorizationPass(); void registerDeviceSchedulePass(); void registerTransposeOptimizationPass(); +void registerReduceSumVectorizationPass(); } // namespace buddy } // namespace mlir @@ -80,4 +81,5 @@ void mlir::buddy::registerAllPasses() { mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerDeviceSchedulePass(); mlir::buddy::registerTransposeOptimizationPass(); + mlir::buddy::registerReduceSumVectorizationPass(); } diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 0abb857fad..bce971dae6 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -28,6 +28,7 @@ target_link_libraries(buddy-opt BatchMatMulOptimization MatMulParallelVectorization TransposeOptimization + TosaVectorization ConvOptimization DepthwiseConvOptimization VectorExp diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index 61a0958c72..9e7035edf8 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -71,6 +71,7 @@ void registerMatMulOptimizePass(); void registerMatMulVectorizationPass(); void registerMatMulParallelVectorizationPass(); void registerTransposeOptimizationPass(); +void registerReduceSumVectorizationPass(); void registerConvOptimizePass(); void registerConvNhwcFhwcOptimizePass(); void registerConvNhwcFhwcTileOptimizePass(); @@ -118,6 +119,7 @@ int main(int argc, char **argv) { mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerMatMulParallelVectorizationPass(); mlir::buddy::registerTransposeOptimizationPass(); + mlir::buddy::registerReduceSumVectorizationPass(); mlir::buddy::registerConvOptimizePass(); mlir::buddy::registerConvNhwcFhwcOptimizePass(); mlir::buddy::registerConvNhwcFhwcTileOptimizePass();