-
Notifications
You must be signed in to change notification settings - Fork 290
[MLX backend] Support for MLX backend across layers
tests
#2337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mlx
Are you sure you want to change the base?
Conversation
* Support None for max_shard_size * Add unit test
* fix-sharded-weights-typeerror * Add unit test case * Moved common initialization code to setUp for cleaner and Updated all three relevant test cases to use the shared setup.
The inputs to `generate` are `"prompts"`, not `"text"`. Fixes keras-team#1685
* routine HF sync * code reformat
Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.6.0+cu126 to 2.7.0+cu126 Updates `torchvision` from 0.21.0+cu126 to 0.22.0+cu126 --- updated-dependencies: - dependency-name: torch dependency-version: 2.7.0+cu126 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python - dependency-name: torchvision dependency-version: 0.22.0+cu126 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* Modify TransformerEncoder masking documentation * Added space before parenthesis
* Fix Mistral conversion script This commit addresses several issues in the Mistral checkpoint conversion script: - Adds `dropout` to the model initialization to match the Hugging Face model. - Replaces `requests.get` with `hf_hub_download` for more reliable tokenizer downloads. - Adds support for both `tokenizer.model` and `tokenizer.json` to handle different Mistral versions. - Fixes a `TypeError` in the `save_to_preset` function call. * address format issues * adopted to latest hub style * address format issues --------- Co-authored-by: laxmareddyp <laxmareddyp@laxma-n2-highmem-256gbram.us-central1-f.c.gtech-rmi-dev.internal>
Updates the requirements on [tensorflow-cpu](https://github.com/tensorflow/tensorflow), [tensorflow](https://github.com/tensorflow/tensorflow), [tensorflow-text](https://github.com/tensorflow/text), torch, torchvision and [tensorflow[and-cuda]](https://github.com/tensorflow/tensorflow) to permit the latest version. Updates `tensorflow-cpu` to 2.19.0 - [Release notes](https://github.com/tensorflow/tensorflow/releases) - [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md) - [Commits](tensorflow/tensorflow@v2.18.1...v2.19.0) Updates `tensorflow` to 2.19.0 - [Release notes](https://github.com/tensorflow/tensorflow/releases) - [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md) - [Commits](tensorflow/tensorflow@v2.18.1...v2.19.0) Updates `tensorflow-text` to 2.19.0 - [Release notes](https://github.com/tensorflow/text/releases) - [Commits](tensorflow/text@v2.18.0...v2.19.0) Updates `torch` from 2.7.0+cu126 to 2.7.1+cu126 Updates `torchvision` from 0.22.0+cu126 to 0.22.1+cu126 Updates `tensorflow[and-cuda]` to 2.19.0 - [Release notes](https://github.com/tensorflow/tensorflow/releases) - [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md) - [Commits](tensorflow/tensorflow@v2.18.0...v2.19.0) --- updated-dependencies: - dependency-name: tensorflow-cpu dependency-version: 2.19.0 dependency-type: direct:production dependency-group: python - dependency-name: tensorflow dependency-version: 2.19.0 dependency-type: direct:production dependency-group: python - dependency-name: tensorflow-text dependency-version: 2.19.0 dependency-type: direct:production dependency-group: python - dependency-name: torch dependency-version: 2.7.1+cu126 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python - dependency-name: torchvision dependency-version: 0.22.1+cu126 dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python - dependency-name: tensorflow[and-cuda] dependency-version: 2.19.0 dependency-type: direct:production dependency-group: python ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* init * update * bug fixes * add qwen causal lm test * fix qwen3 tests
* support flash-attn at torch backend * fix * fix * fix * fix conflit * fix conflit * fix conflit * fix conflit * fix conflit * fix conflit * format
* init: Add initial project structure and files * bug: Small bug related to weight loading in the conversion script * finalizing: Add TIMM preprocessing layer * incorporate reviews: Consolidate stage configurations and improve API consistency * bug: Unexpected argument error in JAX with Keras 3.5 * small addition for the D-FINE to come: No changes to the existing HGNetV2 * D-FINE JIT compile: Remove non-essential conditional statement * refactor: Address reviews and fix some nits
* Register qwen3 presets * fix format
* Safetensors conversion * Reformatted * corrected and formatted into a util file * test cases wip * unit tests for safetensors conversion * rename vocab.spm * reformatted * address comments * minor changes * backend agnostic * address comments * convert_to_safetensor * Compatible with all backends * Cosmetic changes * Cosmetic changes (1) * Cosmetic changes (2) * Cosmetic changes (3) * Address comments --------- Co-authored-by: Abheesht Sharma <[email protected]>
* Add DINOV2. * Replace SigLIP with DINOV2. * Replace `mlp_ratio` with `intermediate_dim`. Fix the interpolation of position embeddings. Refine the docstrings.
* Update SD3 scheduler and the dtype of the text encoders. * Fix the test. * Fix torch float16 issues and jax take issue. Add numeric checks for SD3 scheduler and text encoders. * Fix CLIP test. * Refactor CLIP models. * Update CLIP conversion script. * Update `from_config`. * Fix tests.
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @acsweet, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request is a foundational step towards integrating the MLX backend into Keras Hub, ensuring that existing layers
tests are compatible with MLX's current capabilities and limitations. Beyond backend support, it introduces several new cutting-edge model architectures (DINOV2, HGNetV2, Qwen3) and refactors existing model components for better structure. It also includes significant improvements to the model preset management and Hugging Face conversion tools, streamlining the process of working with diverse models.
Highlights
- MLX Backend Compatibility: Implemented conditional logic in
layers
tests to accommodate MLX backend's current limitations, such as skipping quantization tests, handling masking differently, and adjusting data types for operations like matrix multiplication. This ensures existing tests can run or be appropriately skipped on the MLX backend. - New Model Architectures: Introduced comprehensive support for DINOV2, HGNetV2, and Qwen3 models. This includes adding their respective backbones, image converters, preprocessors, and a suite of specialized layers, significantly expanding the available models in Keras Hub.
- CLIP Model Refactoring: Restructured CLIP model layers by moving common components like
CLIPVisionEmbedding
,CLIPEncoderLayer
,CLIPVisionPooler
,CLIPTextPooler
, andCLIPHead
into a new, dedicatedclip_layers.py
file. This improves modularity and organization within the CLIP model directory. - Enhanced Preset Management & Conversion Tools: Improved the
preset_utils
to support sharded weights and more flexible loading/saving configurations. Additionally, new utilities were added for exporting Keras models to Hugging Face'ssafetensors
format, and existing checkpoint conversion scripts for various models (DINOV2, HGNetV2, Mistral, Qwen3, Stable Diffusion 3) were updated. - Dependency Updates: Updated core dependencies including TensorFlow, JAX, Torch, and torchvision to newer versions. Crucially, new dependencies for the MLX backend (
pybind11
,cmake
,mlx
) andtransformers
were added to therequirements.txt
files.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces MLX backend support, adds several new models (DINOV2, HGNetV2, Qwen3), refactors the CLIP model, and adds a new feature to export models to Hugging Face format. The changes are extensive but well-structured, significantly expanding the capabilities of Keras Hub. My review focuses on ensuring the correctness and maintainability of the new and refactored code. I've identified a few areas for improvement, mainly related to code duplication, consistency, and clarity.
if keras.config.backend() == "mlx": | ||
backend.set_keras_mask(decoder_sequence, mask) | ||
outputs = decoder(decoder_sequence, encoder_sequence) | ||
self.assertAllEqual(backend.get_keras_mask(outputs), mask) | ||
else: | ||
decoder_sequence._keras_mask = mask | ||
outputs = decoder(decoder_sequence, encoder_sequence) | ||
self.assertAllEqual(outputs._keras_mask, mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.to_lower: | ||
x = tf.strings.lower(x) | ||
prompts = tf.strings.lower(prompts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropout_rate=0.0, | ||
drop_path_rate=0.0, | ||
image_shape=(224, 224, 3), | ||
position_embedding_shape=(518, 518, 3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for position_embedding_shape
is (518, 518, 3)
, which includes a channel dimension, but the implementation only uses the first two dimensions. Update the default value to (518, 518)
to match the documentation and usage.
position_embedding_shape=(518, 518, 3), | |
position_embedding_shape=(518, 518), |
class RMSNormalization(layers.Layer): | ||
"""A normalization layer for MMDiT that implements RMS normalization.""" | ||
|
||
def __init__(self, axis=-1, epsilon=1e-6, **kwargs): | ||
super().__init__(**kwargs) | ||
self.axis = axis | ||
self.epsilon = epsilon | ||
|
||
def build(self, input_shape): | ||
if isinstance(self.axis, list): | ||
shape = tuple([input_shape[dim] for dim in self.axis]) | ||
else: | ||
shape = (input_shape[self.axis],) | ||
self.axis = [self.axis] | ||
|
||
self.scale = self.add_weight( | ||
name="scale", shape=shape, initializer="ones" | ||
) | ||
|
||
self.built = True | ||
|
||
def call(self, x): | ||
x = ops.cast( | ||
x, keras.backend.result_type(self.compute_dtype, "float32") | ||
) | ||
rrms = ops.rsqrt( | ||
ops.mean(ops.square(x), axis=self.axis, keepdims=True) | ||
+ self.epsilon | ||
) | ||
return (x * rrms) * ops.cast(self.scale, x.dtype) | ||
|
||
def compute_output_shape(self, input_shape): | ||
if isinstance(self.axis, int): | ||
axes = [self.axis] | ||
else: | ||
axes = self.axis | ||
|
||
for axis in axes: | ||
if axis >= len(input_shape) or axis < -len(input_shape): | ||
raise ValueError( | ||
f"Axis {axis} is out of bounds for " | ||
f"input shape {input_shape}. " | ||
f"Received: axis={self.axis}" | ||
) | ||
return input_shape | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update({"axis": self.axis, "epsilon": self.epsilon}) | ||
return config | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of RMSNormalization
has some repeated logic for handling self.axis
in build()
and compute_output_shape()
. Also, self.axis
is modified in build()
, which can be a confusing side effect. Consider normalizing self.axis
to a list in __init__
to simplify the other methods and make the layer's behavior more predictable.
# mlx array to tf tensor currently only supports flat arrays | ||
array_shape = inputs.shape | ||
inputs = inputs.flatten() | ||
inputs = tf.convert_to_tensor(memoryview(inputs)) | ||
inputs = tf.reshape(inputs, array_shape) | ||
else: | ||
inputs = tf.convert_to_tensor(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description of the change
This PR is the first (of more to come) to add support for the MLX backend of Keras to work with Keras Hub.
The focus of this PR is tests in
keras_hub/src/layers
, with the following notes:_keras_mask
attribute, as it cannot be added to MLXarray
objectsAny feedback on changes and implementations here would be appreciated. Will modify as needed!
Note that I have not included MLX in the GitHub CI yet. There are still a few issues failing tests on the Linux environment, so might need to wait until MLX is available via Keras nightly, also open to any other suggestions.
Checklist