Skip to content

[SelectionDAG][X86] Split via Concat <n x T> vector types for atomic load #120640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/jofrn/spr/main/2894ccd1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_FPOp_MultiType(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_IS_FPCLASS(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD);
void SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_LOAD(VPLoadSDNode *LD, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_STRIDED_LOAD(VPStridedLoadSDNode *SLD, SDValue &Lo,
Expand Down
33 changes: 33 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
SplitVecRes_STEP_VECTOR(N, Lo, Hi);
break;
case ISD::SIGN_EXTEND_INREG: SplitVecRes_InregOp(N, Lo, Hi); break;
case ISD::ATOMIC_LOAD:
SplitVecRes_ATOMIC_LOAD(cast<AtomicSDNode>(N));
break;
case ISD::LOAD:
SplitVecRes_LOAD(cast<LoadSDNode>(N), Lo, Hi);
break;
Expand Down Expand Up @@ -1423,6 +1426,36 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
SetSplitVector(SDValue(N, ResNo), Lo, Hi);
}

void DAGTypeLegalizer::SplitVecRes_ATOMIC_LOAD(AtomicSDNode *LD) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing the Lo / Hi out arguments, like all of the other SplitVecRes cases. You should still be trying to respect the result of DAG.GetSplitDestVTs. With this you are bypassing SetSplitVector

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump

SDLoc dl(LD);

EVT MemoryVT = LD->getMemoryVT();
unsigned NumElts = MemoryVT.getVectorMinNumElements();

EVT IntMemoryVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, NumElts);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't be assuming the type. This still should follow allow with SplitVecRes_LOAD, by using GetSplitDestVTs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the type is coerced so that we load integers. It didn't seem appropriate to split into Lo and Hi here since atomics are different. GetSplitDestVTs will return floats when given a float, but we do not want to load from those (so as to reuse infrastructure), and we want to load the full size at once (not with a split size).

EVT ElemVT =
EVT::getVectorVT(*DAG.getContext(), MemoryVT.getVectorElementType(), 1);

// Create a single atomic to load all the elements at once.
SDValue Atomic =
DAG.getAtomicLoad(ISD::NON_EXTLOAD, dl, IntMemoryVT, IntMemoryVT,
LD->getChain(), LD->getBasePtr(),
LD->getMemOperand());

// Instead of splitting, put all the elements back into a vector.
SmallVector<SDValue, 4> Ops;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop can be replaced with DAG.ExtractVectorElements (but I don't think you should be scalarizing here like this)

for (unsigned i = 0; i < NumElts; ++i) {
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i16, Atomic,
DAG.getVectorIdxConstant(i, dl));
Elt = DAG.getBitcast(ElemVT, Elt);
Ops.push_back(Elt);
}
SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, dl, MemoryVT, Ops);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think using CONCAT_VECTORS with 2 scale elements is valid. This should be setting the Lo and Hi fields like the other SplitVecRes_* functions do anyway, and not doing manual replacement. Here will need to do manual replacement of the load chain though

Copy link
Contributor Author

@jofrn jofrn Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is valid like so:

t21: v2i16,ch = AtomicLoad<(load acquire (s32) from %ir.x)> t0, t2
      t23: i16 = extract_vector_elt t21, Constant:i64<0>
    t24: v1bf16 = bitcast t23
      t26: i16 = extract_vector_elt t21, Constant:i64<1>
    t27: v1bf16 = bitcast t26
  t28: v2bf16 = concat_vectors t24, t27
t4: v2f16 = bitcast t28

In this way, if t4 is consumed by another concat_vectors, then they can all be combined via EltsFromConsecutiveLoads into a BUILD_VECTOR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd consider this a bug, I would expect this to assert in getNode

Copy link
Contributor Author

@jofrn jofrn Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may be looking at different things. What do you mean by a 2 scale element being used? The operands to concat_vectors are v1bf16. The result is a v2bf16.


ReplaceValueWith(SDValue(LD, 0), Concat);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replacement should be done piecewise in the caller, you should be inserting casting code to satisfy the two DAG.GetSplitDestVTs pieces

ReplaceValueWith(SDValue(LD, 1), LD->getChain());
}

