Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions optimum/exporters/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
"validate_models_outputs",
"onnx_export_from_model",
],
"utils": ["MODEL_TYPES_REQUIRING_POSITION_IDS", "get_metaclip_2_models_for_export"],
"utils": [
"MODEL_TYPES_REQUIRING_POSITION_IDS",
"get_metaclip_2_models_for_export",
"get_lighton_ocr_models_for_export",
],
"__main__": ["main_export"],
}

Expand All @@ -44,7 +48,11 @@
validate_model_outputs,
validate_models_outputs,
)
from optimum.exporters.onnx.utils import MODEL_TYPES_REQUIRING_POSITION_IDS, get_metaclip_2_models_for_export
from optimum.exporters.onnx.utils import (
MODEL_TYPES_REQUIRING_POSITION_IDS,
get_lighton_ocr_models_for_export,
get_metaclip_2_models_for_export,
)
else:
import sys

Expand Down
328 changes: 328 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
CohereModelPatcher,
FluxTransformerModelPatcher,
GptOssModelPatcher,
LightonOcrModelPatcher,
MetaCLIP2Patcher,
MgpstrModelPatcher,
MoonshineModelPatcher,
Expand Down Expand Up @@ -581,6 +582,333 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator


class LightonOcrNormalizedConfig(NormalizedTextConfig):
"""Normalized config for LightOn OCR that reads from the text_config sub-config."""

NUM_LAYERS = "text_config.num_hidden_layers"
NUM_ATTENTION_HEADS = "text_config.num_attention_heads"
NUM_KEY_VALUE_HEADS = "text_config.num_key_value_heads"
HIDDEN_SIZE = "text_config.hidden_size"
VOCAB_SIZE = "text_config.vocab_size"

def __init__(self, config):
super().__init__(config)
text_config = config.text_config if hasattr(config, "text_config") else config
self.hidden_size = text_config.hidden_size
self.num_attention_heads = text_config.num_attention_heads
self.num_key_value_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads)
self.num_layers = text_config.num_hidden_layers
self.vocab_size = text_config.vocab_size
self.head_dim = getattr(text_config, "head_dim", self.hidden_size // self.num_attention_heads)
if hasattr(config, "vision_config"):
self.vision_config = config.vision_config


def _register_lighton_ocr_in_transformers():
from transformers import AutoConfig, AutoModelForImageTextToText, Mistral3Config, Mistral3ForConditionalGeneration

if "lighton_ocr" in getattr(AutoConfig, "_model_mapping", {}).get("model_type", {}).__class__.__dict__.get(
"_mapping", {}
):
return

class LightonOcrConfig(Mistral3Config):
model_type = "lighton_ocr"

class LightonOcrForConditionalGeneration(Mistral3ForConditionalGeneration):
config_class = LightonOcrConfig
_checkpoint_conversion_mapping = { # noqa: RUF012
"^model.vision_encoder.": "model.vision_tower.",
"^model.vision_projection.": "model.multi_modal_projector.",
"^model.language_model.lm_head.": "lm_head.",
"^model.language_model.model.": "model.language_model.",
}

LightonOcrConfig.__module__ = Mistral3Config.__module__
LightonOcrForConditionalGeneration.__module__ = Mistral3ForConditionalGeneration.__module__

try:
AutoConfig.register("lighton_ocr", LightonOcrConfig)
AutoModelForImageTextToText.register(LightonOcrConfig, LightonOcrForConditionalGeneration)
except Exception:
pass


_register_lighton_ocr_in_transformers()


@register_tasks_manager_onnx(
"mistral3",
*["image-text-to-text", "image-text-to-text-with-past"],
library_name="transformers",
)
@register_tasks_manager_onnx(
"lighton_ocr",
*["image-text-to-text", "image-text-to-text-with-past"],
library_name="transformers",
)
class LightonOcrOnnxConfig(OnnxConfig):
NORMALIZED_CONFIG_CLASS = LightonOcrNormalizedConfig
MIN_TRANSFORMERS_VERSION = version.parse("4.56.0")
SUPPORTS_PAST = True
VARIANTS = { # noqa: RUF012
"split": "Vision encoder, token embeddings, and decoder are exported as separate ONNX models.",
}
DEFAULT_VARIANT = "split"
_MODEL_PATCHER = LightonOcrModelPatcher

def __init__(
self,
config: PretrainedConfig,
task: str = "image-text-to-text",
int_dtype: str = "int64",
float_dtype: str = "fp32",
variant: str = "split",
component: str | None = None,
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: list[Any] | None = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)
self.variant = variant
self.component = component
self.use_past = use_past
self.use_past_in_inputs = use_past_in_inputs
self.is_merged = False
self.use_cache_branch = None

@property
def inputs(self) -> dict[str, dict[int, str]]:
if self.component == "vision_encoder":
return {
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
}
elif self.component == "embed_tokens":
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
}
elif self.component == "decoder":
common_inputs = {
"inputs_embeds": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "past_sequence_length + sequence_length"},
"position_ids": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self._add_past_key_values(common_inputs, direction="inputs")
return common_inputs
return {}

