Skip to content

Commit 9f3d8e8

Browse files
authored
[CIR] Upstream support for while and do..while loops (#133157)
This adds basic support for while and do..while loops. Support for break and continue are left for a subsequent patch.
1 parent ce296f1 commit 9f3d8e8

File tree

8 files changed

+351
-19
lines changed

8 files changed

+351
-19
lines changed

Diff for: clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

+16
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
111111
return cir::BoolAttr::get(getContext(), getBoolTy(), state);
112112
}
113113

114+
/// Create a do-while operation.
115+
cir::DoWhileOp createDoWhile(
116+
mlir::Location loc,
117+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
118+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
119+
return create<cir::DoWhileOp>(loc, condBuilder, bodyBuilder);
120+
}
121+
122+
/// Create a while operation.
123+
cir::WhileOp createWhile(
124+
mlir::Location loc,
125+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> condBuilder,
126+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)> bodyBuilder) {
127+
return create<cir::WhileOp>(loc, condBuilder, bodyBuilder);
128+
}
129+
114130
/// Create a for operation.
115131
cir::ForOp createFor(
116132
mlir::Location loc,

Diff for: clang/include/clang/CIR/Dialect/IR/CIRDialect.h

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
#include "clang/CIR/Interfaces/CIRLoopOpInterface.h"
3333
#include "clang/CIR/Interfaces/CIROpInterfaces.h"
3434

35+
using BuilderCallbackRef =
36+
llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>;
37+
3538
// TableGen'erated files for MLIR dialects require that a macro be defined when
3639
// they are included. GET_OP_CLASSES tells the file to define the classes for
3740
// the operations of that dialect.

Diff for: clang/include/clang/CIR/Dialect/IR/CIROps.td

+100-3
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ def StoreOp : CIR_Op<"store", [
424424
// ReturnOp
425425
//===----------------------------------------------------------------------===//
426426

427-
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "ForOp"]>,
427+
def ReturnOp : CIR_Op<"return", [ParentOneOf<["FuncOp", "ScopeOp", "DoWhileOp",
428+
"WhileOp", "ForOp"]>,
428429
Terminator]> {
429430
let summary = "Return from function";
430431
let description = [{
@@ -511,7 +512,8 @@ def ConditionOp : CIR_Op<"condition", [
511512
//===----------------------------------------------------------------------===//
512513

513514
def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
514-
ParentOneOf<["ScopeOp", "ForOp"]>]> {
515+
ParentOneOf<["ScopeOp", "WhileOp", "ForOp",
516+
"DoWhileOp"]>]> {
515517
let summary = "Represents the default branching behaviour of a region";
516518
let description = [{
517519
The `cir.yield` operation terminates regions on different CIR operations,
@@ -759,11 +761,106 @@ def BrCondOp : CIR_Op<"brcond",
759761
}];
760762
}
761763

764+
//===----------------------------------------------------------------------===//
765+
// Common loop op definitions
766+
//===----------------------------------------------------------------------===//
767+
768+
class LoopOpBase<string mnemonic> : CIR_Op<mnemonic, [
769+
LoopOpInterface,
770+
NoRegionArguments,
771+
]> {
772+
let extraClassDefinition = [{
773+
void $cppClass::getSuccessorRegions(
774+
mlir::RegionBranchPoint point,
775+
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
776+
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
777+
}
778+
llvm::SmallVector<Region *> $cppClass::getLoopRegions() {
779+
return {&getBody()};
780+
}
781+
}];
782+
}
783+
784+
//===----------------------------------------------------------------------===//
785+
// While & DoWhileOp
786+
//===----------------------------------------------------------------------===//
787+
788+
class WhileOpBase<string mnemonic> : LoopOpBase<mnemonic> {
789+
defvar isWhile = !eq(mnemonic, "while");
790+
let summary = "C/C++ " # !if(isWhile, "while", "do-while") # " loop";
791+
let builders = [
792+
OpBuilder<(ins "BuilderCallbackRef":$condBuilder,
793+
"BuilderCallbackRef":$bodyBuilder), [{
794+
mlir::OpBuilder::InsertionGuard guard($_builder);
795+
$_builder.createBlock($_state.addRegion());
796+
}] # !if(isWhile, [{
797+
condBuilder($_builder, $_state.location);
798+
$_builder.createBlock($_state.addRegion());
799+
bodyBuilder($_builder, $_state.location);
800+
}], [{
801+
bodyBuilder($_builder, $_state.location);
802+
$_builder.createBlock($_state.addRegion());
803+
condBuilder($_builder, $_state.location);
804+
}])>
805+
];
806+
}
807+
808+
def WhileOp : WhileOpBase<"while"> {
809+
let regions = (region SizedRegion<1>:$cond, MinSizedRegion<1>:$body);
810+
let assemblyFormat = "$cond `do` $body attr-dict";
811+
812+
let description = [{
813+
Represents a C/C++ while loop. It consists of two regions:
814+
815+
- `cond`: single block region with the loop's condition. Should be
816+
terminated with a `cir.condition` operation.
817+
- `body`: contains the loop body and an arbitrary number of blocks.
818+
819+
Example:
820+
821+
```mlir
822+
cir.while {
823+
cir.break
824+
^bb2:
825+
cir.yield
826+
} do {
827+
cir.condition %cond : cir.bool
828+
}
829+
```
830+
}];
831+
}
832+
833+
def DoWhileOp : WhileOpBase<"do"> {
834+
let regions = (region MinSizedRegion<1>:$body, SizedRegion<1>:$cond);
835+
let assemblyFormat = " $body `while` $cond attr-dict";
836+
837+
let extraClassDeclaration = [{
838+
mlir::Region &getEntry() { return getBody(); }
839+
}];
840+
841+
let description = [{
842+
Represents a C/C++ do-while loop. Identical to `cir.while` but the
843+
condition is evaluated after the body.
844+
845+
Example:
846+
847+
```mlir
848+
cir.do {
849+
cir.break
850+
^bb2:
851+
cir.yield
852+
} while {
853+
cir.condition %cond : cir.bool
854+
}
855+
```
856+
}];
857+
}
858+
762859
//===----------------------------------------------------------------------===//
763860
// ForOp
764861
//===----------------------------------------------------------------------===//
765862

766-
def ForOp : CIR_Op<"for", [LoopOpInterface, NoRegionArguments]> {
863+
def ForOp : LoopOpBase<"for"> {
767864
let summary = "C/C++ for loop counterpart";
768865
let description = [{
769866
Represents a C/C++ for loop. It consists of three regions:

Diff for: clang/lib/CIR/CodeGen/CIRGenFunction.h

+4
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ class CIRGenFunction : public CIRGenTypeCache {
395395

396396
LValue emitBinaryOperatorLValue(const BinaryOperator *e);
397397

398+
mlir::LogicalResult emitDoStmt(const clang::DoStmt &s);
399+
398400
/// Emit an expression as an initializer for an object (variable, field, etc.)
399401
/// at the given location. The expression is not necessarily the normal
400402
/// initializer for the object, and the address is not necessarily
@@ -493,6 +495,8 @@ class CIRGenFunction : public CIRGenTypeCache {
493495
/// inside a function, including static vars etc.
494496
void emitVarDecl(const clang::VarDecl &d);
495497

498+
mlir::LogicalResult emitWhileStmt(const clang::WhileStmt &s);
499+
496500
/// ----------------------
497501
/// CIR build helpers
498502
/// -----------------

Diff for: clang/lib/CIR/CodeGen/CIRGenStmt.cpp

+111-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
7575

7676
case Stmt::ForStmtClass:
7777
return emitForStmt(cast<ForStmt>(*s));
78+
case Stmt::WhileStmtClass:
79+
return emitWhileStmt(cast<WhileStmt>(*s));
80+
case Stmt::DoStmtClass:
81+
return emitDoStmt(cast<DoStmt>(*s));
7882

7983
case Stmt::OMPScopeDirectiveClass:
8084
case Stmt::OMPErrorDirectiveClass:
@@ -97,8 +101,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
97101
case Stmt::SYCLKernelCallStmtClass:
98102
case Stmt::IfStmtClass:
99103
case Stmt::SwitchStmtClass:
100-
case Stmt::WhileStmtClass:
101-
case Stmt::DoStmtClass:
102104
case Stmt::CoroutineBodyStmtClass:
103105
case Stmt::CoreturnStmtClass:
104106
case Stmt::CXXTryStmtClass:
@@ -387,3 +389,110 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
387389
terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
388390
return mlir::success();
389391
}
392+
393+
mlir::LogicalResult CIRGenFunction::emitDoStmt(const DoStmt &s) {
394+
cir::DoWhileOp doWhileOp;
395+
396+
// TODO: pass in array of attributes.
397+
auto doStmtBuilder = [&]() -> mlir::LogicalResult {
398+
mlir::LogicalResult loopRes = mlir::success();
399+
assert(!cir::MissingFeatures::loopInfoStack());
400+
// From LLVM: if there are any cleanups between here and the loop-exit
401+
// scope, create a block to stage a loop exit along.
402+
// We probably already do the right thing because of ScopeOp, but make
403+
// sure we handle all cases.
404+
assert(!cir::MissingFeatures::requiresCleanups());
405+
406+
doWhileOp = builder.createDoWhile(
407+
getLoc(s.getSourceRange()),
408+
/*condBuilder=*/
409+
[&](mlir::OpBuilder &b, mlir::Location loc) {
410+
assert(!cir::MissingFeatures::createProfileWeightsForLoop());
411+
assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
412+
// C99 6.8.5p2/p4: The first substatement is executed if the
413+
// expression compares unequal to 0. The condition must be a
414+
// scalar type.
415+
mlir::Value condVal = evaluateExprAsBool(s.getCond());
416+
builder.createCondition(condVal);
417+
},
418+
/*bodyBuilder=*/
419+
[&](mlir::OpBuilder &b, mlir::Location loc) {
420+
// The scope of the do-while loop body is a nested scope.
421+
if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
422+
loopRes = mlir::failure();
423+
emitStopPoint(&s);
424+
});
425+
return loopRes;
426+
};
427+
428+
mlir::LogicalResult res = mlir::success();
429+
mlir::Location scopeLoc = getLoc(s.getSourceRange());
430+
builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
431+
[&](mlir::OpBuilder &b, mlir::Location loc) {
432+
LexicalScope lexScope{
433+
*this, loc, builder.getInsertionBlock()};
434+
res = doStmtBuilder();
435+
});
436+
437+
if (res.failed())
438+
return res;
439+
440+
terminateBody(builder, doWhileOp.getBody(), getLoc(s.getEndLoc()));
441+
return mlir::success();
442+
}
443+
444+
mlir::LogicalResult CIRGenFunction::emitWhileStmt(const WhileStmt &s) {
445+
cir::WhileOp whileOp;
446+
447+
// TODO: pass in array of attributes.
448+
auto whileStmtBuilder = [&]() -> mlir::LogicalResult {
449+
mlir::LogicalResult loopRes = mlir::success();
450+
assert(!cir::MissingFeatures::loopInfoStack());
451+
// From LLVM: if there are any cleanups between here and the loop-exit
452+
// scope, create a block to stage a loop exit along.
453+
// We probably already do the right thing because of ScopeOp, but make
454+
// sure we handle all cases.
455+
assert(!cir::MissingFeatures::requiresCleanups());
456+
457+
whileOp = builder.createWhile(
458+
getLoc(s.getSourceRange()),
459+
/*condBuilder=*/
460+
[&](mlir::OpBuilder &b, mlir::Location loc) {
461+
assert(!cir::MissingFeatures::createProfileWeightsForLoop());
462+
assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
463+
mlir::Value condVal;
464+
// If the for statement has a condition scope,
465+
// emit the local variable declaration.
466+
if (s.getConditionVariable())
467+
emitDecl(*s.getConditionVariable());
468+
// C99 6.8.5p2/p4: The first substatement is executed if the
469+
// expression compares unequal to 0. The condition must be a
470+
// scalar type.
471+
condVal = evaluateExprAsBool(s.getCond());
472+
builder.createCondition(condVal);
473+
},
474+
/*bodyBuilder=*/
475+
[&](mlir::OpBuilder &b, mlir::Location loc) {
476+
// The scope of the while loop body is a nested scope.
477+
if (emitStmt(s.getBody(), /*useCurrentScope=*/false).failed())
478+
loopRes = mlir::failure();
479+
emitStopPoint(&s);
480+
});
481+
return loopRes;
482+
};
483+
484+
mlir::LogicalResult res = mlir::success();
485+
mlir::Location scopeLoc = getLoc(s.getSourceRange());
486+
builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
487+
[&](mlir::OpBuilder &b, mlir::Location loc) {
488+
LexicalScope lexScope{
489+
*this, loc, builder.getInsertionBlock()};
490+
res = whileStmtBuilder();
491+
});
492+
493+
if (res.failed())
494+
return res;
495+
496+
terminateBody(builder, whileOp.getBody(), getLoc(s.getEndLoc()));
497+
return mlir::success();
498+
}

Diff for: clang/lib/CIR/Dialect/IR/CIRDialect.cpp

-14
Original file line numberDiff line numberDiff line change
@@ -538,20 +538,6 @@ Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
538538
return nullptr;
539539
}
540540

541-
//===----------------------------------------------------------------------===//
542-
// ForOp
543-
//===----------------------------------------------------------------------===//
544-
545-
void cir::ForOp::getSuccessorRegions(
546-
mlir::RegionBranchPoint point,
547-
llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
548-
LoopOpInterface::getLoopOpSuccessorRegions(*this, point, regions);
549-
}
550-
551-
llvm::SmallVector<Region *> cir::ForOp::getLoopRegions() {
552-
return {&getBody()};
553-
}
554-
555541
//===----------------------------------------------------------------------===//
556542
// GlobalOp
557543
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)