Skip to content

Commit 8b160fb

Browse files
jaladreipsigcbot
authored andcommitted
Improve struct support in EmitPass
* Emit LifetimeStart on insertvalue to improve codegen. * Implement emitting select for structs * Improve struct symbol elision on insertvalue
1 parent 9a6a542 commit 8b160fb

File tree

4 files changed

+87
-2
lines changed

4 files changed

+87
-2
lines changed

IGC/Compiler/CISACodeGen/CShader.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2118,7 +2118,8 @@ CVariable* CShader::GetStructVariable(llvm::Value* v)
21182118
isa<InsertValueInst>(v) ||
21192119
isa<CallInst>(v) ||
21202120
isa<Argument>(v) ||
2121-
isa<PHINode>(v),
2121+
isa<PHINode>(v) ||
2122+
isa<SelectInst>(v),
21222123
"Invalid instruction using struct type!");
21232124

21242125
if (isa<InsertValueInst>(v))
@@ -2176,6 +2177,29 @@ CVariable* CShader::GetStructVariable(llvm::Value* v)
21762177
}
21772178
}
21782179
}
2180+
else if (auto* SI = dyn_cast<SelectInst>(v))
2181+
{
2182+
if (IGC_IS_FLAG_ENABLED(EnableDeSSA) && m_deSSA)
2183+
{
2184+
e_alignment pAlign = EALIGN_GRF;
2185+
Value* rVal = m_deSSA->getRootValue(v, &pAlign);
2186+
2187+
// If a struct type is coalesced with another non-struct type,
2188+
// need to call createAliasIfNeeded(). Otherwise, all coalesced
2189+
// structs are of the same type (Byte).
2190+
if (rVal && !rVal->getType()->isStructTy()) {
2191+
CVariable* rootV = GetSymbol(rVal);
2192+
return createAliasIfNeeded(v, rootV);
2193+
}
2194+
2195+
v = rVal ? rVal : v;
2196+
auto it = symbolMapping.find(v);
2197+
if (it != symbolMapping.end())
2198+
{
2199+
return it->second;
2200+
}
2201+
}
2202+
}
21792203
else if (isa<CallInst>(v) || isa<Argument>(v))
21802204
{
21812205
// For now, special handling of bitcasttostruct intrinsic

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2993,6 +2993,7 @@ void EmitPass::EmitInsertValueToStruct(InsertValueInst* inst)
29932993
CVariable* SrcV = GetSymbol(src0);
29942994
if (DstV != SrcV)
29952995
{
2996+
m_encoder->Lifetime(VISAVarLifetime::LIFETIME_START, DstV);
29962997
// Only copy SrcV's elements that have been initialized already.
29972998
// For the example, DstV = %13, SrcV = %9;
29982999
// 'toBeCopied' will be field {0, 1}, not {0,1,2,3}.
@@ -3278,6 +3279,60 @@ void EmitPass::EmitExtractValueFromLayoutStruct(ExtractValueInst* EVI)
32783279
}
32793280
}
32803281

