Skip to content

Commit 359ada9

Browse files
committed
[CIR] Data member pointer comparison and casts
This patch adds CIRGen and LLVM lowering support for the following language features related to pointers to data members: - Comparisons between pointers to data members. - Casting from pointers to data members to boolean. - Reinterpret casts between pointers to data members.
1 parent 1029b19 commit 359ada9

File tree

9 files changed

+205
-16
lines changed

9 files changed

+205
-16
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

+2-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def CK_FloatComplexToIntegralComplex
123123
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
124124
def CK_IntegralComplexToFloatComplex
125125
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;
126+
def CK_MemberPtrToBoolean : I32EnumAttrCase<"member_ptr_to_bool", 25>;
126127

127128
def CastKind : I32EnumAttr<
128129
"CastKind",
@@ -135,7 +136,7 @@ def CastKind : I32EnumAttr<
135136
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
136137
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
137138
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
138-
CK_IntegralComplexToFloatComplex]> {
139+
CK_IntegralComplexToFloatComplex, CK_MemberPtrToBoolean]> {
139140
let cppNamespace = "::cir";
140141
}
141142

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

+17-5
Original file line numberDiff line numberDiff line change
@@ -932,7 +932,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
932932
};
933933

934934
if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
935-
assert(0 && "not implemented");
935+
assert(E->getOpcode() == BO_EQ || E->getOpcode() == BO_NE);
936+
mlir::Value lhs = CGF.emitScalarExpr(E->getLHS());
937+
mlir::Value rhs = CGF.emitScalarExpr(E->getRHS());
938+
cir::CmpOpKind kind = ClangCmpToCIRCmp(E->getOpcode());
939+
Result =
940+
Builder.createCompare(CGF.getLoc(E->getExprLoc()), kind, lhs, rhs);
936941
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
937942
BinOpInfo BOInfo = emitBinOps(E);
938943
mlir::Value LHS = BOInfo.LHS;
@@ -1741,8 +1746,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
17411746
auto Ty = mlir::cast<cir::DataMemberType>(CGF.convertType(DestTy));
17421747
return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc()));
17431748
}
1744-
case CK_ReinterpretMemberPointer:
1745-
llvm_unreachable("NYI");
1749+
case CK_ReinterpretMemberPointer: {
1750+
mlir::Value src = Visit(E);
1751+
return Builder.createBitcast(CGF.getLoc(E->getExprLoc()), src,
1752+
CGF.convertType(DestTy));
1753+
}
17461754
case CK_BaseToDerivedMemberPointer:
17471755
case CK_DerivedToBaseMemberPointer: {
17481756
mlir::Value src = Visit(E);
@@ -1875,8 +1883,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
18751883
return emitPointerToBoolConversion(Visit(E), E->getType());
18761884
case CK_FloatingToBoolean:
18771885
return emitFloatToBoolConversion(Visit(E), CGF.getLoc(E->getExprLoc()));
1878-
case CK_MemberPointerToBoolean:
1879-
llvm_unreachable("NYI");
1886+
case CK_MemberPointerToBoolean: {
1887+
mlir::Value memPtr = Visit(E);
1888+
return Builder.createCast(CGF.getLoc(CE->getSourceRange()),
1889+
cir::CastKind::member_ptr_to_bool, memPtr,
1890+
CGF.convertType(DestTy));
1891+
}
18801892
case CK_FloatingComplexToReal:
18811893
case CK_IntegralComplexToReal:
18821894
case CK_FloatingComplexToBoolean:

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@ LogicalResult cir::CastOp::verify() {
530530
return success();
531531
}
532532

533+
// Handle the data member pointer types.
534+
if (mlir::isa<cir::DataMemberType>(srcType) &&
535+
mlir::isa<cir::DataMemberType>(resType))
536+
return success();
537+
533538
// This is the only cast kind where we don't want vector types to decay
534539
// into the element type.
535540
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||
@@ -705,6 +710,13 @@ LogicalResult cir::CastOp::verify() {
705710
<< "requires !cir.complex<!cir.float> type for result";
706711
return success();
707712
}
713+
case cir::CastKind::member_ptr_to_bool: {
714+
if (!mlir::isa<cir::DataMemberType>(srcType))
715+
return emitOpError() << "requires !cir.data_member type for source";
716+
if (!mlir::isa<cir::BoolType>(resType))
717+
return emitOpError() << "requires !cir.bool type for result";
718+
return success();
719+
}
708720
}
709721

