Skip to content

[MLX backend] Support for MLX backend across layers tests #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: mlx
Choose a base branch
from

Conversation

acsweet
Copy link

@acsweet acsweet commented Jul 19, 2025

Description of the change

This PR is the first (of more to come) to add support for the MLX backend of Keras to work with Keras Hub.
The focus of this PR is tests in keras_hub/src/layers, with the following notes:

  • Quantization is not yet implemented for the MLX backend
  • MLX does not currently support float8
  • MLX does not support matmul with integer data types
  • masking is handled separately from a _keras_mask attribute, as it cannot be added to MLX array objects

Any feedback on changes and implementations here would be appreciated. Will modify as needed!

Note that I have not included MLX in the GitHub CI yet. There are still a few issues failing tests on the Linux environment, so might need to wait until MLX is available via Keras nightly, also open to any other suggestions.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

laxmareddyp and others added 23 commits June 9, 2025 16:30
* Support None for max_shard_size

* Add unit test
* fix-sharded-weights-typeerror

* Add unit test case

* Moved common initialization code to setUp for cleaner and Updated all three relevant test cases to use the shared setup.
The inputs to `generate` are `"prompts"`, not `"text"`.

Fixes keras-team#1685
* routine HF sync

* code reformat
Bumps the python group with 2 updates: torch and torchvision.


Updates `torch` from 2.6.0+cu126 to 2.7.0+cu126

Updates `torchvision` from 0.21.0+cu126 to 0.22.0+cu126

---
updated-dependencies:
- dependency-name: torch
  dependency-version: 2.7.0+cu126
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
- dependency-name: torchvision
  dependency-version: 0.22.0+cu126
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* Modify TransformerEncoder masking documentation

* Added space before parenthesis
* Fix Mistral conversion script

This commit addresses several issues in the Mistral checkpoint conversion script:

- Adds `dropout` to the model initialization to match the Hugging Face model.
- Replaces `requests.get` with `hf_hub_download` for more reliable tokenizer downloads.
- Adds support for both `tokenizer.model` and `tokenizer.json` to handle different Mistral versions.
- Fixes a `TypeError` in the `save_to_preset` function call.

* address format issues

* adopted to latest hub style

* address format issues

---------

Co-authored-by: laxmareddyp <laxmareddyp@laxma-n2-highmem-256gbram.us-central1-f.c.gtech-rmi-dev.internal>
Updates the requirements on [tensorflow-cpu](https://github.com/tensorflow/tensorflow), [tensorflow](https://github.com/tensorflow/tensorflow), [tensorflow-text](https://github.com/tensorflow/text), torch, torchvision and [tensorflow[and-cuda]](https://github.com/tensorflow/tensorflow) to permit the latest version.

Updates `tensorflow-cpu` to 2.19.0
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](tensorflow/tensorflow@v2.18.1...v2.19.0)

Updates `tensorflow` to 2.19.0
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](tensorflow/tensorflow@v2.18.1...v2.19.0)

Updates `tensorflow-text` to 2.19.0
- [Release notes](https://github.com/tensorflow/text/releases)
- [Commits](tensorflow/text@v2.18.0...v2.19.0)

Updates `torch` from 2.7.0+cu126 to 2.7.1+cu126

Updates `torchvision` from 0.22.0+cu126 to 0.22.1+cu126

Updates `tensorflow[and-cuda]` to 2.19.0
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](tensorflow/tensorflow@v2.18.0...v2.19.0)

---
updated-dependencies:
- dependency-name: tensorflow-cpu
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
- dependency-name: tensorflow
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
- dependency-name: tensorflow-text
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
- dependency-name: torch
  dependency-version: 2.7.1+cu126
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: python
- dependency-name: torchvision
  dependency-version: 0.22.1+cu126
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: python
- dependency-name: tensorflow[and-cuda]
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* init

* update

* bug fixes

* add qwen causal lm test

* fix qwen3 tests
* support flash-attn at torch backend

* fix

* fix

* fix

* fix conflit

* fix conflit

* fix conflit

* fix conflit

* fix conflit

* fix conflit

* format
* init: Add initial project structure and files

* bug: Small bug related to weight loading in the conversion script

* finalizing: Add TIMM preprocessing layer

* incorporate reviews: Consolidate stage configurations and improve API consistency

* bug: Unexpected argument error in JAX with Keras 3.5

* small addition for the D-FINE to come: No changes to the existing HGNetV2

* D-FINE JIT compile: Remove non-essential conditional statement

* refactor: Address reviews and fix some nits
* Register qwen3 presets

* fix format
* Safetensors conversion

* Reformatted

* corrected and formatted into a util file

* test cases wip

* unit tests for safetensors conversion

* rename vocab.spm

* reformatted

* address comments

* minor changes

* backend agnostic

* address comments

* convert_to_safetensor

* Compatible with all backends

* Cosmetic changes

* Cosmetic changes (1)

* Cosmetic changes (2)

* Cosmetic changes (3)

* Address comments

---------

Co-authored-by: Abheesht Sharma <[email protected]>
* Add DINOV2.

* Replace SigLIP with DINOV2.

* Replace `mlp_ratio` with `intermediate_dim`. Fix the interpolation of position embeddings. Refine the docstrings.
* Update SD3 scheduler and the dtype of the text encoders.

* Fix the test.

* Fix torch float16 issues and jax take issue. Add numeric checks for SD3 scheduler and text encoders.

* Fix CLIP test.

* Refactor CLIP models.

* Update CLIP conversion script.

* Update `from_config`.

* Fix tests.
Copy link