@property
def outputs(self) -> dict[str, dict[int, str]]:
if self.component == "vision_encoder":
return {
"image_features": {0: "batch_size", 1: "num_image_tokens"},
}
elif self.component == "embed_tokens":
return {
"inputs_embeds": {0: "batch_size", 1: "sequence_length"},
}
elif self.component == "decoder":
common_outputs = {"logits": {0: "batch_size", 1: "sequence_length"}}
self._add_past_key_values(common_outputs, direction="outputs")
return common_outputs
return {}

def _add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str):
text_config = self._config.text_config if hasattr(self._config, "text_config") else self._config
num_layers = text_config.num_hidden_layers

if direction == "inputs":
for i in range(num_layers):
inputs_or_outputs[f"past_key_values.{i}.key"] = {
0: "batch_size",
2: "past_sequence_length",
}
inputs_or_outputs[f"past_key_values.{i}.value"] = {
0: "batch_size",
2: "past_sequence_length",
}
else:
for i in range(num_layers):
inputs_or_outputs[f"present.{i}.key"] = {
0: "batch_size",
2: "past_sequence_length + sequence_length",
}
inputs_or_outputs[f"present.{i}.value"] = {
0: "batch_size",
2: "past_sequence_length + sequence_length",
}

def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key"] = t[0]
flattened_output[f"{name}.{idx}.value"] = t[1]


def flatten_output_collection_property(self, name: str, field) -> dict[str, Any]:
flattened_output = {}
if name in ["present", "past_key_values"]:
for idx, t in enumerate(field):
self.flatten_past_key_values(flattened_output, name, idx, t)
else:
flattened_output = super().flatten_output_collection_property(name, field)
return flattened_output

def _generate_dummy_inputs_for_validation_inner(
self, reference_model_inputs: dict[str, Any], onnx_input_names: list[str]
) -> dict[str, Any]:
if self.component != "decoder":
return reference_model_inputs
filtered = {}
for name, value in reference_model_inputs.items():
if name == "past_key_values":
if any(n.startswith("past_key_values.") for n in onnx_input_names):
filtered[name] = value
else:
filtered[name] = value
return filtered

def post_process_exported_models(
self,
path,
models_and_onnx_configs,
onnx_files_subpaths,
):
from pathlib import Path as _Path

from optimum.exporters.onnx.constants import (
ONNX_DECODER_MERGED_NAME,
ONNX_DECODER_NAME,
ONNX_DECODER_WITH_PAST_NAME,
)

decoder_path = _Path(path, ONNX_DECODER_NAME + ".onnx")
decoder_with_past_path = _Path(path, ONNX_DECODER_WITH_PAST_NAME + ".onnx")
decoder_merged_path = _Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")

if decoder_path.is_file() and decoder_with_past_path.is_file():
import os

from optimum.onnx import merge_decoders

# Force external data to avoid Protobuf >2GB limit during merge
original_force_external = os.getenv("FORCE_ONNX_EXTERNAL_DATA")
os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1"