710722
llvm_unreachable("Unknown CastOp kind?");

clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h

+13
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,19 @@ class CIRCXXABI {
9797
virtual mlir::Value
9898
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
9999
mlir::OpBuilder &builder) const = 0;
100+
101+
virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
102+
mlir::Value loweredRhs,
103+
mlir::OpBuilder &builder) const = 0;
104+
105+
virtual mlir::Value
106+
lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
107+
mlir::Value loweredSrc,
108+
mlir::OpBuilder &builder) const = 0;
109+
110+
virtual mlir::Value
111+
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
112+
mlir::OpBuilder &builder) const = 0;
100113
};
101114

102115
/// Creates an Itanium-family ABI.

clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp

+48-4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ class ItaniumCXXABI : public CIRCXXABI {
7373
mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
7474
mlir::Value loweredSrc,
7575
mlir::OpBuilder &builder) const override;
76+
77+
mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
78+
mlir::Value loweredRhs,
79+
mlir::OpBuilder &builder) const override;
80+
81+
mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
82+
mlir::Value loweredSrc,
83+
mlir::OpBuilder &builder) const override;
84+
85+
mlir::Value
86+
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
87+
mlir::OpBuilder &builder) const override;
7688
};
7789

7890
} // namespace
@@ -89,18 +101,23 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
89101
return false;
90102
}
91103

92-
mlir::Type ItaniumCXXABI::lowerDataMemberType(
93-
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
104+
static mlir::Type getABITypeForDataMember(LowerModule &lowerMod) {
94105
// Itanium C++ ABI 2.3:
95106
// A pointer to data member is an offset from the base address of
96107
// the class object containing it, represented as a ptrdiff_t
97-
const clang::TargetInfo &target = LM.getTarget();
108+
const clang::TargetInfo &target = lowerMod.getTarget();
98109
clang::TargetInfo::IntType ptrdiffTy =
99110
target.getPtrDiffType(clang::LangAS::Default);
100-
return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy),
111+
return cir::IntType::get(lowerMod.getMLIRContext(),
112+
target.getTypeWidth(ptrdiffTy),
101113
target.isTypeSigned(ptrdiffTy));
102114
}
103115

116+
mlir::Type ItaniumCXXABI::lowerDataMemberType(
117+
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
118+
return getABITypeForDataMember(LM);
119+
}
120+
104121
mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant(
105122
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
106123
const mlir::TypeConverter &typeConverter) const {
@@ -175,6 +192,33 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
175192
/*isDerivedToBase=*/false, builder);
176193
}
177194

