Skip to content
Open
1 change: 1 addition & 0 deletions include/oklt/core/target_backends.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum struct TargetBackend : unsigned char {
CUDA, ///< CUDA backend.
HIP, ///< HIP backend.
DPCPP, ///< DPCPP backend.
METAL, ///< Metal backend.

_LAUNCHER, ///< Launcher backend.
};
Expand Down
14 changes: 14 additions & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ set (OCCA_TRANSPILER_SOURCES
attributes/backend/dpcpp/common.cpp
attributes/backend/dpcpp/common.h

# Metal
attributes/backend/metal/kernel.cpp
attributes/backend/metal/translation_unit.cpp
attributes/backend/metal/outer.cpp
attributes/backend/metal/inner.cpp
attributes/backend/metal/tile.cpp
attributes/backend/metal/shared.cpp
attributes/backend/metal/restrict.cpp
attributes/backend/metal/atomic.cpp
attributes/backend/metal/barrier.cpp
attributes/backend/metal/exclusive.cpp
attributes/backend/metal/common.cpp
attributes/backend/metal/common.h

# Serial subset
attributes/utils/serial_subset/empty.cpp
attributes/utils/serial_subset/kernel.cpp
Expand Down
24 changes: 24 additions & 0 deletions lib/attributes/backend/metal/atomic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include "attributes/backend/metal/common.h"

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

HandleResult handleAtomicStmtAttribute(SessionStage& s, const Stmt& stmt, const Attr& a) {
SPDLOG_DEBUG("Handle attribute [{}]", a.getNormalizedFullName());

removeAttribute(s, a);
return {};
}

__attribute__((constructor)) void registerAttrBackend() {
auto ok =
registerBackendHandler(TargetBackend::METAL, ATOMIC_ATTR_NAME, handleAtomicStmtAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", ATOMIC_ATTR_NAME);
}
}
} // namespace
31 changes: 31 additions & 0 deletions lib/attributes/backend/metal/barrier.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "attributes/backend/metal/common.h"

#include <clang/AST/Attr.h>
#include <clang/AST/Stmt.h>

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

oklt::HandleResult handleBarrierAttribute(SessionStage& stage,
const clang::Stmt& stmt,
const clang::Attr& attr) {
SPDLOG_DEBUG("Handle [@barrier] attribute");

auto range = getAttrFullSourceRange(attr);
stage.getRewriter().ReplaceText(range, metal::SYNC_THREADS_BARRIER);

return {};
}

__attribute__((constructor)) void registerAttrBackend() {
auto ok =
registerBackendHandler(TargetBackend::METAL, BARRIER_ATTR_NAME, handleBarrierAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", BARRIER_ATTR_NAME);
}
}
} // namespace
62 changes: 62 additions & 0 deletions lib/attributes/backend/metal/common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include "attributes/backend/metal/common.h"
#include "core/sema/okl_sema_ctx.h"
#include "core/utils/range_to_string.h"
#include "util/string_utils.hpp"

#include <clang/Rewrite/Core/Rewriter.h>

namespace oklt::metal {
using namespace clang;

std::string axisToStr(const Axis& axis) {
static std::map<Axis, std::string> mapping{{Axis::X, "x"}, {Axis::Y, "y"}, {Axis::Z, "z"}};
return mapping[axis];
}

std::string getIdxVariable(const AttributedLoop& loop) {
auto strAxis = axisToStr(loop.axis);
switch (loop.type) {
case (LoopType::Inner):
return util::fmt("_occa_thread_position.{}", strAxis).value();
case (LoopType::Outer):
return util::fmt(" _occa_group_position.{}", strAxis).value();
default: // Incorrect case
return "";
}
}

std::string getTiledVariableName(const OklLoopInfo& forLoop) {
return "_occa_tiled_" + forLoop.var.name;
}

std::string buildInnerOuterLoopIdxLine(const OklLoopInfo& forLoop,
const AttributedLoop& loop,
int& openedScopeCounter,
oklt::Rewriter& rewriter) {
static_cast<void>(openedScopeCounter);
auto idx = getIdxVariable(loop);
auto op = forLoop.IsInc() ? "+" : "-";

std::string res;
if (forLoop.isUnary()) {
res = std::move(util::fmt("{} {} = ({}) {} {};\n",
forLoop.var.typeName,
forLoop.var.name,
getLatestSourceText(forLoop.range.start, rewriter),
op,
idx)
.value());
} else {
res = std::move(util::fmt("{} {} = ({}) {} (({}) * {});\n",
forLoop.var.typeName,
forLoop.var.name,
getLatestSourceText(forLoop.range.start, rewriter),
op,
getLatestSourceText(forLoop.inc.val, rewriter),
idx)
.value());
}
return res;
}

} // namespace oklt::metal
36 changes: 36 additions & 0 deletions lib/attributes/backend/metal/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "attributes/attribute_names.h"
#include "attributes/utils/code_gen.h"
#include "attributes/utils/default_handlers.h"
#include "attributes/utils/kernel_utils.h"
#include "attributes/utils/utils.h"
#include "core/handler_manager/backend_handler.h"
#include "core/rewriter/rewriter_proxy.h"
#include "core/sema/okl_sema_ctx.h"
#include "core/transpiler_session/session_stage.h"
#include "core/utils/attributes.h"
#include "core/utils/range_to_string.h"