try:
merge_decoders(
decoder=decoder_path,
decoder_with_past=decoder_with_past_path,
save_path=decoder_merged_path,
)
except Exception as e:
raise RuntimeError("Unable to merge decoders") from e
finally:
if original_force_external is not None:
os.environ["FORCE_ONNX_EXTERNAL_DATA"] = original_force_external
else:
os.environ.pop("FORCE_ONNX_EXTERNAL_DATA", None)

new_subpaths = []
for sp in onnx_files_subpaths:
if sp in [ONNX_DECODER_NAME + ".onnx", ONNX_DECODER_WITH_PAST_NAME + ".onnx"]:
new_subpaths.append(decoder_merged_path.name)
else:
new_subpaths.append(sp)
onnx_files_subpaths = new_subpaths

models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_cache_branch = False
models_and_onnx_configs[ONNX_DECODER_NAME][1].use_past_in_inputs = True

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True

for p in [decoder_path, decoder_with_past_path]:
if p.is_file():
p.unlink()
ext_data = _Path(str(p) + "_data")
if ext_data.is_file():
ext_data.unlink()

return models_and_onnx_configs, onnx_files_subpaths

def generate_dummy_inputs(self, framework="pt", **input_shapes):
import torch

batch_size = input_shapes.get("batch_size", 2)
seq_length = input_shapes.get("sequence_length", 16)
text_config = self._config.text_config if hasattr(self._config, "text_config") else self._config
vision_config = getattr(self._config, "vision_config", None)

if self.component == "vision_encoder":
num_channels = getattr(vision_config, "num_channels", 3)
image_size = 56
return {
"pixel_values": torch.randn(1, num_channels, image_size, image_size),
}
elif self.component == "embed_tokens":
return {
"input_ids": torch.randint(0, text_config.vocab_size, (batch_size, seq_length)),
}
elif self.component == "decoder":
hidden_size = text_config.hidden_size
num_layers = text_config.num_hidden_layers
num_kv_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads)
head_dim = getattr(text_config, "head_dim", hidden_size // text_config.num_attention_heads)

dummy = {
"inputs_embeds": torch.randn(batch_size, seq_length, hidden_size),
"position_ids": torch.arange(seq_length, dtype=torch.long).unsqueeze(0).expand(batch_size, -1),
}

if self.use_past_in_inputs:
past_seq_length = input_shapes.get("past_sequence_length", 16)
dummy["attention_mask"] = torch.ones(batch_size, past_seq_length + seq_length, dtype=torch.long)
dummy["past_key_values"] = tuple(
(
torch.randn(batch_size, num_kv_heads, past_seq_length, head_dim),
torch.randn(batch_size, num_kv_heads, past_seq_length, head_dim),
)
for i in range(num_layers)
)
else:
dummy["attention_mask"] = torch.ones(batch_size, seq_length, dtype=torch.long)

return dummy
return {}

def generate_dummy_inputs_for_validation(
self, reference_model_inputs: dict[str, Any], onnx_input_names: list[str]
) -> dict[str, Any]:
import torch

if self.is_merged is True and self.use_cache_branch is not None:
reference_model_inputs["use_cache_branch"] = torch.tensor([self.use_cache_branch], dtype=torch.bool)

if self.use_cache_branch is False:
text_config = self._config.text_config if hasattr(self._config, "text_config") else self._config
num_layers = text_config.num_hidden_layers
num_kv_heads = getattr(text_config, "num_key_value_heads", text_config.num_attention_heads)
hidden_size = text_config.hidden_size
head_dim = getattr(text_config, "head_dim", hidden_size // text_config.num_attention_heads)
batch_size = reference_model_inputs["inputs_embeds"].shape[0]
reference_model_inputs["past_key_values"] = tuple(
(
torch.zeros(batch_size, num_kv_heads, 1, head_dim),
torch.zeros(batch_size, num_kv_heads, 1, head_dim),
)
for _ in range(num_layers)
)

return self._generate_dummy_inputs_for_validation_inner(reference_model_inputs, onnx_input_names)


@register_tasks_manager_onnx("mpt", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification", "token-classification"])
class MPTOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
Expand Down
Loading