Skip to content

Commit 1195122

Browse files
authored
[SM6.10] Implement FillMatrix Builtin (#8186)
- Implements the FillMatrix Builtin - Add support to map LinAlgMatrix types to the overload type (`dx.types.LinAlgMatrixC10M16N16U0S1` -> `mC10M16N16U0S1`) - Puts in placeholders for remaining DXIL op implementations so they can all be uploaded/merged without ordering or merge conflicts. Fixes #7895
1 parent c140840 commit 1195122

File tree

11 files changed

+439
-26
lines changed

11 files changed

+439
-26
lines changed

include/dxc/DXIL/DxilConstants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,6 +2500,7 @@ extern const char *kDxBreakFuncName;
25002500
extern const char *kDxBreakCondName;
25012501
extern const char *kDxBreakMDName;
25022502
extern const char *kDxIsHelperGlobalName;
2503+
extern const char *kDxLinAlgMatrixTypePrefix;
25032504

25042505
extern const char *kHostLayoutTypePrefix;
25052506

include/dxc/DXIL/DxilUtil.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ bool IsHLSLObjectType(llvm::Type *Ty);
164164
bool IsHLSLRayQueryType(llvm::Type *Ty);
165165
llvm::Type *GetHLSLHitObjectType(llvm::Module *M);
166166
bool IsHLSLHitObjectType(llvm::Type *Ty);
167+
bool IsHLSLLinAlgMatrixType(llvm::Type *Ty);
168+
llvm::StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty);
167169
bool IsHLSLResourceDescType(llvm::Type *Ty);
168170
bool IsResourceSingleComponent(llvm::Type *Ty);
169171
uint8_t GetResourceComponentCount(llvm::Type *Ty);

lib/DXIL/DxilModule.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ const char *kDxBreakFuncName = "dx.break";
8282
const char *kDxBreakCondName = "dx.break.cond";
8383
const char *kDxBreakMDName = "dx.break.br";
8484
const char *kDxIsHelperGlobalName = "dx.ishelper";
85+
const char *kDxLinAlgMatrixTypePrefix = "dx.types.LinAlgMatrix";
8586

8687
const char *kHostLayoutTypePrefix = "hostlayout.";
8788

lib/DXIL/DxilOperations.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "dxc/DXIL/DxilConstants.h"
1414
#include "dxc/DXIL/DxilInstructions.h"
1515
#include "dxc/DXIL/DxilModule.h"
16+
#include "dxc/DXIL/DxilUtil.h"
1617
#include "dxc/Support/Global.h"
1718

1819
#include "llvm/ADT/ArrayRef.h"
@@ -3173,6 +3174,9 @@ StringRef OP::GetTypeName(Type *Ty, SmallVectorImpl<char> &Storage) {
31733174
return ST->getStructName();
31743175
} else if (TypeSlot == TS_Object) {
31753176
StructType *ST = cast<StructType>(Ty);
3177+
if (dxilutil::IsHLSLLinAlgMatrixType(Ty))
3178+
return (Twine("m") + Twine(dxilutil::GetHLSLLinAlgMatrixTypeMangling(ST)))
3179+
.toStringRef(Storage);
31763180
return ST->getStructName();
31773181
} else if (TypeSlot == TS_Vector) {
31783182
VectorType *VecTy = cast<VectorType>(Ty);

lib/DXIL/DxilUtil.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
///////////////////////////////////////////////////////////////////////////////
1111

1212
#include "dxc/DXIL/DxilUtil.h"
13+
#include "dxc/DXIL/DxilConstants.h"
1314
#include "dxc/DXIL/DxilInstructions.h"
1415
#include "dxc/DXIL/DxilModule.h"
1516
#include "dxc/DXIL/DxilOperations.h"
@@ -18,9 +19,11 @@
1819
#include "dxc/Support/Global.h"
1920

2021
#include "llvm/ADT/StringExtras.h"
22+
#include "llvm/ADT/StringRef.h"
2123
#include "llvm/ADT/Twine.h"
2224
#include "llvm/IR/Constants.h"
2325
#include "llvm/IR/DIBuilder.h"
26+
#include "llvm/IR/DerivedTypes.h"
2427
#include "llvm/IR/DiagnosticInfo.h"
2528
#include "llvm/IR/DiagnosticPrinter.h"
2629
#include "llvm/IR/GetElementPtrTypeIterator.h"
@@ -577,6 +580,9 @@ bool IsHLSLObjectType(llvm::Type *Ty) {
577580

578581
if (IsHLSLHitObjectType(Ty))
579582
return true;
583+
584+
if (IsHLSLLinAlgMatrixType(Ty))
585+
return true;
580586
}
581587
return false;
582588
}
@@ -612,6 +618,19 @@ bool IsHLSLHitObjectType(llvm::Type *Ty) {
612618
return ST->getName() == "dx.types.HitObject";
613619
}
614620

