Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Data member pointer comparison and casts #1268

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def CK_FloatComplexToIntegralComplex
def CK_IntegralComplexCast : I32EnumAttrCase<"int_complex", 23>;
def CK_IntegralComplexToFloatComplex
: I32EnumAttrCase<"int_complex_to_float_complex", 24>;
def CK_MemberPtrToBoolean : I32EnumAttrCase<"member_ptr_to_bool", 25>;

def CastKind : I32EnumAttr<
"CastKind",
Expand All @@ -135,7 +136,7 @@ def CastKind : I32EnumAttr<
CK_FloatComplexToReal, CK_IntegralComplexToReal, CK_FloatComplexToBoolean,
CK_IntegralComplexToBoolean, CK_FloatComplexCast,
CK_FloatComplexToIntegralComplex, CK_IntegralComplexCast,
CK_IntegralComplexToFloatComplex]> {
CK_IntegralComplexToFloatComplex, CK_MemberPtrToBoolean]> {
let cppNamespace = "::cir";
}

Expand Down
22 changes: 17 additions & 5 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,12 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
};

if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
assert(0 && "not implemented");
assert(E->getOpcode() == BO_EQ || E->getOpcode() == BO_NE);
mlir::Value lhs = CGF.emitScalarExpr(E->getLHS());
mlir::Value rhs = CGF.emitScalarExpr(E->getRHS());
cir::CmpOpKind kind = ClangCmpToCIRCmp(E->getOpcode());
Result =
Builder.createCompare(CGF.getLoc(E->getExprLoc()), kind, lhs, rhs);
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
BinOpInfo BOInfo = emitBinOps(E);
mlir::Value LHS = BOInfo.LHS;
Expand Down Expand Up @@ -1741,8 +1746,11 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
auto Ty = mlir::cast<cir::DataMemberType>(CGF.convertType(DestTy));
return Builder.getNullDataMemberPtr(Ty, CGF.getLoc(E->getExprLoc()));
}
case CK_ReinterpretMemberPointer:
llvm_unreachable("NYI");
case CK_ReinterpretMemberPointer: {
mlir::Value src = Visit(E);
return Builder.createBitcast(CGF.getLoc(E->getExprLoc()), src,
CGF.convertType(DestTy));
}
case CK_BaseToDerivedMemberPointer:
case CK_DerivedToBaseMemberPointer: {
mlir::Value src = Visit(E);
Expand Down Expand Up @@ -1875,8 +1883,12 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
return emitPointerToBoolConversion(Visit(E), E->getType());
case CK_FloatingToBoolean:
return emitFloatToBoolConversion(Visit(E), CGF.getLoc(E->getExprLoc()));
case CK_MemberPointerToBoolean:
llvm_unreachable("NYI");
case CK_MemberPointerToBoolean: {
mlir::Value memPtr = Visit(E);
return Builder.createCast(CGF.getLoc(CE->getSourceRange()),
cir::CastKind::member_ptr_to_bool, memPtr,
CGF.convertType(DestTy));
}
case CK_FloatingComplexToReal:
case CK_IntegralComplexToReal:
case CK_FloatingComplexToBoolean:
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,11 @@ LogicalResult cir::CastOp::verify() {
return success();
}

// Handle the data member pointer types.
if (mlir::isa<cir::DataMemberType>(srcType) &&
mlir::isa<cir::DataMemberType>(resType))
return success();

// This is the only cast kind where we don't want vector types to decay
// into the element type.
if ((!mlir::isa<cir::VectorType>(getSrc().getType()) ||
Expand Down Expand Up @@ -705,6 +710,13 @@ LogicalResult cir::CastOp::verify() {
<< "requires !cir.complex<!cir.float> type for result";
return success();
}
case cir::CastKind::member_ptr_to_bool: {
if (!mlir::isa<cir::DataMemberType>(srcType))
return emitOpError() << "requires !cir.data_member type for source";
if (!mlir::isa<cir::BoolType>(resType))
return emitOpError() << "requires !cir.bool type for result";
return success();
}
}

llvm_unreachable("Unknown CastOp kind?");
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ class CIRCXXABI {
virtual mlir::Value
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value
lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

virtual mlir::Value
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;
};

/// Creates an Itanium-family ABI.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ class ItaniumCXXABI : public CIRCXXABI {
mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberCmp(cir::CmpOp op, mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value
lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;
};

} // namespace
Expand All @@ -89,18 +101,23 @@ bool ItaniumCXXABI::classifyReturnType(LowerFunctionInfo &FI) const {
return false;
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
static mlir::Type getABITypeForDataMember(LowerModule &lowerMod) {
// Itanium C++ ABI 2.3:
// A pointer to data member is an offset from the base address of
// the class object containing it, represented as a ptrdiff_t
const clang::TargetInfo &target = LM.getTarget();
const clang::TargetInfo &target = lowerMod.getTarget();
clang::TargetInfo::IntType ptrdiffTy =
target.getPtrDiffType(clang::LangAS::Default);
return cir::IntType::get(type.getContext(), target.getTypeWidth(ptrdiffTy),
return cir::IntType::get(lowerMod.getMLIRContext(),
target.getTypeWidth(ptrdiffTy),
target.isTypeSigned(ptrdiffTy));
}

mlir::Type ItaniumCXXABI::lowerDataMemberType(
cir::DataMemberType type, const mlir::TypeConverter &typeConverter) const {
return getABITypeForDataMember(LM);
}

mlir::TypedAttr ItaniumCXXABI::lowerDataMemberConstant(
cir::DataMemberAttr attr, const mlir::DataLayout &layout,
const mlir::TypeConverter &typeConverter) const {
Expand Down Expand Up @@ -175,6 +192,33 @@ ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
/*isDerivedToBase=*/false, builder);
}

mlir::Value ItaniumCXXABI::lowerDataMemberCmp(cir::CmpOp op,
mlir::Value loweredLhs,
mlir::Value loweredRhs,
mlir::OpBuilder &builder) const {
return builder.create<cir::CmpOp>(op.getLoc(), op.getKind(), loweredLhs,
loweredRhs);
}

mlir::Value
ItaniumCXXABI::lowerDataMemberBitcast(cir::CastOp op, mlir::Type loweredDstTy,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return builder.create<cir::CastOp>(op.getLoc(), loweredDstTy,
cir::CastKind::bitcast, loweredSrc);
}

mlir::Value
ItaniumCXXABI::lowerDataMemberToBoolCast(cir::CastOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
// Itanium C++ ABI 2.3:
// A NULL pointer is represented as -1.
auto nullAttr = cir::IntAttr::get(getABITypeForDataMember(LM), -1);
auto nullValue = builder.create<cir::ConstantOp>(op.getLoc(), nullAttr);
return builder.create<cir::CmpOp>(op.getLoc(), cir::CmpOpKind::ne, loweredSrc,
nullValue);
}

CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
switch (LM.getCXXABIKind()) {
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't
Expand Down
34 changes: 31 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1299,8 +1299,18 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
}
case cir::CastKind::bitcast: {
auto dstTy = castOp.getType();
auto llvmSrcVal = adaptor.getOperands().front();
auto llvmDstTy = getTypeConverter()->convertType(dstTy);

if (mlir::isa<cir::DataMemberType>(castOp.getSrc().getType())) {
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberBitcast(
castOp, llvmDstTy, src, rewriter);
rewriter.replaceOp(castOp, loweredResult);
return mlir::success();
}
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
llvm_unreachable("NYI");

auto llvmSrcVal = adaptor.getOperands().front();
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
llvmSrcVal);
return mlir::success();
Expand All @@ -1324,6 +1334,16 @@ mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
llvmSrcVal);
break;
}
case cir::CastKind::member_ptr_to_bool: {
mlir::Value loweredResult;
if (mlir::isa<cir::MethodType>(castOp.getSrc().getType()))
llvm_unreachable("NYI");
else
loweredResult = lowerMod->getCXXABI().lowerDataMemberToBoolCast(
castOp, src, rewriter);
rewriter.replaceOp(castOp, loweredResult);
break;
}
default: {
return castOp.emitError("Unhandled cast kind: ")
<< castOp.getKindAttrName();
Expand Down Expand Up @@ -2902,6 +2922,14 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
auto type = cmpOp.getLhs().getType();

if (mlir::isa<cir::DataMemberType>(type)) {
assert(lowerMod && "lowering module is not available");
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDataMemberCmp(
cmpOp, adaptor.getLhs(), adaptor.getRhs(), rewriter);
rewriter.replaceOp(cmpOp, loweredResult);
return mlir::success();
}

// Lower to LLVM comparison op.
// if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
Expand Down Expand Up @@ -4087,6 +4115,7 @@ void populateCIRToLLVMConversionPatterns(
argsVarMap, patterns.getContext());
patterns.add<
// clang-format off
CIRToLLVMCastOpLowering,
CIRToLLVMLoadOpLowering,
CIRToLLVMStoreOpLowering,
CIRToLLVMGlobalOpLowering,
Expand All @@ -4096,14 +4125,14 @@ void populateCIRToLLVMConversionPatterns(
patterns.add<
// clang-format off
CIRToLLVMBaseDataMemberOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMDerivedDataMemberOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering
// clang-format on
>(converter, patterns.getContext(), lowerModule);
patterns.add<
// clang-format off
CIRToLLVMPtrStrideOpLowering,
CIRToLLVMCastOpLowering,
CIRToLLVMInlineAsmOpLowering
// clang-format on
>(converter, patterns.getContext(), dataLayout);
Expand Down Expand Up @@ -4132,7 +4161,6 @@ void populateCIRToLLVMConversionPatterns(
CIRToLLVMCallOpLowering,
CIRToLLVMCatchParamOpLowering,
CIRToLLVMClearCacheOpLowering,
CIRToLLVMCmpOpLowering,
CIRToLLVMCmpThreeWayOpLowering,
CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexImagOpLowering,
Expand Down
15 changes: 12 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,18 @@ class CIRToLLVMBrCondOpLowering
};

class CIRToLLVMCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
cir::LowerModule *lowerMod;
mlir::DataLayout const &dataLayout;

mlir::Type convertTy(mlir::Type ty) const;

public:
CIRToLLVMCastOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}
using mlir::OpConversionPattern<cir::CastOp>::OpConversionPattern;
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
dataLayout(dataLayout) {}

mlir::LogicalResult
matchAndRewrite(cir::CastOp op, OpAdaptor,
Expand Down Expand Up @@ -649,8 +651,15 @@ class CIRToLLVMShiftOpLowering
};

class CIRToLLVMCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
cir::LowerModule *lowerMod;

public:
using mlir::OpConversionPattern<cir::CmpOp>::OpConversionPattern;
CIRToLLVMCmpOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
setHasBoundedRewriteRecursion();
}