195+
mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
196+
mlir::Value loweredLhs,
197+
mlir::Value loweredRhs,
198+
mlir::OpBuilder &builder) const {
199+
return builder.create<cir::CmpOp>(op.getLoc(), op.getKind(), loweredLhs,
200+
loweredRhs);
201+
}
202+
203+
mlir::Value
204+
ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
205+
mlir::Value loweredSrc,
206+
mlir::OpBuilder &builder) const {
207+
return builder.create<cir::CastOp>(op.getLoc(), loweredDstTy,
208+
cir::CastKind::bitcast, loweredSrc);
209+
}
210+
211+
mlir::Value
212+
ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
213+
mlir::OpBuilder &builder) const {
214+
// Itanium C++ ABI 2.3:
215+
// A NULL pointer is represented as -1.
216+
auto nullAttr = cir::IntAttr::get(getABITypeForDataMember(LM), -1);
217+
auto nullValue = builder.create<cir::ConstantOp>(op.getLoc(), nullAttr);
218+
return builder.create<cir::CmpOp>(op.getLoc(), cir::CmpOpKind::ne, loweredSrc,
219+
nullValue);
220+
}
221+
178222
CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
179223
switch (LM.getCXXABIKind()) {
180224
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

+31-3
Original file line numberDiff line numberDiff line change
@@ -1299,8 +1299,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
12991299
}
13001300
case cir::CastKind::bitcast: {
13011301
auto dstTy = castOp.getType();
1302-
auto llvmSrcVal = adaptor.getOperands().front();
13031302
auto llvmDstTy = getTypeConverter()->convertType(dstTy);
1303+
1304+
if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType())) {
1305+
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
1306+
castOp, llvmDstTy, src, rewriter);
1307+
rewriter.replaceOp(castOp, loweredResult);
1308+
return mlir::success();
1309+
}
1310+
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
1311+
llvm_unreachable("NYI");
1312+
1313+
auto llvmSrcVal = adaptor.getOperands().front();
13041314
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
13051315
llvmSrcVal);
13061316
return mlir::success();
@@ -1324,6 +1334,16 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
13241334
llvmSrcVal);
13251335
break;
13261336
}
1337+
case cir::CastKind::member_ptr_to_bool: {
1338+
mlir::Value loweredResult;
1339+
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
1340+
llvm_unreachable("NYI");
1341+
else
1342+
loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast(
1343+
castOp, src, rewriter);
1344+
rewriter.replaceOp(castOp, loweredResult);
1345+
break;
1346+
}
13271347
default: {
13281348
return castOp.emitError("Unhandled cast kind: ")
13291349
<< castOp.getKindAttrName();
@@ -2902,6 +2922,14 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
29022922
mlir::ConversionPatternRewriter &rewriter) const {
29032923
auto type = cmpOp.getLhs().getType();
29042924

2925+
if (mlir::isa<cir::DataMemberType>(type)) {
2926+
assert(lowerMod && "lowering module is not available");
2927+
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
2928+
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
2929+
rewriter.replaceOp(cmpOp, loweredResult);
2930+
return mlir::success();
2931+
}
2932+
29052933
// Lower to LLVM comparison op.
29062934
// if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
29072935
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
@@ -4087,6 +4115,7 @@ void populateCIRToLLVMConversionPatterns(
40874115
argsVarMap, patterns.getContext());
40884116
patterns.add<
40894117
// clang-format off
4118+
CIRToLLVMCastOpLowering,
40904119
CIRToLLVMLoadOpLowering,
40914120
CIRToLLVMStoreOpLowering,
40924121
CIRToLLVMGlobalOpLowering,
@@ -4096,14 +4125,14 @@ void populateCIRToLLVMConversionPatterns(
40964125
patterns.add<
40974126
// clang-format off
40984127
CIRToLLVMBaseDataMemberOpLowering,
4128+
CIRToLLVMCmpOpLowering,
40994129
CIRToLLVMDerivedDataMemberOpLowering,
41004130
CIRToLLVMGetRuntimeMemberOpLowering
41014131
// clang-format on
41024132
>(converter, patterns.getContext(), lowerModule);
41034133
patterns.add<
41044134
// clang-format off
41054135
CIRToLLVMPtrStrideOpLowering,
4106-
CIRToLLVMCastOpLowering,
41074136
CIRToLLVMInlineAsmOpLowering
41084137
// clang-format on
41094138
>(converter, patterns.getContext(), dataLayout);
@@ -4132,7 +4161,6 @@ void populateCIRToLLVMConversionPatterns(
41324161
CIRToLLVMCallOpLowering,
41334162
CIRToLLVMCatchParamOpLowering,
41344163
CIRToLLVMClearCacheOpLowering,
4135-
CIRToLLVMCmpOpLowering,
41364164
CIRToLLVMCmpThreeWayOpLowering,
41374165
CIRToLLVMComplexCreateOpLowering,
41384166
CIRToLLVMComplexImagOpLowering,

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

+12-3
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,18 @@ class CIRToLLVMBrCondOpLowering
232232
};
233233

234234
class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
235+
cir::LowerModule *lowerMod;
235236
mlir::DataLayout const &dataLayout;
236237

237238
mlir::Type convertTy(mlir::Type ty) const;
238239

239240
public:
240241
CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter,
241242
mlir::MLIRContext *context,
243+
cir::LowerModule *lowerModule,
242244
mlir::DataLayout const &dataLayout)
243-
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}
244-
using mlir::OpConversionPattern<cir::CastOp>::OpConversionPattern;
245+
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
246+
dataLayout(dataLayout) {}
245247

