From 1caa2eb390f4b53495bfc3e9bc351305e9f09a51 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 5 Nov 2025 16:02:18 -0500 Subject: [PATCH 1/2] feat: add metal-cpp as a known dependency and relu metal cpp example --- build2cmake/src/config/v2.rs | 2 + examples/relu/build.toml | 18 +- examples/relu/flake.lock | 165 ++++++++++++++++++ examples/relu/relu_metal_cpp/common.h | 10 ++ .../relu/relu_metal_cpp/metallib_loader.mm | 24 +++ examples/relu/relu_metal_cpp/relu.cpp | 118 +++++++++++++ examples/relu/relu_metal_cpp/relu_cpp.metal | 17 ++ flake.lock | 19 +- flake.nix | 2 +- lib/deps.nix | 3 + lib/torch-extension/arch.nix | 31 +++- 11 files changed, 392 insertions(+), 17 deletions(-) create mode 100644 examples/relu/flake.lock create mode 100644 examples/relu/relu_metal_cpp/common.h create mode 100644 examples/relu/relu_metal_cpp/metallib_loader.mm create mode 100644 examples/relu/relu_metal_cpp/relu.cpp create mode 100644 examples/relu/relu_metal_cpp/relu_cpp.metal diff --git a/build2cmake/src/config/v2.rs b/build2cmake/src/config/v2.rs index ecbdd9ec..0f8457e9 100644 --- a/build2cmake/src/config/v2.rs +++ b/build2cmake/src/config/v2.rs @@ -247,6 +247,8 @@ pub enum Dependencies { Cutlass4_0, #[serde(rename = "cutlass_sycl")] CutlassSycl, + #[serde(rename = "metal-cpp")] + MetalCpp, Torch, } diff --git a/examples/relu/build.toml b/examples/relu/build.toml index 84eb068a..238711b5 100644 --- a/examples/relu/build.toml +++ b/examples/relu/build.toml @@ -13,14 +13,24 @@ backend = "cuda" depends = ["torch"] src = ["relu_cuda/relu.cu"] +# [kernel.relu_metal] +# backend = "metal" +# src = [ +# "relu_metal/relu.mm", +# "relu_metal/relu.metal", +# "relu_metal/common.h", +# ] +# depends = [ "torch" ] + [kernel.relu_metal] backend = "metal" src = [ - "relu_metal/relu.mm", - "relu_metal/relu.metal", - "relu_metal/common.h", + "relu_metal_cpp/relu.cpp", + "relu_metal_cpp/metallib_loader.mm", + "relu_metal_cpp/relu_cpp.metal", + "relu_metal_cpp/common.h", ] -depends = [ "torch" ] +depends = [ "torch", "metal-cpp" ] [kernel.relu_rocm] backend = "rocm" diff --git a/examples/relu/flake.lock b/examples/relu/flake.lock new file mode 100644 index 00000000..c563093e --- /dev/null +++ b/examples/relu/flake.lock @@ -0,0 +1,165 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "locked": { + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "hf-nix": { + "inputs": { + "flake-compat": "flake-compat_2", + "flake-utils": "flake-utils_2", + "nixpkgs": "nixpkgs" + }, + "locked": { + "lastModified": 1762375532, + "narHash": "sha256-6JLEGsTvjrb9ZV2N7SFO1lNEmgx/bWFylDneyw81kb4=", + "owner": "huggingface", + "repo": "hf-nix", + "rev": "71b94809ae4a2bdd5d847a5873c511e2fe2c2a95", + "type": "github" + }, + "original": { + "owner": "huggingface", + "ref": "add-metal-cpp-package", + "repo": "hf-nix", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "hf-nix": "hf-nix", + "nixpkgs": [ + "kernel-builder", + "hf-nix", + "nixpkgs" + ] + }, + "locked": { + "path": "../..", + "type": "path" + }, + "original": { + "path": "../..", + "type": "path" + }, + "parent": [] + }, + "nixpkgs": { + "locked": { + "lastModified": 1762328495, + "narHash": "sha256-IUZvw5kvLiExApP9+SK/styzEKSqfe0NPclu9/z85OQ=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "4c621660e393922cf68cdbfc40eb5a2d54d3989a", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/examples/relu/relu_metal_cpp/common.h b/examples/relu/relu_metal_cpp/common.h new file mode 100644 index 00000000..1b891fad --- /dev/null +++ b/examples/relu/relu_metal_cpp/common.h @@ -0,0 +1,10 @@ +#ifndef COMMON_H +#define COMMON_H + +#include +using namespace metal; + +// Common constants and utilities for Metal kernels +constant float RELU_THRESHOLD = 0.0f; + +#endif // COMMON_H \ No newline at end of file diff --git a/examples/relu/relu_metal_cpp/metallib_loader.mm b/examples/relu/relu_metal_cpp/metallib_loader.mm new file mode 100644 index 00000000..be5bc836 --- /dev/null +++ b/examples/relu/relu_metal_cpp/metallib_loader.mm @@ -0,0 +1,24 @@ +#import +#import + +#ifdef EMBEDDED_METALLIB_HEADER +#include EMBEDDED_METALLIB_HEADER +#else +#error "EMBEDDED_METALLIB_HEADER not defined" +#endif + +// C++ interface to load the embedded metallib without exposing ObjC types +extern "C" { + void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg) { + id mtlDevice = (__bridge id)device; + NSError* error = nil; + + id library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(mtlDevice, &error); + + if (!library && errorMsg && error) { + *errorMsg = strdup([error.localizedDescription UTF8String]); + } + + return (__bridge void*)library; + } +} diff --git a/examples/relu/relu_metal_cpp/relu.cpp b/examples/relu/relu_metal_cpp/relu.cpp new file mode 100644 index 00000000..68acfc89 --- /dev/null +++ b/examples/relu/relu_metal_cpp/relu.cpp @@ -0,0 +1,118 @@ +#define NS_PRIVATE_IMPLEMENTATION +#define MTL_PRIVATE_IMPLEMENTATION + +// Include metal-cpp headers from system +#include +#include +#include + +#include + +// C interface from metallib_loader.mm +extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg); + +namespace { + +MTL::Buffer* getMTLBuffer(const torch::Tensor& tensor) { + return reinterpret_cast(const_cast(tensor.storage().data())); +} + +NS::String* makeNSString(const std::string& value) { + return NS::String::string(value.c_str(), NS::StringEncoding::UTF8StringEncoding); +} + +MTL::Library* loadLibrary(MTL::Device* device) { + const char* errorMsg = nullptr; + void* library = loadEmbeddedMetalLibrary(reinterpret_cast(device), &errorMsg); + + TORCH_CHECK(library != nullptr, "Failed to create Metal library from embedded data: ", + errorMsg ? errorMsg : "Unknown error"); + + if (errorMsg) { + free(const_cast(errorMsg)); + } + + return reinterpret_cast(library); +} + +} // namespace + +torch::Tensor& dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { + NS::SharedPtr pool = NS::TransferPtr(NS::AutoreleasePool::alloc()->init()); + + NS::SharedPtr device = NS::TransferPtr(MTL::CreateSystemDefaultDevice()); + TORCH_CHECK(device.get() != nullptr, "Failed to create Metal device"); + + NS::SharedPtr commandQueue = NS::TransferPtr(device->newCommandQueue()); + TORCH_CHECK(commandQueue.get() != nullptr, "Failed to create Metal command queue"); + + NS::SharedPtr library = NS::TransferPtr(loadLibrary(device.get())); + + const std::string kernelName = + std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half"); + NS::SharedPtr kernelNameString = NS::TransferPtr(makeNSString(kernelName)); + + NS::SharedPtr computeFunction = + NS::TransferPtr(library->newFunction(kernelNameString.get())); + TORCH_CHECK(computeFunction.get() != nullptr, "Failed to create Metal function for ", kernelName); + + NS::Error* pipelineError = nullptr; + NS::SharedPtr pipelineState = + NS::TransferPtr(device->newComputePipelineState(computeFunction.get(), &pipelineError)); + TORCH_CHECK(pipelineState.get() != nullptr, + "Failed to create compute pipeline state: ", + pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error"); + + NS::SharedPtr commandBuffer = NS::TransferPtr(commandQueue->commandBuffer()); + TORCH_CHECK(commandBuffer.get() != nullptr, "Failed to create Metal command buffer"); + + NS::SharedPtr encoder = + NS::TransferPtr(commandBuffer->computeCommandEncoder()); + TORCH_CHECK(encoder.get() != nullptr, "Failed to create compute command encoder"); + + encoder->setComputePipelineState(pipelineState.get()); + + auto* inputBuffer = getMTLBuffer(input); + auto* outputBuffer = getMTLBuffer(output); + + encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0); + encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1); + + const NS::UInteger totalThreads = input.numel(); + NS::UInteger threadGroupSize = pipelineState->maxTotalThreadsPerThreadgroup(); + if (threadGroupSize > totalThreads) { + threadGroupSize = totalThreads; + } + + const MTL::Size gridSize = MTL::Size::Make(totalThreads, 1, 1); + const MTL::Size threadsPerThreadgroup = MTL::Size::Make(threadGroupSize, 1, 1); + + encoder->dispatchThreads(gridSize, threadsPerThreadgroup); + encoder->endEncoding(); + + commandBuffer->commit(); + commandBuffer->waitUntilCompleted(); + + return output; +} + +void relu(torch::Tensor& out, const torch::Tensor& input) { + TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.scalar_type() == torch::kFloat || input.scalar_type() == torch::kHalf, + "Unsupported data type: ", input.scalar_type()); + + TORCH_CHECK(input.sizes() == out.sizes(), + "Tensors must have the same shape. Got input shape: ", + input.sizes(), " and output shape: ", out.sizes()); + + TORCH_CHECK(input.scalar_type() == out.scalar_type(), + "Tensors must have the same data type. Got input dtype: ", + input.scalar_type(), " and output dtype: ", out.scalar_type()); + + TORCH_CHECK(input.device() == out.device(), + "Tensors must be on the same device. Got input device: ", + input.device(), " and output device: ", out.device()); + + dispatchReluKernel(input, out); +} diff --git a/examples/relu/relu_metal_cpp/relu_cpp.metal b/examples/relu/relu_metal_cpp/relu_cpp.metal new file mode 100644 index 00000000..969ec170 --- /dev/null +++ b/examples/relu/relu_metal_cpp/relu_cpp.metal @@ -0,0 +1,17 @@ +#include +#include "common.h" +using namespace metal; + +kernel void relu_forward_kernel_float(device const float *inA [[buffer(0)]], + device float *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(RELU_THRESHOLD, inA[index]); +} + +kernel void relu_forward_kernel_half(device const half *inA [[buffer(0)]], + device half *outC [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + // Explicitly write to output + outC[index] = max(static_cast(0.0), inA[index]); +} diff --git a/flake.lock b/flake.lock index e7602650..9e99b726 100644 --- a/flake.lock +++ b/flake.lock @@ -17,11 +17,11 @@ }, "flake-compat_2": { "locked": { - "lastModified": 1747046372, - "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=", + "lastModified": 1761588595, + "narHash": "sha256-XKUZz9zewJNUj46b4AJdiRZJAvSZ0Dqj2BNfXvFlJC4=", "owner": "edolstra", "repo": "flake-compat", - "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885", + "rev": "f387cd2afec9419c8ee37694406ca490c3f34ee5", "type": "github" }, "original": { @@ -73,26 +73,27 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1762268370, - "narHash": "sha256-gf3TJcaiHdw3dvLL7RF6hc/5BLzQDQj5oakFrKZkOZo=", + "lastModified": 1762375532, + "narHash": "sha256-6JLEGsTvjrb9ZV2N7SFO1lNEmgx/bWFylDneyw81kb4=", "owner": "huggingface", "repo": "hf-nix", - "rev": "25c23c765a907d1a5528c5ce65c58a73e974603f", + "rev": "71b94809ae4a2bdd5d847a5873c511e2fe2c2a95", "type": "github" }, "original": { "owner": "huggingface", + "ref": "add-metal-cpp-package", "repo": "hf-nix", "type": "github" } }, "nixpkgs": { "locked": { - "lastModified": 1755963616, - "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=", + "lastModified": 1762328495, + "narHash": "sha256-IUZvw5kvLiExApP9+SK/styzEKSqfe0NPclu9/z85OQ=", "owner": "nixos", "repo": "nixpkgs", - "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4", + "rev": "4c621660e393922cf68cdbfc40eb5a2d54d3989a", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 081180f2..63caacef 100644 --- a/flake.nix +++ b/flake.nix @@ -5,7 +5,7 @@ flake-utils.url = "github:numtide/flake-utils"; nixpkgs.follows = "hf-nix/nixpkgs"; flake-compat.url = "github:edolstra/flake-compat"; - hf-nix.url = "github:huggingface/hf-nix"; + hf-nix.url = "github:huggingface/hf-nix/add-metal-cpp-package"; }; outputs = diff --git a/lib/deps.nix b/lib/deps.nix index 9e8c6c81..2fb4e4fa 100644 --- a/lib/deps.nix +++ b/lib/deps.nix @@ -33,6 +33,9 @@ let #torch.cxxdev ]; "cutlass_sycl" = [ torch.xpuPackages.cutlass-sycl ]; + "metal-cpp" = lib.optionals pkgs.stdenv.hostPlatform.isDarwin [ + pkgs.metal-cpp.dev + ]; }; in let diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index 0d5468c6..ea38e6d5 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -25,6 +25,7 @@ # Build inputs apple-sdk_15, + metal-cpp, clr, oneapi-torch-dev, onednn-xpu, @@ -66,9 +67,33 @@ let # On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders. # It's not supported by the nixpkgs shim. xcrunHost = writeScriptBin "xcrunHost" '' - # Use system SDK for Metal files. - unset DEVELOPER_DIR - /usr/bin/xcrun $@ + # When called with '-sdk macosx metal/metallib', call the tool directly to avoid SDK issues + # Check for metallib first as it's more specific + if [[ "$*" =~ "metallib" ]]; then + # Find the metallib linker (air-lld) from the Metal toolchain + METALLIB_BIN=$(ls /var/run/com.apple.security.cryptexd/mnt/com.apple.MobileAsset.MetalToolchain*/Metal.xctoolchain/usr/bin/air-lld 2>/dev/null | head -n 1) + if [ -z "$METALLIB_BIN" ]; then + echo "Error: metallib (air-lld) not found" >&2 + exit 1 + fi + # Remove only '-sdk macosx metallib' as command arguments + ARGS=$(echo "$@" | sed 's/-sdk macosx metallib //') + $METALLIB_BIN $ARGS + elif [[ "$*" =~ "metal" ]]; then + # Find the metal compiler from the Metal toolchain + METAL_BIN=$(ls /var/run/com.apple.security.cryptexd/mnt/com.apple.MobileAsset.MetalToolchain*/Metal.xctoolchain/usr/bin/metal 2>/dev/null | head -n 1) + if [ -z "$METAL_BIN" ]; then + echo "Error: Metal compiler not found" >&2 + exit 1 + fi + # Remove only '-sdk macosx metal' as command arguments + ARGS=$(echo "$@" | sed 's/-sdk macosx metal //') + $METAL_BIN $ARGS + else + # For other commands, use system SDK + unset DEVELOPER_DIR + /usr/bin/xcrun $@ + fi ''; metalSupport = buildConfig.metal or false; From ebcf7b11b4bfa0bdcb4cd8dd3f48160275c123e5 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Nov 2025 13:32:05 -0500 Subject: [PATCH 2/2] fix: adjust cpp driver to correctly process data --- .../relu/relu_metal_cpp/metallib_loader.mm | 17 ++++++++++ examples/relu/relu_metal_cpp/relu.cpp | 33 ++++++++++--------- lib/torch-extension/arch.nix | 25 +++++++++----- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/examples/relu/relu_metal_cpp/metallib_loader.mm b/examples/relu/relu_metal_cpp/metallib_loader.mm index be5bc836..050ee791 100644 --- a/examples/relu/relu_metal_cpp/metallib_loader.mm +++ b/examples/relu/relu_metal_cpp/metallib_loader.mm @@ -1,5 +1,7 @@ #import #import +#include +#include #ifdef EMBEDDED_METALLIB_HEADER #include EMBEDDED_METALLIB_HEADER @@ -19,6 +21,21 @@ *errorMsg = strdup([error.localizedDescription UTF8String]); } + // Manually retain since we're not using ARC + // The caller will wrap in NS::TransferPtr which assumes ownership + if (library) { + [library retain]; + } return (__bridge void*)library; } + + // Get PyTorch's MPS device (returns id as void*) + void* getMPSDevice() { + return (__bridge void*)at::mps::MPSDevice::getInstance()->device(); + } + + // Get PyTorch's current MPS command queue (returns id as void*) + void* getMPSCommandQueue() { + return (__bridge void*)at::mps::getCurrentMPSStream()->commandQueue(); + } } diff --git a/examples/relu/relu_metal_cpp/relu.cpp b/examples/relu/relu_metal_cpp/relu.cpp index 68acfc89..c07ad544 100644 --- a/examples/relu/relu_metal_cpp/relu.cpp +++ b/examples/relu/relu_metal_cpp/relu.cpp @@ -10,6 +10,8 @@ // C interface from metallib_loader.mm extern "C" void* loadEmbeddedMetalLibrary(void* device, const char** errorMsg); +extern "C" void* getMPSDevice(); +extern "C" void* getMPSCommandQueue(); namespace { @@ -37,16 +39,16 @@ MTL::Library* loadLibrary(MTL::Device* device) { } // namespace -torch::Tensor& dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { - NS::SharedPtr pool = NS::TransferPtr(NS::AutoreleasePool::alloc()->init()); +void dispatchReluKernel(const torch::Tensor& input, torch::Tensor& output) { + // Use PyTorch's MPS device and command queue (these are borrowed references, not owned) + MTL::Device* device = reinterpret_cast(getMPSDevice()); + TORCH_CHECK(device != nullptr, "Failed to get MPS device"); - NS::SharedPtr device = NS::TransferPtr(MTL::CreateSystemDefaultDevice()); - TORCH_CHECK(device.get() != nullptr, "Failed to create Metal device"); + MTL::CommandQueue* commandQueue = reinterpret_cast(getMPSCommandQueue()); + TORCH_CHECK(commandQueue != nullptr, "Failed to get MPS command queue"); - NS::SharedPtr commandQueue = NS::TransferPtr(device->newCommandQueue()); - TORCH_CHECK(commandQueue.get() != nullptr, "Failed to create Metal command queue"); - - NS::SharedPtr library = NS::TransferPtr(loadLibrary(device.get())); + MTL::Library* libraryPtr = reinterpret_cast(loadLibrary(device)); + NS::SharedPtr library = NS::TransferPtr(libraryPtr); const std::string kernelName = std::string("relu_forward_kernel_") + (input.scalar_type() == torch::kFloat ? "float" : "half"); @@ -63,17 +65,19 @@ torch::Tensor& dispatchReluKernel(const torch::Tensor& input, torch::Tensor& out "Failed to create compute pipeline state: ", pipelineError ? pipelineError->localizedDescription()->utf8String() : "Unknown error"); - NS::SharedPtr commandBuffer = NS::TransferPtr(commandQueue->commandBuffer()); - TORCH_CHECK(commandBuffer.get() != nullptr, "Failed to create Metal command buffer"); + // Don't use SharedPtr for command buffer/encoder - they're managed by PyTorch's command queue + MTL::CommandBuffer* commandBuffer = commandQueue->commandBuffer(); + TORCH_CHECK(commandBuffer != nullptr, "Failed to create Metal command buffer"); - NS::SharedPtr encoder = - NS::TransferPtr(commandBuffer->computeCommandEncoder()); - TORCH_CHECK(encoder.get() != nullptr, "Failed to create compute command encoder"); + MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder(); + TORCH_CHECK(encoder != nullptr, "Failed to create compute command encoder"); encoder->setComputePipelineState(pipelineState.get()); auto* inputBuffer = getMTLBuffer(input); auto* outputBuffer = getMTLBuffer(output); + TORCH_CHECK(inputBuffer != nullptr, "Input buffer is null"); + TORCH_CHECK(outputBuffer != nullptr, "Output buffer is null"); encoder->setBuffer(inputBuffer, input.storage_offset() * input.element_size(), 0); encoder->setBuffer(outputBuffer, output.storage_offset() * output.element_size(), 1); @@ -91,9 +95,6 @@ torch::Tensor& dispatchReluKernel(const torch::Tensor& input, torch::Tensor& out encoder->endEncoding(); commandBuffer->commit(); - commandBuffer->waitUntilCompleted(); - - return output; } void relu(torch::Tensor& out, const torch::Tensor& input) { diff --git a/lib/torch-extension/arch.nix b/lib/torch-extension/arch.nix index ea38e6d5..a1c81509 100644 --- a/lib/torch-extension/arch.nix +++ b/lib/torch-extension/arch.nix @@ -67,30 +67,37 @@ let # On Darwin, we need the host's xcrun for `xcrun metal` to compile Metal shaders. # It's not supported by the nixpkgs shim. xcrunHost = writeScriptBin "xcrunHost" '' - # When called with '-sdk macosx metal/metallib', call the tool directly to avoid SDK issues - # Check for metallib first as it's more specific + echo "Calling command: $*" + + # Check if we are invoking metallib or metal if [[ "$*" =~ "metallib" ]]; then - # Find the metallib linker (air-lld) from the Metal toolchain + + # If metallib is requested, find the air-lld from the Metal toolchain METALLIB_BIN=$(ls /var/run/com.apple.security.cryptexd/mnt/com.apple.MobileAsset.MetalToolchain*/Metal.xctoolchain/usr/bin/air-lld 2>/dev/null | head -n 1) if [ -z "$METALLIB_BIN" ]; then echo "Error: metallib (air-lld) not found" >&2 exit 1 fi - # Remove only '-sdk macosx metallib' as command arguments - ARGS=$(echo "$@" | sed 's/-sdk macosx metallib //') - $METALLIB_BIN $ARGS + + # Remove the '-sdk macosx metallib' and other unsupported flags from the command arguments + ARGS=$(echo "$@" | sed 's/-sdk macosx metallib //' | sed 's/-mmacosx-version-min=[^ ]* //') + # Add platform version for macOS 15+ to support Metal 3.2 / AIR 2.7 + $METALLIB_BIN -platform_version macos 15.0 15.0 $ARGS + elif [[ "$*" =~ "metal" ]]; then - # Find the metal compiler from the Metal toolchain + + # If metal is requested, find the metal compiler from the Metal toolchain METAL_BIN=$(ls /var/run/com.apple.security.cryptexd/mnt/com.apple.MobileAsset.MetalToolchain*/Metal.xctoolchain/usr/bin/metal 2>/dev/null | head -n 1) if [ -z "$METAL_BIN" ]; then echo "Error: Metal compiler not found" >&2 exit 1 fi - # Remove only '-sdk macosx metal' as command arguments + + # Remove the '-sdk macosx metal' from the command arguments ARGS=$(echo "$@" | sed 's/-sdk macosx metal //') $METAL_BIN $ARGS else - # For other commands, use system SDK + # In all other cases, just use the host xcrun unset DEVELOPER_DIR /usr/bin/xcrun $@ fi