#include <string>

namespace clang {
class Rewriter;
}

namespace oklt {
struct OklLoopInfo;
}

namespace oklt::metal {
std::string axisToStr(const Axis& axis);
std::string getIdxVariable(const AttributedLoop& loop);
std::string getTiledVariableName(const OklLoopInfo& forLoop);

// Produces something like: int i = start +- (inc * _occa_group_position.x);
// or: int i = start +- (inc * _occa_thread_position.x);
std::string buildInnerOuterLoopIdxLine(const OklLoopInfo& forLoop,
const AttributedLoop& loop,
int& openedScopeCounter,
oklt::Rewriter& rewriter);

const std::string SYNC_THREADS_BARRIER = "threadgroup_barrier(mem_flags::mem_threadgroup)";
} // namespace oklt::metal
53 changes: 53 additions & 0 deletions lib/attributes/backend/metal/exclusive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "attributes/backend/metal/common.h"

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

HandleResult handleExclusiveDeclAttribute(SessionStage& s, const Decl& decl, const Attr& a) {
SPDLOG_DEBUG("Handle [@exclusive] attribute (Decl)");

removeAttribute(s, a);
return {};
}

HandleResult handleExclusiveVarAttribute(SessionStage& s, const VarDecl& decl, const Attr& a) {
SPDLOG_DEBUG("Handle [@exclusive] attribute");

removeAttribute(s, a);

auto& sema = s.tryEmplaceUserCtx<OklSemaCtx>();
auto loopInfo = sema.getLoopInfo();
if (loopInfo && loopInfo->isRegular()) {
loopInfo = loopInfo->getAttributedParent();
}
if (loopInfo && loopInfo->has(LoopType::Inner)) {
return tl::make_unexpected(
Error{{}, "Cannot define [@exclusive] variables inside an [@inner] loop"});
}

auto child = loopInfo ? loopInfo->getFirstAttributedChild() : nullptr;
bool isInnerChild = child && child->has(LoopType::Inner);
if (!loopInfo || !loopInfo->has(LoopType::Outer) || !isInnerChild) {
return tl::make_unexpected(
Error{{}, "Must define [@exclusive] variables between [@outer] and [@inner] loops"});
}

return defaultHandleExclusiveDeclAttribute(s, decl, a);
}

__attribute__((constructor)) void registerAttrBackend() {
auto ok = registerBackendHandler(
TargetBackend::METAL, EXCLUSIVE_ATTR_NAME, handleExclusiveDeclAttribute);
ok &= registerBackendHandler(
TargetBackend::METAL, EXCLUSIVE_ATTR_NAME, handleExclusiveVarAttribute);
ok &= registerBackendHandler(
TargetBackend::METAL, EXCLUSIVE_ATTR_NAME, defaultHandleExclusiveStmtAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", EXCLUSIVE_ATTR_NAME);
}
}
} // namespace
49 changes: 49 additions & 0 deletions lib/attributes/backend/metal/inner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "attributes/backend/metal/common.h"
#include "attributes/frontend/params/loop.h"

#include <spdlog/spdlog.h>

namespace {
using namespace oklt;
using namespace clang;

HandleResult handleInnerAttribute(SessionStage& s,
const clang::ForStmt& forStmt,
const clang::Attr& a,
const AttributedLoop* params) {
SPDLOG_DEBUG("Handle [@inner] attribute");
handleChildAttr(s, forStmt, NO_BARRIER_ATTR_NAME);

auto& sema = s.tryEmplaceUserCtx<OklSemaCtx>();
auto loopInfo = sema.getLoopInfo(forStmt);
if (!loopInfo) {
return tl::make_unexpected(
Error{std::error_code(), "@inner: failed to fetch loop meta data from sema"});
}

// Auto Axis in loopInfo are replaced with specific.
// TODO: maybe somehow update params earlier?
auto updatedParams = *params;
updatedParams.axis = loopInfo->axis.front();

std::string afterRBraceCode = "";
if (loopInfo->shouldSync()) {
afterRBraceCode += metal::SYNC_THREADS_BARRIER + ";\n";
}

int openedScopeCounter = 0;
auto prefixCode = metal::buildInnerOuterLoopIdxLine(
*loopInfo, updatedParams, openedScopeCounter, s.getRewriter());
auto suffixCode = buildCloseScopes(openedScopeCounter);

return replaceAttributedLoop(s, forStmt, a, suffixCode, afterRBraceCode, prefixCode, true);
}

__attribute__((constructor)) void registerBackendHandler() {
auto ok = registerBackendHandler(TargetBackend::METAL, INNER_ATTR_NAME, handleInnerAttribute);

if (!ok) {
SPDLOG_ERROR("[METAL] Failed to register {} attribute handler", INNER_ATTR_NAME);
}
}
} // namespace
Loading