From 1c94fe158b5da755c645f1d3e3cc63c95ed6a892 Mon Sep 17 00:00:00 2001 From: seme0011 Date: Sun, 19 Apr 2026 13:59:43 +0200 Subject: [PATCH 1/6] Add ONNX export support for LightOn OCR models Registers lighton_ocr as a model type and exports it as three separate ONNX files: vision_encoder (ViT + projector), embed_tokens (embedding table), and decoder_model_merged (language model with merged KV cache support). Handles weight key remapping from lighton_ocr to Mistral3 internals and works around the >2GB protobuf limit during decoder merge. --- optimum/exporters/onnx/__init__.py | 12 +- optimum/exporters/onnx/model_configs.py | 346 ++++++++++++++++++++++++ optimum/exporters/onnx/model_patcher.py | 122 +++++++++ optimum/exporters/onnx/utils.py | 48 ++++ 4 files changed, 526 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/__init__.py b/optimum/exporters/onnx/__init__.py index 7cb02989..2852acc0 100644 --- a/optimum/exporters/onnx/__init__.py +++ b/optimum/exporters/onnx/__init__.py @@ -29,7 +29,11 @@ "validate_models_outputs", "onnx_export_from_model", ], - "utils": ["MODEL_TYPES_REQUIRING_POSITION_IDS", "get_metaclip_2_models_for_export"], + "utils": [ + "MODEL_TYPES_REQUIRING_POSITION_IDS", + "get_metaclip_2_models_for_export", + "get_lighton_ocr_models_for_export", + ], "__main__": ["main_export"], } @@ -44,7 +48,11 @@ validate_model_outputs, validate_models_outputs, ) - from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS, get_metaclip_2_models_for_export + from optimum.exporters.onnx.utils import ( + MODEL_TYPES_REQUIRING_POSITION_IDS, + get_lighton_ocr_models_for_export, + get_metaclip_2_models_for_export, + ) else: import sys diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index de696396..ea60333f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -44,6 +44,7 @@ CohereModelPatcher, FluxTransformerModelPatcher, GptOssModelPatcher, + LightonOcrModelPatcher, MetaCLIP2Patcher, MgpstrModelPatcher, MoonshineModelPatcher, @@ -581,6 +582,351 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator +class LightonOcrNormalizedConfig(NormalizedTextConfig): + """Normalized config for LightOn OCR that reads from the text_config sub-config.""" + + NUM_LAYERS = "text_config.num_hidden_layers" + NUM_ATTENTION_HEADS = "text_config.num_attention_heads" + NUM_KEY_VALUE_HEADS = "text_config.num_key_value_heads" + HIDDEN_SIZE = "text_config.hidden_size" + VOCAB_SIZE = "text_config.vocab_size" + + def __init__(self, config): + super().__init__(config) + text_config = config.text_config if hasattr(config, "text_config") else config + self.hidden_size = text_config.hidden_size + self.num_attention_heads = text_config.num_attention_heads + self.num_key_value_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads) + self.num_layers = text_config.num_hidden_layers + self.vocab_size = text_config.vocab_size + self.head_dim = getattr(text_config, "head_dim", self.hidden_size // self.num_attention_heads) + if hasattr(config, "vision_config"): + self.vision_config = config.vision_config + + +def _register_lighton_ocr_in_transformers(): + from transformers import AutoConfig, AutoModelForImageTextToText, Mistral3Config, Mistral3ForConditionalGeneration + + if "lighton_ocr" in getattr(AutoConfig, "_model_mapping", {}).get("model_type", {}).__class__.__dict__.get( + "_mapping", {} + ): + return + + class LightonOcrConfig(Mistral3Config): + model_type = "lighton_ocr" + + class LightonOcrForConditionalGeneration(Mistral3ForConditionalGeneration): + config_class = LightonOcrConfig + _checkpoint_conversion_mapping = { # noqa: RUF012 + "^model.vision_encoder.": "model.vision_tower.", + "^model.vision_projection.": "model.multi_modal_projector.", + "^model.language_model.lm_head.": "lm_head.", + "^model.language_model.model.": "model.language_model.", + } + + LightonOcrConfig.__module__ = Mistral3Config.__module__ + LightonOcrForConditionalGeneration.__module__ = Mistral3ForConditionalGeneration.__module__ + + try: + AutoConfig.register("lighton_ocr", LightonOcrConfig) + AutoModelForImageTextToText.register(LightonOcrConfig, LightonOcrForConditionalGeneration) + except Exception: + pass + + +_register_lighton_ocr_in_transformers() + + +@register_tasks_manager_onnx( + "lighton_ocr", + *["image-text-to-text", "image-text-to-text-with-past"], + library_name="transformers", +) +class LightonOcrOnnxConfig(OnnxConfig): + + NORMALIZED_CONFIG_CLASS = LightonOcrNormalizedConfig + MIN_TRANSFORMERS_VERSION = version.parse("4.56.0") + SUPPORTS_PAST = True + VARIANTS = { + "split": "Vision encoder, token embeddings, and decoder are exported as separate ONNX models.", + } + DEFAULT_VARIANT = "split" + _MODEL_PATCHER = LightonOcrModelPatcher + + def __init__( + self, + config: PretrainedConfig, + task: str = "image-text-to-text", + int_dtype: str = "int64", + float_dtype: str = "fp32", + variant: str = "split", + component: str | None = None, + use_past: bool = False, + use_past_in_inputs: bool = False, + preprocessors: list[Any] | None = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + self.variant = variant + self.component = component + self.use_past = use_past + self.use_past_in_inputs = use_past_in_inputs + self.is_merged = False + self.use_cache_branch = None + + @property + def inputs(self) -> dict[str, dict[int, str]]: + if self.component == "vision_encoder": + return { + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + } + elif self.component == "embed_tokens": + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } + elif self.component == "decoder": + common_inputs = { + "inputs_embeds": {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}, + "attention_mask": {0: "batch_size", 1: "past_sequence_length + sequence_length"}, + "position_ids": {0: "batch_size", 1: "sequence_length"}, + } + if self.use_past_in_inputs: + self._add_past_key_values(common_inputs, direction="inputs") + return common_inputs + return {} + + @property + def outputs(self) -> dict[str, dict[int, str]]: + if self.component == "vision_encoder": + return { + "image_features": {0: "batch_size", 1: "num_image_tokens", 2: "hidden_size"}, + } + elif self.component == "embed_tokens": + return { + "inputs_embeds": {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}, + } + elif self.component == "decoder": + common_outputs = {"logits": {0: "batch_size", 1: "sequence_length"}} + self._add_past_key_values(common_outputs, direction="outputs") + return common_outputs + return {} + + def _add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str): + text_config = self._config.text_config if hasattr(self._config, "text_config") else self._config + num_layers = text_config.num_hidden_layers + + if direction == "inputs": + for i in range(num_layers): + inputs_or_outputs[f"past_key_values.{i}.key"] = { + 0: "batch_size", + 2: "past_sequence_length", + } + inputs_or_outputs[f"past_key_values.{i}.value"] = { + 0: "batch_size", + 2: "past_sequence_length", + } + else: + for i in range(num_layers): + inputs_or_outputs[f"present.{i}.key"] = { + 0: "batch_size", + 2: "past_sequence_length + sequence_length", + } + inputs_or_outputs[f"present.{i}.value"] = { + 0: "batch_size", + 2: "past_sequence_length + sequence_length", + } + + def flatten_past_key_values(self, flattened_output, name, idx, t): + flattened_output[f"{name}.{idx}.key"] = t[0] + flattened_output[f"{name}.{idx}.value"] = t[1] + + def flatten_output_collection_property(self, name: str, field) -> dict[str, Any]: + flattened_output = {} + if name in ["present", "past_key_values"]: + for idx, t in enumerate(field): + self.flatten_past_key_values(flattened_output, name, idx, t) + else: + flattened_output = super().flatten_output_collection_property(name, field) + return flattened_output + + def _generate_dummy_inputs_for_validation_inner( + self, reference_model_inputs: dict[str, Any], onnx_input_names: list[str] + ) -> dict[str, Any]: + if self.component != "decoder": + return reference_model_inputs + filtered = {} + for name, value in reference_model_inputs.items(): + if name == "past_key_values": + if any(n.startswith("past_key_values.") for n in onnx_input_names): + filtered[name] = value + else: + filtered[name] = value + return filtered + + def post_process_exported_models( + self, + path, + models_and_onnx_configs, + onnx_files_subpaths, + ): + from pathlib import Path as _Path + + from optimum.exporters.onnx.constants import ( + ONNX_DECODER_MERGED_NAME, + ONNX_DECODER_NAME, + ONNX_DECODER_WITH_PAST_NAME, + ) + + + decoder_path = _Path(path, ONNX_DECODER_NAME + ".onnx") + decoder_with_past_path = _Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx") + decoder_merged_path = _Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") + + if decoder_path.is_file() and decoder_with_past_path.is_file(): + import os + + import onnx + from optimum.onnx import graph_transformations, merge_decoders + + _original_check_and_save = graph_transformations.check_and_save_model + + def _large_model_check_and_save(model, save_path): + save_path = _Path(save_path).as_posix() + external_file_name = os.path.basename(save_path) + "_data" + external_file_path = os.path.join(os.path.dirname(save_path), external_file_name) + + if save_path.endswith(".onnx") and os.path.isfile(save_path): + os.remove(save_path) + if os.path.isfile(external_file_path): + os.remove(external_file_path) + + onnx.save( + model, + save_path, + save_as_external_data=True, + location=external_file_name, + all_tensors_to_one_file=True, + convert_attribute=True, + size_threshold=100, + ) + try: + onnx.checker.check_model(save_path) + except Exception as e: + if "No Op registered for" not in str(e): + raise + + graph_transformations.check_and_save_model = _large_model_check_and_save + try: + merge_decoders( + decoder=decoder_path, + decoder_with_past=decoder_with_past_path, + save_path=decoder_merged_path, + ) + except Exception as e: + raise RuntimeError("Unable to merge decoders") from e + finally: + graph_transformations.check_and_save_model = _original_check_and_save + + new_subpaths = [] + for sp in onnx_files_subpaths: + if sp in [ONNX_DECODER_NAME + ".onnx", ONNX_DECODER_WITH_PAST_NAME + ".onnx"]: + new_subpaths.append(decoder_merged_path.name) + else: + new_subpaths.append(sp) + onnx_files_subpaths = new_subpaths + + models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True + models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False + models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True + + models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True + models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True + + for p in [decoder_path, decoder_with_past_path]: + if p.is_file(): + p.unlink() + ext_data = _Path(str(p) + "_data") + if ext_data.is_file(): + ext_data.unlink() + + return models_and_onnx_configs, onnx_files_subpaths + + def generate_dummy_inputs(self, framework="pt", **input_shapes): + import torch + + batch_size = input_shapes.get("batch_size", 2) + seq_length = input_shapes.get("sequence_length", 16) + text_config = self._config.text_config if hasattr(self._config, "text_config") else self._config + vision_config = getattr(self._config, "vision_config", None) + + if self.component == "vision_encoder": + num_channels = getattr(vision_config, "num_channels", 3) + image_size = 56 + return { + "pixel_values": torch.randn(1, num_channels, image_size, image_size), + } + elif self.component == "embed_tokens": + return { + "input_ids": torch.randint(0, text_config.vocab_size, (batch_size, seq_length)), + } + elif self.component == "decoder": + hidden_size = text_config.hidden_size + num_layers = text_config.num_hidden_layers + num_kv_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads) + head_dim = getattr(text_config, "head_dim", hidden_size // text_config.num_attention_heads) + + dummy = { + "inputs_embeds": torch.randn(batch_size, seq_length, hidden_size), + "position_ids": torch.arange(seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1), + } + + if self.use_past_in_inputs: + past_seq_length = input_shapes.get("past_sequence_length", 16) + dummy["attention_mask"] = torch.ones(batch_size, past_seq_length + seq_length, dtype=torch.long) + dummy["past_key_values"] = tuple( + ( + torch.randn(batch_size, num_kv_heads, past_seq_length, head_dim), + torch.randn(batch_size, num_kv_heads, past_seq_length, head_dim), + ) + for i in range(num_layers) + ) + else: + dummy["attention_mask"] = torch.ones(batch_size, seq_length, dtype=torch.long) + + return dummy + return {} + + def generate_dummy_inputs_for_validation( + self, reference_model_inputs: dict[str, Any], onnx_input_names: list[str] + ) -> dict[str, Any]: + import torch + + if self.is_merged is True and self.use_cache_branch is not None: + reference_model_inputs["use_cache_branch"] = torch.tensor([self.use_cache_branch], dtype=torch.bool) + + if self.use_cache_branch is False: + text_config = self._config.text_config if hasattr(self._config, "text_config") else self._config + num_layers = text_config.num_hidden_layers + num_kv_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads) + hidden_size = text_config.hidden_size + head_dim = getattr(text_config, "head_dim", hidden_size // text_config.num_attention_heads) + batch_size = reference_model_inputs["inputs_embeds"].shape[0] + reference_model_inputs["past_key_values"] = tuple( + ( + torch.zeros(batch_size, num_kv_heads, 1, head_dim), + torch.zeros(batch_size, num_kv_heads, 1, head_dim), + ) + for _ in range(num_layers) + ) + + return self._generate_dummy_inputs_for_validation_inner(reference_model_inputs, onnx_input_names) + + @register_tasks_manager_onnx("mpt", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification", "token-classification"]) class MPTOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index b21aa92d..4460d3af 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -1459,3 +1459,125 @@ def __exit__(self, exc_type, exc_value, traceback): if is_transformers_version(">=", "4.55.0"): GptOssExperts.forward = self.original_gpt_oss_forward + + +class LightonOcrModelPatcher(ModelPatcher): + def __init__( + self, + config: OnnxConfig, + model: PreTrainedModel, + model_kwargs: dict[str, Any] | None = None, + ): + super().__init__(config, model, model_kwargs) + self._export_config = config + + orig_sig = inspect.signature(self.orig_forward) + orig_param_names = list(orig_sig.parameters.keys()) + + @functools.wraps(self.orig_forward) + def patched_forward(*args, **kwargs): + for i, val in enumerate(args): + if i < len(orig_param_names): + kwargs[orig_param_names[i]] = val + + pixel_values = kwargs.get("pixel_values") + input_ids = kwargs.get("input_ids") + inputs_embeds = kwargs.get("inputs_embeds") + attention_mask = kwargs.get("attention_mask") + position_ids = kwargs.get("position_ids") + past_key_values = kwargs.get("past_key_values") + if config.component == "vision_encoder": + vision_tower = model.model.vision_tower + projector = model.model.multi_modal_projector + + patch_size = vision_tower.patch_size + spatial_merge_size = projector.patch_merger.spatial_merge_size + + patch_embeds = vision_tower.patch_conv(pixel_values) + h_patches = patch_embeds.shape[2] + w_patches = patch_embeds.shape[3] + + patch_embeds_flat = patch_embeds[0].flatten(1).T.unsqueeze(0) + patch_embeds_flat = vision_tower.ln_pre(patch_embeds_flat) + + max_width = vision_tower.config.image_size // patch_size + spatial = patch_embeds[0, 0] + row_ids = torch.ones_like(spatial).cumsum(dim=0) - 1 + col_ids = torch.ones_like(spatial).cumsum(dim=1) - 1 + position_ids_vis = (row_ids * max_width + col_ids).reshape(-1).long() + + position_embeddings = vision_tower.patch_positional_embedding(patch_embeds_flat, position_ids_vis) + + transformer_output = vision_tower.transformer( + patch_embeds_flat, + attention_mask=None, + position_embeddings=position_embeddings, + output_hidden_states=False, + output_attentions=False, + return_dict=True, + ) + image_features = transformer_output[0].squeeze(0) + + image_features = projector.norm(image_features) + + d = image_features.shape[-1] + image_grid = image_features.view(h_patches, w_patches, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold( + image_grid, kernel_size=spatial_merge_size, stride=spatial_merge_size + ) + grid = grid.view(d * spatial_merge_size**2, -1).t() + image_features = projector.patch_merger.merging_layer(grid) + + image_features = projector.linear_1(image_features) + image_features = projector.act(image_features) + image_features = projector.linear_2(image_features) + + return {"image_features": image_features.unsqueeze(0)} + + elif config.component == "embed_tokens": + embeds = model.model.language_model.embed_tokens(input_ids) + return {"inputs_embeds": embeds} + + elif config.component == "decoder": + hidden_states = inputs_embeds + language_model = model.model.language_model + + cache = None + if past_key_values is not None: + cache = DynamicCache() + for i, (k, v) in enumerate(past_key_values): + cache.update(k, v, layer_idx=i) + + lm_outputs = language_model( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=cache, + use_cache=True, + return_dict=True, + ) + + hidden_states = lm_outputs.last_hidden_state + logits = model.lm_head(hidden_states) + + result = {"logits": logits} + out_cache = lm_outputs.past_key_values + num_layers = len(language_model.layers) + for i in range(num_layers): + key, value = out_cache[i] + result[f"present.{i}.key"] = key + result[f"present.{i}.value"] = value + return result + + self.patched_forward = patched_forward + + def __enter__(self): + if self._export_config.component == "vision_encoder": + self._orig_attn_impl = self._model.model.vision_tower.config._attn_implementation + self._model.model.vision_tower.config._attn_implementation = "eager" + super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self, "_orig_attn_impl"): + self._model.model.vision_tower.config._attn_implementation = self._orig_attn_impl diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 1da33042..d5679fa0 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -167,6 +167,41 @@ def get_metaclip_2_models_for_export(model: PreTrainedModel, config: ExporterCon return models_for_export +def get_lighton_ocr_models_for_export(model: PreTrainedModel, config: ExporterConfig): + """Create the 4-part split for LightOn OCR: vision_encoder, embed_tokens, decoder, decoder_with_past.""" + models_for_export = {} + + vision_encoder_config = config.__class__( + model.config, task=config.task, variant=config.variant, component="vision_encoder" + ) + embed_tokens_config = config.__class__( + model.config, task=config.task, variant=config.variant, component="embed_tokens" + ) + decoder_config = config.__class__( + model.config, + task=config.task, + variant=config.variant, + component="decoder", + use_past=True, + use_past_in_inputs=False, + ) + decoder_with_past_config = config.__class__( + model.config, + task=config.task, + variant=config.variant, + component="decoder", + use_past=True, + use_past_in_inputs=True, + ) + + models_for_export["vision_encoder"] = (model, vision_encoder_config) + models_for_export["embed_tokens"] = (model, embed_tokens_config) + models_for_export["decoder_model"] = (model, decoder_config) + models_for_export["decoder_with_past_model"] = (model, decoder_with_past_config) + + return models_for_export + + def get_sana_models_for_export(pipeline: DiffusionPipeline, int_dtype: str = "int64", float_dtype: str = "fp32"): import copy @@ -261,6 +296,19 @@ def _get_submodels_and_onnx_configs( export_config.variant = _variant return export_config, get_metaclip_2_models_for_export(model, export_config) + if library_name == "transformers" and model.config.model_type == "lighton_ocr": + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="onnx", task=task, library_name="transformers", model_type="lighton_ocr" + ) + export_config = export_config_constructor( + model.config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + export_config.variant = _variant + return export_config, get_lighton_ocr_models_for_export(model, export_config) + if library_name == "diffusers" and model.__class__.__name__.startswith("Sana"): return None, get_sana_models_for_export(model, int_dtype, float_dtype) From 4426aea8a10f18d33e5b562efa0a587104cf7b81 Mon Sep 17 00:00:00 2001 From: seme0011 Date: Sun, 19 Apr 2026 14:40:08 +0200 Subject: [PATCH 2/6] Add lighton_ocr tiny test model to CI test suite --- tests/exporters/onnx/utils_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/exporters/onnx/utils_tests.py b/tests/exporters/onnx/utils_tests.py index 626ff417..54c7424e 100644 --- a/tests/exporters/onnx/utils_tests.py +++ b/tests/exporters/onnx/utils_tests.py @@ -130,6 +130,7 @@ "layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel", "layoutlmv3": "hf-internal-testing/tiny-random-LayoutLMv3Model", "lilt": "hf-internal-testing/tiny-random-LiltModel", + "lighton_ocr": "Remidesbois/tiny-random-LightOnOCR", "llama": "fxmarty/tiny-llama-fast-tokenizer", "longt5": "fxmarty/tiny-random-working-LongT5Model", "longformer": "hf-internal-testing/tiny-random-LongformerModel", From eaedd00891725f4655b035183c88388cf8d675ec Mon Sep 17 00:00:00 2001 From: seme0011 Date: Sun, 19 Apr 2026 15:11:18 +0200 Subject: [PATCH 3/6] Fix ruff formatting and lint issues --- optimum/exporters/onnx/model_configs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index ea60333f..e88cab69 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -643,11 +643,10 @@ class LightonOcrForConditionalGeneration(Mistral3ForConditionalGeneration): library_name="transformers", ) class LightonOcrOnnxConfig(OnnxConfig): - NORMALIZED_CONFIG_CLASS = LightonOcrNormalizedConfig MIN_TRANSFORMERS_VERSION = version.parse("4.56.0") SUPPORTS_PAST = True - VARIANTS = { + VARIANTS = { # noqa: RUF012 "split": "Vision encoder, token embeddings, and decoder are exported as separate ONNX models.", } DEFAULT_VARIANT = "split" @@ -782,7 +781,6 @@ def post_process_exported_models( ONNX_DECODER_WITH_PAST_NAME, ) - decoder_path = _Path(path, ONNX_DECODER_NAME + ".onnx") decoder_with_past_path = _Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx") decoder_merged_path = _Path(path, ONNX_DECODER_MERGED_NAME + ".onnx") From e20a0d3ca08bd537305617858428501a04f1cf2d Mon Sep 17 00:00:00 2001 From: Remidesbois Date: Tue, 28 Apr 2026 10:50:04 +0200 Subject: [PATCH 4/6] Feat: Add native support for the official LightOnOCR model via mistral3 model_type --- optimum/exporters/onnx/model_configs.py | 5 +++++ optimum/exporters/onnx/utils.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e88cab69..2da8c8e4 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -637,6 +637,11 @@ class LightonOcrForConditionalGeneration(Mistral3ForConditionalGeneration): _register_lighton_ocr_in_transformers() +@register_tasks_manager_onnx( + "mistral3", + *["image-text-to-text", "image-text-to-text-with-past"], + library_name="transformers", +) @register_tasks_manager_onnx( "lighton_ocr", *["image-text-to-text", "image-text-to-text-with-past"], diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index d5679fa0..5eeff8a2 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -296,9 +296,9 @@ def _get_submodels_and_onnx_configs( export_config.variant = _variant return export_config, get_metaclip_2_models_for_export(model, export_config) - if library_name == "transformers" and model.config.model_type == "lighton_ocr": + if library_name == "transformers" and getattr(model.config, "model_type", None) in ["lighton_ocr", "mistral3"]: export_config_constructor = TasksManager.get_exporter_config_constructor( - model=model, exporter="onnx", task=task, library_name="transformers", model_type="lighton_ocr" + model=model, exporter="onnx", task=task, library_name="transformers", model_type=model.config.model_type ) export_config = export_config_constructor( model.config, From 5f8bb2b4bb95b851d194fe4b073bdbfcc99e80b3 Mon Sep 17 00:00:00 2001 From: Remidesbois Date: Tue, 28 Apr 2026 10:50:32 +0200 Subject: [PATCH 5/6] Fix: Automatically use external data for >2GB models during merge_decoders --- optimum/exporters/onnx/model_configs.py | 40 ++++++------------------- 1 file changed, 9 insertions(+), 31 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 2da8c8e4..65a8770e 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -793,37 +793,12 @@ def post_process_exported_models( if decoder_path.is_file() and decoder_with_past_path.is_file(): import os - import onnx - from optimum.onnx import graph_transformations, merge_decoders - - _original_check_and_save = graph_transformations.check_and_save_model - - def _large_model_check_and_save(model, save_path): - save_path = _Path(save_path).as_posix() - external_file_name = os.path.basename(save_path) + "_data" - external_file_path = os.path.join(os.path.dirname(save_path), external_file_name) - - if save_path.endswith(".onnx") and os.path.isfile(save_path): - os.remove(save_path) - if os.path.isfile(external_file_path): - os.remove(external_file_path) - - onnx.save( - model, - save_path, - save_as_external_data=True, - location=external_file_name, - all_tensors_to_one_file=True, - convert_attribute=True, - size_threshold=100, - ) - try: - onnx.checker.check_model(save_path) - except Exception as e: - if "No Op registered for" not in str(e): - raise + from optimum.onnx import merge_decoders + + # Force external data to avoid Protobuf >2GB limit during merge + original_force_external = os.getenv("FORCE_ONNX_EXTERNAL_DATA") + os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" - graph_transformations.check_and_save_model = _large_model_check_and_save try: merge_decoders( decoder=decoder_path, @@ -833,7 +808,10 @@ def _large_model_check_and_save(model, save_path): except Exception as e: raise RuntimeError("Unable to merge decoders") from e finally: - graph_transformations.check_and_save_model = _original_check_and_save + if original_force_external is not None: + os.environ["FORCE_ONNX_EXTERNAL_DATA"] = original_force_external + else: + os.environ.pop("FORCE_ONNX_EXTERNAL_DATA", None) new_subpaths = [] for sp in onnx_files_subpaths: From a7a0e866a77c0989c561a54ea0d49b8bfbc390eb Mon Sep 17 00:00:00 2001 From: Remidesbois Date: Tue, 28 Apr 2026 17:06:37 +0200 Subject: [PATCH 6/6] Optimized LightOnOCR dynamic axes for stricter exporter compatibility. --- optimum/exporters/onnx/model_configs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 65a8770e..b59a962d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -687,7 +687,7 @@ def __init__( def inputs(self) -> dict[str, dict[int, str]]: if self.component == "vision_encoder": return { - "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, } elif self.component == "embed_tokens": return { @@ -695,7 +695,7 @@ def inputs(self) -> dict[str, dict[int, str]]: } elif self.component == "decoder": common_inputs = { - "inputs_embeds": {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}, + "inputs_embeds": {0: "batch_size", 1: "sequence_length"}, "attention_mask": {0: "batch_size", 1: "past_sequence_length + sequence_length"}, "position_ids": {0: "batch_size", 1: "sequence_length"}, } @@ -708,11 +708,11 @@ def inputs(self) -> dict[str, dict[int, str]]: def outputs(self) -> dict[str, dict[int, str]]: if self.component == "vision_encoder": return { - "image_features": {0: "batch_size", 1: "num_image_tokens", 2: "hidden_size"}, + "image_features": {0: "batch_size", 1: "num_image_tokens"}, } elif self.component == "embed_tokens": return { - "inputs_embeds": {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}, + "inputs_embeds": {0: "batch_size", 1: "sequence_length"}, } elif self.component == "decoder": common_outputs = {"logits": {0: "batch_size", 1: "sequence_length"}} @@ -749,6 +749,7 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): flattened_output[f"{name}.{idx}.key"] = t[0] flattened_output[f"{name}.{idx}.value"] = t[1] + def flatten_output_collection_property(self, name: str, field) -> dict[str, Any]: flattened_output = {} if name in ["present", "past_key_values"]: