-
Notifications
You must be signed in to change notification settings - Fork 27
feat: support metal cpp #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
a844fa0
8632abc
988ebd7
92782fa
914ed28
9a22bbb
67473e2
43aea1d
55a9904
1b4fe2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| [general] | ||
| name = "relu" | ||
| universal = false | ||
|
|
||
| [torch] | ||
| src = [ | ||
| "torch-ext/torch_binding.cpp", | ||
| "torch-ext/torch_binding.h", | ||
| ] | ||
|
|
||
|
|
||
| [kernel.relu_metal] | ||
| backend = "metal" | ||
| src = [ | ||
| "relu/relu.cpp", | ||
| "relu/metallib_loader.mm", | ||
| "relu/relu_cpp.metal", | ||
| "relu/common.h", | ||
| ] | ||
| depends = [ "torch", "metal-cpp" ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| description = "Flake for ReLU metal cpp kernel"; | ||
|
|
||
| inputs = { | ||
| kernel-builder.url = "path:../.."; | ||
| }; | ||
|
|
||
| outputs = | ||
| { | ||
| self, | ||
| kernel-builder, | ||
| }: | ||
| kernel-builder.lib.genFlakeOutputs { | ||
| inherit self; | ||
| path = ./.; | ||
| }; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| #pragma once | ||
|
|
||
| #include <metal_stdlib> | ||
| using namespace metal; | ||
|
|
||
| // Common constants and utilities for Metal kernels | ||
| constant float RELU_THRESHOLD = 0.0f; |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| #import <Metal/Metal.h> | ||
| #include <ATen/mps/MPSDevice.h> | ||
| #include <ATen/mps/MPSStream.h> | ||
|
|
||
| #ifdef EMBEDDED_METALLIB_HEADER | ||
| #include EMBEDDED_METALLIB_HEADER | ||
| #else | ||
| #error "EMBEDDED_METALLIB_HEADER not defined" | ||
| #endif | ||
|
|
||
| // C++ interface to load the embedded metallib without exposing ObjC types | ||
| extern "C" { | ||
| void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) { | ||
| id<MTLDevice> mtlDevice = (__bridge id<MTLDevice>)device; | ||
| NSError* error = nil; | ||
|
|
||
| id<MTLLibrary> library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error); | ||
|
|
||
| if (!library && errorMsg && error) { | ||
| *errorMsg = strdup([error.localizedDescription UTF8String]); | ||
| } | ||
|
|
||
| // Manually retain since we're not using ARC | ||
| // The caller will wrap in NS::TransferPtr which assumes ownership | ||
| if (library) { | ||
| [library retain]; | ||
| } | ||
| return (__bridge void*)library; | ||
| } | ||
|
|
||
| // Get PyTorch's MPS device (returns id<MTLDevice> as void*) | ||
| void* getMPSDevice() { | ||
| return (__bridge void*)at::mps::MPSDevice::getInstance()->device(); | ||
| } | ||
|
|
||
| // Get PyTorch's current MPS command queue (returns id<MTLCommandQueue> as void*) | ||
| void* getMPSCommandQueue() { | ||
| return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue(); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| #define NS_PRIVATE_IMPLEMENTATION | ||
| #define MTL_PRIVATE_IMPLEMENTATION | ||
|
|
||
| // Include metal-cpp headers from system | ||
| #include <Metal/Metal.hpp> | ||
| #include <Foundation/NSSharedPtr.hpp> | ||
|
|
||
| #include <torch/torch.h> | ||
|
|
||
| // C interface from metallib_loader.mm | ||
| extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg); | ||
| extern "C" void* getMPSDevice(); | ||
| extern "C" void* getMPSCommandQueue(); | ||
|
|
||
| namespace { | ||
|
|
||
| MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) { | ||
| return reinterpret_cast<MTL::Buffer*>(const_cast<void*>(tensor.storage().data())); | ||
| } | ||
|
|
||
| NS::String* makeNSString(const std::string& value) { | ||
| return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding); | ||
| } | ||
|
|
||
| MTL::Library* loadLibrary(MTL::Device* device) { | ||
| const char* errorMsg = nullptr; | ||
| void* library = loadEmbeddedMetalLibrary(reinterpret_cast<void*>(device), &errorMsg); | ||
|
|
||
| TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ", | ||
| errorMsg ? errorMsg : "Unknown error"); | ||
|
|
||
| if (errorMsg) { | ||
| free(const_cast<char*>(errorMsg)); | ||
| } | ||
|
|
||
| return reinterpret_cast<MTL::Library*>(library); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { | ||
| // Use PyTorch's MPS device and command queue (these are borrowed references, not owned) | ||
| MTL::Device* device = reinterpret_cast<MTL::Device*>(getMPSDevice()); | ||
| TORCH_CHECK(device != nullptr, "Failed to get MPS device"); | ||
|
|
||
| MTL::CommandQueue* commandQueue = reinterpret_cast<MTL::CommandQueue*>(getMPSCommandQueue()); | ||
| TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue"); | ||
|
|
||
| MTL::Library* libraryPtr = reinterpret_cast<MTL::Library*>(loadLibrary(device)); | ||
| NS::SharedPtr<MTL::Library> library = NS::TransferPtr(libraryPtr); | ||
|
|
||
| const std::string kernelName = | ||
| std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half"); | ||
| NS::SharedPtr<NS::String> kernelNameString = NS::TransferPtr(makeNSString(kernelName)); | ||
|
|
||
| NS::SharedPtr<MTL::Function> computeFunction = | ||
| NS::TransferPtr(library->newFunction(kernelNameString.get())); | ||
| TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName); | ||
|
|
||
| NS::Error* pipelineError = nullptr; | ||
| NS::SharedPtr<MTL::ComputePipelineState> pipelineState = | ||
| NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError)); | ||
| TORCH_CHECK(pipelineState.get() != nullptr, | ||
| "Failed to create compute pipeline state: ", | ||
| pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error"); | ||
|
|
||
| // Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue | ||
| MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer(); | ||
| TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer"); | ||
|
|
||
| MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder(); | ||
| TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder"); | ||
|
|
||
| encoder->setComputePipelineState(pipelineState.get()); | ||
|
|
||
| auto* inputBuffer = getMTLBuffer(input); | ||
| auto* outputBuffer = getMTLBuffer(output); | ||
| TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null"); | ||
| TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null"); | ||
|
|
||
| encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0); | ||
| encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1); | ||
|
|
||
| const NS::UInteger totalThreads = input.numel(); | ||
| NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup(); | ||
| if (threadGroupSize > totalThreads) { | ||
| threadGroupSize = totalThreads; | ||
| } | ||
|
|
||
| const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1); | ||
| const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1); | ||
|
|
||
| encoder->dispatchThreads(gridSize, threadsPerThreadgroup); | ||
| encoder->endEncoding(); | ||
|
|
||
| commandBuffer->commit(); | ||
| } | ||
|
|
||
| void relu(torch::Tensor& out, const torch::Tensor& input) { | ||
| TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); | ||
| TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); | ||
| TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf, | ||
| "Unsupported data type: ", input.scalar_type()); | ||
|
|
||
| TORCH_CHECK(input.sizes() == out.sizes(), | ||
| "Tensors must have the same shape. Got input shape: ", | ||
| input.sizes(), " and output shape: ", out.sizes()); | ||
|
|
||
| TORCH_CHECK(input.scalar_type() == out.scalar_type(), | ||
| "Tensors must have the same data type. Got input dtype: ", | ||
| input.scalar_type(), " and output dtype: ", out.scalar_type()); | ||
|
|
||
| TORCH_CHECK(input.device() == out.device(), | ||
| "Tensors must be on the same device. Got input device: ", | ||
| input.device(), " and output device: ", out.device()); | ||
|
|
||
| dispatchReluKernel(input, out); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| #include <metal_stdlib> | ||
| #include "common.h" | ||
| using namespace metal; | ||
|
|
||
| kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]], | ||
| device float *outC [[buffer(1)]], | ||
| uint index [[thread_position_in_grid]]) { | ||
| // Explicitly write to output | ||
| outC[index] = max(RELU_THRESHOLD, inA[index]); | ||
| } | ||
|
|
||
| kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]], | ||
| device half *outC [[buffer(1)]], | ||
| uint index [[thread_position_in_grid]]) { | ||
| // Explicitly write to output | ||
| outC[index] = max(static_cast<half>(0.0), inA[index]); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| import platform | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| import relu | ||
|
|
||
|
|
||
| def test_relu(): | ||
| if platform.system() == "Darwin": | ||
| device = torch.device("mps") | ||
| elif hasattr(torch, "xpu") and torch.xpu.is_available(): | ||
| device = torch.device("xpu") | ||
| elif torch.version.cuda is not None and torch.cuda.is_available(): | ||
| device = torch.device("cuda") | ||
| else: | ||
| device = torch.device("cpu") | ||
| x = torch.randn(1024, 1024, dtype=torch.float32, device=device) | ||
| torch.testing.assert_allclose(F.relu(x), relu.relu(x)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from ._ops import ops | ||
|
|
||
|
|
||
| def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| if out is None: | ||
| out = torch.empty_like(x) | ||
| ops.relu(out, x) | ||
| return out |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| #include <torch/library.h> | ||
|
|
||
| #include "registration.h" | ||
| #include "torch_binding.h" | ||
|
|
||
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||
| ops.def("relu(Tensor! out, Tensor input) -> ()"); | ||
| #if defined(CPU_KERNEL) | ||
| ops.impl("relu", torch::kCPU, &relu); | ||
| #elif defined(CUDA_KERNEL) || defined(ROCM_KERNEL) | ||
| ops.impl("relu", torch::kCUDA, &relu); | ||
| #elif defined(METAL_KERNEL) | ||
| ops.impl("relu", torch::kMPS, relu); | ||
| #elif defined(XPU_KERNEL) | ||
| ops.impl("relu", torch::kXPU, &relu); | ||
| #endif | ||
| } | ||
|
|
||
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #pragma once | ||
|
|
||
| #include <torch/torch.h> | ||
|
|
||
| void relu(torch::Tensor &out, torch::Tensor const &input); |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| # /// script | ||
|
||
| # requires-python = ">=3.10" | ||
| # dependencies = ["kernels", "torch", "numpy"] | ||
| # /// | ||
| from kernels import get_local_kernel | ||
| import torch | ||
| from pathlib import Path | ||
|
|
||
| relu = get_local_kernel(Path("examples/relu-metal-cpp/result"), "relu").relu | ||
|
|
||
| input = torch.tensor([-1.0, -1.5, 0.0, 2.0, 3.5], device="mps", dtype=torch.float16) | ||
| out = relu(input) | ||
| ref = torch.relu(input) | ||
|
|
||
| assert torch.allclose(out, ref), f"Float16 failed: {out} != {ref}" | ||
|
|
||
| print(out.cpu().numpy()) | ||
| print(ref.cpu().numpy()) | ||
|
|
||
| print("PASS") | ||
| # [0. 0. 0. 2. 3.5] | ||
| # [0. 0. 0. 2. 3.5] | ||
| # PASS | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a version like cutlass? My first guess is not, since on Mac we always have everything the latest, but I thought I'd check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question! I agree we probably do not need a version here and we'll always prefer latest