Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/omnisvg_example/omnisvg_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Qwen2_5_VLForConditionalGeneration

from QEfficient import QEFFAutoModelForImageTextToText


class SketchDecoder(nn.Module):
"""
Autoregressive generative model
"""

def __init__(self, **kwargs):
super().__init__()
self.vocab_size = 196042
self.bos_token_id = 151643
self.eos_token_id = 196041
self.pad_token_id = 151643

config = AutoConfig.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct",
vocab_size=self.vocab_size,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
)

self.transformer = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2.5-VL-3B-Instruct", config=config, attn_implementation="eager", ignore_mismatched_sizes=True
)

self.transformer.resize_token_embeddings(self.vocab_size)

def forward(self, *args, **kwargs):
raise NotImplementedError("Forward pass not included in open-source version")


model_id = "Qwen/Qwen2.5-VL-3B-Instruct"
sketch_decoder = SketchDecoder()
weight_path = "/home/dipankar/omnisvg/OmniSVG"
sketch_weight_file = os.path.join(weight_path, "pytorch_model.bin")
if not os.path.exists(sketch_weight_file):
raise FileNotFoundError(f"pytorch_model.bin not found in {weight_path}")
sketch_decoder.load_state_dict(torch.load(sketch_weight_file))
sketch_decoder.transformer.eval()
qeff_model = QEFFAutoModelForImageTextToText(sketch_decoder.transformer)
qeff_model.export()
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", padding_side="left")
path = qeff_model.compile(
batch_size=1,
prefill_seq_len=128,
ctx_len=4096,
num_cores=16,
num_devices=8,
height=354,
width=536,
mxfp6_matmul=False,
aic_enable_depth_first=True,
skip_vision=True,
mos=1,
)
Loading