void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
MachinePointerInfo &MPI, SDValue &Ptr,
uint64_t *ScaledOffset) {
Expand Down
171 changes: 171 additions & 0 deletions llvm/test/CodeGen/X86/atomic-load-store.ll
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,76 @@ define <2 x float> @atomic_vec2_float_align(ptr %x) {
ret <2 x float> %ret
}

define <2 x half> @atomic_vec2_half(ptr %x) {
; CHECK3-LABEL: atomic_vec2_half:
; CHECK3: ## %bb.0:
; CHECK3-NEXT: movl (%rdi), %eax
; CHECK3-NEXT: movd %eax, %xmm1
; CHECK3-NEXT: shrl $16, %eax
; CHECK3-NEXT: pinsrw $0, %eax, %xmm2
; CHECK3-NEXT: movdqa {{.*#+}} xmm0 = [65535,0,65535,65535,65535,65535,65535,65535]
; CHECK3-NEXT: pand %xmm0, %xmm1
; CHECK3-NEXT: pslld $16, %xmm2
; CHECK3-NEXT: pandn %xmm2, %xmm0
; CHECK3-NEXT: por %xmm1, %xmm0
; CHECK3-NEXT: retq
;
; CHECK0-LABEL: atomic_vec2_half:
; CHECK0: ## %bb.0:
; CHECK0-NEXT: movl (%rdi), %eax
; CHECK0-NEXT: movl %eax, %ecx
; CHECK0-NEXT: shrl $16, %ecx
; CHECK0-NEXT: movw %cx, %dx
; CHECK0-NEXT: ## implicit-def: $ecx
; CHECK0-NEXT: movw %dx, %cx
; CHECK0-NEXT: ## implicit-def: $xmm2
; CHECK0-NEXT: pinsrw $0, %ecx, %xmm2
; CHECK0-NEXT: movd %eax, %xmm0
; CHECK0-NEXT: movaps {{.*#+}} xmm1 = [65535,0,65535,65535,65535,65535,65535,65535]
; CHECK0-NEXT: pand %xmm1, %xmm0
; CHECK0-NEXT: pslld $16, %xmm2
; CHECK0-NEXT: pandn %xmm2, %xmm1
; CHECK0-NEXT: por %xmm1, %xmm0
; CHECK0-NEXT: retq
%ret = load atomic <2 x half>, ptr %x acquire, align 4
ret <2 x half> %ret
}

define <2 x bfloat> @atomic_vec2_bfloat(ptr %x) {
; CHECK3-LABEL: atomic_vec2_bfloat:
; CHECK3: ## %bb.0:
; CHECK3-NEXT: movl (%rdi), %eax
; CHECK3-NEXT: movd %eax, %xmm1
; CHECK3-NEXT: shrl $16, %eax
; CHECK3-NEXT: movdqa {{.*#+}} xmm0 = [65535,0,65535,65535,65535,65535,65535,65535]
; CHECK3-NEXT: pand %xmm0, %xmm1
; CHECK3-NEXT: pinsrw $0, %eax, %xmm2
; CHECK3-NEXT: pslld $16, %xmm2
; CHECK3-NEXT: pandn %xmm2, %xmm0
; CHECK3-NEXT: por %xmm1, %xmm0
; CHECK3-NEXT: retq
;
; CHECK0-LABEL: atomic_vec2_bfloat:
; CHECK0: ## %bb.0:
; CHECK0-NEXT: movl (%rdi), %eax
; CHECK0-NEXT: movl %eax, %ecx
; CHECK0-NEXT: shrl $16, %ecx
; CHECK0-NEXT: ## kill: def $cx killed $cx killed $ecx
; CHECK0-NEXT: movd %eax, %xmm0
; CHECK0-NEXT: movaps {{.*#+}} xmm1 = [65535,0,65535,65535,65535,65535,65535,65535]
; CHECK0-NEXT: pand %xmm1, %xmm0
; CHECK0-NEXT: ## implicit-def: $eax
; CHECK0-NEXT: movw %cx, %ax
; CHECK0-NEXT: ## implicit-def: $xmm2
; CHECK0-NEXT: pinsrw $0, %eax, %xmm2
; CHECK0-NEXT: pslld $16, %xmm2
; CHECK0-NEXT: pandn %xmm2, %xmm1
; CHECK0-NEXT: por %xmm1, %xmm0
; CHECK0-NEXT: retq
%ret = load atomic <2 x bfloat>, ptr %x acquire, align 4
ret <2 x bfloat> %ret
}

define <1 x ptr> @atomic_vec1_ptr(ptr %x) nounwind {
; CHECK3-LABEL: atomic_vec1_ptr:
; CHECK3: ## %bb.0:
Expand Down Expand Up @@ -376,6 +446,107 @@ define <4 x i16> @atomic_vec4_i16(ptr %x) nounwind {
ret <4 x i16> %ret
}

define <4 x half> @atomic_vec4_half(ptr %x) nounwind {
; CHECK3-LABEL: atomic_vec4_half:
; CHECK3: ## %bb.0:
; CHECK3-NEXT: movq (%rdi), %rax
; CHECK3-NEXT: movl %eax, %ecx
; CHECK3-NEXT: shrl $16, %ecx
; CHECK3-NEXT: pinsrw $0, %ecx, %xmm1
; CHECK3-NEXT: movq %rax, %rcx
; CHECK3-NEXT: shrq $32, %rcx
; CHECK3-NEXT: pinsrw $0, %ecx, %xmm2
; CHECK3-NEXT: movq %rax, %xmm0
; CHECK3-NEXT: shrq $48, %rax
; CHECK3-NEXT: pinsrw $0, %eax, %xmm3
; CHECK3-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm3[0],xmm2[1],xmm3[1],xmm2[2],xmm3[2],xmm2[3],xmm3[3]
; CHECK3-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; CHECK3-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
; CHECK3-NEXT: retq
;
; CHECK0-LABEL: atomic_vec4_half:
; CHECK0: ## %bb.0:
; CHECK0-NEXT: movq (%rdi), %rax
; CHECK0-NEXT: movl %eax, %ecx
; CHECK0-NEXT: shrl $16, %ecx
; CHECK0-NEXT: movw %cx, %dx
; CHECK0-NEXT: ## implicit-def: $ecx
; CHECK0-NEXT: movw %dx, %cx
; CHECK0-NEXT: ## implicit-def: $xmm2
; CHECK0-NEXT: pinsrw $0, %ecx, %xmm2
; CHECK0-NEXT: movq %rax, %rcx
; CHECK0-NEXT: shrq $32, %rcx
; CHECK0-NEXT: movw %cx, %dx
; CHECK0-NEXT: ## implicit-def: $ecx
; CHECK0-NEXT: movw %dx, %cx
; CHECK0-NEXT: ## implicit-def: $xmm1
; CHECK0-NEXT: pinsrw $0, %ecx, %xmm1
; CHECK0-NEXT: movq %rax, %rcx
; CHECK0-NEXT: shrq $48, %rcx
; CHECK0-NEXT: movw %cx, %dx
; CHECK0-NEXT: ## implicit-def: $ecx
; CHECK0-NEXT: movw %dx, %cx
; CHECK0-NEXT: ## implicit-def: $xmm3
; CHECK0-NEXT: pinsrw $0, %ecx, %xmm3
; CHECK0-NEXT: movq %rax, %xmm0
; CHECK0-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm3[0],xmm1[1],xmm3[1],xmm1[2],xmm3[2],xmm1[3],xmm3[3]
; CHECK0-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
; CHECK0-NEXT: unpcklps {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; CHECK0-NEXT: retq
%ret = load atomic <4 x half>, ptr %x acquire, align 8
ret <4 x half> %ret
}

define <4 x bfloat> @atomic_vec4_bfloat(ptr %x) nounwind {
; CHECK3-LABEL: atomic_vec4_bfloat:
; CHECK3: ## %bb.0:
; CHECK3-NEXT: movq (%rdi), %rax
; CHECK3-NEXT: movq %rax, %xmm0
; CHECK3-NEXT: movl %eax, %ecx
; CHECK3-NEXT: shrl $16, %ecx
; CHECK3-NEXT: movq %rax, %rdx
; CHECK3-NEXT: shrq $32, %rdx
; CHECK3-NEXT: shrq $48, %rax
; CHECK3-NEXT: pinsrw $0, %eax, %xmm1
; CHECK3-NEXT: pinsrw $0, %edx, %xmm2
; CHECK3-NEXT: punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
; CHECK3-NEXT: pinsrw $0, %ecx, %xmm1
; CHECK3-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3]
; CHECK3-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1]
; CHECK3-NEXT: retq
;
; CHECK0-LABEL: atomic_vec4_bfloat:
; CHECK0: ## %bb.0:
; CHECK0-NEXT: movq (%rdi), %rax
; CHECK0-NEXT: movq %rax, %xmm0
; CHECK0-NEXT: movl %eax, %ecx
; CHECK0-NEXT: shrl $16, %ecx
; CHECK0-NEXT: ## kill: def $cx killed $cx killed $ecx
; CHECK0-NEXT: movq %rax, %rdx
; CHECK0-NEXT: shrq $32, %rdx
; CHECK0-NEXT: ## kill: def $dx killed $dx killed $rdx
; CHECK0-NEXT: shrq $48, %rax
; CHECK0-NEXT: movw %ax, %si
; CHECK0-NEXT: ## implicit-def: $eax
; CHECK0-NEXT: movw %si, %ax
; CHECK0-NEXT: ## implicit-def: $xmm2
; CHECK0-NEXT: pinsrw $0, %eax, %xmm2
; CHECK0-NEXT: ## implicit-def: $eax
; CHECK0-NEXT: movw %dx, %ax
; CHECK0-NEXT: ## implicit-def: $xmm1
; CHECK0-NEXT: pinsrw $0, %eax, %xmm1
; CHECK0-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm2[0],xmm1[1],xmm2[1],xmm1[2],xmm2[2],xmm1[3],xmm2[3]
; CHECK0-NEXT: ## implicit-def: $eax
; CHECK0-NEXT: movw %cx, %ax
; CHECK0-NEXT: ## implicit-def: $xmm2
; CHECK0-NEXT: pinsrw $0, %eax, %xmm2
; CHECK0-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3]
; CHECK0-NEXT: unpcklps {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; CHECK0-NEXT: retq
%ret = load atomic <4 x bfloat>, ptr %x acquire, align 8
ret <4 x bfloat> %ret
}

define <4 x float> @atomic_vec4_float_align(ptr %x) nounwind {
; CHECK-LABEL: atomic_vec4_float_align:
; CHECK: ## %bb.0:
Expand Down
Loading