Skip to content

Commit 1ff5f32

Browse files
joaosaffranjoaosaffran
and
joaosaffran
authored
[DXIL] Add support for root signature flag element in DXContainer (#123147)
Adding support for Root Signature Flags Element extraction and writing to DXContainer. - Adding an analysis to deal with RootSignature metadata definition - Adding validation for Flag - writing RootSignature blob into DXIL Closes: [126632](#126632) --------- Co-authored-by: joaosaffran <[email protected]>
1 parent e3cab30 commit 1ff5f32

24 files changed

+563
-53
lines changed

llvm/include/llvm/BinaryFormat/DXContainer.h

+3-13
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
#define LLVM_BINARYFORMAT_DXCONTAINER_H
1515

1616
#include "llvm/ADT/StringRef.h"
17-
#include "llvm/Support/BinaryStreamError.h"
18-
#include "llvm/Support/Error.h"
1917
#include "llvm/Support/SwapByteOrder.h"
2018
#include "llvm/TargetParser/Triple.h"
2119

@@ -550,18 +548,10 @@ static_assert(sizeof(ProgramSignatureElement) == 32,
550548

551549
struct RootSignatureValidations {
552550

553-
static Expected<uint32_t> validateRootFlag(uint32_t Flags) {
554-
if ((Flags & ~0x80000fff) != 0)
555-
return llvm::make_error<BinaryStreamError>("Invalid Root Signature flag");
556-
return Flags;
557-
}
558-
559-
static Expected<uint32_t> validateVersion(uint32_t Version) {
560-
if (Version == 1 || Version == 2)
561-
return Version;
551+
static bool isValidRootFlag(uint32_t Flags) { return (Flags & ~0xfff) == 0; }
562552

563-
return llvm::make_error<BinaryStreamError>(
564-
"Invalid Root Signature Version");
553+
static bool isValidVersion(uint32_t Version) {
554+
return (Version == 1 || Version == 2);
565555
}
566556
};
567557

llvm/include/llvm/MC/DXContainerRootSignature.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ namespace llvm {
1414
class raw_ostream;
1515

1616
namespace mcdxbc {
17-
struct RootSignatureHeader {
17+
struct RootSignatureDesc {
1818
uint32_t Version = 2;
1919
uint32_t NumParameters = 0;
2020
uint32_t RootParametersOffset = 0;
2121
uint32_t NumStaticSamplers = 0;
2222
uint32_t StaticSamplersOffset = 0;
2323
uint32_t Flags = 0;
2424

25-
void write(raw_ostream &OS);
25+
void write(raw_ostream &OS) const;
2626
};
2727
} // namespace mcdxbc
2828
} // namespace llvm

llvm/include/llvm/ObjectYAML/DXContainerYAML.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ struct ShaderHash {
7474
};
7575

7676
#define ROOT_ELEMENT_FLAG(Num, Val) bool Val = false;
77-
struct RootSignatureDesc {
78-
RootSignatureDesc() = default;
79-
RootSignatureDesc(const object::DirectX::RootSignature &Data);
77+
struct RootSignatureYamlDesc {
78+
RootSignatureYamlDesc() = default;
79+
RootSignatureYamlDesc(const object::DirectX::RootSignature &Data);
8080

8181
uint32_t Version;
8282
uint32_t NumParameters;
@@ -176,7 +176,7 @@ struct Part {
176176
std::optional<ShaderHash> Hash;
177177
std::optional<PSVInfo> Info;
178178
std::optional<DXContainerYAML::Signature> Signature;
179-
std::optional<DXContainerYAML::RootSignatureDesc> RootSignature;
179+
std::optional<DXContainerYAML::RootSignatureYamlDesc> RootSignature;
180180
};
181181

182182
struct Object {
@@ -259,9 +259,9 @@ template <> struct MappingTraits<DXContainerYAML::Signature> {
259259
static void mapping(IO &IO, llvm::DXContainerYAML::Signature &El);
260260
};
261261

262-
template <> struct MappingTraits<DXContainerYAML::RootSignatureDesc> {
262+
template <> struct MappingTraits<DXContainerYAML::RootSignatureYamlDesc> {
263263
static void mapping(IO &IO,
264-
DXContainerYAML::RootSignatureDesc &RootSignature);
264+
DXContainerYAML::RootSignatureYamlDesc &RootSignature);
265265
};
266266

267267
} // namespace yaml

llvm/lib/MC/DXContainerRootSignature.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
using namespace llvm;
1313
using namespace llvm::mcdxbc;
1414

15-
void RootSignatureHeader::write(raw_ostream &OS) {
15+
void RootSignatureDesc::write(raw_ostream &OS) const {
1616

1717
support::endian::write(OS, Version, llvm::endianness::little);
1818
support::endian::write(OS, NumParameters, llvm::endianness::little);

llvm/lib/Object/DXContainer.cpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ static Error parseFailed(const Twine &Msg) {
2020
return make_error<GenericBinaryError>(Msg.str(), object_error::parse_failed);
2121
}
2222

23+
static Error validationFailed(const Twine &Msg) {
24+
return make_error<StringError>(Msg.str(), inconvertibleErrorCode());
25+
}
26+
2327
template <typename T>
2428
static Error readStruct(StringRef Buffer, const char *Src, T &Struct) {
2529
// Don't read before the beginning or past the end of the file
@@ -254,11 +258,10 @@ Error DirectX::RootSignature::parse(StringRef Data) {
254258
support::endian::read<uint32_t, llvm::endianness::little>(Current);
255259
Current += sizeof(uint32_t);
256260

257-
Expected<uint32_t> MaybeVersion =
258-
dxbc::RootSignatureValidations::validateVersion(VValue);
259-
if (Error E = MaybeVersion.takeError())
260-
return E;
261-
Version = MaybeVersion.get();
261+
if (!dxbc::RootSignatureValidations::isValidVersion(VValue))
262+
return validationFailed("unsupported root signature version read: " +
263+
llvm::Twine(VValue));
264+
Version = VValue;
262265

263266
NumParameters =
264267
support::endian::read<uint32_t, llvm::endianness::little>(Current);
@@ -280,11 +283,10 @@ Error DirectX::RootSignature::parse(StringRef Data) {
280283
support::endian::read<uint32_t, llvm::endianness::little>(Current);
281284
Current += sizeof(uint32_t);
282285

283-
Expected<uint32_t> MaybeFlag =
284-
dxbc::RootSignatureValidations::validateRootFlag(FValue);
285-
if (Error E = MaybeFlag.takeError())
286-
return E;
287-
Flags = MaybeFlag.get();
286+
if (!dxbc::RootSignatureValidations::isValidRootFlag(FValue))
287+
return validationFailed("unsupported root signature flag value read: " +
288+
llvm::Twine(FValue));
289+
Flags = FValue;
288290

289291
return Error::success();
290292
}

llvm/lib/ObjectYAML/DXContainerEmitter.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,15 @@ void DXContainerWriter::writeParts(raw_ostream &OS) {
266266
if (!P.RootSignature.has_value())
267267
continue;
268268

269-
mcdxbc::RootSignatureHeader Header;
270-
Header.Flags = P.RootSignature->getEncodedFlags();
271-
Header.Version = P.RootSignature->Version;
272-
Header.NumParameters = P.RootSignature->NumParameters;
273-
Header.RootParametersOffset = P.RootSignature->RootParametersOffset;
274-
Header.NumStaticSamplers = P.RootSignature->NumStaticSamplers;
275-
Header.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
276-
277-
Header.write(OS);
269+
mcdxbc::RootSignatureDesc RS;
270+
RS.Flags = P.RootSignature->getEncodedFlags();
271+
RS.Version = P.RootSignature->Version;
272+
RS.NumParameters = P.RootSignature->NumParameters;
273+
RS.RootParametersOffset = P.RootSignature->RootParametersOffset;
274+
RS.NumStaticSamplers = P.RootSignature->NumStaticSamplers;
275+
RS.StaticSamplersOffset = P.RootSignature->StaticSamplersOffset;
276+
277+
RS.write(OS);
278278
break;
279279
}
280280
uint64_t BytesWritten = OS.tell() - DataStart;

llvm/lib/ObjectYAML/DXContainerYAML.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ DXContainerYAML::ShaderFeatureFlags::ShaderFeatureFlags(uint64_t FlagData) {
2929
#include "llvm/BinaryFormat/DXContainerConstants.def"
3030
}
3131

32-
DXContainerYAML::RootSignatureDesc::RootSignatureDesc(
32+
DXContainerYAML::RootSignatureYamlDesc::RootSignatureYamlDesc(
3333
const object::DirectX::RootSignature &Data)
3434
: Version(Data.getVersion()), NumParameters(Data.getNumParameters()),
3535
RootParametersOffset(Data.getRootParametersOffset()),
@@ -41,7 +41,7 @@ DXContainerYAML::RootSignatureDesc::RootSignatureDesc(
4141
#include "llvm/BinaryFormat/DXContainerConstants.def"
4242
}
4343

44-
uint32_t DXContainerYAML::RootSignatureDesc::getEncodedFlags() {
44+
uint32_t DXContainerYAML::RootSignatureYamlDesc::getEncodedFlags() {
4545
uint64_t Flag = 0;
4646
#define ROOT_ELEMENT_FLAG(Num, Val) \
4747
if (Val) \
@@ -209,8 +209,8 @@ void MappingTraits<DXContainerYAML::Signature>::mapping(
209209
IO.mapRequired("Parameters", S.Parameters);
210210
}
211211

212-
void MappingTraits<DXContainerYAML::RootSignatureDesc>::mapping(
213-
IO &IO, DXContainerYAML::RootSignatureDesc &S) {
212+
void MappingTraits<DXContainerYAML::RootSignatureYamlDesc>::mapping(
213+
IO &IO, DXContainerYAML::RootSignatureYamlDesc &S) {
214214
IO.mapRequired("Version", S.Version);
215215
IO.mapRequired("NumParameters", S.NumParameters);
216216
IO.mapRequired("RootParametersOffset", S.RootParametersOffset);

llvm/lib/Target/DirectX/CMakeLists.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ add_llvm_target(DirectXCodeGen
3333
DXILResourceAccess.cpp
3434
DXILShaderFlags.cpp
3535
DXILTranslateMetadata.cpp
36-
36+
DXILRootSignature.cpp
37+
3738
LINK_COMPONENTS
3839
Analysis
3940
AsmPrinter

llvm/lib/Target/DirectX/DXContainerGlobals.cpp

+36
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "DXILRootSignature.h"
1314
#include "DXILShaderFlags.h"
1415
#include "DirectX.h"
1516
#include "llvm/ADT/SmallVector.h"
@@ -25,7 +26,9 @@
2526
#include "llvm/MC/DXContainerPSVInfo.h"
2627
#include "llvm/Pass.h"
2728
#include "llvm/Support/MD5.h"
29+
#include "llvm/TargetParser/Triple.h"
2830
#include "llvm/Transforms/Utils/ModuleUtils.h"
31+
#include <optional>
2932

3033
using namespace llvm;
3134
using namespace llvm::dxil;
@@ -41,6 +44,7 @@ class DXContainerGlobals : public llvm::ModulePass {
4144
GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name,
4245
StringRef SectionName);
4346
void addSignature(Module &M, SmallVector<GlobalValue *> &Globals);
47+
void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals);
4448
void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV);
4549
void addPipelineStateValidationInfo(Module &M,
4650
SmallVector<GlobalValue *> &Globals);
@@ -60,6 +64,7 @@ class DXContainerGlobals : public llvm::ModulePass {
6064
void getAnalysisUsage(AnalysisUsage &AU) const override {
6165
AU.setPreservesAll();
6266
AU.addRequired<ShaderFlagsAnalysisWrapper>();
67+
AU.addRequired<RootSignatureAnalysisWrapper>();
6368
AU.addRequired<DXILMetadataAnalysisWrapperPass>();
6469
AU.addRequired<DXILResourceTypeWrapperPass>();
6570
AU.addRequired<DXILResourceBindingWrapperPass>();
@@ -73,6 +78,7 @@ bool DXContainerGlobals::runOnModule(Module &M) {
7378
Globals.push_back(getFeatureFlags(M));
7479
Globals.push_back(computeShaderHash(M));
7580
addSignature(M, Globals);
81+
addRootSignature(M, Globals);
7682
addPipelineStateValidationInfo(M, Globals);
7783
appendToCompilerUsed(M, Globals);
7884
return true;
@@ -144,6 +150,36 @@ void DXContainerGlobals::addSignature(Module &M,
144150
Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1"));
145151
}
146152

153+
void DXContainerGlobals::addRootSignature(Module &M,
154+
SmallVector<GlobalValue *> &Globals) {
155+
156+
dxil::ModuleMetadataInfo &MMI =
157+
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
158+
159+
// Root Signature in Library don't compile to DXContainer.
160+
if (MMI.ShaderProfile == llvm::Triple::Library)
161+
return;
162+
163+
assert(MMI.EntryPropertyVec.size() == 1);
164+
165+
auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>();
166+
const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry;
167+
const auto &FuncRs = RSA.find(EntryFunction);
168+
169+
if (FuncRs == RSA.end())
170+
return;
171+
172+
const RootSignatureDesc &RS = FuncRs->second;
173+
SmallString<256> Data;
174+
raw_svector_ostream OS(Data);
175+
176+
RS.write(OS);
177+
178+
Constant *Constant =
179+
ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false);
180+
Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0"));
181+
}
182+
147183
void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) {
148184
const DXILBindingMap &DBM =
149185
getAnalysis<DXILResourceBindingWrapperPass>().getBindingMap();

0 commit comments

Comments
 (0)