Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1513,17 +1513,41 @@ mlir::Value ScalarExprEmitter::emitSub(const BinOpInfo &Ops) {
if (!mlir::isa<cir::PointerType>(Ops.RHS.getType()))
return emitPointerArithmetic(CGF, Ops, /*isSubtraction=*/true);

// Otherwise, this is a pointer subtraction
const BinaryOperator *Expr = cast<BinaryOperator>(Ops.E);
QualType ElementType = Expr->getLHS()->getType()->getPointeeType();

// Do the raw subtraction part.
//
// Check if this is a VLA pointee type.
if (const auto *VLA = CGF.getContext().getAsVariableArrayType(ElementType)) {
llvm_unreachable("NYI: CIR ptrdiff on VLA pointee");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also need a missing else { ... }, we try to keep the skeleton close to make it easier to spot differences against classic codegen, it also needs to keep the address space cast dance in some other function to make it easier to see that difference.

// TODO(cir): note for LLVM lowering out of this; when expanding this into
// LLVM we shall take VLA's, division by element size, etc.
// now we just ensure that all pointers are on the proper address space.
mlir::Value LHS = Ops.LHS;
mlir::Value RHS = Ops.RHS;
mlir::Location loc = CGF.getLoc(Ops.Loc);

cir::PointerType LHSPtrTy = mlir::dyn_cast<cir::PointerType>(LHS.getType());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names should be camelBack, pre-existing ones you can leave as is, at some point they'd be migrated, but new ones should fit the coding style

cir::PointerType RHSPtrTy = mlir::dyn_cast<cir::PointerType>(RHS.getType());

if (LHSPtrTy && RHSPtrTy) {
auto LHSAS = LHSPtrTy.getAddrSpace();
auto RHSAS = RHSPtrTy.getAddrSpace();

if (LHSAS != RHSAS) {
// Different address spaces → use addrspacecast
RHS = Builder.createAddrSpaceCast(RHS, LHSPtrTy);
} else if (LHSPtrTy != RHSPtrTy) {
// Same addrspace but different pointee/type → bitcast is fine
RHS = Builder.createBitcast(RHS, LHSPtrTy);
}
}
// Do the raw subtraction part.
//
// See more in `EmitSub` in CGExprScalar.cpp.
assert(!cir::MissingFeatures::llvmLoweringPtrDiffConsidersPointee());
return cir::PtrDiffOp::create(Builder, CGF.getLoc(Ops.Loc), CGF.PtrDiffTy,
Ops.LHS, Ops.RHS);
LHS, RHS);
}

