Skip to content
Merged
2 changes: 1 addition & 1 deletion scripts/compile-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ build_llvm() {
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=true \
-DLLVM_ENABLE_PROJECTS="mlir" \
-DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \
-DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU;SPIRV" \
-DLLVM_INSTALL_UTILS=true \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DCMAKE_INSTALL_PREFIX=$PACKAGES_DIR/llvm \
Expand Down
39 changes: 30 additions & 9 deletions third_party/intel/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,38 @@
# SPIRV-LLVM-Translator is required.
find_package(SPIRVToLLVMTranslator)

add_mlir_translation_library(TritonSPIRV
SPIRVTranslation.cpp
# Check if there is the LLVM SPIR-V backend.
is_llvm_target_library("SPIRV" spirv_present_result INCLUDED_TARGETS)

LINK_COMPONENTS
Core
if(spirv_present_result)
message(STATUS "Found SPIR-V Backend")
add_compile_definitions(LLVM_SPIRV_BACKEND_TARGET_PRESENT)
add_mlir_translation_library(TritonSPIRV
SPIRVTranslation.cpp

LINK_LIBS PUBLIC
TritonLLVMIR
# spirv tools
LLVMSPIRVLib
)
LINK_COMPONENTS
Core
# spirv backend
SPIRVCodeGen

LINK_LIBS PUBLIC
TritonLLVMIR
# spirv tools
LLVMSPIRVLib
)
else()
add_mlir_translation_library(TritonSPIRV
SPIRVTranslation.cpp

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
TritonLLVMIR
# spirv tools
LLVMSPIRVLib
)
endif()

# Add SPIRV-LLVM-Translator include directory.
target_include_directories(TritonSPIRV PRIVATE ${SPIRVToLLVMTranslator_INCLUDE_DIR})
Expand Down
89 changes: 89 additions & 0 deletions third_party/intel/lib/Target/SPIRV/SPIRVTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,85 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/TargetParser/Triple.h"

#if defined(LLVM_SPIRV_BACKEND_TARGET_PRESENT)
namespace llvm {

using namespace llvm;
using namespace SPIRV;

// The LLVM SPIR-V backend exposes an API call that translates LLVM module to
// SPIR-V and writes results into a string as binary SPIR-V output, providing
// diagnostics on fail and means of configuring translation.
extern "C" bool SPIRVTranslate(Module *M, std::string &SpirvObj,
std::string &ErrMsg,
const std::vector<std::string> &AllowExtNames,
llvm::CodeGenOptLevel OLevel,
Triple TargetTriple);

static inline Triple::SubArchType
spirvVersionToSubArch(SPIRV::VersionNumber VN) {
switch (VN) {
case SPIRV::VersionNumber::SPIRV_1_0:
return Triple::SPIRVSubArch_v10;
case VersionNumber::SPIRV_1_1:
return Triple::SPIRVSubArch_v11;
case VersionNumber::SPIRV_1_2:
return Triple::SPIRVSubArch_v12;
case VersionNumber::SPIRV_1_3:
return Triple::SPIRVSubArch_v13;
case VersionNumber::SPIRV_1_4:
return Triple::SPIRVSubArch_v14;
case VersionNumber::SPIRV_1_5:
return Triple::SPIRVSubArch_v15;
case VersionNumber::SPIRV_1_6:
return Triple::SPIRVSubArch_v16;
}
return Triple::NoSubArch;
}

bool runSpirvBackend(Module *M, std::string &Result, std::string &ErrMsg,
const SPIRV::TranslatorOpts &TranslatorOpts) {
static const std::string DefaultTriple = "spirv64v1.6-unknown-unknown";
static const std::vector<std::string> AllowExtNames{"all"};

// Correct the Triple value if needed
Triple TargetTriple(M->getTargetTriple());
if (TargetTriple.isSPIR()) {
TargetTriple.setArch(TargetTriple.getArch() == Triple::spir64
? Triple::spirv64
: Triple::spirv32,
TargetTriple.getSubArch());
M->setTargetTriple(TargetTriple.str());
// We need to reset Data Layout to conform with the TargetMachine
M->setDataLayout("");
}
if (TargetTriple.getTriple().empty())
TargetTriple.setTriple(DefaultTriple);
if (TranslatorOpts.getMaxVersion() != VersionNumber::MaximumVersion) {
TargetTriple.setArch(TargetTriple.getArch(),
spirvVersionToSubArch(TranslatorOpts.getMaxVersion()));
M->setTargetTriple(TargetTriple.str());
}

// Translate the Module into SPIR-V
return SPIRVTranslate(M, Result, ErrMsg, AllowExtNames,
CodeGenOptLevel::Aggressive, TargetTriple);
}

bool runSpirvBackend(Module *M, std::ostream &OS, std::string &ErrMsg,
const SPIRV::TranslatorOpts &TranslatorOpts) {
std::string Result;
bool Status = runSpirvBackend(M, Result, ErrMsg, TranslatorOpts);
if (Status)
OS << Result;
return Status;
}

} // namespace llvm

#endif // LLVM_SPIRV_BACKEND_TARGET_PRESENT

namespace triton {

Expand Down Expand Up @@ -63,7 +142,17 @@ std::string translateLLVMIRToSPIRV(llvm::Module &module) {
SPIRVOpts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
SPIRVOpts.setPreserveAuxData(false);
SPIRVOpts.setSPIRVAllowUnknownIntrinsics({"llvm.genx.GenISA."});

#if defined(LLVM_SPIRV_BACKEND_TARGET_PRESENT)
int SpvTranslateMode = 0;
if (const char *EnvIsBackend = std::getenv("TRITON_USE_SPIRV_BACKEND"))
llvm::StringRef(EnvIsBackend).getAsInteger(10, SpvTranslateMode);
auto success = SpvTranslateMode
? llvm::runSpirvBackend(&module, OS, Err, SPIRVOpts)
: llvm::writeSpirv(&module, SPIRVOpts, OS, Err);
#else
auto success = llvm::writeSpirv(&module, SPIRVOpts, OS, Err);
#endif // LLVM_SPIRV_BACKEND_TARGET_PRESENT

if (!success) {
llvm::errs() << "SPIRVTranslation: SPIRV translation failed with"
Expand Down