Skip to content

Commit 466550d

Browse files
AdvancedCompilershuailong616zhzhcookie
authored
[TTIR] add expression restructuring pass (#4)
* add expression restructuring pass * add test for expression restructuring pass * [CI/CD] Add operators test to CI --------- Co-authored-by: shuailong616 <[email protected]> Co-authored-by: zhengyang <[email protected]>
1 parent c1c8ba8 commit 466550d

File tree

8 files changed

+278
-3
lines changed

8 files changed

+278
-3
lines changed

.github/workflows/nv-build-and-test.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ jobs:
2020
- name: Checkout code
2121
uses: actions/checkout@v4
2222

23-
- name: FlagTree Build on NVIDIA-A100
23+
- name: FlagTree Build
2424
shell: bash
2525
run: |
2626
source ~/env.sh
2727
cd python
28-
MAX_JOBS=20 pip3.11 install . --no-build-isolation
28+
MAX_JOBS=32 pip3.11 install . --no-build-isolation
2929
30-
- name: FlagTree Test on NVIDIA-A100
30+
- name: FlagTree Test
3131
shell: bash
3232
run: |
3333
pytest -s python/test/unit
34+
pytest -s python/test/operators

include/triton/Dialect/Triton/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ std::unique_ptr<Pass> createCombineOpsPass();
1111
std::unique_ptr<Pass> createReorderBroadcastPass();
1212
std::unique_ptr<Pass> createRewriteTensorPointerPass();
1313

14+
std::unique_ptr<Pass> createExpressionRestructingPass();
1415
} // namespace triton
1516

1617
#define GEN_PASS_REGISTRATION

include/triton/Dialect/Triton/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,18 @@ def TritonRewriteTensorPointer : Pass</*cli-arg*/"triton-rewrite-tensor-pointer"
4141
let dependentDialects = ["mlir::triton::TritonDialect"];
4242
}
4343

44+
def TritonExpressionRestructing : Pass</*cli-arg*/"triton-expression-resturcting", /*Op*/"mlir::ModuleOp"> {
45+
let summary = "ExpressionRestructing";
46+
let description = [{
47+
transform a = b / c; d = a / e; to a = c * e; d = b / a;
48+
transform a = b + c; d = a + c; to a = c + c; d = b + a;
49+
transform a = b - c; d = a - c; to a = c + c; d = b - a;
50+
transform a = b * c; d = a * c; to a = c * c; d = b * a;
51+
}];
52+
53+
let constructor = "mlir::triton::createExpressionRestructingPass()";
54+
55+
let dependentDialects = ["mlir::triton::TritonDialect", "mlir::arith::ArithDialect"];
56+
}
57+
4458
#endif

lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ add_triton_library(TritonTransforms
66
Combine.cpp
77
ReorderBroadcast.cpp
88
RewriteTensorPointer.cpp
9+
ExpressionRestructing.cpp
10+
911

1012
DEPENDS
1113
TritonTransformsIncGen
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#include <memory>
2+
3+
#include "mlir/IR/BuiltinAttributes.h"
4+
#include "mlir/IR/Matchers.h"
5+
#include "mlir/IR/PatternMatch.h"
6+
#include "mlir/Pass/Pass.h"
7+
#include "mlir/Support/LLVM.h"
8+
#include "mlir/Support/LogicalResult.h"
9+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
10+
#include "triton/Analysis/Utility.h"
11+
#include "triton/Dialect/Triton/IR/Dialect.h"
12+
#include "triton/Dialect/Triton/Transforms/Passes.h"
13+
14+
#define GEN_PASS_CLASSES
15+
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
16+
17+
using namespace mlir;
18+
using llvm::ArrayRef;
19+
namespace mlir::triton {
20+
21+
struct Div2Mul : public OpRewritePattern<arith::DivFOp> {
22+
using OpRewritePattern<arith::DivFOp>::OpRewritePattern;
23+
24+
LogicalResult matchAndRewrite(arith::DivFOp op,
25+
PatternRewriter &rewriter) const override {
26+
Value result = op.getResult();
27+
Value l = op.getLhs();
28+
Value r = op.getRhs();
29+
auto loc = op.getLoc();
30+
31+
if (!result.hasOneUse())
32+
return failure();
33+
for (auto &use : result.getUses()) {
34+
if (!dyn_cast<arith::DivFOp>(use.getOwner()))
35+
return failure();
36+
auto DivUser = dyn_cast<arith::DivFOp>(use.getOwner());
37+
if (DivUser.getLhs() != op.getResult())
38+
return failure();
39+
auto originalInsertionPoint = rewriter.saveInsertionPoint();
40+
rewriter.setInsertionPointAfter(DivUser);
41+
auto loc_div = DivUser.getLoc();
42+
auto product =
43+
rewriter.create<arith::MulFOp>(loc_div, r, DivUser.getRhs());
44+
rewriter.setInsertionPointAfter(product);
45+
auto ResultEnd =
46+
rewriter.create<arith::DivFOp>(loc_div, l, product.getResult());
47+
rewriter.restoreInsertionPoint(originalInsertionPoint);
48+
rewriter.replaceOp(op, product.getResult());
49+
DivUser.replaceAllUsesWith(ResultEnd.getResult());
50+
rewriter.eraseOp(DivUser);
51+
}
52+
return success();
53+
}
54+
};
55+
56+
struct Mul2Mul : public OpRewritePattern<arith::MulFOp> {
57+
using OpRewritePattern<arith::MulFOp>::OpRewritePattern;
58+
59+
LogicalResult matchAndRewrite(arith::MulFOp op,
60+
PatternRewriter &rewriter) const override {
61+
Value result = op.getResult();
62+
Value l = op.getLhs();
63+
Value r = op.getRhs();
64+
auto loc = op.getLoc();
65+
66+
if (!result.hasOneUse())
67+
return failure();
68+
for (auto &use : result.getUses()) {
69+
if (!dyn_cast<arith::MulFOp>(use.getOwner()))
70+
return failure();
71+
auto MulUser = dyn_cast<arith::MulFOp>(use.getOwner());
72+
if (!(MulUser.getLhs() == op.getResult() &&
73+
((MulUser.getRhs().getDefiningOp<arith::ConstantOp>() &&
74+
r.getDefiningOp<arith::ConstantOp>()) ||
75+
(r == MulUser.getRhs()))))
76+
return failure();
77+
auto originalInsertionPoint = rewriter.saveInsertionPoint();
78+
rewriter.setInsertionPointAfter(MulUser);
79+
auto loc_mul = MulUser.getLoc();
80+
auto product =
81+
rewriter.create<arith::MulFOp>(loc_mul, r, MulUser.getRhs());
82+
rewriter.setInsertionPointAfter(product);
83+
auto ResultEnd =
84+
rewriter.create<arith::MulFOp>(loc_mul, l, product.getResult());
85+
rewriter.restoreInsertionPoint(originalInsertionPoint);
86+
rewriter.replaceOp(op, product.getResult());
87+
MulUser.replaceAllUsesWith(ResultEnd.getResult());
88+
rewriter.eraseOp(MulUser);
89+
}
90+
return success();
91+
}
92+
};
93+
94+
struct Add2Add : public OpRewritePattern<arith::AddFOp> {
95+
using OpRewritePattern<arith::AddFOp>::OpRewritePattern;
96+
97+
LogicalResult matchAndRewrite(arith::AddFOp op,
98+
PatternRewriter &rewriter) const override {
99+
Value result = op.getResult();
100+
Value l = op.getLhs();
101+
Value r = op.getRhs();
102+
auto loc = op.getLoc();
103+
104+
if (!result.hasOneUse())
105+
return failure();
106+
for (auto &use : result.getUses()) {
107+
if (!dyn_cast<arith::AddFOp>(use.getOwner()))
108+
return failure();
109+
auto AddUser = dyn_cast<arith::AddFOp>(use.getOwner());
110+
if (!(AddUser.getLhs() == op.getResult() &&
111+
((AddUser.getRhs().getDefiningOp<arith::ConstantOp>() &&
112+
r.getDefiningOp<arith::ConstantOp>()) ||
113+
(r == AddUser.getRhs()))))
114+
return failure();
115+
auto originalInsertionPoint = rewriter.saveInsertionPoint();
116+
rewriter.setInsertionPointAfter(AddUser);
117+
auto loc_add = AddUser.getLoc();
118+
auto sum = rewriter.create<arith::AddFOp>(loc_add, r, AddUser.getRhs());
119+
rewriter.setInsertionPointAfter(sum);
120+
auto ResultEnd =
121+
rewriter.create<arith::AddFOp>(loc_add, l, sum.getResult());
122+
rewriter.restoreInsertionPoint(originalInsertionPoint);
123+
rewriter.replaceOp(op, sum.getResult());
124+
AddUser.replaceAllUsesWith(ResultEnd.getResult());
125+
rewriter.eraseOp(AddUser);
126+
}
127+
return success();
128+
}
129+
};
130+
131+
struct Sub2Add : public OpRewritePattern<arith::SubFOp> {
132+
using OpRewritePattern<arith::SubFOp>::OpRewritePattern;
133+
134+
LogicalResult matchAndRewrite(arith::SubFOp op,
135+
PatternRewriter &rewriter) const override {
136+
Value result = op.getResult();
137+
Value l = op.getLhs();
138+
Value r = op.getRhs();
139+
auto loc = op.getLoc();
140+
141+
if (!result.hasOneUse())
142+
return failure();
143+
for (auto &use : result.getUses()) {
144+
if (!dyn_cast<arith::SubFOp>(use.getOwner()))
145+
return failure();
146+
auto SubUser = dyn_cast<arith::SubFOp>(use.getOwner());
147+
if (!(SubUser.getLhs() == op.getResult() &&
148+
((SubUser.getRhs().getDefiningOp<arith::ConstantOp>() &&
149+
r.getDefiningOp<arith::ConstantOp>()) ||
150+
(r == SubUser.getRhs()))))
151+
return failure();
152+
auto originalInsertionPoint = rewriter.saveInsertionPoint();
153+
rewriter.setInsertionPointAfter(SubUser);
154+
auto loc_sub = SubUser.getLoc();
155+
auto sum = rewriter.create<arith::AddFOp>(loc_sub, r, SubUser.getRhs());
156+
rewriter.setInsertionPointAfter(sum);
157+
auto ResultEnd =
158+
rewriter.create<arith::SubFOp>(loc_sub, l, sum.getResult());
159+
rewriter.restoreInsertionPoint(originalInsertionPoint);
160+
rewriter.replaceOp(op, sum.getResult());
161+
SubUser.replaceAllUsesWith(ResultEnd.getResult());
162+
rewriter.eraseOp(SubUser);
163+
}
164+
return success();
165+
}
166+
};
167+
168+
class ExpressionRestructingPass
169+
: public TritonExpressionRestructingBase<ExpressionRestructingPass> {
170+
public:
171+
void runOnOperation() override {
172+
MLIRContext *context = &getContext();
173+
RewritePatternSet patterns(context);
174+
ModuleOp m = getOperation();
175+
patterns.add<Div2Mul>(context);
176+
patterns.add<Mul2Mul>(context);
177+
patterns.add<Add2Add>(context);
178+
patterns.add<Sub2Add>(context);
179+
180+
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
181+
signalPassFailure();
182+
}
183+
};
184+
185+
std::unique_ptr<mlir::Pass> createExpressionRestructingPass() {
186+
return std::make_unique<ExpressionRestructingPass>();
187+
}
188+
189+
} // namespace mlir::triton

python/src/passes.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ void init_triton_passes_ttir(py::module &&m) {
3737
using namespace mlir::triton;
3838
ADD_PASS_WRAPPER_0("add_combine", createCombineOpsPass);
3939
ADD_PASS_WRAPPER_0("add_reorder_broadcast", createReorderBroadcastPass);
40+
ADD_PASS_WRAPPER_0("add_expression_restructing",
41+
createExpressionRestructingPass);
4042
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
4143
createRewriteTensorPointerPass);
4244
ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir",
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import triton
2+
import triton.language as tl
3+
import torch
4+
5+
import pytest
6+
7+
VEC_SHAPES = [[64, 640], [32, 128], [128, 256]]
8+
9+
10+
def custom_rand_strided(shape, strides, device, dtype, seed=0):
11+
torch.manual_seed(seed)
12+
total_size = sum((s - 1) * st for s, st in zip(shape, strides)) + 1
13+
storage = torch.randn(total_size, device=device, dtype=dtype)
14+
return torch.as_strided(storage, size=shape, stride=strides)
15+
16+
17+
def torch_equivalent(arg_0, arg_1, arg_2, arg_3):
18+
reshaped_arg_0 = arg_0.view(arg_2.shape[0], arg_2.shape[0], arg_2.shape[2])
19+
reshaped_arg_3 = arg_3.squeeze(-1)
20+
tmp0 = -reshaped_arg_0
21+
tmp4 = arg_1 * arg_2
22+
tmp7 = reshaped_arg_3 + 1e-06
23+
tmp8 = tmp4 / tmp7.unsqueeze(-1)
24+
tmp9 = tmp8 / tmp7.unsqueeze(-1)
25+
result = tmp0 * tmp9
26+
return result
27+
28+
29+
@triton.jit
30+
def expression_restructuring_function_test(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr2, rnumel):
31+
XBLOCK: tl.constexpr = 1
32+
xoffset = tl.program_id(0) * XBLOCK
33+
RBLOCK: tl.constexpr = 1024
34+
xindex = tl.full([1], xoffset, tl.int32)
35+
rindex = tl.arange(0, RBLOCK)[:]
36+
rmask = rindex < rnumel
37+
r1 = rindex
38+
x0 = xindex
39+
tmp0 = tl.load(in_ptr0 + (r1 + (rnumel * x0)), rmask, other=0)
40+
tmp2 = tl.load(in_ptr1 + (r1), rmask, eviction_policy='evict_last', other=0)
41+
tmp3 = tl.load(in_ptr2 + (r1 + (rnumel * x0)), rmask, other=0)
42+
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
43+
tmp1 = -tmp0
44+
tmp4 = tmp2 * tmp3
45+
tmp6 = 1e-06
46+
tmp7 = tmp5 + tmp6
47+
tmp8 = tmp4 / tmp7
48+
tmp9 = tmp8 / tmp7
49+
tmp10 = tmp1 * tmp9
50+
tl.store(out_ptr2 + (r1 + (rnumel * x0)), tmp10, rmask)
51+
52+
53+
@pytest.mark.parametrize("vec_shape", VEC_SHAPES)
54+
def test_accruacy_kernel(vec_shape):
55+
x = vec_shape[0]
56+
y = vec_shape[1]
57+
arg_0 = custom_rand_strided((x * x, y), (y, 1), dtype=torch.float32, device='cuda')
58+
arg_1 = custom_rand_strided((y, ), (1, ), dtype=torch.float32, device='cuda')
59+
arg_2 = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda')
60+
arg_3 = custom_rand_strided((x, x, 1), (x, 1, 1), dtype=torch.float32, device='cuda')
61+
triton_result = custom_rand_strided((x, x, y), (x * y, y, 1), dtype=torch.float32, device='cuda')
62+
grid = lambda meta: (x * x, )
63+
expression_restructuring_function_test[grid](arg_0, arg_1, arg_2, arg_3, triton_result, y)
64+
torch_result = torch_equivalent(arg_0, arg_1, arg_2, arg_3)
65+
torch.testing.assert_close(triton_result, torch_result)

third_party/nvidia/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def make_ttir(mod, metadata, opt):
144144
passes.common.add_canonicalizer(pm)
145145
passes.ttir.add_reorder_broadcast(pm)
146146
passes.common.add_cse(pm)
147+
passes.ttir.add_expression_restructing(pm)
147148
passes.common.add_licm(pm)
148149
passes.common.add_symbol_dce(pm)
149150
pm.run(mod)

0 commit comments

Comments
 (0)