mlir::LogicalResult
matchAndRewrite(cir::CmpOp op, OpAdaptor,
Expand Down
26 changes: 26 additions & 0 deletions clang/test/CIR/CodeGen/pointer-to-data-member-cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,29 @@ auto derived_to_base_zero_offset(int Derived::*ptr) -> int Base1::* {
// LLVM-NEXT: %[[#ret:]] = load i64, ptr %[[#ret_slot]]
// LLVM-NEXT: ret i64 %[[#ret]]
}

struct Foo {
int a;
};

struct Bar {
int a;
};

bool to_bool(int Foo::*x) {
return x;
}

// CIR-LABEL: @_Z7to_boolM3Fooi
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cast(member_ptr_to_bool, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.bool
// CIR: }

auto bitcast(int Foo::*x) {
return reinterpret_cast<int Bar::*>(x);
}

// CIR-LABEL: @_Z7bitcastM3Fooi
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cast(bitcast, %[[#x]] : !cir.data_member<!s32i in !ty_Foo>), !cir.data_member<!s32i in !ty_Bar>
// CIR: }
44 changes: 44 additions & 0 deletions clang/test/CIR/CodeGen/pointer-to-data-member-cmp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir --check-prefix=CIR %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++17 -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s

struct Foo {
int a;
};

struct Bar {
int a;
};

bool eq(int Foo::*x, int Foo::*y) {
return x == y;
}

// CIR-LABEL: @_Z2eqM3FooiS0_
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cmp(eq, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
// CIR: }

// LLVM-LABEL: @_Z2eqM3FooiS0_
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %{{.+}} = icmp eq i64 %[[#x]], %[[#y]]
// LLVM: }

bool ne(int Foo::*x, int Foo::*y) {
return x != y;
}

// CIR-LABEL: @_Z2neM3FooiS0_
// CIR: %[[#x:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %[[#y:]] = cir.load %{{.+}} : !cir.ptr<!cir.data_member<!s32i in !ty_Foo>>, !cir.data_member<!s32i in !ty_Foo>
// CIR-NEXT: %{{.+}} = cir.cmp(ne, %[[#x]], %[[#y]]) : !cir.data_member<!s32i in !ty_Foo>, !cir.bool
// CIR: }

// LLVM-LABEL: @_Z2neM3FooiS0_
// LLVM: %[[#x:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %[[#y:]] = load i64, ptr %{{.+}}, align 8
// LLVM-NEXT: %{{.+}} = icmp ne i64 %[[#x]], %[[#y]]
// LLVM: }
Loading