621+
bool IsHLSLLinAlgMatrixType(llvm::Type *Ty) {
622+
llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty);
623+
if (!ST)
624+
return false;
625+
if (!ST->hasName())
626+
return false;
627+
return ST->getName().startswith(DXIL::kDxLinAlgMatrixTypePrefix);
628+
}
629+
630+
StringRef GetHLSLLinAlgMatrixTypeMangling(llvm::StructType *Ty) {
631+
return Ty->getStructName().substr(strlen(DXIL::kDxLinAlgMatrixTypePrefix));
632+
}
633+
615634
bool IsHLSLResourceDescType(llvm::Type *Ty) {
616635
if (llvm::StructType *ST = dyn_cast<llvm::StructType>(Ty)) {
617636
if (!ST->hasName())

lib/HLSL/HLOperationLower.cpp

Lines changed: 156 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6908,6 +6908,131 @@ Value *TranslateVectorAccumulate(CallInst *CI, IntrinsicOp IOP,
69086908
{OpArg, InputVector, MatrixBuffer, MatrixOffset});
69096909
}
69106910

6911+
Value *TranslateLinAlgFillMatrix(CallInst *CI, IntrinsicOp IOP,
6912+
OP::OpCode OpCode,
6913+
HLOperationLowerHelper &Helper,
6914+
HLObjectOperationLowerHelper *ObjHelper,
6915+
bool &Translated) {
6916+
hlsl::OP *HlslOp = &Helper.hlslOP;
6917+
IRBuilder<> Builder(CI);
6918+
6919+
Value *MatrixPtr = CI->getArgOperand(1);
6920+
DXASSERT_NOMSG(isa<PointerType>(MatrixPtr->getType()));
6921+
Type *MatrixType = MatrixPtr->getType()->getPointerElementType();
6922+
Value *Scalar = CI->getArgOperand(2);
6923+
6924+
Constant *OpArg = HlslOp->GetU32Const((unsigned)OpCode);
6925+
Function *DxilFunc =
6926+
HlslOp->GetOpFunc(OpCode, {MatrixType, Scalar->getType()});
6927+
6928+
Value *Matrix = Builder.CreateCall(DxilFunc, {OpArg, Scalar});
6929+
Builder.CreateStore(Matrix, MatrixPtr);
6930+
6931+
return nullptr;
6932+
}
6933+
6934+
Value *TranslateLinAlgMatrixAccumStoreToDescriptor(
6935+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
6936+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
6937+
bool &Translated) {
6938+
DXASSERT(false, "Not implemented.");
6939+
return nullptr;
6940+
}
6941+
6942+
Value *TranslateLinAlgMatVecMul(CallInst *CI, IntrinsicOp IOP,
6943+
OP::OpCode OpCode,
6944+
HLOperationLowerHelper &Helper,
6945+
HLObjectOperationLowerHelper *ObjHelper,
6946+
bool &Translated) {
6947+
DXASSERT(false, "Not implemented.");
6948+
return nullptr;
6949+
}
6950+
6951+
Value *TranslateLinAlgMatVecMulAdd(CallInst *CI, IntrinsicOp IOP,
6952+
OP::OpCode OpCode,
6953+
HLOperationLowerHelper &Helper,
6954+
HLObjectOperationLowerHelper *ObjHelper,
6955+
bool &Translated) {
6956+
DXASSERT(false, "Not implemented.");
6957+
return nullptr;
6958+
}
6959+
6960+
Value *TranslateLinAlgMatrixLoadFromDescriptor(
6961+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
6962+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
6963+
bool &Translated) {
6964+
DXASSERT(false, "Not implemented.");
6965+
return nullptr;
6966+
}
6967+
6968+
Value *TranslateLinAlgMatrixOuterProduct(
6969+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
6970+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
6971+
bool &Translated) {
6972+
DXASSERT(false, "Not implemented.");
6973+
return nullptr;
6974+
}
6975+
6976+
Value *TranslateLinAlgMatrixAccumulate(CallInst *CI, IntrinsicOp IOP,
6977+
OP::OpCode OpCode,
6978+
HLOperationLowerHelper &Helper,
6979+
HLObjectOperationLowerHelper *ObjHelper,
6980+
bool &Translated) {
6981+
DXASSERT(false, "Not implemented.");
6982+
return nullptr;
6983+
}
6984+
6985+
Value *TranslateLinAlgMatrixGetCoordinate(
6986+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
6987+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
6988+
bool &Translated) {
6989+
DXASSERT(false, "Not implemented.");
6990+
return nullptr;
6991+
}
6992+
6993+
Value *TranslateLinAlgMatrixGetElement(CallInst *CI, IntrinsicOp IOP,
6994+
OP::OpCode OpCode,
6995+
HLOperationLowerHelper &Helper,
6996+
HLObjectOperationLowerHelper *ObjHelper,
6997+
bool &Translated) {
6998+
DXASSERT(false, "Not implemented.");
6999+
return nullptr;
7000+
}
7001+
7002+
Value *TranslateLinAlgMatrixSetElement(CallInst *CI, IntrinsicOp IOP,
7003+
OP::OpCode OpCode,
7004+
HLOperationLowerHelper &Helper,
7005+
HLObjectOperationLowerHelper *ObjHelper,
7006+
bool &Translated) {
7007+
DXASSERT(false, "Not implemented.");
7008+
return nullptr;
7009+
}
7010+
7011+
Value *TranslateLinAlgMatrixMatrixMultiply(
7012+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
7013+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
7014+
bool &Translated) {
7015+
DXASSERT(false, "Not implemented.");
7016+
return nullptr;
7017+
}
7018+
7019+
Value *TranslateLinAlgMatrixMatrixMultiplyAccumulate(
7020+
CallInst *CI, IntrinsicOp IOP, OP::OpCode OpCode,
7021+
HLOperationLowerHelper &Helper, HLObjectOperationLowerHelper *ObjHelper,
7022+
bool &Translated) {
7023+
DXASSERT(false, "Not implemented.");
7024+
return nullptr;
7025+
}
7026+
7027+
Value *TranslateLinAlgCopyConvertMatrix(CallInst *CI, IntrinsicOp IOP,
7028+
OP::OpCode OpCode,
7029+
HLOperationLowerHelper &Helper,
7030+
HLObjectOperationLowerHelper *ObjHelper,
7031+
bool &Translated) {
7032+
DXASSERT(false, "Not implemented.");
7033+
return nullptr;
7034+
}
7035+
69117036
} // namespace
69127037

