Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 161 additions & 4 deletions mineru_vl_utils/mineru_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import asyncio
import math
import time
import re
import os
import subprocess
from concurrent.futures import Executor
from typing import Literal, Sequence
from datetime import datetime

from PIL import Image

Expand All @@ -13,7 +17,6 @@

_layout_re = r"^<\|box_start\|>(\d+)\s+(\d+)\s+(\d+)\s+(\d+)<\|box_end\|><\|ref_start\|>(\w+?)<\|ref_end\|>(.*)$"


class MinerUSamplingParams(SamplingParams):
def __init__(
self,
Expand Down Expand Up @@ -59,7 +62,6 @@ def __init__(
"<|rotate_left|>": 270,
}


def _convert_bbox(bbox: Sequence[int] | Sequence[str]) -> list[float] | None:
bbox = tuple(map(int, bbox))
if any(coord < 0 or coord > 1000 for coord in bbox):
Expand Down Expand Up @@ -188,7 +190,7 @@ def post_process(self, blocks: list[ContentBlock]) -> list[ContentBlock]:
abandon_paratext=self.abandon_paratext,
debug=self.debug,
)

def batch_prepare_for_layout(
self,
executor: Executor | None,
Expand Down Expand Up @@ -259,6 +261,20 @@ async def aio_post_process(
loop = asyncio.get_running_loop()
return await loop.run_in_executor(executor, self.post_process, blocks)

render_check_template = r'''
\documentclass{article}
\usepackage{amsmath}
\usepackage{amssymb}
\usepackage{mathtools}
\usepackage{mathrsfs}
\usepackage{arydshln}
\usepackage{extarrows}
\usepackage[CJKmath]{xeCJK}
\setCJKmainfont{simsun.ttc}
\begin{document}
%s
\end{document}
'''

class MinerUClient:
def __init__(
Expand All @@ -279,6 +295,7 @@ def __init__(
min_image_edge: int = 28,
max_image_edge_ratio: float = 50,
handle_equation_block: bool = True,
equation_render_check: bool = False,
abandon_list: bool = False,
abandon_paratext: bool = False,
incremental_priority: bool = False,
Expand Down Expand Up @@ -378,6 +395,7 @@ def __init__(
self.executor = executor
self.use_tqdm = use_tqdm
self.debug = debug
self.equation_render_check = equation_render_check

if backend in ("http-client", "vllm-async-engine"):
self.batching_mode = "concurrent"
Expand Down Expand Up @@ -626,6 +644,133 @@ async def aio_concurrent_two_step_extract(
tqdm_desc="Two Step Extraction",
)

def xelatex_render(
self,
latex,
indice,
):
img_idx, idx = indice
now = datetime.now()
render_check_string = f"\\begin{{displaymath}}{latex}\\end{{displaymath}}"
render_check_string = render_check_template.replace("%s", render_check_string)
timestamp = f"{now.strftime('%Y%m%d%H%M%S')}{now.microsecond // 1000:03d}_{img_idx}_{idx}"
with open(f"render_check_log/src/{timestamp}.tex", "w") as w:
print(render_check_string, file=w)
if not os.path.exists(f"render_check_log/output/{timestamp}"):
os.makedirs(f"render_check_log/output/{timestamp}")

try:
render_log = subprocess.run(
[
'/mnt/petrelfs/zhaozhiyuan/formula/textlive/texlive/bin/x86_64-linux/xelatex',
'-interaction=nonstopmode',
f'-output-directory=./render_check_log/output/{timestamp}',
f'./render_check_log/src/{timestamp}.tex',
],
capture_output=True,
text=True,
timeout=60,
)

render_log = render_log.stdout

except subprocess.TimeoutExpired:
print(f"XeLaTeX process timed out after 30 seconds for: {timestamp}")
return False

except FileNotFoundError:
print(f"XeLaTeX executable not found at specified path for: {timestamp}")
return False

except Exception as e:
print(f"Unexpected error during formula render check for: {timestamp}: {str(e)}")
return False

# render check
if "double subscript" in render_log.lower() or \
"double superscript" in render_log.lower() or \
"missing" in render_log.lower() or \
"invalid" in render_log.lower() or \
"latex error" in render_log.lower() or \
"extra" in render_log.lower() or \
"forgotten" in render_log.lower() or \
"undefined" in render_log.lower() or \
"illegal" in render_log.lower() or \
"runaway" in render_log.lower():
return False
else:
return True

def equation_render_fix(
self,
blocks_list,
prepared_inputs,
client,
executor,
):
try_cnt = 0
equation_idx_list = []
for img_idx, (block_images, prompts, params, indices) in enumerate(prepared_inputs):
equation_idx_list.extend([
(img_idx, idx) for idx in indices \
if blocks_list[img_idx][idx].type == "equation"
])
render_flag_list = [False for _ in range(len(equation_idx_list))]
cand_latex_list = [
blocks_list[img_idx][idx].content \
for (img_idx, idx) in equation_idx_list
]

# mapping indice -> inputs
indice2img = {}
for img_idx, (block_images, _, _, indices) in enumerate(prepared_inputs):
for block_image, indice in zip(block_images, indices):
indice2img[(img_idx, indice)] = block_image

while not all(render for render in render_flag_list) and try_cnt < 20:
print(f"check {len(equation_idx_list)} latex formulas in try {try_cnt}...")
start = time.time()
render_flag_list = list(executor.map(
self.xelatex_render, cand_latex_list, equation_idx_list
))
end = time.time()
print(f"render check complete in {round(end - start, 3)} seconds")

# fix block content
equation_idx_list_fail = []
cand_latex_list_fail = []
for latex_idx, (render_flag, (img_idx, idx)) in enumerate(zip(
render_flag_list,
equation_idx_list,
)):
if render_flag:
blocks_list[img_idx][idx].content = cand_latex_list[latex_idx]
else:
equation_idx_list_fail.append((img_idx, idx))

# re-inference for fail latex
equation_image_list_fail = [indice2img[indice] for indice in equation_idx_list_fail]
cand_latex_list = self.client.batch_predict(
equation_image_list_fail,
["Formula Recognition:\n" for _ in range(len(equation_image_list_fail))],
[
MinerUSamplingParams(
temperature=4.0 + try_cnt/10,
top_p=0.3 + try_cnt/100,
top_k=10,
presence_penalty=1.0,
frequency_penalty=0.05,
repetition_penalty=1.0,
no_repeat_ngram_size=100,
max_new_tokens=None
) for _ in range(len(equation_image_list_fail))
],
None,
)
equation_idx_list = [indice for indice in equation_idx_list_fail]
try_cnt += 1


def stepping_two_step_extract(
self,
images: list[Image.Image],
Expand All @@ -647,7 +792,19 @@ def stepping_two_step_extract(
outputs = self.client.batch_predict(all_images, all_prompts, all_params, priority)
for (img_idx, idx), output in zip(all_indices, outputs):
blocks_list[img_idx][idx].content = output
return self.helper.batch_post_process(self.executor, blocks_list)

if self.equation_render_check:
self.equation_render_fix(
blocks_list,
prepared_inputs,
self.client,
self.executor,
)

return self.helper.batch_post_process(
self.executor,
blocks_list,
)

async def aio_stepping_two_step_extract(
self,
Expand Down