Skip to content

Commit dcf296a

Browse files
ChinChangYangclaude
andcommitted
Add MLX backend for Apple Silicon
Introduces a new neural-net backend (USE_BACKEND=MLX) targeting Apple Silicon via Apple's MLX framework. The backend implements the full nninterface contract (model load, batched evaluation, FP16/FP32 paths) and ships with a Winograd 3x3 convolution path plus an adaptive per-shape tuner that picks the fastest implementation for each conv-3x3 shape at model load. Backend - cpp/neuralnet/mlxbackend.cpp: backend implementation. Supports variable board sizes 7-19 via masking, FP16/FP32 with the mlxUseFP16 config (default auto -> fp16), and the same input feature layout as the other backends. Mish activation runs as FP16-safe (asserts on ACTIVATION_MISH_SCALE8 so out-of-range variants are caught explicitly rather than silently truncated). - cpp/neuralnet/mlxwinograd.h: F(4x4, 3x3) Winograd transform with fused activation + residual add. - cpp/neuralnet/mlxwinotuner.{cpp,h}: per-shape Winograd tuner with adaptive scoring (rotates the candidate set per shape, scores by median-time delta against a baked-default baseline). Logs the conv-3x3 shape distribution at model load. - cpp/neuralnet/mlxtests.cpp: unit tests for the Winograd path and tuner numeric-consistency, gated under runnnlayertests. Build / wiring - cpp/CMakeLists.txt: USE_BACKEND=MLX target. MLX requires CMake 3.27 (cmake_minimum_required stays at 3.18.2 so other backends keep building on older CMake). Links Homebrew's prebuilt libmlx.dylib; OSX deployment target intentionally not pinned so the executable's minos matches the dylib it was linked against. - cpp/main.cpp, cpp/program/setup.cpp, cpp/command/benchmark.cpp: wire MLX into backend selection / benchmark. - cpp/configs/{gtp,analysis,match,contribute}_example.cfg: document mlxUseFP16 (auto / true / false), default auto -> fp16. - Compiling.md: build instructions for the MLX backend. Validation - Cross-backend validation against an Eigen reference (testgpuerror) for b18c384nbt, b40v8, and humanv0 nets shows FP32 max winrate error 0.00095% and FP16 max 2.63%, well within the existing backend tolerances. This is the squash of 130 commits from feature/mlx-backend. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a1d832c commit dcf296a

14 files changed

Lines changed: 4785 additions & 3 deletions

