Skip to content

Commit 2f37716

Browse files
committed
[feature] support batch_matmul
1 parent 41248c8 commit 2f37716

File tree

13 files changed

+235
-9
lines changed

13 files changed

+235
-9
lines changed

include/soda/Conversion/KernelsToSODA/LinalgToCGRA.h

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct LogicalResult;
1111
namespace linalg {
1212
// class DotOp;
1313
class MatmulOp;
14+
class BatchMatmulOp;
1415
class Conv2DOp;
1516
class GenericOp;
1617
// TODO: add more ops
@@ -22,6 +23,9 @@ class GenericOp;
2223
/// Convert linalg Matmul op into CGRA.
2324
LogicalResult convertLinalgMatmulToCGRALaunch(linalg::MatmulOp matmulOp);
2425

26+
/// Convert linalg BatchMatmul op into CGRA.
27+
LogicalResult convertLinalgBatchMatmulToCGRALaunch(linalg::BatchMatmulOp batchMatmulOp);
28+
2529
/// Convert linalg Conv op into CGRA.
2630
LogicalResult convertLinalgConvToCGRALaunch(linalg::Conv2DOp convOp);
2731

include/soda/Conversion/KernelsToSODA/LinalgToCGRAPass.h

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Pass;
1818
/// Create a pass that converts linalg ops into soda launch ops.
1919
// std::unique_ptr<OperationPass<func::FuncOp>> createLinalgDotToSODAPass();
2020
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgMatmulToCGRAPass();
21+
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgBatchMatmulToCGRAPass();
2122
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgConvToCGRAPass();
2223
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgGenericToCGRAPass();
2324

include/soda/Conversion/Passes.td

+6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ def ConvertLinalgMatmulToCGRA : Pass<"convert-linalg-matmul-to-cgra", "func::Fun
8787
let dependentDialects = ["soda::SODADialect"];
8888
}
8989

90+
def ConvertLinalgBatchMatmulToCGRA : Pass<"convert-linalg-batch_matmul-to-cgra", "func::FuncOp"> {
91+
let summary = "Offload (nested) linalg::batch_matmul Ops for CGRA acceleration";
92+
let constructor = "mlir::createLinalgBatchMatmulToCGRAPass()";
93+
let dependentDialects = ["soda::SODADialect"];
94+
}
95+
9096
def ConvertLinalgConvToCGRA : Pass<"convert-linalg-conv-to-cgra", "func::FuncOp"> {
9197
let summary = "Offload (nested) linalg::conv Ops for CGRA acceleration";
9298
let constructor = "mlir::createLinalgConvToCGRAPass()";

include/soda/Dialect/SODA/SODAOps.td

+18
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,24 @@ def SODA_MatmulOp : SODA_Op<"cgra.matmul", [NoSideEffect]>,
453453
// let hasVerifier = 1;
454454
}
455455

