diff --git a/examples/BuddyMatmul/amx-bf16-matmul.mlir b/examples/BuddyMatmul/amx-bf16-matmul.mlir new file mode 100644 index 0000000000..3cf9a751ea --- /dev/null +++ b/examples/BuddyMatmul/amx-bf16-matmul.mlir @@ -0,0 +1,121 @@ +// NOTE: AMX testing is disabled for automated test suites due to system requirements. +// AMX requires arch_prctl system calls for permission setup which cannot be +// performed in JIT environments and may not be available in CI/testing systems. +// +// To test manually, use: make amx-bf16-matmul-aot +// +// RUN: +// +// AMX BF16 MatMul (No-Transpose Interface) +// Requirements: +// - M, N are multiples of 16; K is a multiple of 32. +// - A, B are bf16; C is f32. +// - B must be pre-packed into an "AMX-friendly" layout so that each logical +// block B[k0:k0+32, n0:n0+16] can be loaded by a single amx.tile_load into a +// !amx.tile<16x32xbf16> (i.e., stored in memory as 16 rows x 32 columns bf16). +// This avoids runtime transposes/gathers and ensures optimal AMX loads. +// +// Note: +// The AMX dialect abstracts the hardware orientation; both lhs and rhs tiles for +// amx.tile_mulf use the same tile type !amx.tile<16x32xbf16>, and the reduction +// dimension is K=32 under the hood. + +module { + // External functions for timing and printing + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + + // Real AMX BF16 kernel + func.func @amx_bf16_matmul( + %A: memref, // [M x K], row-major + %Bpack: memref, // B pre-packed for AMX-friendly tile loads + %C: memref, // [M x N], row-major + %M: index, %N: index, %K: index) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + scf.for %m = %c0 to %M step %c16 { + scf.for %n = %c0 to %N step %c16 { + // Initialize C tile to zero once per (m,n) tile. + %zero_tile = amx.tile_zero : !amx.tile<16x16xf32> + amx.tile_store %C[%m, %n], %zero_tile + : memref, !amx.tile<16x16xf32> + + // Reduce across K in chunks of 32. Accumulator is stored/loaded via C tile. + scf.for %k0 = %c0 to %K step %c32 { + // Load A sub-block: 16x32xbf16 from [%m, %k0] + %tA = amx.tile_load %A[%m, %k0] + : memref into !amx.tile<16x32xbf16> + + // Load B sub-block (pre-packed): 16x32xbf16 from [%k0, %n] + %tB = amx.tile_load %Bpack[%k0, %n] + : memref into !amx.tile<16x32xbf16> + + // Load current accumulator from C, perform FMA, then store back. + %tAcc = amx.tile_load %C[%m, %n] + : memref into !amx.tile<16x16xf32> + %tAcc2 = amx.tile_mulf %tA, %tB, %tAcc + : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> + amx.tile_store %C[%m, %n], %tAcc2 + : memref, !amx.tile<16x16xf32> + } + } + } + return + } + + // Performance test with MLIR-level timing: larger matrices for meaningful benchmarks. + func.func @amx_main() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + + // Test with 512x2048x1024 matrices: A[512x1024] × B[1024x2048] = C[512x2048] + // Allocate matrices + %A = memref.alloc(%c512, %c1024) : memref // 512x1024 + %Bpack = memref.alloc(%c1024, %c2048) : memref // 1024x2048 (pre-packed) + %C = memref.alloc(%c512, %c2048) : memref // 512x2048 + + // Initialize A = 1.0bf16, Bpack = 1.0bf16, C = 0.0f32 + %one_bf16 = arith.constant 1.0 : bf16 + %zero_f32 = arith.constant 0.0 : f32 + + linalg.fill ins(%one_bf16 : bf16) outs(%A : memref) + linalg.fill ins(%one_bf16 : bf16) outs(%Bpack : memref) + linalg.fill ins(%zero_f32 : f32) outs(%C : memref) + + // Start timing + %t_start = call @rtclock() : () -> f64 + + // Call AMX kernel + call @amx_bf16_matmul(%A, %Bpack, %C, %c512, %c2048, %c1024) + : (memref, memref, memref, index, index, index) -> () + + // End timing (only measure computation, not printing) + %t_end = call @rtclock() : () -> f64 + %computation_time = arith.subf %t_end, %t_start : f64 + + // Print the entire output matrix + // All elements should be ~1024.0f (since A=1.0, B=1.0, K=1024) + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [512, 2048] strides = [2048, 1] data = + %Cu = memref.cast %C : memref to memref<*xf32> + call @printMemrefF32(%Cu) : (memref<*xf32>) -> () + + // Print timing result (computation only, excluding printing time) + // CHECK: {{[0-9]+\.[0-9]+}} + vector.print %computation_time : f64 + + memref.dealloc %C : memref + memref.dealloc %Bpack : memref + memref.dealloc %A : memref + return + } +} + diff --git a/examples/BuddyMatmul/amx-wrapper.c b/examples/BuddyMatmul/amx-wrapper.c new file mode 100644 index 0000000000..7327f64e5b --- /dev/null +++ b/examples/BuddyMatmul/amx-wrapper.c @@ -0,0 +1,80 @@ +//===- amx-wrapper.c - AMX Permission and MLIR Entry Wrapper --------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file wraps AMX permission setup and calls the MLIR-generated entry +// point. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include +#include + +#define ARCH_REQ_XCOMP_PERM 0x1023 + +extern void _mlir_ciface_amx_main(); + +// #define ARCH_REQ_XCOMP_PERM 0x1023 + +// External functions +// extern void _mlir_ciface_amx_main(); +// extern long long get_time_us(); +// extern void print_timing(long long start_us, long long end_us); + +// MLIR rtclock function implementation +double _mlir_ciface_rtclock() { + struct timeval tv; + gettimeofday(&tv, NULL); + return tv.tv_sec + tv.tv_usec * 1e-6; +} + +int main() { + printf("Checking AMX support...\n"); + + // Try to request permission to use AMX + long ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, 18); // 18 = AMX_TILE + if (ret != 0) { + printf("Warning: Failed to request AMX_TILE permission: %ld (errno: %s)\n", + ret, strerror(errno)); + printf("This might be due to kernel version or configuration.\n"); + printf("Attempting to run anyway...\n"); + } else { + printf("AMX_TILE permission granted\n"); + } + + ret = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, 19); // 19 = AMX_DATA + if (ret != 0) { + printf("Warning: Failed to request AMX_DATA permission: %ld (errno: %s)\n", + ret, strerror(errno)); + printf("Attempting to run anyway...\n"); + } else { + printf("AMX_DATA permission granted\n"); + } + + printf("Starting AMX computation...\n"); + + // Call the MLIR-generated main function + _mlir_ciface_amx_main(); + + printf("AMX computation completed successfully!\n"); + + return 0; +} diff --git a/examples/BuddyMatmul/linalg-bf16-matmul.mlir b/examples/BuddyMatmul/linalg-bf16-matmul.mlir new file mode 100644 index 0000000000..280596021a --- /dev/null +++ b/examples/BuddyMatmul/linalg-bf16-matmul.mlir @@ -0,0 +1,111 @@ +// RUN: buddy-opt %s \ +// RUN: -convert-linalg-to-loops \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-vector-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-cf-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-runner -e linalg_main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// Regular linalg.matmul for performance comparison +module { + // External functions for timing and printing + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(memref<*xf32>) + + // Regular linalg matmul kernel for comparison + func.func @linalg_bf16_matmul( + %A: memref, // [M x K], row-major + %B: memref, // [K x N], row-major + %C: memref, // [M x N], row-major + %M: index, %N: index, %K: index) { + + // Allocate f32 versions for linalg.matmul + %A_f32 = memref.alloc(%M, %K) : memref + %B_f32 = memref.alloc(%K, %N) : memref + + // Copy and convert bf16 to f32 element-wise + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.for %i = %c0 to %M step %c1 { + scf.for %j = %c0 to %K step %c1 { + %val_bf16 = memref.load %A[%i, %j] : memref + %val_f32 = arith.extf %val_bf16 : bf16 to f32 + memref.store %val_f32, %A_f32[%i, %j] : memref + } + } + scf.for %i = %c0 to %K step %c1 { + scf.for %j = %c0 to %N step %c1 { + %val_bf16 = memref.load %B[%i, %j] : memref + %val_f32 = arith.extf %val_bf16 : bf16 to f32 + memref.store %val_f32, %B_f32[%i, %j] : memref + } + } + + linalg.matmul ins(%A_f32, %B_f32 : memref, memref) + outs(%C : memref) + + memref.dealloc %A_f32 : memref + memref.dealloc %B_f32 : memref + return + } + + // Performance test with MLIR-level timing: same size as AMX version for comparison. + func.func @linalg_main() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + %c1024 = arith.constant 1024 : index + %c2048 = arith.constant 2048 : index + + // Test with 512x2048x1024 matrices: A[512x1024] × B[1024x2048] = C[512x2048] (same as AMX version) + // Allocate matrices + %A = memref.alloc(%c512, %c1024) : memref // 512x1024 + %B = memref.alloc(%c1024, %c2048) : memref // 1024x2048 + %C = memref.alloc(%c512, %c2048) : memref // 512x2048 + + // Initialize A = 1.0bf16, B = 1.0bf16, C = 0.0f32 + %one_bf16 = arith.constant 1.0 : bf16 + %zero_f32 = arith.constant 0.0 : f32 + + linalg.fill ins(%one_bf16 : bf16) outs(%A : memref) + linalg.fill ins(%one_bf16 : bf16) outs(%B : memref) + linalg.fill ins(%zero_f32 : f32) outs(%C : memref) + + // Start timing + %t_start = call @rtclock() : () -> f64 + + // Call linalg kernel + call @linalg_bf16_matmul(%A, %B, %C, %c512, %c2048, %c1024) + : (memref, memref, memref, index, index, index) -> () + + // End timing (only measure computation, not printing) + %t_end = call @rtclock() : () -> f64 + %computation_time = arith.subf %t_end, %t_start : f64 + + // Print the entire output matrix + // All elements should be ~1024.0f (since A=1.0, B=1.0, K=1024) + // CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [512, 2048] strides = [2048, 1] data = + %Cu = memref.cast %C : memref to memref<*xf32> + call @printMemrefF32(%Cu) : (memref<*xf32>) -> () + + // Print timing result (computation only, excluding printing time) + // CHECK: {{[0-9]+\.[0-9]+}} + vector.print %computation_time : f64 + + memref.dealloc %C : memref + memref.dealloc %B : memref + memref.dealloc %A : memref + return + } +} diff --git a/examples/BuddyMatmul/makefile b/examples/BuddyMatmul/makefile index f5ccd10747..d99ccc0dfa 100644 --- a/examples/BuddyMatmul/makefile +++ b/examples/BuddyMatmul/makefile @@ -179,3 +179,70 @@ batchmatmul-vectorization-run: -reconcile-unrealized-casts | \ ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +# AMX BF16 test +# AMX requires kernel-level permissions that cannot be obtained in JIT environments +amx-bf16-matmul-run: + @echo "AMX JIT version not supported due to permission issues." + @echo "Running simplified AOT version instead..." + @${BUDDY_OPT} ./amx-bf16-matmul.mlir \ + --llvm-request-c-wrappers \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm="enable-amx" \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir -o amx-quick.ll + @${LLC} -mtriple ${MTRIPLE} -mattr=+amx-bf16,+amx-tile,+amx-int8 ${OPT_FLAG} -filetype=obj amx-quick.ll -o amx-quick.o + @gcc -c amx-wrapper.c -o amx-wrapper-quick.o + @g++ amx-wrapper-quick.o amx-quick.o -o amx-quick -L${LLVM_BUILD_DIR}/lib -lmlir_runner_utils -lmlir_c_runner_utils -lpthread -Wl,-rpath,${LLVM_BUILD_DIR}/lib + @echo "Running AMX quick test..." + @LD_LIBRARY_PATH=${LLVM_BUILD_DIR}/lib ./amx-quick + @rm -f amx-quick.ll amx-quick.o amx-wrapper-quick.o amx-quick + +# Check what AMX instructions are generated +amx-bf16-matmul-asm: + @${BUDDY_OPT} ./amx-bf16-matmul.mlir \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-scf-to-cf \ + -convert-vector-to-llvm="enable-amx" \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_TRANSLATE} --mlir-to-llvmir | \ + ${LLC} -mtriple ${MTRIPLE} -mattr=+amx-bf16,+amx-tile,+amx-int8 ${OPT_FLAG} -filetype=asm -o amx-asm.s + @echo "Generated assembly in amx-asm.s" + +linalg-bf16-matmul-run: + @echo "Running linalg JIT version with MLIR-level timing..." + @${BUDDY_OPT} ./linalg-bf16-matmul.mlir \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize="bufferize-function-boundaries" \ + -matmul-parallel-vectorization-optimize \ + -batchmatmul-optimize \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -lower-affine \ + -memref-expand \ + -arith-expand \ + -convert-vector-to-llvm \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e linalg_main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}