Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/build_kernel_macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ jobs:
# kernels. Also run tests once we have a macOS runner.
- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )

- name: Build relu metal cpp kernel
run: ( cd examples/relu-metal-cpp && nix build .\#redistributable.torch29-metal-aarch64-darwin -L )
2 changes: 2 additions & 0 deletions build2cmake/src/config/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ pub enum Dependencies {
Cutlass4_0,
#[serde(rename = "cutlass_sycl")]
CutlassSycl,
#[serde(rename = "metal-cpp")]
MetalCpp,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a version like cutlass? My first guess is not, since on Mac we always have everything the latest, but I thought I'd check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question! I agree we probably do not need a version here and we'll always prefer latest

Torch,
}

Expand Down
20 changes: 20 additions & 0 deletions examples/relu-metal-cpp/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[general]
name = "relu"
universal = false

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]


[kernel.relu_metal]
backend = "metal"
src = [
"relu/relu.cpp",
"relu/metallib_loader.mm",
"relu/relu_cpp.metal",
"relu/common.h",
]
depends = [ "torch", "metal-cpp" ]
17 changes: 17 additions & 0 deletions examples/relu-metal-cpp/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for ReLU metal cpp kernel";

inputs = {
kernel-builder.url = "path:../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genFlakeOutputs {
inherit self;
path = ./.;
};
}
7 changes: 7 additions & 0 deletions examples/relu-metal-cpp/relu/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

#include <metal_stdlib>
using namespace metal;

// Common constants and utilities for Metal kernels
constant float RELU_THRESHOLD = 0.0f;
40 changes: 40 additions & 0 deletions examples/relu-metal-cpp/relu/metallib_loader.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#import <Metal/Metal.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSStream.h>

#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#else
#error "EMBEDDED_METALLIB_HEADER not defined"
#endif

// C++ interface to load the embedded metallib without exposing ObjC types
extern "C" {
void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) {
id<MTLDevice> mtlDevice = (__bridge id<MTLDevice>)device;
NSError* error = nil;

id<MTLLibrary> library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error);

if (!library && errorMsg && error) {
*errorMsg = strdup([error.localizedDescription UTF8String]);
}

// Manually retain since we're not using ARC
// The caller will wrap in NS::TransferPtr which assumes ownership
if (library) {
[library retain];
}
return (__bridge void*)library;
}

// Get PyTorch's MPS device (returns id<MTLDevice> as void*)
void* getMPSDevice() {
return (__bridge void*)at::mps::MPSDevice::getInstance()->device();
}

// Get PyTorch's current MPS command queue (returns id<MTLCommandQueue> as void*)
void* getMPSCommandQueue() {
return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue();
}
}
118 changes: 118 additions & 0 deletions examples/relu-metal-cpp/relu/relu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#define NS_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION

// Include metal-cpp headers from system
#include <Metal/Metal.hpp>
#include <Foundation/NSSharedPtr.hpp>

#include <torch/torch.h>

// C interface from metallib_loader.mm
extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg);
extern "C" void* getMPSDevice();
extern "C" void* getMPSCommandQueue();

namespace {

MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) {
return reinterpret_cast<MTL::Buffer*>(const_cast<void*>(tensor.storage().data()));
}

NS::String* makeNSString(const std::string& value) {
return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding);
}

MTL::Library* loadLibrary(MTL::Device* device) {
const char* errorMsg = nullptr;
void* library = loadEmbeddedMetalLibrary(reinterpret_cast<void*>(device), &errorMsg);

TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ",
errorMsg ? errorMsg : "Unknown error");

if (errorMsg) {
free(const_cast<char*>(errorMsg));
}

return reinterpret_cast<MTL::Library*>(library);
}

} // namespace

void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) {
// Use PyTorch's MPS device and command queue (these are borrowed references, not owned)
MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice());
TORCH_CHECK(device != nullptr, "Failed to get MPS device");

MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue());
TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue");

MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device));
NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr);

const std::string kernelName =
std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half");
NS::SharedPtr<NS::String> kernelNameString = NS::TransferPtr(makeNSString(kernelName));

NS::SharedPtr<MTL::Function> computeFunction =
NS::TransferPtr(library->newFunction(kernelNameString.get()));
TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName);

NS::Error* pipelineError = nullptr;
NS::SharedPtr<MTL::ComputePipelineState> pipelineState =
NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError));
TORCH_CHECK(pipelineState.get() != nullptr,
"Failed to create compute pipeline state: ",
pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error");

// Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue
MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer();
TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer");

MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder");

encoder->setComputePipelineState(pipelineState.get());

auto* inputBuffer = getMTLBuffer(input);
auto* outputBuffer = getMTLBuffer(output);
TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null");
TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null");

encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0);
encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1);

const NS::UInteger totalThreads = input.numel();
NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup();
if (threadGroupSize > totalThreads) {
threadGroupSize = totalThreads;
}

const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1);
const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1);

encoder->dispatchThreads(gridSize, threadsPerThreadgroup);
encoder->endEncoding();

commandBuffer->commit();
}

void relu(torch::Tensor& out, const torch::Tensor& input) {
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf,
"Unsupported data type: ", input.scalar_type());

TORCH_CHECK(input.sizes() == out.sizes(),
"Tensors must have the same shape. Got input shape: ",
input.sizes(), " and output shape: ", out.sizes());

TORCH_CHECK(input.scalar_type() == out.scalar_type(),
"Tensors must have the same data type. Got input dtype: ",
input.scalar_type(), " and output dtype: ", out.scalar_type());

TORCH_CHECK(input.device() == out.device(),
"Tensors must be on the same device. Got input device: ",
input.device(), " and output device: ", out.device());

dispatchReluKernel(input, out);
}
17 changes: 17 additions & 0 deletions examples/relu-metal-cpp/relu/relu_cpp.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <metal_stdlib>
#include "common.h"
using namespace metal;

kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]],
device float *outC [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
// Explicitly write to output
outC[index] = max(RELU_THRESHOLD, inA[index]);
}

kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]],
device half *outC [[buffer(1)]],
uint index [[thread_position_in_grid]]) {
// Explicitly write to output
outC[index] = max(static_cast<half>(0.0), inA[index]);
}
Empty file.
19 changes: 19 additions & 0 deletions examples/relu-metal-cpp/tests/test_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import platform

import torch
import torch.nn.functional as F

import relu


def test_relu():
if platform.system() == "Darwin":
device = torch.device("mps")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
elif torch.version.cuda is not None and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
x = torch.randn(1024, 1024, dtype=torch.float32, device=device)
torch.testing.assert_allclose(F.relu(x), relu.relu(x))
12 changes: 12 additions & 0 deletions examples/relu-metal-cpp/torch-ext/relu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional

import torch

from ._ops import ops


def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
out = torch.empty_like(x)
ops.relu(out, x)
return out
19 changes: 19 additions & 0 deletions examples/relu-metal-cpp/torch-ext/torch_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("relu(Tensor! out, Tensor input) -> ()");
#if defined(CPU_KERNEL)
ops.impl("relu", torch::kCPU, &relu);
#elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL)
ops.impl("relu", torch::kCUDA, &relu);
#elif defined(METAL_KERNEL)
ops.impl("relu", torch::kMPS, relu);
#elif defined(XPU_KERNEL)
ops.impl("relu", torch::kXPU, &relu);
#endif
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
5 changes: 5 additions & 0 deletions examples/relu-metal-cpp/torch-ext/torch_binding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include <torch/torch.h>

void relu(torch::Tensor &out, torch::Tensor const &input);
18 changes: 9 additions & 9 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions lib/deps.nix
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ let
#torch.cxxdev
];
"cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ];
"metal-cpp" = [
pkgs.metal-cpp.dev
];
};
in
let
Expand Down
23 changes: 23 additions & 0 deletions test-kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# /// script
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is another development file that can be removed? (I think I missed it the first time.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh I think I actually added it in 914ed28 thanks for catching!

removed

# requires-python = ">=3.10"
# dependencies = ["kernels", "torch", "numpy"]
# ///
from kernels import get_local_kernel
import torch
from pathlib import Path

relu = get_local_kernel(Path("examples/relu-metal-cpp/result"), "relu").relu

input = torch.tensor([-1.0, -1.5, 0.0, 2.0, 3.5], device="mps", dtype=torch.float16)
out = relu(input)
ref = torch.relu(input)

assert torch.allclose(out, ref), f"Float16 failed: {out} != {ref}"

print(out.cpu().numpy())
print(ref.cpu().numpy())

print("PASS")
# [0. 0. 0. 2. 3.5]
# [0. 0. 0. 2. 3.5]
# PASS
Loading