google-cla bot commented Jul 19, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @acsweet, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request is a foundational step towards integrating the MLX backend into Keras Hub, ensuring that existing layers tests are compatible with MLX's current capabilities and limitations. Beyond backend support, it introduces several new cutting-edge model architectures (DINOV2, HGNetV2, Qwen3) and refactors existing model components for better structure. It also includes significant improvements to the model preset management and Hugging Face conversion tools, streamlining the process of working with diverse models.

Highlights

  • MLX Backend Compatibility: Implemented conditional logic in layers tests to accommodate MLX backend's current limitations, such as skipping quantization tests, handling masking differently, and adjusting data types for operations like matrix multiplication. This ensures existing tests can run or be appropriately skipped on the MLX backend.
  • New Model Architectures: Introduced comprehensive support for DINOV2, HGNetV2, and Qwen3 models. This includes adding their respective backbones, image converters, preprocessors, and a suite of specialized layers, significantly expanding the available models in Keras Hub.
  • CLIP Model Refactoring: Restructured CLIP model layers by moving common components like CLIPVisionEmbedding, CLIPEncoderLayer, CLIPVisionPooler, CLIPTextPooler, and CLIPHead into a new, dedicated clip_layers.py file. This improves modularity and organization within the CLIP model directory.
  • Enhanced Preset Management & Conversion Tools: Improved the preset_utils to support sharded weights and more flexible loading/saving configurations. Additionally, new utilities were added for exporting Keras models to Hugging Face's safetensors format, and existing checkpoint conversion scripts for various models (DINOV2, HGNetV2, Mistral, Qwen3, Stable Diffusion 3) were updated.
  • Dependency Updates: Updated core dependencies including TensorFlow, JAX, Torch, and torchvision to newer versions. Crucially, new dependencies for the MLX backend (pybind11, cmake, mlx) and transformers were added to the requirements.txt files.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MLX backend support, adds several new models (DINOV2, HGNetV2, Qwen3), refactors the CLIP model, and adds a new feature to export models to Hugging Face format. The changes are extensive but well-structured, significantly expanding the capabilities of Keras Hub. My review focuses on ensuring the correctness and maintainability of the new and refactored code. I've identified a few areas for improvement, mainly related to code duplication, consistency, and clarity.

Comment on lines +119 to +126
if keras.config.backend() == "mlx":
backend.set_keras_mask(decoder_sequence, mask)
outputs = decoder(decoder_sequence, encoder_sequence)
self.assertAllEqual(backend.get_keras_mask(outputs), mask)
else:
decoder_sequence._keras_mask = mask
outputs = decoder(decoder_sequence, encoder_sequence)
self.assertAllEqual(outputs._keras_mask, mask)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for testing mask propagation is nearly identical to the one in test_mask_propagation_without_cross_attention (lines 136-143). To improve maintainability and reduce redundancy, consider extracting this logic into a helper method.

if self.to_lower:
x = tf.strings.lower(x)
prompts = tf.strings.lower(prompts)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure backend-agnosticism, use keras.ops.strings.lower instead of tf.strings.lower. This ensures the preprocessor works correctly with JAX, PyTorch, and MLX backends.

Suggested change
prompts = tf.strings.lower(prompts)
prompts = keras.ops.strings.lower(prompts)

dropout_rate=0.0,
drop_path_rate=0.0,
image_shape=(224, 224, 3),
position_embedding_shape=(518, 518, 3),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for position_embedding_shape is (518, 518, 3), which includes a channel dimension, but the implementation only uses the first two dimensions. Update the default value to (518, 518) to match the documentation and usage.

Suggested change
position_embedding_shape=(518, 518, 3),
position_embedding_shape=(518, 518),

Comment on lines +19 to +69
class RMSNormalization(layers.Layer):
"""A normalization layer for MMDiT that implements RMS normalization."""

def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
super().__init__(**kwargs)
self.axis = axis
self.epsilon = epsilon

def build(self, input_shape):
if isinstance(self.axis, list):
shape = tuple([input_shape[dim] for dim in self.axis])
else:
shape = (input_shape[self.axis],)
self.axis = [self.axis]

self.scale = self.add_weight(
name="scale", shape=shape, initializer="ones"
)

self.built = True

def call(self, x):
x = ops.cast(
x, keras.backend.result_type(self.compute_dtype, "float32")
)
rrms = ops.rsqrt(
ops.mean(ops.square(x), axis=self.axis, keepdims=True)
+ self.epsilon
)
return (x * rrms) * ops.cast(self.scale, x.dtype)

def compute_output_shape(self, input_shape):
if isinstance(self.axis, int):
axes = [self.axis]
else:
axes = self.axis

for axis in axes:
if axis >= len(input_shape) or axis < -len(input_shape):
raise ValueError(
f"Axis {axis} is out of bounds for "
f"input shape {input_shape}. "
f"Received: axis={self.axis}"
)
return input_shape

def get_config(self):
config = super().get_config()
config.update({"axis": self.axis, "epsilon": self.epsilon})
return config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of RMSNormalization has some repeated logic for handling self.axis in build() and compute_output_shape(). Also, self.axis is modified in build(), which can be a confusing side effect. Consider normalizing self.axis to a list in __init__ to simplify the other methods and make the layer's behavior more predictable.

Comment on lines +249 to +255
# mlx array to tf tensor currently only supports flat arrays
array_shape = inputs.shape
inputs = inputs.flatten()
inputs = tf.convert_to_tensor(memoryview(inputs))
inputs = tf.reshape(inputs, array_shape)
else:
inputs = tf.convert_to_tensor(inputs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 249 explains that the MLX array to TF tensor conversion only supports flat arrays. Add more context, such as whether this is a known limitation in TensorFlow's MLX integration and a link to a tracking issue, if one exists.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.