Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 69 additions & 3 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,15 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
return SYCLGenError();
OS() << ", ";
switch (T->getKind()) {
case InlineAsmVectorType::x1:
OS() << 1;
break;
case InlineAsmVectorType::v2:
case InlineAsmVectorType::x2:
OS() << 2;
break;
case InlineAsmVectorType::v4:
case InlineAsmVectorType::x4:
OS() << 4;
break;
case InlineAsmVectorType::v8:
Expand Down Expand Up @@ -589,9 +594,9 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {

bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
// Address expression only support ld/st/red & atom instructions.
if (!CurrInst ||
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
asmtok::op_prefetch, asmtok::op_red,
asmtok::op_cp, asmtok::op_ldmatrix)) {
return SYCLGenError();
}
std::string Type;
Expand Down Expand Up @@ -624,6 +629,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
CanSuppressCast(Dst->getSymbol()))
OS() << llvm::formatv("{0}", Reg);
else if (CurrInst->is(asmtok::op_ldmatrix))
OS() << llvm::formatv("(uintptr_t){0}", Reg);
else
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
break;
Expand Down Expand Up @@ -1305,6 +1312,64 @@ class SYCLGen : public SYCLGenBase {
return SYCLGenSuccess();
}

bool handle_ldmatrix(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 1)
return SYCLGenError();

const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));

if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
return SYCLGenError();

const InlineAsmVectorExpr *VE;
if (VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand())) {
auto numOutputOperands = VE->getNumElements();
if (Inst->hasAttr(InstAttr::x1)) {
if (numOutputOperands != 1)
return SYCLGenError();
} else if (Inst->hasAttr(InstAttr::x2)) {
if (numOutputOperands != 2)
return SYCLGenError();
} else if (Inst->hasAttr(InstAttr::x4)) {
if (numOutputOperands != 4)
return SYCLGenError();
}
} else {
return SYCLGenError();
}

llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
CurrInst = Inst;
const auto *Src =
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
if (!Src)
return false;

OS() << MapNames::getDpctNamespace() << "experimental::matrix::ldmatrix(";
if (emitStmt(Src)) {
return SYCLGenError();
}
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
continue;
OS() << ", &";
if (emitStmt(VE->getElement(Inst)))
return SYCLGenError();
}
if (Inst->hasAttr(InstAttr::trans))
OS() << ", true";
OS() << ");";
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
if (KernelDecl) {
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
if (FuncInfo)
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
DpctGlobalInfo::getSubGroup(GAS));
}

return SYCLGenSuccess();
}

bool handle_prefetch(const InlineAsmInstruction *Inst) override {
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
return SYCLGenError();
Expand Down Expand Up @@ -2881,6 +2946,7 @@ class SYCLGen : public SYCLGenBase {
bool handle_ld(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 1)
return SYCLGenError();

llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
CurrInst = Inst;
const auto *Src =
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
// This class is used for device asm vector types.
class InlineAsmVectorType : public InlineAsmType {
public:
enum VecKind { v2, v4, v8 };
enum VecKind { v2, v4, v8, x1, x2, x4 };

private:
VecKind Kind;
Expand Down
46 changes: 32 additions & 14 deletions clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
if (!Tok.getIdentifier() || !Tok.getIdentifier()->isInstruction())
return AsmStmtError();

InlineAsmIdentifierInfo *Opcode = Tok.getIdentifier();
Opcode = Tok.getIdentifier();
ConsumeToken();

SmallVector<InstAttr, 4> Attrs;
Expand Down Expand Up @@ -736,20 +736,38 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
InlineAsmExprResult
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {

// Vector size must be 2, 4, or 8.
// Vector size for ldmatrix are 1, 2, 4
// size(x) = 2 * sizeof(v).
InlineAsmVectorType::VecKind Kind;
switch (Vec.size()) {
case 2:
Kind = InlineAsmVectorType::v2;
break;
case 4:
Kind = InlineAsmVectorType::v4;
break;
case 8:
Kind = InlineAsmVectorType::v8;
break;
default:
return AsmExprError();
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
switch (Vec.size()) {
case 1:
Kind = InlineAsmVectorType::x1;
break;
case 2:
Kind = InlineAsmVectorType::x2;
break;
case 4:
Kind = InlineAsmVectorType::x4;
break;
default:
return AsmExprError();
}
} else {
// Vector size must be 2, 4, or 8.
switch (Vec.size()) {
case 2:
Kind = InlineAsmVectorType::v2;
break;
case 4:
Kind = InlineAsmVectorType::v4;
break;
case 8:
Kind = InlineAsmVectorType::v8;
break;
default:
return AsmExprError();
}
}

InlineAsmBuiltinType *ElementType = nullptr;
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ class InlineAsmParser {
};

public:
InlineAsmIdentifierInfo *Opcode;

InlineAsmParser(InlineAsmContext &Ctx, SourceMgr &Mgr)
: Lexer(*Mgr.getMemoryBuffer(Mgr.getMainFileID())), Context(Ctx),
SrcMgr(Mgr), CurScope(nullptr) {
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,14 @@ MODIFIER(v2, ".v2")
MODIFIER(v4, ".v4")
MODIFIER(v8, ".v8")

// Matrix modifiers
MODIFIER(x1, ".x1")
MODIFIER(x2, ".x2")
MODIFIER(x4, ".x4")

// Matrix shape
MODIFIER(m8n8, ".m8n8")

STATE_SPACE(reg, ".reg")
STATE_SPACE(sreg, ".sreg")
STATE_SPACE(const, ".const")
Expand Down Expand Up @@ -420,6 +428,8 @@ MODIFIER(ecr, ".ecr")
MODIFIER(rc16, ".rc16")
MODIFIER(cs, ".cs")
MODIFIER(to, ".to")
MODIFIER(aligned, ".aligned")
MODIFIER(trans, ".trans")

#undef LINKAGE
#undef TARGET
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/SrcAPI/APINames_ASM.inc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ENTRY("griddepcontrol", "griddepcontrol", false, NO_FLAG, P1, "Comment")
ENTRY("isspacep", "isspacep", false, NO_FLAG, P1, "Comment")
ENTRY("istypep", "istypep", false, NO_FLAG, P1, "Comment")
ENTRY("ld", "ld", true, NO_FLAG, P1, "Partial")
ENTRY("ldmatrix", "ldmatrix", false, NO_FLAG, P1, "Comment")
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Successful")
ENTRY("ldu", "ldu", false, NO_FLAG, P1, "Comment")
ENTRY("lg2", "lg2", true, NO_FLAG, P1, "Successful")
ENTRY("lop3", "lop3", true, NO_FLAG, P1, "Successful")
Expand Down
Loading