mlir::Value ScalarExprEmitter::emitShl(const BinOpInfo &Ops) {
Expand Down
60 changes: 60 additions & 0 deletions clang/test/CIR/CodeGen/HIP/ptr-diff.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "cuda.h"

// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
// RUN: -fcuda-is-device -fhip-new-launch-api \
// RUN: -I%S/../Inputs/ -emit-cir %s -o %t.ll
// RUN: FileCheck --check-prefix=CIR-DEVICE --input-file=%t.ll %s

// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip -fclangir \
// RUN: -fcuda-is-device -fhip-new-launch-api \
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=LLVM-DEVICE --input-file=%t.ll %s

// RUN: %clang_cc1 -triple=amdgcn-amd-amdhsa -x hip \
// RUN: -fcuda-is-device -fhip-new-launch-api \
// RUN: -I%S/../Inputs/ -emit-llvm %s -o %t.ll
// RUN: FileCheck --check-prefix=OGCG-DEVICE --input-file=%t.ll %s

__device__ int ptr_diff() {
const char c_str[] = "c-string";
const char* len = c_str;
return c_str - len;
}


// CIR-DEVICE: %[[#LenLocalAddr:]] = cir.alloca !cir.ptr<!s8i>, !cir.ptr<!cir.ptr<!s8i>>, ["len", init]
// CIR-DEVICE: %[[#GlobalPtr:]] = cir.get_global @_ZZ8ptr_diffvE5c_str : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)>
// CIR-DEVICE: %[[#CastDecay:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)>
// CIR-DEVICE: %[[#LenLocalAddrCast:]] = cir.cast bitcast %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>> -> !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>>
// CIR-DEVICE: cir.store align(8) %[[#CastDecay]], %[[#LenLocalAddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>, !cir.ptr<!cir.ptr<!s8i, addrspace(offload_constant)>>
// CIR-DEVICE: %[[#CStr:]] = cir.cast array_to_ptrdecay %[[#GlobalPtr]] : !cir.ptr<!cir.array<!s8i x 9>, addrspace(offload_constant)> -> !cir.ptr<!s8i, addrspace(offload_constant)>
// CIR-DEVICE: %[[#LoadedLenAddr:]] = cir.load align(8) %[[#LenLocalAddr]] : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i> loc(#loc7)
// CIR-DEVICE: %[[#AddrCast:]] = cir.cast address_space %[[#LoadedLenAddr]] : !cir.ptr<!s8i> -> !cir.ptr<!s8i, addrspace(offload_constant)>
// CIR-DEVICE: %[[#DIFF:]] = cir.ptr_diff %[[#CStr]], %[[#AddrCast]] : !cir.ptr<!s8i, addrspace(offload_constant)>

// LLVM-DEVICE: define dso_local i32 @_Z8ptr_diffv()
// LLVM-DEVICE: %[[#GlobalPtrAddr:]] = alloca i32, i64 1, align 4, addrspace(5)
// LLVM-DEVICE: %[[#GlobalPtrCast:]] = addrspacecast ptr addrspace(5) %[[#GlobalPtrAddr]] to ptr
// LLVM-DEVICE: %[[#LenLocalAddr:]] = alloca ptr, i64 1, align 8, addrspace(5)
// LLVM-DEVICE: %[[#LenLocalAddrCast:]] = addrspacecast ptr addrspace(5) %[[#LenLocalAddr]] to ptr
// LLVM-DEVICE: store ptr addrspace(4) @_ZZ8ptr_diffvE5c_str, ptr %[[#LenLocalAddrCast]], align 8
// LLVM-DEVICE: %[[#LoadedAddr:]] = load ptr, ptr %[[#LenLocalAddrCast]], align 8
// LLVM-DEVICE: %[[#CastedVal:]] = addrspacecast ptr %[[#LoadedAddr]] to ptr addrspace(4)
// LLVM-DEVICE: %[[#IntVal:]] = ptrtoint ptr addrspace(4) %[[#CastedVal]] to i64
// LLVM-DEVICE: %[[#SubVal:]] = sub i64 ptrtoint (ptr addrspace(4) @_ZZ8ptr_diffvE5c_str to i64), %[[#IntVal]]

// OGCG-DEVICE: define dso_local noundef i32 @_Z8ptr_diffv() #0
// OGCG-DEVICE: %[[RETVAL:.*]] = alloca i32, align 4, addrspace(5)
// OGCG-DEVICE: %[[C_STR:.*]] = alloca [9 x i8], align 1, addrspace(5)
// OGCG-DEVICE: %[[LEN:.*]] = alloca ptr, align 8, addrspace(5)
// OGCG-DEVICE: %[[RETVAL_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[RETVAL]] to ptr
// OGCG-DEVICE: %[[C_STR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[C_STR]] to ptr
// OGCG-DEVICE: %[[LEN_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[LEN]] to ptr
// OGCG-DEVICE: %[[ARRAYDECAY:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0
// OGCG-DEVICE: store ptr %[[ARRAYDECAY]], ptr %[[LEN_ASCAST]], align 8
// OGCG-DEVICE: %[[ARRAYDECAY1:.*]] = getelementptr inbounds [9 x i8], ptr %[[C_STR_ASCAST]], i64 0, i64 0
// OGCG-DEVICE: %[[LOADED:.*]] = load ptr, ptr %[[LEN_ASCAST]], align 8
// OGCG-DEVICE: %[[LHS:.*]] = ptrtoint ptr %[[ARRAYDECAY1]] to i64
// OGCG-DEVICE: %[[RHS:.*]] = ptrtoint ptr %[[LOADED]] to i64
// OGCG-DEVICE: %[[SUB:.*]] = sub i64 %[[LHS]], %[[RHS]]
// OGCG-DEVICE: %[[CONV:.*]] = trunc i64 %[[SUB]] to i32
Loading