246248
mlir::LogicalResult
247249
matchAndRewrite(cir::CastOp op, OpAdaptor,
@@ -649,8 +651,15 @@ class CIRToLLVMShiftOpLowering
649651
};
650652

651653
class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
654+
cir::LowerModule *lowerMod;
655+
652656
public:
653-
using mlir::OpConversionPattern<cir::CmpOp>::OpConversionPattern;
657+
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
658+
mlir::MLIRContext *context,
659+
cir::LowerModule *lowerModule)
660+
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
661+
setHasBoundedRewriteRecursion();
662+
}
654663

655664
mlir::LogicalResult
656665
matchAndRewrite(cir::CmpOp op, OpAdaptor,

clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,29 @@ auto derived_to_base_zero_offset(int Derived::*ptr) -> int Base1::* {
7070
// LLVM-NEXT: %[[#ret:]] = load i64, ptr %[[#ret_slot]]
7171
// LLVM-NEXT: ret i64 %[[#ret]]
7272
}
73+
74+
struct Foo {
75+
int a;
76+
};
77+
78+
struct Bar {
79+
int a;
80+
};
81+
82+
bool to_bool(int Foo::*x) {
83+
return x;
84+
}
85+
86+
// CIR-LABEL: @_Z7to_boolM3Fooi
87+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
88+
// CIR-NEXT: %{{.+}} = cir.cast(member_ptr_to_bool, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.bool
89+
// CIR: }
90+
91+
auto bitcast(int Foo::*x) {
92+
return reinterpret_cast<int Bar::*>(x);
93+
}
94+
95+
// CIR-LABEL: @_Z7bitcastM3Fooi
96+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
97+
// CIR-NEXT: %{{.+}} = cir.cast(bitcast, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.data_member<!s32i in !ty_Bar>
98+
// CIR: }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir --check-prefix=CIR %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t.ll
4+
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s
5+
6+
struct Foo {
7+
int a;
8+
};
9+
10+
struct Bar {
11+
int a;
12+
};
13+
14+
bool eq(int Foo::*x, int Foo::*y) {
15+
return x == y;
16+
}
17+
18+
// CIR-LABEL: @_Z2eqM3FooiS0_
19+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
20+
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
21+
// CIR-NEXT: %{{.+}} = cir.cmp(eq, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
22+
// CIR: }
23+
24+
// LLVM-LABEL: @_Z2eqM3FooiS0_
25+
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
26+
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
27+
// LLVM-NEXT: %{{.+}} = icmp eq i64 %[[#x]], %[[#y]]
28+
// LLVM: }
29+
30+
bool ne(int Foo::*x, int Foo::*y) {
31+
return x != y;
32+
}
33+
34+
// CIR-LABEL: @_Z2neM3FooiS0_
35+
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
36+
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
37+
// CIR-NEXT: %{{.+}} = cir.cmp(ne, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
38+
// CIR: }
39+
40+
// LLVM-LABEL: @_Z2neM3FooiS0_
41+
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
42+
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
43+
// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#x]], %[[#y]]
44+
// LLVM: }

0 commit comments

Comments
 (0)