Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[submodule "llvm"]
path = llvm
url = https://github.com/llvm/llvm-project.git
branch = main
branch = llvmorg-21.1.0-rc2
shallow = true
[submodule "thirdparty/mimalloc"]
path = thirdparty/mimalloc
Expand Down
4 changes: 2 additions & 2 deletions backend/llvm/lib/CodeGen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ add_llvm_component_library(LLVMBuddyCodeGen
${LLVM_CodeGen_DIR}/DwarfEHPrepare.cpp
${LLVM_CodeGen_DIR}/EarlyIfConversion.cpp
${LLVM_CodeGen_DIR}/EdgeBundles.cpp
${LLVM_CodeGen_DIR}/EHContGuardCatchret.cpp
${LLVM_CodeGen_DIR}/EHContGuardTargets.cpp
${LLVM_CodeGen_DIR}/ExecutionDomainFix.cpp
${LLVM_CodeGen_DIR}/ExpandLargeDivRem.cpp
${LLVM_CodeGen_DIR}/ExpandLargeFpConvert.cpp
${LLVM_CodeGen_DIR}/ExpandFp.cpp
${LLVM_CodeGen_DIR}/ExpandMemCmp.cpp
${LLVM_CodeGen_DIR}/ExpandPostRAPseudos.cpp
${LLVM_CodeGen_DIR}/ExpandReductions.cpp
Expand Down
1 change: 0 additions & 1 deletion backend/llvm/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ add_llvm_component_library(LLVMBuddyCore
${LLVM_IR_DIR}/User.cpp
${LLVM_IR_DIR}/Value.cpp
${LLVM_IR_DIR}/ValueSymbolTable.cpp
${LLVM_IR_DIR}/VectorBuilder.cpp
${LLVM_IR_DIR}/VectorTypeUtils.cpp
${LLVM_IR_DIR}/Verifier.cpp
${LLVM_IR_DIR}/VFABIDemangler.cpp
Expand Down
1 change: 0 additions & 1 deletion backend/llvm/lib/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ set(LLVM_Target_DIR ${LLVM_MAIN_SRC_DIR}/lib/Target)

add_llvm_component_library(LLVMBuddyTarget
${LLVM_Target_DIR}/Target.cpp
${LLVM_Target_DIR}/TargetIntrinsicInfo.cpp
${LLVM_Target_DIR}/TargetLoweringObjectFile.cpp
${LLVM_Target_DIR}/TargetMachine.cpp
${LLVM_Target_DIR}/TargetMachineC.cpp
Expand Down
57 changes: 23 additions & 34 deletions backend/llvm/lib/Target/RISCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ tablegen(LLVM RISCVGenRegisterInfo.inc -gen-register-info)
tablegen(LLVM RISCVGenSearchableTables.inc -gen-searchable-tables)
tablegen(LLVM RISCVGenSubtargetInfo.inc -gen-subtarget)
tablegen(LLVM RISCVGenExegesis.inc -gen-exegesis)
tablegen(LLVM RISCVGenSDNodeInfo.inc -gen-sd-node-info)

set(LLVM_TARGET_DEFINITIONS ${LLVM_TARGET_RISCV_DIR}/RISCVGISel.td)
tablegen(LLVM RISCVGenGlobalISel.inc -gen-global-isel)
Expand Down Expand Up @@ -189,42 +190,30 @@ endforeach()

# Build LLVMBuddyRISCVDesc target.
add_llvm_component_library(LLVMBuddyRISCVDesc
${CMAKE_CURRENT_BINARY_DIR}/RISCVAsmPrinter.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVCallingConv.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVCodeGenPrepare.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVConstantPoolValue.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVDeadRegisterDefinitions.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVExpandAtomicPseudoInsts.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVExpandPseudoInsts.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVFrameLowering.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVGatherScatterLowering.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVISelDAGToDAG.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVLandingPadSetup.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMergeBaseOffset.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMoveMerger.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVOptWInstrs.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVPostRAExpandPseudoInsts.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVPushPopOptimizer.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVRedundantCopyElimination.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVRegisterInfo.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVSelectionDAGInfo.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVSubtarget.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetMachine.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetObjectFile.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetTransformInfo.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVVectorMaskDAGMutation.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVVectorPeephole.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVZacasABIFix.cpp

${CMAKE_CURRENT_BINARY_DIR}/RISCVAsmBackend.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVBaseInfo.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVELFObjectWriter.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVInstPrinter.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCAsmInfo.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCCodeEmitter.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCExpr.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCObjectFileInfo.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCTargetDesc.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVMatInt.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetStreamer.cpp
${CMAKE_CURRENT_BINARY_DIR}/RISCVELFStreamer.cpp

# Add *.h files to track the copies above.
${CMAKE_CURRENT_BINARY_DIR}/RISCVCallingConv.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVInstrInfo.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVRegisterInfo.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVSelectionDAGInfo.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVSubtarget.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetMachine.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetObjectFile.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVAsmBackend.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVBaseInfo.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVELFStreamer.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVFixupKinds.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVInstPrinter.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVMatInt.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCAsmInfo.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCObjectFileInfo.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVMCTargetDesc.h
${CMAKE_CURRENT_BINARY_DIR}/RISCVTargetStreamer.h


LINK_COMPONENTS
Expand Down
24 changes: 20 additions & 4 deletions backend/llvm/lib/Target/RISCV/RISCVInstrInfoBuddyExt.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@
// This is the instruction information file of RISC-V buddy extension.
//
//===----------------------------------------------------------------------===//
//
// IMPORTANT: Instruction Encoding Fix (LLVM 21.0.0 RC1 Upgrade)
// ------------------------------------------------------------
// The CONFIG_* instructions originally had identical encodings (func7=0b0000000),
// which caused decoding conflicts in LLVM's TableGen disassembler.
//
// Fixed encodings:
// - CONFIG_LD: func7 = 0b0010110 (was 0b0000000)
// - CONFIG_ST: func7 = 0b0010111 (was 0b0000000)
// - CONFIG_EX: func7 = 0b0011000 (was 0b0000000)
// - CONFIG_NORM: func7 = 0b0011001 (was 0b0000000)
//
// All instructions retain func3=0b011 and use OPC_CUSTOM_3.
// This ensures each instruction has a unique bit pattern for proper decoding.
//
//===----------------------------------------------------------------------===//

include "llvm/IR/IntrinsicsRISCVBuddyExt.td"

Expand Down Expand Up @@ -60,25 +76,25 @@ def FLUSH : RVInstR<0b0000111, 0b011, OPC_CUSTOM_3, (outs),
}

let Predicates = [HasBuddyExt] in
def CONFIG_LD : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3, (outs),
def CONFIG_LD : RVInstR<0b0010110, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "config_ld", "$rs1, $rs2"> {
let rd = 0;
}

let Predicates = [HasBuddyExt] in
def CONFIG_ST : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3, (outs),
def CONFIG_ST : RVInstR<0b0010111, 0b011, OPC_CUSTOM_3, (outs),
(ins GPR:$rs1, GPR:$rs2), "config_st", "$rs1, $rs2"> {
let rd = 0;
}

let Predicates = [HasBuddyExt] in
def CONFIG_EX : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs),
def CONFIG_EX : RVInstR<0b0011000, 0b011, OPC_CUSTOM_3,(outs),
(ins GPR:$rs1, GPR:$rs2), "config_ex", "$rs1, $rs2"> {
let rd = 0;
}

let Predicates = [HasBuddyExt] in
def CONFIG_NORM : RVInstR<0b0000000, 0b011, OPC_CUSTOM_3,(outs),
def CONFIG_NORM : RVInstR<0b0011001, 0b011, OPC_CUSTOM_3,(outs),
(ins GPR:$rs1, GPR:$rs2), "config_norm", "$rs1, $rs2"> {
let rd = 0;
}
Expand Down
1 change: 0 additions & 1 deletion backend/llvm/lib/Transforms/Vectorize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ add_llvm_component_library(LLVMBuddyVectorize
${LLVM_Vectorize_DIR}/Vectorize.cpp
${LLVM_Vectorize_DIR}/VectorCombine.cpp
${LLVM_Vectorize_DIR}/VPlan.cpp
${LLVM_Vectorize_DIR}/VPlanHCFGBuilder.cpp
${LLVM_Vectorize_DIR}/VPlanRecipes.cpp
${LLVM_Vectorize_DIR}/VPlanSLP.cpp
${LLVM_Vectorize_DIR}/VPlanTransforms.cpp
Expand Down
2 changes: 1 addition & 1 deletion examples/BuddyConvolution/conv2d-nhwc-fhwc-opt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// RUN: --one-shot-bufferize="bufferize-function-boundaries" \
// RUN: -convert-scf-to-cf \
// RUN: -convert-cf-to-llvm \
// RUN: -convert-vector-to-llvm \
// RUN: -convert-vector-to-llvm -convert-ub-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-func-to-llvm \
Expand Down
33 changes: 13 additions & 20 deletions examples/BuddyNext/next-attention-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ module {
memref.global "private" constant @__constant_32x128x40xf32 : memref<32x128x40xf32> = dense<2.000000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32x40x128xf32 : memref<32x40x128xf32> = dense<3.000000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_1x32x40x40xf32 : memref<1x32x40x40xf32> = dense<11.3137083> {alignment = 64 : i64}
func.func @kenerl(%arg0: tensor<32x40x128xf32>, %arg1: tensor<32x128x40xf32>, %arg2: tensor<1x1x40x40xf32>, %arg3: tensor<1x32x40x128xf32>) {
func.func @kenerl(%arg0: memref<32x40x128xf32>, %arg1: memref<32x128x40xf32>, %arg2: memref<1x1x40x40xf32>, %arg3: memref<1x32x40x128xf32>) {
%t_start = call @rtclock() : () -> f64
%cst = arith.constant 0.0883883461 : f32
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_2 = arith.constant -3.40282347E+38 : f32
%0 = bufferization.to_memref %arg3 : tensor<1x32x40x128xf32> to memref<1x32x40x128xf32, strided<[?, ?, ?, ?], offset: ?>>
%1 = bufferization.to_memref %arg2 : tensor<1x1x40x40xf32> to memref<1x1x40x40xf32, strided<[?, ?, ?, ?], offset: ?>>
%2 = bufferization.to_memref %arg1 : tensor<32x128x40xf32> to memref<32x128x40xf32, strided<[?, ?, ?], offset: ?>>
%3 = bufferization.to_memref %arg0 : tensor<32x40x128xf32> to memref<32x40x128xf32, strided<[?, ?, ?], offset: ?>>

// MatMul
// %0 = tosa.matmul %t0, %t1 : (tensor<32x40x128xf32>, tensor<32x128x40xf32>) -> tensor<32x40x40xf32>
// Initialize MatMul Output.
Expand All @@ -57,8 +52,8 @@ module {
affine.for %arg5 = 0 to 40 {
affine.for %arg6 = 0 to 40 {
affine.for %arg7 = 0 to 128 {
%5 = affine.load %3[%arg4, %arg5, %arg7] : memref<32x40x128xf32, strided<[?, ?, ?], offset: ?>>
%6 = affine.load %2[%arg4, %arg7, %arg6] : memref<32x128x40xf32, strided<[?, ?, ?], offset: ?>>
%5 = affine.load %arg0[%arg4, %arg5, %arg7] : memref<32x40x128xf32>
%6 = affine.load %arg1[%arg4, %arg7, %arg6] : memref<32x128x40xf32>
%7 = affine.load %alloc[%arg4, %arg5, %arg6] : memref<32x40x40xf32>
%8 = arith.mulf %5, %6 : f32
%9 = arith.addf %7, %8 : f32
Expand All @@ -72,7 +67,8 @@ module {
// %1 = tosa.reshape %0 {new_shape = array<i64: 1, 32, 40, 40>} : (tensor<32x40x40xf32>) -> tensor<1x32x40x40xf32>
// %2 = "tosa.const"() <{value = dense<11.3137083> : tensor<1x32x40x40xf32>}> : () -> tensor<1x32x40x40xf32>
// %3 = tosa.reciprocal %2 : (tensor<1x32x40x40xf32>) -> tensor<1x32x40x40xf32>
// %4 = tosa.mul %1, %3 {shift = 0 : i8} : (tensor<1x32x40x40xf32>, tensor<1x32x40x40xf32>) -> tensor<1x32x40x40xf32>
// %shift_4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// %4 = tosa.mul %1, %3, %shift_4 : (tensor<1x32x40x40xf32>, tensor<1x32x40x40xf32>, tensor<1xi8>) -> tensor<1x32x40x40xf32>
// %5 = tosa.add %4, %t2 : (tensor<1x32x40x40xf32>, tensor<1x1x40x40xf32>) -> tensor<1x32x40x40xf32>
// %6 = tosa.reduce_max %5 {axis = 3 : i32} : (tensor<1x32x40x40xf32>) -> tensor<1x32x40x1xf32>
%expand_shape = memref.expand_shape %alloc [[0, 1], [2], [3]] output_shape [1, 32, 40, 40]: memref<32x40x40xf32> into memref<1x32x40x40xf32>
Expand All @@ -93,7 +89,7 @@ module {
// Fusion point: reshape + constant + reciprocal -> %cst
%6 = arith.mulf %5, %cst : f32
// Fusion point: addition
%7 = affine.load %1[%c0, %c0, %arg6, %arg7] : memref<1x1x40x40xf32, strided<[?, ?, ?, ?], offset: ?>>
%7 = affine.load %arg2[%c0, %c0, %arg6, %arg7] : memref<1x1x40x40xf32>
%8 = arith.addf %6, %7 : f32
// Fusion point: reduce max
%9 = affine.load %alloc_6[%arg4, %arg5, %arg6] : memref<1x32x40xf32>
Expand Down Expand Up @@ -142,7 +138,8 @@ module {

// Fusion: Reciprocal + Multiplication
// %10 = tosa.reciprocal %9 : (tensor<1x32x40x1xf32>) -> tensor<1x32x40x1xf32>
// %11 = tosa.mul %8, %10 {shift = 0 : i8} : (tensor<1x32x40x40xf32>, tensor<1x32x40x1xf32>) -> tensor<1x32x40x40xf32>
// %shift_11 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// %11 = tosa.mul %8, %10, %shift_11 : (tensor<1x32x40x40xf32>, tensor<1x32x40x1xf32>, tensor<1xi8>) -> tensor<1x32x40x40xf32>
%expand_shape_11 = memref.expand_shape %alloc_10 [[0], [1], [2, 3]] output_shape [1, 32, 40, 1]: memref<1x32x40xf32> into memref<1x32x40x1xf32>
%alloc_13 = memref.alloc() {alignment = 64 : i64} : memref<1x32x40x40xf32>
affine.for %arg4 = 0 to 1 {
Expand Down Expand Up @@ -171,7 +168,7 @@ module {
%collapse_shape = memref.collapse_shape %alloc_13 [[0, 1], [2], [3]] : memref<1x32x40x40xf32> into memref<32x40x40xf32>
%alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<1x32x40x128xf32>
// SSA value %0 is from %arg3
memref.copy %0, %alloc_14 : memref<1x32x40x128xf32, strided<[?, ?, ?, ?], offset: ?>> to memref<1x32x40x128xf32>
memref.copy %arg3, %alloc_14 : memref<1x32x40x128xf32> to memref<1x32x40x128xf32>
%collapse_shape_15 = memref.collapse_shape %alloc_14 [[0, 1], [2], [3]] : memref<1x32x40x128xf32> into memref<32x40x128xf32>

// MatMul
Expand Down Expand Up @@ -222,14 +219,10 @@ module {
}
func.func @main() {
%0 = memref.get_global @__constant_32x40x128xf32 : memref<32x40x128xf32>
%1 = bufferization.to_tensor %0 restrict: memref<32x40x128xf32> to tensor<32x40x128xf32>
%2 = memref.get_global @__constant_32x128x40xf32 : memref<32x128x40xf32>
%3 = bufferization.to_tensor %2 restrict: memref<32x128x40xf32> to tensor<32x128x40xf32>
%4 = memref.get_global @__constant_1x1x40x40xf32 : memref<1x1x40x40xf32>
%5 = bufferization.to_tensor %4 restrict: memref<1x1x40x40xf32> to tensor<1x1x40x40xf32>
%6 = memref.get_global @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32>
%7 = bufferization.to_tensor %6 restrict: memref<1x32x40x128xf32> to tensor<1x32x40x128xf32>
call @kenerl(%1, %3, %5, %7) : (tensor<32x40x128xf32>, tensor<32x128x40xf32>, tensor<1x1x40x40xf32>, tensor<1x32x40x128xf32>) -> ()
%1 = memref.get_global @__constant_32x128x40xf32 : memref<32x128x40xf32>
%2 = memref.get_global @__constant_1x1x40x40xf32 : memref<1x1x40x40xf32>
%3 = memref.get_global @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32>
call @kenerl(%0, %1, %2, %3) : (memref<32x40x128xf32>, memref<32x128x40xf32>, memref<1x1x40x40xf32>, memref<1x32x40x128xf32>) -> ()
return
}
func.func private @printMemrefF32(memref<*xf32>)
Expand Down
33 changes: 13 additions & 20 deletions examples/BuddyNext/next-attention-loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ module {
memref.global "private" constant @__constant_32x128x40xf32 : memref<32x128x40xf32> = dense<2.000000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_32x40x128xf32 : memref<32x40x128xf32> = dense<3.000000e+00> {alignment = 64 : i64}
memref.global "private" constant @__constant_1x32x40x40xf32 : memref<1x32x40x40xf32> = dense<11.3137083> {alignment = 64 : i64}
func.func @kenerl(%arg0: tensor<32x40x128xf32>, %arg1: tensor<32x128x40xf32>, %arg2: tensor<1x1x40x40xf32>, %arg3: tensor<1x32x40x128xf32>) {
func.func @kenerl(%arg0: memref<32x40x128xf32>, %arg1: memref<32x128x40xf32>, %arg2: memref<1x1x40x40xf32>, %arg3: memref<1x32x40x128xf32>) {
%t_start = call @rtclock() : () -> f64
%cst = arith.constant 0.0883883461 : f32
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%cst_2 = arith.constant -3.40282347E+38 : f32
%0 = bufferization.to_memref %arg3 : tensor<1x32x40x128xf32> to memref<1x32x40x128xf32, strided<[?, ?, ?, ?], offset: ?>>
%1 = bufferization.to_memref %arg2 : tensor<1x1x40x40xf32> to memref<1x1x40x40xf32, strided<[?, ?, ?, ?], offset: ?>>
%2 = bufferization.to_memref %arg1 : tensor<32x128x40xf32> to memref<32x128x40xf32, strided<[?, ?, ?], offset: ?>>
%3 = bufferization.to_memref %arg0 : tensor<32x40x128xf32> to memref<32x40x128xf32, strided<[?, ?, ?], offset: ?>>

// MatMul
// %0 = tosa.matmul %t0, %t1 : (tensor<32x40x128xf32>, tensor<32x128x40xf32>) -> tensor<32x40x40xf32>
// Initialize MatMul Output.
Expand All @@ -57,8 +52,8 @@ module {
affine.for %arg5 = 0 to 40 {
affine.for %arg6 = 0 to 40 {
affine.for %arg7 = 0 to 128 {
%5 = affine.load %3[%arg4, %arg5, %arg7] : memref<32x40x128xf32, strided<[?, ?, ?], offset: ?>>
%6 = affine.load %2[%arg4, %arg7, %arg6] : memref<32x128x40xf32, strided<[?, ?, ?], offset: ?>>
%5 = affine.load %arg0[%arg4, %arg5, %arg7] : memref<32x40x128xf32>
%6 = affine.load %arg1[%arg4, %arg7, %arg6] : memref<32x128x40xf32>
%7 = affine.load %alloc[%arg4, %arg5, %arg6] : memref<32x40x40xf32>
%8 = arith.mulf %5, %6 : f32
%9 = arith.addf %7, %8 : f32
Expand All @@ -85,7 +80,8 @@ module {
}

// Multiplication
// %4 = tosa.mul %1, %3 {shift = 0 : i8} : (tensor<1x32x40x40xf32>, tensor<1x32x40x40xf32>) -> tensor<1x32x40x40xf32>
// %shift_4 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// %4 = tosa.mul %1, %3, %shift_4 : (tensor<1x32x40x40xf32>, tensor<1x32x40x40xf32>, tensor<1xi8>) -> tensor<1x32x40x40xf32>
%alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<1x32x40x40xf32>
affine.for %arg4 = 0 to 1 {
affine.for %arg5 = 0 to 32 {
Expand All @@ -108,7 +104,7 @@ module {
affine.for %arg6 = 0 to 40 {
affine.for %arg7 = 0 to 40 {
%5 = affine.load %alloc_4[%c0, %arg5, %arg6, %arg7] : memref<1x32x40x40xf32>
%6 = affine.load %1[%c0, %c0, %arg6, %arg7] : memref<1x1x40x40xf32, strided<[?, ?, ?, ?], offset: ?>>
%6 = affine.load %arg2[%c0, %c0, %arg6, %arg7] : memref<1x1x40x40xf32>
%7 = arith.addf %5, %6 : f32
affine.store %7, %alloc_5[%arg4, %arg5, %arg6, %arg7] : memref<1x32x40x40xf32>
}
Expand Down Expand Up @@ -220,7 +216,8 @@ module {
}

// Multiplication
// %11 = tosa.mul %8, %10 {shift = 0 : i8} : (tensor<1x32x40x40xf32>, tensor<1x32x40x1xf32>) -> tensor<1x32x40x40xf32>
// %shift_11 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// %11 = tosa.mul %8, %10, %shift_11 : (tensor<1x32x40x40xf32>, tensor<1x32x40x1xf32>, tensor<1xi8>) -> tensor<1x32x40x40xf32>
%alloc_13 = memref.alloc() {alignment = 64 : i64} : memref<1x32x40x40xf32>
affine.for %arg4 = 0 to 1 {
affine.for %arg5 = 0 to 32 {
Expand All @@ -245,7 +242,7 @@ module {
%collapse_shape = memref.collapse_shape %alloc_13 [[0, 1], [2], [3]] : memref<1x32x40x40xf32> into memref<32x40x40xf32>
%alloc_14 = memref.alloc() {alignment = 64 : i64} : memref<1x32x40x128xf32>
// SSA value %0 is from %arg3
memref.copy %0, %alloc_14 : memref<1x32x40x128xf32, strided<[?, ?, ?, ?], offset: ?>> to memref<1x32x40x128xf32>
memref.copy %arg3, %alloc_14 : memref<1x32x40x128xf32> to memref<1x32x40x128xf32>
%collapse_shape_15 = memref.collapse_shape %alloc_14 [[0, 1], [2], [3]] : memref<1x32x40x128xf32> into memref<32x40x128xf32>

// MatMul
Expand Down Expand Up @@ -297,14 +294,10 @@ module {
}
func.func @main() {
%0 = memref.get_global @__constant_32x40x128xf32 : memref<32x40x128xf32>
%1 = bufferization.to_tensor %0 restrict: memref<32x40x128xf32> to tensor<32x40x128xf32>
%2 = memref.get_global @__constant_32x128x40xf32 : memref<32x128x40xf32>
%3 = bufferization.to_tensor %2 restrict: memref<32x128x40xf32> to tensor<32x128x40xf32>
%4 = memref.get_global @__constant_1x1x40x40xf32 : memref<1x1x40x40xf32>
%5 = bufferization.to_tensor %4 restrict: memref<1x1x40x40xf32> to tensor<1x1x40x40xf32>
%6 = memref.get_global @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32>
%7 = bufferization.to_tensor %6 restrict: memref<1x32x40x128xf32> to tensor<1x32x40x128xf32>
call @kenerl(%1, %3, %5, %7) : (tensor<32x40x128xf32>, tensor<32x128x40xf32>, tensor<1x1x40x40xf32>, tensor<1x32x40x128xf32>) -> ()
%1 = memref.get_global @__constant_32x128x40xf32 : memref<32x128x40xf32>
%2 = memref.get_global @__constant_1x1x40x40xf32 : memref<1x1x40x40xf32>
%3 = memref.get_global @__constant_1x32x40x128xf32 : memref<1x32x40x128xf32>
call @kenerl(%0, %1, %2, %3) : (memref<32x40x128xf32>, memref<32x128x40xf32>, memref<1x1x40x40xf32>, memref<1x32x40x128xf32>) -> ()
return
}
func.func private @printMemrefF32(memref<*xf32>)
Expand Down
Loading