Skip to content

Commit cb4b6ad

Browse files
authored
[CIR] Add the ability to detect if SwitchOp covers all the cases (#171246)
1 parent 4f9d5a8 commit cb4b6ad

File tree

8 files changed

+90
-65
lines changed

8 files changed

+90
-65
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,12 @@ def CIR_SwitchOp : CIR_Op<"switch", [
10801080
conditionally executing multiple regions of code. The operand to an switch
10811081
is an integral condition value.
10821082

1083+
Besides taking an integer condition and CIR regions, it also accepts an
1084+
`all_enum_cases_covered` attribute indicating whether all enum cases are
1085+
handled by the operation. Note that the presence of a default CaseOp does
1086+
not imply `all_enum_cases_covered`. The original AST switch must explicitly list
1087+
every enum case.
1088+
10831089
The set of `cir.case` operations and their enclosing `cir.switch`
10841090
represent the semantics of a C/C++ switch statement. Users can use
10851091
`collectCases(llvm::SmallVector<CaseOp> &cases)` to collect the `cir.case`
@@ -1206,7 +1212,10 @@ def CIR_SwitchOp : CIR_Op<"switch", [
12061212
```
12071213
}];
12081214

1209-
let arguments = (ins CIR_IntType:$condition);
1215+
let arguments = (ins
1216+
CIR_IntType:$condition,
1217+
UnitAttr:$allEnumCasesCovered
1218+
);
12101219

12111220
let regions = (region AnyRegion:$body);
12121221

@@ -1217,9 +1226,9 @@ def CIR_SwitchOp : CIR_Op<"switch", [
12171226
];
12181227

12191228
let assemblyFormat = [{
1220-
custom<SwitchOp>(
1221-
$body, $condition, type($condition)
1222-
)
1229+
`(` $condition `:` qualified(type($condition)) `)`
1230+
(`allEnumCasesCovered` $allEnumCasesCovered^)?
1231+
$body
12231232
attr-dict
12241233
}];
12251234

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,8 @@ mlir::LogicalResult CIRGenFunction::emitSwitchStmt(const clang::SwitchStmt &s) {
11051105
terminateBody(builder, caseOp.getCaseRegion(), caseOp.getLoc());
11061106
terminateBody(builder, swop.getBody(), swop.getLoc());
11071107

1108+
swop.setAllEnumCasesCovered(s.isAllEnumCasesCovered());
1109+
11081110
return res;
11091111
}
11101112

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,44 +1359,6 @@ void cir::CaseOp::build(OpBuilder &builder, OperationState &result,
13591359
// SwitchOp
13601360
//===----------------------------------------------------------------------===//
13611361

1362-
static ParseResult parseSwitchOp(OpAsmParser &parser, mlir::Region &regions,
1363-
mlir::OpAsmParser::UnresolvedOperand &cond,
1364-
mlir::Type &condType) {
1365-
cir::IntType intCondType;
1366-
1367-
if (parser.parseLParen())
1368-
return mlir::failure();
1369-
1370-
if (parser.parseOperand(cond))
1371-
return mlir::failure();
1372-
if (parser.parseColon())
1373-
return mlir::failure();
1374-
if (parser.parseCustomTypeWithFallback(intCondType))
1375-
return mlir::failure();
1376-
condType = intCondType;
1377-
1378-
if (parser.parseRParen())
1379-
return mlir::failure();
1380-
if (parser.parseRegion(regions, /*arguments=*/{}, /*argTypes=*/{}))
1381-
return failure();
1382-
1383-
return mlir::success();
1384-
}
1385-
1386-
static void printSwitchOp(OpAsmPrinter &p, cir::SwitchOp op,
1387-
mlir::Region &bodyRegion, mlir::Value condition,
1388-
mlir::Type condType) {
1389-
p << "(";
1390-
p << condition;
1391-
p << " : ";
1392-
p.printStrippedAttrOrType(condType);
1393-
p << ")";
1394-
1395-
p << ' ';
1396-
p.printRegion(bodyRegion, /*printEntryBlockArgs=*/false,
1397-
/*printBlockTerminators=*/true);
1398-
}
1399-
14001362
void cir::SwitchOp::getSuccessorRegions(
14011363
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &region) {
14021364
if (!point.isParent()) {

clang/test/CIR/CodeGen/atomic.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,7 @@ int atomic_load_dynamic_order(int *ptr, int order) {
11431143

11441144
// CIR: %[[PTR:.+]] = cir.load align(8) %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
11451145
// CIR-NEXT: %[[ORDER:.+]] = cir.load align(4) %{{.+}} : !cir.ptr<!s32i>, !s32i
1146-
// CIR-NEXT: cir.switch (%[[ORDER]] : !s32i) {
1146+
// CIR-NEXT: cir.switch(%[[ORDER]] : !s32i) {
11471147
// CIR-NEXT: cir.case(default, []) {
11481148
// CIR-NEXT: %[[RES:.+]] = cir.load align(4) syncscope(system) atomic(relaxed) %[[PTR]] : !cir.ptr<!s32i>, !s32i
11491149
// CIR-NEXT: cir.store align(4) %[[RES]], %[[RES_SLOT:.+]] : !s32i, !cir.ptr<!s32i>
@@ -1219,7 +1219,7 @@ void atomic_store_dynamic_order(int *ptr, int order) {
12191219

12201220
// CIR: %[[PTR:.+]] = cir.load align(8) %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
12211221
// CIR-NEXT: %[[ORDER:.+]] = cir.load align(4) %{{.+}} : !cir.ptr<!s32i>, !s32i
1222-
// CIR: cir.switch (%[[ORDER]] : !s32i) {
1222+
// CIR: cir.switch(%[[ORDER]] : !s32i) {
12231223
// CIR-NEXT: cir.case(default, []) {
12241224
// CIR-NEXT: %[[VALUE:.+]] = cir.load align(4) %{{.+}} : !cir.ptr<!s32i>, !s32i
12251225
// CIR-NEXT: cir.store align(4) atomic(relaxed) %[[VALUE]], %[[PTR]] : !s32i, !cir.ptr<!s32i>
@@ -1288,7 +1288,7 @@ int atomic_load_and_store_dynamic_order(int *ptr, int order) {
12881288

12891289
// CIR: %[[PTR:.+]] = cir.load align(8) %{{.+}} : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
12901290
// CIR-NEXT: %[[ORDER:.+]] = cir.load align(4) %{{.+}} : !cir.ptr<!s32i>, !s32i
1291-
// CIR: cir.switch (%[[ORDER]] : !s32i) {
1291+
// CIR: cir.switch(%[[ORDER]] : !s32i) {
12921292
// CIR-NEXT: cir.case(default, []) {
12931293
// CIR-NEXT: %[[LIT:.+]] = cir.load align(4) %{{.+}} : !cir.ptr<!s32i>, !s32i
12941294
// CIR-NEXT: %[[RES:.+]] = cir.atomic.xchg relaxed %[[PTR]], %[[LIT]] : (!cir.ptr<!s32i>, !s32i) -> !s32i

clang/test/CIR/CodeGen/switch.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ void sw1(int a) {
2020
}
2121

2222
// CIR: cir.func{{.*}} @_Z3sw1i
23-
// CIR: cir.switch (%[[COND:.*]] : !s32i) {
23+
// CIR: cir.switch(%[[COND:.*]] : !s32i) {
2424
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
2525
// CIR: cir.break
2626
// CIR: cir.case(equal, [#cir.int<1> : !s32i]) {
@@ -101,7 +101,7 @@ void sw2(int a) {
101101
// CIR: cir.scope {
102102
// CIR-NEXT: %[[YOLO:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["yolo", init]
103103
// CIR-NEXT: %[[FOMO:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["fomo", init]
104-
// CIR: cir.switch (%[[COND:.*]] : !s32i) {
104+
// CIR: cir.switch(%[[COND:.*]] : !s32i) {
105105
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
106106
// CIR-NEXT: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
107107
// CIR-NEXT: cir.store{{.*}} %[[ZERO]], %[[FOMO]] : !s32i, !cir.ptr<!s32i>
@@ -154,7 +154,7 @@ void sw3(int a) {
154154
// CIR: cir.func{{.*}} @_Z3sw3i
155155
// CIR: cir.scope {
156156
// CIR-NEXT: %[[COND:.*]] = cir.load{{.*}} %[[A:.*]] : !cir.ptr<!s32i>, !s32i
157-
// CIR-NEXT: cir.switch (%[[COND]] : !s32i) {
157+
// CIR-NEXT: cir.switch(%[[COND]] : !s32i) {
158158
// CIR-NEXT: cir.case(default, []) {
159159
// CIR-NEXT: cir.break
160160
// CIR-NEXT: }
@@ -196,7 +196,7 @@ int sw4(int a) {
196196
}
197197

198198
// CIR: cir.func{{.*}} @_Z3sw4i
199-
// CIR: cir.switch (%[[COND:.*]] : !s32i) {
199+
// CIR: cir.switch(%[[COND:.*]] : !s32i) {
200200
// CIR-NEXT: cir.case(equal, [#cir.int<42> : !s32i]) {
201201
// CIR-NEXT: cir.scope {
202202
// CIR-NEXT: %[[THREE:.*]] = cir.const #cir.int<3> : !s32i
@@ -264,7 +264,7 @@ void sw5(int a) {
264264
}
265265

266266
// CIR: cir.func{{.*}} @_Z3sw5i
267-
// CIR: cir.switch (%[[A:.*]] : !s32i) {
267+
// CIR: cir.switch(%[[A:.*]] : !s32i) {
268268
// CIR-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
269269
// CIR-NEXT: cir.yield
270270
// CIR-NEXT: }
@@ -314,7 +314,7 @@ void sw6(int a) {
314314
}
315315

316316
// CIR: cir.func{{.*}} @_Z3sw6i
317-
// CIR: cir.switch (%[[A:.*]] : !s32i) {
317+
// CIR: cir.switch(%[[A:.*]] : !s32i) {
318318
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
319319
// CIR-NEXT: cir.yield
320320
// CIR-NEXT: }
@@ -406,7 +406,7 @@ void sw7(int a) {
406406

407407
// CIR: cir.func{{.*}} @_Z3sw7i
408408
// CIR: %[[X:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["x"]
409-
// CIR: cir.switch (%[[A:.*]] : !s32i)
409+
// CIR: cir.switch(%[[A:.*]] : !s32i)
410410
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
411411
// CIR-NEXT: cir.yield
412412
// CIR-NEXT: }
@@ -499,7 +499,7 @@ void sw8(int a) {
499499
}
500500

501501
// CIR: cir.func{{.*}} @_Z3sw8i
502-
// CIR: cir.switch (%[[A:.*]] : !s32i)
502+
// CIR: cir.switch(%[[A:.*]] : !s32i)
503503
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
504504
// CIR-NEXT: cir.break
505505
// CIR-NEXT: }
@@ -557,7 +557,7 @@ void sw9(int a) {
557557
}
558558

559559
// CIR: cir.func{{.*}} @_Z3sw9i
560-
// CIR: cir.switch (%[[A:.*]] : !s32i)
560+
// CIR: cir.switch(%[[A:.*]] : !s32i)
561561
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
562562
// CIR-NEXT: cir.break
563563
// CIR-NEXT: }
@@ -616,7 +616,7 @@ void sw10(int a) {
616616
}
617617

618618
// CIR: cir.func{{.*}} @_Z4sw10i
619-
// CIR: cir.switch (%[[A:.*]] : !s32i)
619+
// CIR: cir.switch(%[[A:.*]] : !s32i)
620620
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
621621
// CIR-NEXT: cir.break
622622
// CIR-NEXT: }
@@ -688,7 +688,7 @@ void sw11(int a) {
688688
}
689689

690690
// CIR: cir.func{{.*}} @_Z4sw11i
691-
// CIR: cir.switch (%[[A:.*]] : !s32i)
691+
// CIR: cir.switch(%[[A:.*]] : !s32i)
692692
// CIR-NEXT: cir.case(equal, [#cir.int<3> : !s32i]) {
693693
// CIR-NEXT: cir.break
694694
// CIR-NEXT: }
@@ -1063,7 +1063,7 @@ int nested_switch(int a) {
10631063
return 0;
10641064
}
10651065

1066-
// CIR: cir.switch (%[[COND:.*]] : !s32i) {
1066+
// CIR: cir.switch(%[[COND:.*]] : !s32i) {
10671067
// CIR: cir.case(equal, [#cir.int<0> : !s32i]) {
10681068
// CIR: cir.yield
10691069
// CIR: }
@@ -1198,7 +1198,7 @@ int sw_return_multi_cases(int x) {
11981198
}
11991199

12001200
// CIR-LABEL: cir.func{{.*}} @_Z21sw_return_multi_casesi
1201-
// CIR: cir.switch (%{{.*}} : !s32i) {
1201+
// CIR: cir.switch(%{{.*}} : !s32i) {
12021202
// CIR-NEXT: cir.case(equal, [#cir.int<0> : !s32i]) {
12031203
// CIR: %[[ZERO:.*]] = cir.const #cir.int<0> : !s32i
12041204
// CIR: cir.store{{.*}} %[[ZERO]], %{{.*}} : !s32i, !cir.ptr<!s32i>
@@ -1270,3 +1270,25 @@ int sw_return_multi_cases(int x) {
12701270
// OGCG: [[RETURN]]:
12711271
// OGCG: %[[RETVAL_LOAD:.*]] = load i32, ptr %[[RETVAL]], align 4
12721272
// OGCG: ret i32 %[[RETVAL_LOAD]]
1273+
1274+
enum M {
1275+
Six,
1276+
Seven
1277+
};
1278+
1279+
void testSwitchCoverAllCase(M m) {
1280+
switch (m) {
1281+
case Six:case Seven:
1282+
break;
1283+
}
1284+
}
1285+
// CIR: cir.switch(%[[ARG:.*]] : !s32i) allEnumCasesCovered {
1286+
1287+
void testSwitchNotCoverAllCase(M m) {
1288+
switch (m) {
1289+
case Six:
1290+
default:
1291+
break;
1292+
}
1293+
}
1294+
// CIR: cir.switch(%[[ARG:.*]] : !s32i) {

clang/test/CIR/CodeGen/switch_flat_op.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void swf(int a) {
2121
// BEFORE: cir.func{{.*}} @_Z3swfi
2222
// BEFORE: %[[VAR_B:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["b", init] {alignment = 4 : i64}
2323
// BEFORE: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
24-
// BEFORE: cir.switch (%[[COND:.*]] : !s32i) {
24+
// BEFORE: cir.switch(%[[COND:.*]] : !s32i) {
2525
// BEFORE: cir.case(equal, [#cir.int<3> : !s32i]) {
2626
// BEFORE: %[[LOAD_B_EQ:.*]] = cir.load{{.*}} %[[VAR_B]] : !cir.ptr<!s32i>, !s32i
2727
// BEFORE: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i

clang/test/CIR/IR/switch.cir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ cir.func @s0() {
2121
cir.return
2222
}
2323

24-
// CHECK: cir.switch (%0 : !s32i) {
24+
// CHECK: cir.switch(%0 : !s32i) {
2525
// CHECK-NEXT: cir.case(default, []) {
2626
// CHECK-NEXT: cir.return
2727
// CHECK-NEXT: }
@@ -36,3 +36,33 @@ cir.func @s0() {
3636
// CHECK-NEXT: }
3737
// CHECK-NEXT: cir.yield
3838
// CHECK-NEXT: }
39+
40+
41+
// Pretends that this is lowered from a C file and was tagged with allEnumCasesCovered = true
42+
cir.func @s1(%1 : !s32i) {
43+
cir.switch (%1 : !s32i) allEnumCasesCovered {
44+
cir.case (default, []) {
45+
cir.return
46+
}
47+
cir.case (equal, [#cir.int<1> : !s32i]) {
48+
cir.yield
49+
}
50+
cir.case (equal, [#cir.int<2> : !s32i]) {
51+
cir.yield
52+
}
53+
cir.yield
54+
} { }
55+
cir.return
56+
}
57+
// CHECK: cir.switch(%[[ARG:.*]] : !s32i) allEnumCasesCovered {
58+
// CHECK-NEXT: cir.case(default, []) {
59+
// CHECK-NEXT: cir.return
60+
// CHECK-NEXT: }
61+
// CHECK-NEXT: cir.case(equal, [#cir.int<1> : !s32i]) {
62+
// CHECK-NEXT: cir.yield
63+
// CHECK-NEXT: }
64+
// CHECK-NEXT: cir.case(equal, [#cir.int<2> : !s32i]) {
65+
// CHECK-NEXT: cir.yield
66+
// CHECK-NEXT: }
67+
// CHECK-NEXT: cir.yield
68+
// CHECK-NEXT: }

clang/test/CIR/Transforms/switch-fold.cir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ module {
2727
cir.return
2828
}
2929
//CHECK: cir.func @foldCascade
30-
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
30+
//CHECK: cir.switch(%[[COND:.*]] : !s32i) {
3131
//CHECK-NEXT: cir.case(anyof, [#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i]) {
3232
//CHECK-NEXT: %[[TWO:.*]] = cir.const #cir.int<2> : !s32i
3333
//CHECK-NEXT: cir.store %[[TWO]], %[[ARG0:.*]] : !s32i, !cir.ptr<!s32i>
@@ -66,7 +66,7 @@ module {
6666
cir.return
6767
}
6868
//CHECK: @foldCascade2
69-
//CHECK: cir.switch (%[[COND2:.*]] : !s32i) {
69+
//CHECK: cir.switch(%[[COND2:.*]] : !s32i) {
7070
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<2> : !s32i, #cir.int<4> : !s32i]) {
7171
//CHECK: cir.break
7272
//cehck: }
@@ -106,7 +106,7 @@ module {
106106
cir.return
107107
}
108108
//CHECK: cir.func @foldCascade3
109-
//CHECK: cir.switch (%[[COND3:.*]] : !s32i) {
109+
//CHECK: cir.switch(%[[COND3:.*]] : !s32i) {
110110
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
111111
//CHECK: cir.break
112112
//CHECK: }
@@ -142,7 +142,7 @@ module {
142142
cir.return
143143
}
144144
//CHECK: cir.func @foldCascadeWithDefault
145-
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
145+
//CHECK: cir.switch(%[[COND:.*]] : !s32i) {
146146
//CHECK: cir.case(equal, [#cir.int<3> : !s32i]) {
147147
//CHECK: cir.break
148148
//CHECK: }
@@ -187,7 +187,7 @@ module {
187187
cir.return
188188
}
189189
//CHECK: cir.func @foldAllCascade
190-
//CHECK: cir.switch (%[[COND:.*]] : !s32i) {
190+
//CHECK: cir.switch(%[[COND:.*]] : !s32i) {
191191
//CHECK: cir.case(anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i, #cir.int<5> : !s32i]) {
192192
//CHECK: cir.yield
193193
//CHECK: }

0 commit comments

Comments
 (0)