Skip to content

Commit a761e26

Browse files
authored
[RISCV] Allow non-loop invariant steps in RISCVGatherScatterLowering (#122244)
The motivation for this is to allow us to match strided accesses that are emitted from the loop vectorizer with EVL tail folding (see #122232) In these loops the step isn't loop invariant and is based off of @llvm.experimental.get.vector.length. We can relax this as long as we make sure to construct the updates after the definition inside the loop, instead of the preheader. I presume the restriction was previously added so that the step would dominate the insertion point in the preheader. I can't think of why it wouldn't be safe to calculate it in the loop otherwise.
1 parent a5bd01e commit a761e26

File tree

3 files changed

+288
-30
lines changed

3 files changed

+288
-30
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

+17-6
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
211211
assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
212212
"Expected one operand of phi to be Inc");
213213

214-
// Only proceed if the step is loop invariant.
215-
if (!L->isLoopInvariant(Step))
216-
return false;
217-
218214
// Step should be a splat.
219215
Step = getSplatValue(Step);
220216
if (!Step)
@@ -298,6 +294,7 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
298294
BasePtr->getIncomingBlock(StartBlock)->getTerminator());
299295
Builder.SetCurrentDebugLocation(DebugLoc());
300296

297+
// TODO: Share this switch with matchStridedStart?
301298
switch (BO->getOpcode()) {
302299
default:
303300
llvm_unreachable("Unexpected opcode!");
@@ -310,18 +307,32 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
310307
}
311308
case Instruction::Mul: {
312309
Start = Builder.CreateMul(Start, SplatOp, "start");
313-
Step = Builder.CreateMul(Step, SplatOp, "step");
314310
Stride = Builder.CreateMul(Stride, SplatOp, "stride");
315311
break;
316312
}
317313
case Instruction::Shl: {
318314
Start = Builder.CreateShl(Start, SplatOp, "start");
319-
Step = Builder.CreateShl(Step, SplatOp, "step");
320315
Stride = Builder.CreateShl(Stride, SplatOp, "stride");
321316
break;
322317
}
323318
}
324319

320+
// If the Step was defined inside the loop, adjust it before its definition
321+
// instead of in the preheader.
322+
if (auto *StepI = dyn_cast<Instruction>(Step); StepI && L->contains(StepI))
323+
Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef());
324+
325+
switch (BO->getOpcode()) {
326+
default:
327+
break;
328+
case Instruction::Mul:
329+
Step = Builder.CreateMul(Step, SplatOp, "step");
330+
break;
331+
case Instruction::Shl:
332+
Step = Builder.CreateShl(Step, SplatOp, "step");
333+
break;
334+
}
335+
325336
Inc->setOperand(StepIndex, Step);
326337
BasePtr->setIncomingValue(StartBlock, Start);
327338
return true;

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-store.ll

+2-2
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,8 @@ for.cond.cleanup: ; preds = %vector.body
320320
define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) {
321321
; CHECK-LABEL: @gather_unknown_pow2(
322322
; CHECK-NEXT: entry:
323-
; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]]
324-
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT]]
323+
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT:%.*]]
324+
; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT]]
325325
; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[STRIDE]], 4
326326
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
327327
; CHECK: vector.body:

0 commit comments

Comments
 (0)