456+
def SODA_BatchMatmulOp : SODA_Op<"cgra.batch_matmul", [NoSideEffect]>,
457+
Arguments<(ins Variadic<AnyType>:$operands)>, Results<(outs)> {
458+
let summary = "CGRA BatchMatmul operation.";
459+
let description = [{
460+
An soda operation `cgra.batch_matmul` to replace `linalg.batch_matmul`. The operands and
461+
output are the same.
462+
}];
463+
464+
let builders = [OpBuilder<(ins), [{ // empty}]>];
465+
466+
let arguments = (ins Variadic<AnyType>:$operands);
467+
// let arguments = (ins AnyType:$operandA, AnyType:$operandB, AnyType:$operandC);
468+
469+
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
470+
471+
// let hasVerifier = 1;
472+
}
473+
456474
def SODA_FusionOp : SODA_Op<"cgra.fusion", [NoSideEffect]>,
457475
Arguments<(ins Variadic<AnyType>:$operands)>, Results<(outs)> {
458476
let summary = "CGRA fused operation.";

include/soda/Dialect/SODA/Utils.h

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace soda {
2525
class SODAFuncOp;
2626
class LaunchOp;
2727
class MatmulOp;
28+
class BatchMatmulOp;
2829
} // namespace soda
2930

3031
/// Get a soda.func created from outlining the region of a soda.launch op with the

lib/Conversion/KernelsToSODA/LinalgToCGRA.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ struct LinalgToCGRAConverter {
2828
template <class T>
2929
void createMatmulLaunch(T rootOp);
3030
template <class T>
31+
void createBatchMatmulLaunch(T rootOp);
32+
template <class T>
3133
void createGenericLaunch(T rootOp);
3234
};
3335

@@ -84,6 +86,29 @@ void LinalgToCGRAConverter::createMatmulLaunch(T rootLinalgOp) {
8486
}
8587
}
8688

89+
90+
/// Add a CGRA launch operation around the "linalg.batch_matmul" op.
91+
template <class T>
92+
void LinalgToCGRAConverter::createBatchMatmulLaunch(T rootLinalgOp) {
93+
OpBuilder builder(rootLinalgOp.getOperation());
94+
95+
if (dyn_cast<linalg::BatchMatmulOp>(&rootLinalgOp) != nullptr) {
96+
97+
// Create a launch op and move target op into the region
98+
Location loc = rootLinalgOp.getLoc();
99+
auto launchOp = builder.create<soda::LaunchOp>(loc);
100+
builder.setInsertionPointToEnd(&launchOp.body().front());
101+
builder.create<soda::TerminatorOp>(loc);
102+
builder.setInsertionPointToStart(&launchOp.body().front());
103+
104+
Operation* newOp = builder.create<soda::BatchMatmulOp>(loc, rootLinalgOp->getOperands());
105+
106+
auto results = newOp->getResults();
107+
rootLinalgOp->replaceAllUsesWith(results);
108+
rootLinalgOp->erase();
109+
}
110+
}
111+
87112
/// Add a CGRA launch operation around the "linalg.generic" op.
88113
template <class T>
89114
void LinalgToCGRAConverter::createGenericLaunch(T rootLinalgOp) {
@@ -150,6 +175,18 @@ LogicalResult mlir::convertLinalgMatmulToCGRALaunch(linalg::MatmulOp op) {
150175
return ::convertLinalgMatmulToCGRALaunch(op);
151176
}
152177

178+
static LogicalResult convertLinalgBatchMatmulToCGRALaunch(linalg::BatchMatmulOp op) {
179+
180+
LinalgToCGRAConverter converter;
181+
converter.createBatchMatmulLaunch(op);
182+
183+
return success();
184+
}
185+
186+
LogicalResult mlir::convertLinalgBatchMatmulToCGRALaunch(linalg::BatchMatmulOp op) {
187+
return ::convertLinalgBatchMatmulToCGRALaunch(op);
188+
}
189+
153190
static LogicalResult convertLinalgConvToCGRALaunch(linalg::Conv2DOp op) {
154191

155192
LinalgToCGRAConverter converter;

lib/Conversion/KernelsToSODA/LinalgToCGRAPass.cpp

+38-5
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ namespace {
4444
// };
4545

4646
// A pass that traverses top-level matmuls in the function and converts them to
47-
// SODA launch operations. Nested launches are not allowed, so this does not
48-
// walk the function recursively to avoid considering nested matmuls.
47+
// CGRA launch operations.
4948
struct LinalgMatmulMapper: public ConvertLinalgMatmulToCGRABase<LinalgMatmulMapper> {
5049
LinalgMatmulMapper() = default;
5150

@@ -73,8 +72,38 @@ struct LinalgMatmulMapper: public ConvertLinalgMatmulToCGRABase<LinalgMatmulMapp
7372
}
7473
};
7574

75+
76+
// A pass that traverses top-level batch_matmuls in the function and converts them to
77+
// CGRA launch operations.
78+
struct LinalgBatchMatmulMapper: public ConvertLinalgBatchMatmulToCGRABase<LinalgBatchMatmulMapper> {
79+
LinalgBatchMatmulMapper() = default;
80+
81+
void runOnInnerOp(scf::ForOp& forOp) {
82+
for (Operation &innerOp : llvm::make_early_inc_range(forOp.getBody()->getOperations())) {
83+
if (auto innerMatmulOp = dyn_cast<linalg::BatchMatmulOp>(&innerOp)) {
84+
if (failed(convertLinalgBatchMatmulToCGRALaunch(innerMatmulOp))) {
85+
signalPassFailure();
86+
}
87+
} else if (auto forOp = dyn_cast<scf::ForOp>(&innerOp)) {
88+
runOnInnerOp(forOp);
89+
}
90+
}
91+
}
92+
93+
void runOnOperation() override {
94+
for (Operation &op : llvm::make_early_inc_range(getOperation().getOps())) {
95+
if (auto matmulOp = dyn_cast<linalg::BatchMatmulOp>(&op)) {
96+
if (failed(convertLinalgBatchMatmulToCGRALaunch(matmulOp)))
97+
signalPassFailure();
98+
} else if (auto forOp = dyn_cast<scf::ForOp>(&op)) {
99+
runOnInnerOp(forOp);
100+
}
101+
}
102+
}
103+
};
104+
76105
// A pass that traverses top-level conv in the function and converts them to
77-
// SODA launch operations. Nested launches are not allowed, so this does not
106+
// CGRA launch operations. Nested launches are not allowed, so this does not
78107
// walk the function recursively to avoid considering nested conv.
79108
struct LinalgConvMapper: public ConvertLinalgConvToCGRABase<LinalgConvMapper> {
80109
LinalgConvMapper() = default;
@@ -90,8 +119,7 @@ struct LinalgConvMapper: public ConvertLinalgConvToCGRABase<LinalgConvMapper> {
90119
};
91120

92121
// A pass that traverses top-level GenericOps in the function and converts them
93-
// to SODA launch operations. Nested launches are not allowed, so this does not
94-
// walk the function recursively to avoid considering nested GenericOp.
122+
// to CGRA launch operations.
95123
struct LinalgGenericMapper: public ConvertLinalgGenericToCGRABase<LinalgGenericMapper> {
96124
LinalgGenericMapper() = default;
97125

@@ -130,6 +158,11 @@ mlir::createLinalgMatmulToCGRAPass() {
130158
return std::make_unique<LinalgMatmulMapper>();
131159
}
132160

161+
std::unique_ptr<OperationPass<func::FuncOp>>
162+
mlir::createLinalgBatchMatmulToCGRAPass() {
163+
return std::make_unique<LinalgBatchMatmulMapper>();
164+
}
165+
133166
std::unique_ptr<OperationPass<func::FuncOp>>
134167
mlir::createLinalgConvToCGRAPass() {
135168
return std::make_unique<LinalgConvMapper>();

lib/Dialect/SODA/Transforms/HostGeneration.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
#include "soda/Dialect/SODA/Utils.h"
2020

2121
#include <iostream>
22+
#include <map>
23+
#include <string>
2224

25+
using namespace std;
2326
using namespace mlir;
2427

2528
namespace {
@@ -78,6 +81,7 @@ class SODALaunchFuncLowering : public OpRewritePattern<soda::LaunchFuncOp> {
7881
.str();
7982

8083
auto func = module.lookupSymbol<func::FuncOp>(newName);
84+
8185
if (!func) {
8286

8387
// Get callee
@@ -118,7 +122,11 @@ class SODALaunchCGRALowering : public OpRewritePattern<soda::LaunchCGRAOp> {
118122
auto newName = "cgra_" + Twine(op.getKernelName()).str();
119123
auto func = module.lookupSymbol<func::FuncOp>(newName);
120124

121-
if (!func) {
125+
// std::cout<<"found func... "<<newName<<std::endl;
126+
while (func) {
127+
newName += "_";
128+
func = module.lookupSymbol<func::FuncOp>(newName);
129+
}
122130

123131
// Get callee
124132
Operation *kernelFunc = module.lookupSymbol(op.kernelAttr());
@@ -130,12 +138,12 @@ class SODALaunchCGRALowering : public OpRewritePattern<soda::LaunchCGRAOp> {
130138
if (kernelSODAFunction == NULL)
131139
std::cout<<"kernelSODAFunction is NULL"<<std::endl;
132140
FunctionType funcTy = kernelSODAFunction.getFunctionType();
133-
func::FuncOp func = rewriter.create<func::FuncOp>(
141+
func::FuncOp updatedFunc = rewriter.create<func::FuncOp>(
134142
rewriter.getUnknownLoc(), newName, funcTy);
135-
func.setPrivate();
143+
updatedFunc.setPrivate();
136144

137145
rewriter.setInsertionPoint(op);
138-
}
146+
// }
139147

140148
assert(
141149
isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, newName)));

lib/Dialect/SODA/Transforms/KernelOutlining.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ class CGRAKernelOutliningPass
350350
kernelFnName += (*op.body().front().op_begin<soda::FusionOp>())->getAttr("pattern").cast<StringAttr>().str();
351351
} else if (op.body().front().op_begin<soda::MatmulOp>() != op.body().front().op_end<soda::MatmulOp>()) {
352352
kernelFnName = "matmul";
353+
} else if (op.body().front().op_begin<soda::BatchMatmulOp>() != op.body().front().op_end<soda::BatchMatmulOp>()) {
354+
kernelFnName = "batch_matmul";
353355
} else {
354356
kernelFnName = "generic_" + to_string(genericFuncCount);
355357
isGenericFunc = true;

sim/CGRAFunc.h

+37
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,43 @@ void matmul(DataReq& input, DataReq& output, Simulator& sim) {
3636
}
3737
}
3838

39+
void batch_matmul(DataReq& input, DataReq& output, Simulator& sim) {
40+
41+
42+
MemRef inA = input.memRefs[0];
43+
MemRef inB = input.memRefs[1];
44+
MemRef out = output.memRefs[0];
45+
46+
// walk-around for the current bug in MLIR tiling memref
47+
int row = out.offset / (out.sizes[1] * out.strides[1]);
48+
int col = out.offset % (out.sizes[1] * out.strides[1]) / out.sizes[2];
49+
string locKey = to_string(row) + "," + to_string(col);
50+
if (sim.matmulLocCount.find(locKey) == sim.matmulLocCount.end()) {
51+
sim.matmulLocCount.insert({locKey, -1});
52+
}
53+
sim.matmulLocCount[locKey] += 1;
54+
55+
int64_t offsetA = row * inA.sizes[1] * inA.strides[1] + sim.matmulLocCount[locKey] * inA.sizes[2];
56+
int64_t offsetB = col * inB.sizes[2] + sim.matmulLocCount[locKey] * inB.sizes[1] * inB.strides[1];
57+
58+
cout<<"offsetA: "<<offsetA<<"; offsetB: "<<offsetB<<endl;
59+
60+
for (int b=0; b<out.sizes[0]; ++b) {
61+
for (int i=0; i<out.sizes[1]; ++i) {
62+
for (int j=0; j<out.sizes[2]; ++j) {
63+
for (int k=0; k<inB.sizes[1]; ++k) {
64+
out.aligned[b*out.strides[0]+out.offset+i*out.strides[1]+j] += inA.aligned[b*inA.strides[0]+offsetA+i*inA.strides[1]+k] * inB.aligned[b*inB.strides[0]+offsetB+k*inB.strides[1]+j];
65+
}
66+
}
67+
}
68+
}
69+
70+
// reset the locCount
71+
if (sim.matmulLocCount[locKey] == inA.strides[1] / inA.sizes[2] - 1) {
72+
sim.matmulLocCount[locKey] = -1;
73+
}
74+
}
75+
3976
void fusion_add_max_add(DataReq& input, DataReq& output, Simulator& sim) {
4077

4178
MemRef inA = input.memRefs[0];

sim/GlobalRuntime.cpp

+76
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,82 @@ extern "C" void cgra_matmul(float* a_allocated, float* a_aligned, int64_t a_offs
5656
*/
5757
}
5858

59+
extern "C" void cgra_matmul_(float* a_allocated, float* a_aligned, int64_t a_offset, int64_t a_size0, int64_t a_size1, int64_t a_stride0, int64_t a_stride1,
60+
float* b_allocated, float* b_aligned, int64_t b_offset, int64_t b_size0, int64_t b_size1, int64_t b_stride0, int64_t b_stride1,
61+
float* c_allocated, float* c_aligned, int64_t c_offset, int64_t c_size0, int64_t c_size1, int64_t c_stride0, int64_t c_stride1) {
62+
cgra_matmul(a_allocated, a_aligned, a_offset, a_size0, a_size1, a_stride0, a_stride1,
63+
b_allocated, b_aligned, b_offset, b_size0, b_size1, b_stride0, b_stride1,
64+
c_allocated, c_aligned, c_offset, c_size0, c_size1, c_stride0, c_stride1);
65+
66+
}
67+
68+
extern "C" void cgra_batch_matmul(float* a_allocated, float* a_aligned, int64_t a_offset, int64_t a_size0, int64_t a_size1, int64_t a_size2, int64_t a_stride0, int64_t a_stride1, int64_t a_stride2,
69+
float* b_allocated, float* b_aligned, int64_t b_offset, int64_t b_size0, int64_t b_size1, int64_t b_size2, int64_t b_stride0, int64_t b_stride1, int64_t b_stride2,
70+
float* c_allocated, float* c_aligned, int64_t c_offset, int64_t c_size0, int64_t c_size1, int64_t c_size2, int64_t c_stride0, int64_t c_stride1, int64_t c_stride2) {
71+
72+
// prepare inputs
73+
vector<int64_t> a_sizes = {a_size0, a_size1, a_size2};
74+
vector<int64_t> a_strides = {a_stride0, a_stride1, a_stride2};
75+
MemRef memRef0(a_allocated, a_aligned, a_offset, a_sizes, a_strides, 3);
76+
77+
vector<int64_t> b_sizes = {b_size0, b_size1, b_size2};
78+
vector<int64_t> b_strides = {b_stride0, b_stride1, b_stride2};
79+
MemRef memRef1(b_allocated, b_aligned, b_offset, b_sizes, b_strides, 3);
80+
81+
DataReq input;
82+
input.assembleReq(memRef0);
83+
input.assembleReq(memRef1);
84+
85+
// prepare outputs
86+
vector<int64_t> c_sizes = {c_size0, c_size1, c_size2};
87+
vector<int64_t> c_strides = {c_stride0, c_stride1, c_stride2};
88+
MemRef memRef2(c_allocated, c_aligned, c_offset, c_sizes, c_strides, 3);
89+
90+
DataReq output;
91+
output.assembleReq(memRef2);
92+
93+
// issue READ/EXECUTE/WRITE requests for simulation
94+
cgra->issueRD(input);
95+
cgra->issueEX("batch_matmul");
96+
cgra->issueWR(output, true);
97+
}
98+
99+
extern "C" void cgra_batch_matmul_(float* a_allocated, float* a_aligned, int64_t a_offset, int64_t a_size0, int64_t a_size1, int64_t a_size2, int64_t a_stride0, int64_t a_stride1, int64_t a_stride2,
100+
float* b_allocated, float* b_aligned, int64_t b_offset, int64_t b_size0, int64_t b_size1, int64_t b_size2, int64_t b_stride0, int64_t b_stride1, int64_t b_stride2,
101+
float* c_allocated, float* c_aligned, int64_t c_offset, int64_t c_size0, int64_t c_size1, int64_t c_size2, int64_t c_stride0, int64_t c_stride1, int64_t c_stride2) {
102+
103+
// prepare inputs
104+
vector<int64_t> a_sizes = {a_size0, a_size1, a_size2};
105+
vector<int64_t> a_strides = {a_stride0, a_stride1, a_stride2};
106+
MemRef memRef0(a_allocated, a_aligned, a_offset, a_sizes, a_strides, 3);
107+
108+
vector<int64_t> b_sizes = {b_size0, b_size1, b_size2};
109+
vector<int64_t> b_strides = {b_stride0, b_stride1, b_stride2};
110+
MemRef memRef1(b_allocated, b_aligned, b_offset, b_sizes, b_strides, 3);
111+
112+
DataReq input;
113+
input.assembleReq(memRef0);
114+
input.assembleReq(memRef1);
115+
116+
// prepare outputs
117+
vector<int64_t> c_sizes = {c_size0, c_size1, c_size2};
118+
vector<int64_t> c_strides = {c_stride0, c_stride1, c_stride2};
119+
MemRef memRef2(c_allocated, c_aligned, c_offset, c_sizes, c_strides, 3);
120+
121+
DataReq output;
122+
output.assembleReq(memRef2);
123+
124+
// issue READ/EXECUTE/WRITE requests for simulation
125+
cgra->issueRD(input);
126+
cgra->issueEX("batch_matmul");
127+
cgra->issueWR(output, true);
128+
129+
cout<<"calculated output for cgra_batch_matmul() a_alloc: "<<a_allocated<<"; a_aligned: "<<a_aligned<<"; a_offset: "<<a_offset<<"; a_size0: "<<a_size0<<"; a_size1: "<<a_size1<<"; a_size2: "<<a_size2<<"; a_stride0: "<<a_stride0<<"; a_stride1: "<<a_stride1<<"; a_stride2: "<<a_stride2<<endl;
130+
cout<<"calculated output for cgra_batch_matmul() b_alloc: "<<b_allocated<<"; b_aligned: "<<b_aligned<<"; b_offset: "<<b_offset<<"; b_size0: "<<b_size0<<"; b_size1: "<<b_size1<<"; b_size2: "<<b_size2<<"; b_stride0: "<<b_stride0<<"; b_stride1: "<<b_stride1<<"; b_stride2: "<<b_stride2<<endl;
131+
cout<<"calculated output for cgra_batch_matmul() c_alloc: "<<c_allocated<<"; c_aligned: "<<c_aligned<<"; c_offset: "<<c_offset<<"; c_size0: "<<c_size0<<"; c_size1: "<<c_size1<<"; c_size2: "<<c_size2<<"; c_stride0: "<<c_stride0<<"; c_stride1: "<<c_stride1<<"; c_stride2: "<<c_stride2<<endl;
132+
cout<<"check total cycles: "<<cgra->getTotalCycles()<<endl;
133+
}
134+
59135
// This fusion is an example for add+max+add. A robust fusion call should
60136
// be able to figure out what type of operation chain is targeted.
61137
extern "C" void cgra_fusion_add_max_add(float* a_allocated, float* a_aligned, int64_t a_offset, int64_t a_size0, int64_t a_size1, int64_t a_stride0, int64_t a_stride1,

sim/Simulator.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ Simulator::Simulator(bool enableDoubleBuffer) {
1313

1414
void Simulator::registerPredefinedMappings() {
1515
exCycleMap.insert({"matmul", 20});
16+
exCycleMap.insert({"batch_matmul", 20});
1617
exCycleMap.insert({"fusion_add_max_add", 20});
1718
exFuncMap["matmul"] = matmul;
19+
exFuncMap["batch_matmul"] = batch_matmul;
1820
exFuncMap["fusion_add_max_add"] = fusion_add_max_add;
1921
}
2022

0 commit comments

Comments
 (0)