Compiling.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,14 +133,15 @@ As also mentioned in the instructions below but repeated here for visibility, if
133133
* AppleClang and Swift compilers: `xcode-select --install`.
134134
* If using the Metal backend, [Ninja](https://ninja-build.org): `brew install ninja`
135135
* If using the Metal backend, protobuf and abseil: `brew install protobuf abseil`
136+
* If using the MLX backend (Apple Silicon only): `brew install mlx` (≥0.18). Requires CMake ≥3.27. KataGo finds MLX via CMake's default search (Homebrew installs it at `/opt/homebrew/share/cmake/MLX/`); override with `-DMLX_ROOT=/path/to/mlx/cmake` if needed.
136137
* libzip: `brew install libzip`.
137138
* If you want to do self-play training and research, probably Google perftools `brew install gperftools` for TCMalloc or some other better malloc implementation. For unknown reasons, the allocation pattern in self-play with large numbers of threads and parallel games causes a lot of memory fragmentation under glibc malloc that will eventually run your machine out of memory, but better mallocs handle it fine.
138139
* If compiling to contribute to public distributed training runs, OpenSSL is required (`brew install openssl`).
139140
* Clone this repo:
140141
* `git clone https://github.com/lightvector/KataGo.git`
141142
* Compile using CMake and make in the cpp directory:
142143
* `cd KataGo/cpp`
143-
* `cmake . -G Ninja -DUSE_BACKEND=METAL` or `cmake . -DUSE_BACKEND=OPENCL` or `cmake . -DUSE_BACKEND=EIGEN` depending on which backend you want.
144+
* `cmake . -G Ninja -DUSE_BACKEND=METAL` or `cmake . -DUSE_BACKEND=MLX` or `cmake . -DUSE_BACKEND=OPENCL` or `cmake . -DUSE_BACKEND=EIGEN` depending on which backend you want.
144145
* Specify also `-DUSE_TCMALLOC=1` if using TCMalloc.
145146
* Compiling will also call git commands to embed the git hash into the compiled executable, specify also `-DNO_GIT_REVISION=1` to disable it if this is causing issues for you.
146147
* Specify `-DUSE_AVX2=1` to also compile Eigen with AVX2 and FMA support, which will make it incompatible with old CPUs but much faster. Intel-based Macs with new processors support AVX2, but Apple Silicon Macs do not support AVX2 natively. (If you want to go further, you can also add `-DCMAKE_CXX_FLAGS='-march=native'` which will specialize to precisely your machine's CPU, but the exe might not run on other machines at all).

cpp/CMakeLists.txt

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,23 @@
11
cmake_minimum_required(VERSION 3.18.2)
2+
3+
# Pre-project MLX setup. KataGo's MLX path enforces CMake 3.27 via the guard
4+
# below (MLX itself requires only 3.25 - 3.27 is chosen to match
5+
# cmake_policy(VERSION 3.27)); the global cmake_minimum_required stays at
6+
# 3.18.2 so non-MLX backends keep building on older CMake.
7+
#
8+
# The OSX deployment target is deliberately NOT pinned here. KataGo links
9+
# Homebrew's prebuilt libmlx.dylib, whose minos reflects the macOS it was
10+
# bottled on - that dylib, not this build, sets the real minimum macOS.
11+
# Pinning a lower value only stamps a misleading minos on the executable and
12+
# triggers a "linking with dylib built for newer version" linker warning;
13+
# letting CMake default the target to the build host keeps minos honest.
14+
if(USE_BACKEND STREQUAL "MLX")
15+
if(CMAKE_VERSION VERSION_LESS 3.27)
16+
message(FATAL_ERROR "KataGo's USE_BACKEND=MLX path requires CMake 3.27 or newer. You have ${CMAKE_VERSION}. Install via: brew install cmake")
17+
endif()
18+
cmake_policy(VERSION 3.27)
19+
endif()
20+
221
if(USE_BACKEND STREQUAL "METAL")
322
project(katago LANGUAGES CXX Swift)
423
else()
@@ -44,7 +63,7 @@ endif()
4463
set(BUILD_DISTRIBUTED 0 CACHE BOOL "Build with http support for contributing to distributed training")
4564
set(USE_BACKEND CACHE STRING "Neural net backend")
4665
string(TOUPPER "${USE_BACKEND}" USE_BACKEND)
47-
set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN METAL)
66+
set_property(CACHE USE_BACKEND PROPERTY STRINGS "" CUDA TENSORRT OPENCL EIGEN MLX METAL)
4867

4968
set(USE_TCMALLOC 0 CACHE BOOL "Use TCMalloc")
5069
set(NO_GIT_REVISION 0 CACHE BOOL "Disable embedding the git revision into the compiled exe")
@@ -158,8 +177,35 @@ elseif(USE_BACKEND STREQUAL "EIGEN")
158177
set(NEURALNET_BACKEND_SOURCES
159178
neuralnet/eigenbackend.cpp
160179
)
180+
elseif(USE_BACKEND STREQUAL "MLX")
181+
message(STATUS "-DUSE_BACKEND=MLX, using MLX backend for Apple Silicon.")
182+
183+
if(NOT APPLE)
184+
message(FATAL_ERROR "USE_BACKEND=MLX is only supported on macOS. Detected: ${CMAKE_SYSTEM_NAME}")
185+
endif()
186+
if(CMAKE_OSX_ARCHITECTURES)
187+
if(NOT CMAKE_OSX_ARCHITECTURES STREQUAL "arm64")
188+
message(FATAL_ERROR "USE_BACKEND=MLX requires arm64. Got: ${CMAKE_OSX_ARCHITECTURES}")
189+
endif()
190+
elseif(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
191+
message(FATAL_ERROR "USE_BACKEND=MLX requires Apple Silicon (arm64). Detected: ${CMAKE_SYSTEM_PROCESSOR}")
192+
endif()
193+
194+
set(MLX_MIN_VERSION "0.18")
195+
set(MLX_ROOT "" CACHE PATH "Optional path to MLX's CMake package; leave empty to use CMake's default search (e.g. Homebrew's /opt/homebrew/share/cmake/MLX/)")
196+
197+
# Homebrew installs MLX's CMake config to /opt/homebrew/share/cmake/MLX/, which is
198+
# on CMake's default search path. MLX_ROOT, when set, is added as an extra hint.
199+
find_package(MLX ${MLX_MIN_VERSION} CONFIG REQUIRED HINTS "${MLX_ROOT}")
200+
message(STATUS "Found MLX ${MLX_VERSION} at ${MLX_LIBRARY}")
201+
202+
set(NEURALNET_BACKEND_SOURCES
203+
neuralnet/mlxbackend.cpp
204+
neuralnet/mlxwinotuner.cpp
205+
neuralnet/mlxtests.cpp
206+
)
161207
elseif(USE_BACKEND STREQUAL "")
162-
message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN to compile with the respective backend.${ColorReset}")
208+
message(WARNING "${ColorBoldRed}WARNING: Using dummy neural net backend, intended for non-neural-net testing only, will fail on any code path requiring a neural net. To use neural net, specify -DUSE_BACKEND=CUDA or -DUSE_BACKEND=TENSORRT or -DUSE_BACKEND=OPENCL or -DUSE_BACKEND=EIGEN or -DUSE_BACKEND=MLX or -DUSE_BACKEND=METAL to compile with the respective backend.${ColorReset}")
163209
set(NEURALNET_BACKEND_SOURCES neuralnet/dummybackend.cpp)
164210
else()
165211
message(FATAL_ERROR "Unrecognized backend: " ${USE_BACKEND})
@@ -496,6 +542,9 @@ elseif(USE_BACKEND STREQUAL "EIGEN")
496542
message(STATUS "Found Eigen3 at ${EIGEN3_INCLUDE_DIRS}")
497543
endif()
498544
endif()
545+
elseif(USE_BACKEND STREQUAL "MLX")
546+
target_compile_definitions(katago PRIVATE USE_MLX_BACKEND)
547+
target_link_libraries(katago mlx)
499548
endif()
500549

501550
if(USE_BIGGER_BOARDS_EXPENSIVE)

cpp/command/benchmark.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,9 @@ int MainCmds::benchmark(const vector<string>& args) {
267267
#endif
268268
#ifdef USE_EIGEN_BACKEND
269269
cout << "You are currently using the Eigen (CPU) version of KataGo. Due to having no GPU, it may be slow." << endl;
270+
#endif
271+
#ifdef USE_MLX_BACKEND
272+
cout << "Your GTP config is currently set to mlxUseFP16 = " << nnEval->getUsingFP16Mode().toString() << endl;
270273
#endif
271274
cout << endl;
272275
cout << "Your GTP config is currently set to use numSearchThreads = " << params.numThreads << endl;

cpp/configs/analysis_example.cfg

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,18 @@ nnRandomize = true
298298
# It defaults to min(numAnalysisThreads * numSearchThreadsPerAnalysisThread, numCPUCores).
299299
# numEigenThreadsPerModel = X
300300

301+
# ------------------------------
302+
# MLX-specific settings
303+
# ------------------------------
304+
# These only apply when using the MLX backend (Apple Silicon).
305+
306+
# Whether to use FP16 (half precision) for neural net evaluation on MLX.
307+
# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path.
308+
# Set `false` for bit-exact FP32 reproducibility.
309+
#
310+
# Default: auto (resolves to fp16 on MLX).
311+
# mlxUseFP16 = auto
312+
301313

302314
# Misc Behavior --------------------
303315

cpp/configs/contribute_example.cfg

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,15 @@ watchOngoingGameInFileName = watchgame.txt
139139
# This is the number of CPU threads for evaluating the neural net on the Eigen backend.
140140
# It defaults to numSearchThreads.
141141
# numEigenThreadsPerModel = X
142+
143+
# ------------------------------
144+
# MLX-specific settings
145+
# ------------------------------
146+
# These only apply when using the MLX backend (Apple Silicon).
147+
148+
# Whether to use FP16 (half precision) for neural net evaluation on MLX.
149+
# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path.
150+
# Set `false` for bit-exact FP32 reproducibility.
151+
#
152+
# Default: auto (resolves to fp16 on MLX).
153+
# mlxUseFP16 = auto

cpp/configs/gtp_example.cfg

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,18 @@ searchFactorWhenWinningThreshold = 0.95
539539
# Default: numSearchThreads
540540
# numEigenThreadsPerModel = X
541541

542+
# ------------------------------
543+
# MLX-specific settings
544+
# ------------------------------
545+
# These only apply when using the MLX backend (Apple Silicon).
546+
547+
# Whether to use FP16 (half precision) for neural net evaluation on MLX.
548+
# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path.
549+
# Set `false` for bit-exact FP32 reproducibility.
550+
#
551+
# Default: auto (resolves to fp16 on MLX).
552+
# mlxUseFP16 = auto
553+
542554
# ===========================================================================
543555
# Root move selection and biases
544556
# ===========================================================================

cpp/configs/match_example.cfg

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,18 @@ numNNServerThreadsPerModel = 1
197197
# It defaults to numSearchThreads.
198198
# numEigenThreadsPerModel = X
199199

200+
# ------------------------------
201+
# MLX-specific settings
202+
# ------------------------------
203+
# These only apply when using the MLX backend (Apple Silicon).
204+
205+
# Whether to use FP16 (half precision) for neural net evaluation on MLX.
206+
# FP16 is faster than FP32 on Apple Silicon via the MLX Winograd path.
207+
# Set `false` for bit-exact FP32 reproducibility.
208+
#
209+
# Default: auto (resolves to fp16 on MLX).
210+
# mlxUseFP16 = auto
211+
200212

201213
# Root move selection and biases------------------------------------------------------------------------------
202214
# Uncomment and edit any of the below values to change them from their default.

cpp/main.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ string Version::getKataGoVersionFullInfo() {
246246
out << "Using OpenCL backend" << endl;
247247
#elif defined(USE_EIGEN_BACKEND)
248248
out << "Using Eigen(CPU) backend" << endl;
249+
#elif defined(USE_MLX_BACKEND)
250+
out << "Using MLX backend" << endl;
249251
#else
250252
out << "Using dummy backend" << endl;
251253
#endif
@@ -282,6 +284,8 @@ string Version::getGitRevisionWithBackend() {
282284
s += "-opencl";
283285
#elif defined(USE_EIGEN_BACKEND)
284286
s += "-eigen";
287+
#elif defined(USE_MLX_BACKEND)
288+
s += "-mlx";
285289
#else
286290
s += "-dummy";
287291
#endif

0 commit comments

Comments
 (0)