69137038
// Lower table.
@@ -7657,44 +7782,50 @@ constexpr IntrinsicLower gLowerTable[] = {
76577782
TranslateHitObjectTriangleObjectPositions,
76587783
DXIL::OpCode::HitObject_TriangleObjectPosition},
76597784

7660-
{IntrinsicOp::IOP___builtin_LinAlg_CopyConvertMatrix, EmptyLower,
7661-
DXIL::OpCode::LinAlgCopyConvertMatrix},
7662-
{IntrinsicOp::IOP___builtin_LinAlg_FillMatrix, EmptyLower,
7785+
{IntrinsicOp::IOP___builtin_LinAlg_CopyConvertMatrix,
7786+
TranslateLinAlgCopyConvertMatrix, DXIL::OpCode::LinAlgCopyConvertMatrix},
7787+
{IntrinsicOp::IOP___builtin_LinAlg_FillMatrix, TranslateLinAlgFillMatrix,
76637788
DXIL::OpCode::LinAlgFillMatrix},
7664-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixGetCoordinate, EmptyLower,
7789+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixGetCoordinate,
7790+
TranslateLinAlgMatrixGetCoordinate,
76657791
DXIL::OpCode::LinAlgMatrixGetCoordinate},
7666-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixGetElement, EmptyLower,
7667-
DXIL::OpCode::LinAlgMatrixGetElement},
7668-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLength, EmptyLower,
7792+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixGetElement,
7793+
TranslateLinAlgMatrixGetElement, DXIL::OpCode::LinAlgMatrixGetElement},
7794+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLength, TrivialUnaryOperation,
76697795
DXIL::OpCode::LinAlgMatrixLength},
7670-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromDescriptor, EmptyLower,
7796+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromDescriptor,
7797+
TranslateLinAlgMatrixLoadFromDescriptor,
76717798
DXIL::OpCode::LinAlgMatrixLoadFromDescriptor},
76727799
{IntrinsicOp::IOP___builtin_LinAlg_MatrixLoadFromMemory, EmptyLower,
76737800
DXIL::OpCode::LinAlgMatrixLoadFromMemory},
7674-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixSetElement, EmptyLower,
7675-
DXIL::OpCode::LinAlgMatrixSetElement},
7676-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToDescriptor, EmptyLower,
7801+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixSetElement,
7802+
TranslateLinAlgMatrixSetElement, DXIL::OpCode::LinAlgMatrixSetElement},
7803+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToDescriptor,
7804+
TranslateLinAlgMatrixAccumStoreToDescriptor,
76777805
DXIL::OpCode::LinAlgMatrixStoreToDescriptor},
76787806
{IntrinsicOp::IOP___builtin_LinAlg_MatrixStoreToMemory, EmptyLower,
76797807
DXIL::OpCode::LinAlgMatrixStoreToMemory},
7680-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulate, EmptyLower,
7681-
DXIL::OpCode::LinAlgMatrixAccumulate},
7682-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiply, EmptyLower,
7683-
DXIL::OpCode::LinAlgMatrixMultiply},
7808+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulate,
7809+
TranslateLinAlgMatrixAccumulate, DXIL::OpCode::LinAlgMatrixAccumulate},
7810+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiply,
7811+
TranslateLinAlgMatrixMatrixMultiply, DXIL::OpCode::LinAlgMatrixMultiply},
76847812
{IntrinsicOp::IOP___builtin_LinAlg_MatrixMatrixMultiplyAccumulate,
7685-
EmptyLower, DXIL::OpCode::LinAlgMatrixMultiplyAccumulate},
7686-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixQueryAccumulatorLayout, EmptyLower,
7687-
DXIL::OpCode::LinAlgMatrixQueryAccumulatorLayout},
7688-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToDescriptor, EmptyLower,
7813+
TranslateLinAlgMatrixMatrixMultiplyAccumulate,
7814+
DXIL::OpCode::LinAlgMatrixMultiplyAccumulate},
7815+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixQueryAccumulatorLayout,
7816+
TrivialNoArgOperation, DXIL::OpCode::LinAlgMatrixQueryAccumulatorLayout},
7817+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToDescriptor,
7818+
TranslateLinAlgMatrixAccumStoreToDescriptor,
76897819
DXIL::OpCode::LinAlgMatrixAccumulateToDescriptor},
76907820
{IntrinsicOp::IOP___builtin_LinAlg_MatrixAccumulateToMemory, EmptyLower,
76917821
DXIL::OpCode::LinAlgMatrixAccumulateToMemory},
7692-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixOuterProduct, EmptyLower,
7693-
DXIL::OpCode::LinAlgMatrixOuterProduct},
7694-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixVectorMultiply, EmptyLower,
7695-
DXIL::OpCode::LinAlgMatVecMul},
7696-
{IntrinsicOp::IOP___builtin_LinAlg_MatrixVectorMultiplyAdd, EmptyLower,
7697-
DXIL::OpCode::LinAlgMatVecMulAdd},
7822+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixOuterProduct,
7823+
TranslateLinAlgMatrixOuterProduct, DXIL::OpCode::LinAlgMatrixOuterProduct},
7824+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixVectorMultiply,
7825+
TranslateLinAlgMatVecMul, DXIL::OpCode::LinAlgMatVecMul},
7826+
{IntrinsicOp::IOP___builtin_LinAlg_MatrixVectorMultiplyAdd,
7827+
TranslateLinAlgMatVecMulAdd, DXIL::OpCode::LinAlgMatVecMulAdd},
7828+
76987829
{IntrinsicOp::IOP_DebugBreak, TrivialNoArgOperation,
76997830
DXIL::OpCode::DebugBreak},
77007831
{IntrinsicOp::IOP_DxIsDebuggerPresent, TranslateWaveToVal,

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "CGRecordLayout.h"
1515
#include "CodeGenFunction.h"
1616
#include "CodeGenModule.h"
17+
#include "dxc/DXIL/DxilConstants.h"
1718
#include "dxc/DXIL/DxilOperations.h"
1819
#include "dxc/DXIL/DxilTypeSystem.h"
1920
#include "dxc/DXIL/DxilUtil.h"
@@ -6621,7 +6622,7 @@ llvm::Type *CGMSHLSLRuntime::ConvertAttributedLinAlgMatrixType(
66216622

66226623
llvm::SmallString<64> Buf;
66236624
llvm::raw_svector_ostream OS(Buf);
6624-
OS << "dx.types.LinAlgMatrix";
6625+
OS << DXIL::kDxLinAlgMatrixTypePrefix;
66256626
T->appendMangledAttributes(OS);
66266627
StringRef TypeName = OS.str();
66276628

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T cs_6_10 -E main %s | FileCheck %s
3+
4+
[numthreads(1,1,1)]
5+
void main() {
6+
// CHECK-LABEL: define void @main()
7+
8+
// CHECK: %{{.*}} = call %dx.types.LinAlgMatrixC4M5N4U1S2 @dx.op.linAlgFillMatrix.mC4M5N4U1S2.i32(i32 -2147483636, i32 {{.*}}) ; LinAlgFillMatrix(value)
9+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(4, 5, 4, 1, 2)]] mat1;
10+
__builtin_LinAlg_FillMatrix(mat1, 5);
11+
// CHECK: %{{.*}} = call %dx.types.LinAlgMatrixC5M3N4U0S0 @dx.op.linAlgFillMatrix.mC5M3N4U0S0.f32(i32 -2147483636, float {{.*}}) ; LinAlgFillMatrix(value)
12+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(5, 3, 4, 0, 0)]] mat2;
13+
__builtin_LinAlg_FillMatrix(mat2, 3.14);
14+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -T lib_6_10 -E main %s -ast-dump-implicit | FileCheck %s
3+
4+
// CHECK: FunctionDecl {{.*}} implicit used __builtin_LinAlg_FillMatrix 'void (__builtin_LinAlgMatrix & {{.*}}, unsigned int)' extern
5+
// CHECK-NEXT: ParmVarDecl {{.*}} ret '__builtin_LinAlgMatrix &&__restrict {{.*}}'
6+
// CHECK-NEXT: ParmVarDecl {{.*}} value 'unsigned int'
7+
// CHECK-NEXT: HLSLIntrinsicAttr {{.*}} Implicit "op" "" 406
8+
// CHECK-NEXT: AvailabilityAttr {{.*}} Implicit 6.10 0 0 ""
9+
10+
[shader("compute")]
11+
[numthreads(1,1,1)]
12+
void main() {
13+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(1, 5, 4, 2, 2)]] mat;
14+
__builtin_LinAlg_FillMatrix(mat, 15);
15+
}

0 commit comments

Comments
 (0)