3282+
void EmitPass::EmitSelectStruct(SelectInst* SI)
3283+
{
3284+
StructType* sTy = dyn_cast<StructType>(SI->getType());
3285+
IGC_ASSERT_MESSAGE(sTy, "This method is only for structs");
3286+
3287+
const uint32_t nLanes = numLanes(m_currShader->m_SIMDSize);
3288+
3289+
auto* srcTrueV = SI->getTrueValue();
3290+
auto* srcFalseV = SI->getFalseValue();
3291+
3292+
auto* srcTrueS = GetSymbol(srcTrueV);
3293+
auto* srcFalseS = GetSymbol(srcFalseV);
3294+
auto* destS = GetSymbol(SI);
3295+
auto* cond = GetSymbol(SI->getCondition());
3296+
3297+
if (srcTrueS != destS && srcFalseS != destS)
3298+
// if we aren't coalescing with any of the operands, emit lifetime start for the struct
3299+
m_encoder->Lifetime(VISAVarLifetime::LIFETIME_START, destS);
3300+
3301+
auto iterator = {
3302+
std::make_tuple(srcTrueV, srcTrueS, false),
3303+
std::make_tuple(srcFalseV, srcFalseS, true)
3304+
};
3305+
3306+
for (auto [srcV, srcS, inv] : iterator)
3307+
{
3308+
// For now, do not support uniform dst and non-uniform src
3309+
IGC_ASSERT_MESSAGE(!srcS->IsUniform() || srcS->IsUniform() == destS->IsUniform(),
3310+
"Can't select non-uniform value into a uniform struct!");
3311+
3312+
if (srcS == destS)
3313+
continue;
3314+
3315+
SmallVector<std::vector<unsigned>> toBeCopied;
3316+
getAllDefinedMembers(srcV, toBeCopied);
3317+
for (const auto& II : toBeCopied)
3318+
{
3319+
3320+
Type* ty;
3321+
uint32_t byteOffset;
3322+
getStructMemberByteOffsetAndType_1(m_DL, sTy, II, ty, byteOffset);
3323+
3324+
uint32_t d_off =
3325+
byteOffset * (destS->IsUniform() ? 1 : nLanes);
3326+
uint32_t s_off =
3327+
byteOffset * (srcS->IsUniform() ? 1 : nLanes);
3328+
3329+
m_encoder->SetPredicate(cond);
3330+
m_encoder->SetInversePredicate(inv);
3331+
emitMayUnalignedVectorCopy(destS, d_off, srcS, s_off, ty);
3332+
}
3333+
}
3334+
}
3335+
32813336
void EmitPass::EmitAddPair(GenIntrinsicInst* GII, const SSource Sources[4], const DstModifier& DstMod) {
32823337
auto [L, H] = getPairOutput(GII);
32833338
CVariable* Lo = L ? GetSymbol(L) : nullptr;
@@ -5135,7 +5190,6 @@ void EmitPass::Select(const SSource sources[3], const DstModifier& modifier)
51355190

51365191
m_encoder->Select(flag, m_destination, src0, src1);
51375192
m_encoder->Push();
5138-
51395193
}
51405194

51415195
void EmitPass::PredAdd(const SSource& pred, bool invert, const SSource sources[2], const DstModifier& modifier)
@@ -11706,6 +11760,9 @@ void EmitPass::EmitNoModifier(llvm::Instruction* inst)
1170611760
case Instruction::ExtractValue:
1170711761
EmitExtractValueFromStruct(cast<ExtractValueInst>(inst));
1170811762
break;
11763+
case Instruction::Select:
11764+
EmitSelectStruct(cast<SelectInst>(inst));
11765+
break;
1170911766
case Instruction::Unreachable:
1171011767
break;
1171111768
default:

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class EmitPass : public llvm::FunctionPass
158158
void EmitExtractValueFromStruct(llvm::ExtractValueInst* EI);
159159
void EmitInsertValueToLayoutStruct(llvm::InsertValueInst* IVI);
160160
void EmitExtractValueFromLayoutStruct(llvm::ExtractValueInst* EVI);
161+
void EmitSelectStruct(llvm::SelectInst* SI);
161162
void emitVectorCopyToAOS(uint32_t AOSBytes,
162163
CVariable* Dst, CVariable* Src, uint32_t nElts,
163164
uint32_t DstSubRegOffset = 0, uint32_t SrcSubRegOffset = 0) {

IGC/Compiler/CISACodeGen/PatternMatchPass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3811,6 +3811,9 @@ namespace IGC
38113811

38123812
bool CodeGenPatternMatch::MatchSelectModifier(llvm::SelectInst& I)
38133813
{
3814+
if (I.getType()->isAggregateType())
3815+
return MatchSingleInstruction(I);
3816+
38143817
struct SelectPattern : Pattern
38153818
{
38163819
SSource sources[3];

0 commit comments

Comments
 (0)