diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 85493e17c915..8250667b2589 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -123,6 +123,7 @@ def CK_FloatComplexToIntegralComplex def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>; def CK_IntegralComplexToFloatComplex : I32EnumAttrCase<"int_complex_to_float_complex", 24>; +def CK_MemberPtrToBoolean : I32EnumAttrCase<"member_ptr_to_bool", 25>; def CastKind : I32EnumAttr< "CastKind", @@ -135,7 +136,7 @@ def CastKind : I32EnumAttr< CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean, CK_IntegralComplexToBoolean, CK_FloatComplexCast, CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast, - CK_IntegralComplexToFloatComplex]> { + CK_IntegralComplexToFloatComplex, CK_MemberPtrToBoolean]> { let cppNamespace = "::cir"; } diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index c0b6ac3c78e7..836e78c32176 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -932,7 +932,12 @@ class ScalarExprEmitter : public StmtVisitor { }; if (const MemberPointerType *MPT = LHSTy->getAs()) { - assert(0 && "not implemented"); + assert(E->getOpcode() == BO_EQ || E->getOpcode() == BO_NE); + mlir::Value lhs = CGF.emitScalarExpr(E->getLHS()); + mlir::Value rhs = CGF.emitScalarExpr(E->getRHS()); + cir::CmpOpKind kind = ClangCmpToCIRCmp(E->getOpcode()); + Result = + Builder.createCompare(CGF.getLoc(E->getExprLoc()), kind, lhs, rhs); } else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) { BinOpInfo BOInfo = emitBinOps(E); mlir::Value LHS = BOInfo.LHS; @@ -1741,8 +1746,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { auto Ty = mlir::cast(CGF.convertType(DestTy)); return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc())); } - case CK_ReinterpretMemberPointer: - llvm_unreachable("NYI"); + case CK_ReinterpretMemberPointer: { + mlir::Value src = Visit(E); + return Builder.createBitcast(CGF.getLoc(E->getExprLoc()), src, + CGF.convertType(DestTy)); + } case CK_BaseToDerivedMemberPointer: case CK_DerivedToBaseMemberPointer: { mlir::Value src = Visit(E); @@ -1875,8 +1883,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) { return emitPointerToBoolConversion(Visit(E), E->getType()); case CK_FloatingToBoolean: return emitFloatToBoolConversion(Visit(E), CGF.getLoc(E->getExprLoc())); - case CK_MemberPointerToBoolean: - llvm_unreachable("NYI"); + case CK_MemberPointerToBoolean: { + mlir::Value memPtr = Visit(E); + return Builder.createCast(CGF.getLoc(CE->getSourceRange()), + cir::CastKind::member_ptr_to_bool, memPtr, + CGF.convertType(DestTy)); + } case CK_FloatingComplexToReal: case CK_IntegralComplexToReal: case CK_FloatingComplexToBoolean: diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index a1eb11007261..edc5eda1f4a1 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -530,6 +530,11 @@ LogicalResult cir::CastOp::verify() { return success(); } + // Handle the data member pointer types. + if (mlir::isa(srcType) && + mlir::isa(resType)) + return success(); + // This is the only cast kind where we don't want vector types to decay // into the element type. if ((!mlir::isa(getSrc().getType()) || @@ -705,6 +710,13 @@ LogicalResult cir::CastOp::verify() { << "requires !cir.complex type for result"; return success(); } + case cir::CastKind::member_ptr_to_bool: { + if (!mlir::isa(srcType)) + return emitOpError() << "requires !cir.data_member type for source"; + if (!mlir::isa(resType)) + return emitOpError() << "requires !cir.bool type for result"; + return success(); + } } llvm_unreachable("Unknown CastOp kind?"); diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index 830d5589fbe9..a1948059d783 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -97,6 +97,19 @@ class CIRCXXABI { virtual mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc, mlir::OpBuilder &builder) const = 0; + + virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const = 0; + + virtual mlir::Value + lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const = 0; + + virtual mlir::Value + lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc, + mlir::OpBuilder &builder) const = 0; }; /// Creates an Itanium-family ABI. diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp index f22eca2f15c6..f3569eca9e0a 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp @@ -73,6 +73,18 @@ class ItaniumCXXABI : public CIRCXXABI { mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; + + mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const override; + + mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const override; + + mlir::Value + lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc, + mlir::OpBuilder &builder) const override; }; } // namespace @@ -89,18 +101,23 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const { return false; } -mlir::Type ItaniumCXXABI::lowerDataMemberType( - cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const { +static mlir::Type getABITypeForDataMember(LowerModule &lowerMod) { // Itanium C++ ABI 2.3: // A pointer to data member is an offset from the base address of // the class object containing it, represented as a ptrdiff_t - const clang::TargetInfo &target = LM.getTarget(); + const clang::TargetInfo &target = lowerMod.getTarget(); clang::TargetInfo::IntType ptrdiffTy = target.getPtrDiffType(clang::LangAS::Default); - return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy), + return cir::IntType::get(lowerMod.getMLIRContext(), + target.getTypeWidth(ptrdiffTy), target.isTypeSigned(ptrdiffTy)); } +mlir::Type ItaniumCXXABI::lowerDataMemberType( + cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const { + return getABITypeForDataMember(LM); +} + mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant( cir::DataMemberAttr attr, const mlir::DataLayout &layout, const mlir::TypeConverter &typeConverter) const { @@ -175,6 +192,33 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op, /*isDerivedToBase=*/false, builder); } +mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op, + mlir::Value loweredLhs, + mlir::Value loweredRhs, + mlir::OpBuilder &builder) const { + return builder.create(op.getLoc(), op.getKind(), loweredLhs, + loweredRhs); +} + +mlir::Value +ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy, + mlir::Value loweredSrc, + mlir::OpBuilder &builder) const { + return builder.create(op.getLoc(), loweredDstTy, + cir::CastKind::bitcast, loweredSrc); +} + +mlir::Value +ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc, + mlir::OpBuilder &builder) const { + // Itanium C++ ABI 2.3: + // A NULL pointer is represented as -1. + auto nullAttr = cir::IntAttr::get(getABITypeForDataMember(LM), -1); + auto nullValue = builder.create(op.getLoc(), nullAttr); + return builder.create(op.getLoc(), cir::CmpOpKind::ne, loweredSrc, + nullValue); +} + CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) { switch (LM.getCXXABIKind()) { // Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 8848ca0f1e7d..6c16732232e2 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1299,8 +1299,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( } case cir::CastKind::bitcast: { auto dstTy = castOp.getType(); - auto llvmSrcVal = adaptor.getOperands().front(); auto llvmDstTy = getTypeConverter()->convertType(dstTy); + + if (mlir::isa(castOp.getSrc().getType())) { + mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast( + castOp, llvmDstTy, src, rewriter); + rewriter.replaceOp(castOp, loweredResult); + return mlir::success(); + } + if (mlir::isa(castOp.getSrc().getType())) + llvm_unreachable("NYI"); + + auto llvmSrcVal = adaptor.getOperands().front(); rewriter.replaceOpWithNewOp(castOp, llvmDstTy, llvmSrcVal); return mlir::success(); @@ -1324,6 +1334,16 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( llvmSrcVal); break; } + case cir::CastKind::member_ptr_to_bool: { + mlir::Value loweredResult; + if (mlir::isa(castOp.getSrc().getType())) + llvm_unreachable("NYI"); + else + loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast( + castOp, src, rewriter); + rewriter.replaceOp(castOp, loweredResult); + break; + } default: { return castOp.emitError("Unhandled cast kind: ") << castOp.getKindAttrName(); @@ -2902,6 +2922,14 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( mlir::ConversionPatternRewriter &rewriter) const { auto type = cmpOp.getLhs().getType(); + if (mlir::isa(type)) { + assert(lowerMod && "lowering module is not available"); + mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp( + cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter); + rewriter.replaceOp(cmpOp, loweredResult); + return mlir::success(); + } + // Lower to LLVM comparison op. // if (auto intTy = mlir::dyn_cast(type)) { if (mlir::isa(type)) { @@ -4087,6 +4115,7 @@ void populateCIRToLLVMConversionPatterns( argsVarMap, patterns.getContext()); patterns.add< // clang-format off + CIRToLLVMCastOpLowering, CIRToLLVMLoadOpLowering, CIRToLLVMStoreOpLowering, CIRToLLVMGlobalOpLowering, @@ -4096,6 +4125,7 @@ void populateCIRToLLVMConversionPatterns( patterns.add< // clang-format off CIRToLLVMBaseDataMemberOpLowering, + CIRToLLVMCmpOpLowering, CIRToLLVMDerivedDataMemberOpLowering, CIRToLLVMGetRuntimeMemberOpLowering // clang-format on @@ -4103,7 +4133,6 @@ void populateCIRToLLVMConversionPatterns( patterns.add< // clang-format off CIRToLLVMPtrStrideOpLowering, - CIRToLLVMCastOpLowering, CIRToLLVMInlineAsmOpLowering // clang-format on >(converter, patterns.getContext(), dataLayout); @@ -4132,7 +4161,6 @@ void populateCIRToLLVMConversionPatterns( CIRToLLVMCallOpLowering, CIRToLLVMCatchParamOpLowering, CIRToLLVMClearCacheOpLowering, - CIRToLLVMCmpOpLowering, CIRToLLVMCmpThreeWayOpLowering, CIRToLLVMComplexCreateOpLowering, CIRToLLVMComplexImagOpLowering, diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 264ae29a0e85..104ce3a0b105 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -232,6 +232,7 @@ class CIRToLLVMBrCondOpLowering }; class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern { + cir::LowerModule *lowerMod; mlir::DataLayout const &dataLayout; mlir::Type convertTy(mlir::Type ty) const; @@ -239,9 +240,10 @@ class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern { public: CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, + cir::LowerModule *lowerModule, mlir::DataLayout const &dataLayout) - : OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {} - using mlir::OpConversionPattern::OpConversionPattern; + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule), + dataLayout(dataLayout) {} mlir::LogicalResult matchAndRewrite(cir::CastOp op, OpAdaptor, @@ -649,8 +651,15 @@ class CIRToLLVMShiftOpLowering }; class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern { + cir::LowerModule *lowerMod; + public: - using mlir::OpConversionPattern::OpConversionPattern; + CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + cir::LowerModule *lowerModule) + : OpConversionPattern(typeConverter, context), lowerMod(lowerModule) { + setHasBoundedRewriteRecursion(); + } mlir::LogicalResult matchAndRewrite(cir::CmpOp op, OpAdaptor, diff --git a/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp b/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp index 63625236e42a..51913c09af23 100644 --- a/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp +++ b/clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp @@ -70,3 +70,29 @@ auto derived_to_base_zero_offset(int Derived::*ptr) -> int Base1::* { // LLVM-NEXT: %[[#ret:]] = load i64, ptr %[[#ret_slot]] // LLVM-NEXT: ret i64 %[[#ret]] } + +struct Foo { + int a; +}; + +struct Bar { + int a; +}; + +bool to_bool(int Foo::*x) { + return x; +} + +// CIR-LABEL: @_Z7to_boolM3Fooi +// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr>, !cir.data_member +// CIR-NEXT: %{{.+}} = cir.cast(member_ptr_to_bool, %[[#x]] : !cir.data_member), !cir.bool +// CIR: } + +auto bitcast(int Foo::*x) { + return reinterpret_cast(x); +} + +// CIR-LABEL: @_Z7bitcastM3Fooi +// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr>, !cir.data_member +// CIR-NEXT: %{{.+}} = cir.cast(bitcast, %[[#x]] : !cir.data_member), !cir.data_member +// CIR: } diff --git a/clang/test/CIR/CodeGen/pointer-to-data-member-cmp.cpp b/clang/test/CIR/CodeGen/pointer-to-data-member-cmp.cpp new file mode 100644 index 000000000000..ebcf141de32b --- /dev/null +++ b/clang/test/CIR/CodeGen/pointer-to-data-member-cmp.cpp @@ -0,0 +1,44 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir --check-prefix=CIR %s +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t.ll +// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s + +struct Foo { + int a; +}; + +struct Bar { + int a; +}; + +bool eq(int Foo::*x, int Foo::*y) { + return x == y; +} + +// CIR-LABEL: @_Z2eqM3FooiS0_ +// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr>, !cir.data_member +// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr>, !cir.data_member +// CIR-NEXT: %{{.+}} = cir.cmp(eq, %[[#x]], %[[#y]]) : !cir.data_member, !cir.bool +// CIR: } + +// LLVM-LABEL: @_Z2eqM3FooiS0_ +// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8 +// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8 +// LLVM-NEXT: %{{.+}} = icmp eq i64 %[[#x]], %[[#y]] +// LLVM: } + +bool ne(int Foo::*x, int Foo::*y) { + return x != y; +} + +// CIR-LABEL: @_Z2neM3FooiS0_ +// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr>, !cir.data_member +// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr>, !cir.data_member +// CIR-NEXT: %{{.+}} = cir.cmp(ne, %[[#x]], %[[#y]]) : !cir.data_member, !cir.bool +// CIR: } + +// LLVM-LABEL: @_Z2neM3FooiS0_ +// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8 +// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8 +// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#x]], %[[#y]] +// LLVM: }