diff --git a/mineru_vl_utils/mineru_client.py b/mineru_vl_utils/mineru_client.py index ee8f6a6..eb1999f 100644 --- a/mineru_vl_utils/mineru_client.py +++ b/mineru_vl_utils/mineru_client.py @@ -1,8 +1,12 @@ import asyncio import math +import time import re +import os +import subprocess from concurrent.futures import Executor from typing import Literal, Sequence +from datetime import datetime from PIL import Image @@ -13,7 +17,6 @@ _layout_re = r"^<\|box_start\|>(\d+)\s+(\d+)\s+(\d+)\s+(\d+)<\|box_end\|><\|ref_start\|>(\w+?)<\|ref_end\|>(.*)$" - class MinerUSamplingParams(SamplingParams): def __init__( self, @@ -59,7 +62,6 @@ def __init__( "<|rotate_left|>": 270, } - def _convert_bbox(bbox: Sequence[int] | Sequence[str]) -> list[float] | None: bbox = tuple(map(int, bbox)) if any(coord < 0 or coord > 1000 for coord in bbox): @@ -188,7 +190,7 @@ def post_process(self, blocks: list[ContentBlock]) -> list[ContentBlock]: abandon_paratext=self.abandon_paratext, debug=self.debug, ) - + def batch_prepare_for_layout( self, executor: Executor | None, @@ -259,6 +261,20 @@ async def aio_post_process( loop = asyncio.get_running_loop() return await loop.run_in_executor(executor, self.post_process, blocks) +render_check_template = r''' +\documentclass{article} +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{mathtools} +\usepackage{mathrsfs} +\usepackage{arydshln} +\usepackage{extarrows} +\usepackage[CJKmath]{xeCJK} +\setCJKmainfont{simsun.ttc} +\begin{document} +%s +\end{document} +''' class MinerUClient: def __init__( @@ -279,6 +295,7 @@ def __init__( min_image_edge: int = 28, max_image_edge_ratio: float = 50, handle_equation_block: bool = True, + equation_render_check: bool = False, abandon_list: bool = False, abandon_paratext: bool = False, incremental_priority: bool = False, @@ -378,6 +395,7 @@ def __init__( self.executor = executor self.use_tqdm = use_tqdm self.debug = debug + self.equation_render_check = equation_render_check if backend in ("http-client", "vllm-async-engine"): self.batching_mode = "concurrent" @@ -626,6 +644,133 @@ async def aio_concurrent_two_step_extract( tqdm_desc="Two Step Extraction", ) + def xelatex_render( + self, + latex, + indice, + ): + img_idx, idx = indice + now = datetime.now() + render_check_string = f"\\begin{{displaymath}}{latex}\\end{{displaymath}}" + render_check_string = render_check_template.replace("%s", render_check_string) + timestamp = f"{now.strftime('%Y%m%d%H%M%S')}{now.microsecond // 1000:03d}_{img_idx}_{idx}" + with open(f"render_check_log/src/{timestamp}.tex", "w") as w: + print(render_check_string, file=w) + if not os.path.exists(f"render_check_log/output/{timestamp}"): + os.makedirs(f"render_check_log/output/{timestamp}") + + try: + render_log = subprocess.run( + [ + '/mnt/petrelfs/zhaozhiyuan/formula/textlive/texlive/bin/x86_64-linux/xelatex', + '-interaction=nonstopmode', + f'-output-directory=./render_check_log/output/{timestamp}', + f'./render_check_log/src/{timestamp}.tex', + ], + capture_output=True, + text=True, + timeout=60, + ) + + render_log = render_log.stdout + + except subprocess.TimeoutExpired: + print(f"XeLaTeX process timed out after 30 seconds for: {timestamp}") + return False + + except FileNotFoundError: + print(f"XeLaTeX executable not found at specified path for: {timestamp}") + return False + + except Exception as e: + print(f"Unexpected error during formula render check for: {timestamp}: {str(e)}") + return False + + # render check + if "double subscript" in render_log.lower() or \ + "double superscript" in render_log.lower() or \ + "missing" in render_log.lower() or \ + "invalid" in render_log.lower() or \ + "latex error" in render_log.lower() or \ + "extra" in render_log.lower() or \ + "forgotten" in render_log.lower() or \ + "undefined" in render_log.lower() or \ + "illegal" in render_log.lower() or \ + "runaway" in render_log.lower(): + return False + else: + return True + + def equation_render_fix( + self, + blocks_list, + prepared_inputs, + client, + executor, + ): + try_cnt = 0 + equation_idx_list = [] + for img_idx, (block_images, prompts, params, indices) in enumerate(prepared_inputs): + equation_idx_list.extend([ + (img_idx, idx) for idx in indices \ + if blocks_list[img_idx][idx].type == "equation" + ]) + render_flag_list = [False for _ in range(len(equation_idx_list))] + cand_latex_list = [ + blocks_list[img_idx][idx].content \ + for (img_idx, idx) in equation_idx_list + ] + + # mapping indice -> inputs + indice2img = {} + for img_idx, (block_images, _, _, indices) in enumerate(prepared_inputs): + for block_image, indice in zip(block_images, indices): + indice2img[(img_idx, indice)] = block_image + + while not all(render for render in render_flag_list) and try_cnt < 20: + print(f"check {len(equation_idx_list)} latex formulas in try {try_cnt}...") + start = time.time() + render_flag_list = list(executor.map( + self.xelatex_render, cand_latex_list, equation_idx_list + )) + end = time.time() + print(f"render check complete in {round(end - start, 3)} seconds") + + # fix block content + equation_idx_list_fail = [] + cand_latex_list_fail = [] + for latex_idx, (render_flag, (img_idx, idx)) in enumerate(zip( + render_flag_list, + equation_idx_list, + )): + if render_flag: + blocks_list[img_idx][idx].content = cand_latex_list[latex_idx] + else: + equation_idx_list_fail.append((img_idx, idx)) + + # re-inference for fail latex + equation_image_list_fail = [indice2img[indice] for indice in equation_idx_list_fail] + cand_latex_list = self.client.batch_predict( + equation_image_list_fail, + ["Formula Recognition:\n" for _ in range(len(equation_image_list_fail))], + [ + MinerUSamplingParams( + temperature=4.0 + try_cnt/10, + top_p=0.3 + try_cnt/100, + top_k=10, + presence_penalty=1.0, + frequency_penalty=0.05, + repetition_penalty=1.0, + no_repeat_ngram_size=100, + max_new_tokens=None + ) for _ in range(len(equation_image_list_fail)) + ], + None, + ) + equation_idx_list = [indice for indice in equation_idx_list_fail] + try_cnt += 1 + + def stepping_two_step_extract( self, images: list[Image.Image], @@ -647,7 +792,19 @@ def stepping_two_step_extract( outputs = self.client.batch_predict(all_images, all_prompts, all_params, priority) for (img_idx, idx), output in zip(all_indices, outputs): blocks_list[img_idx][idx].content = output - return self.helper.batch_post_process(self.executor, blocks_list) + + if self.equation_render_check: + self.equation_render_fix( + blocks_list, + prepared_inputs, + self.client, + self.executor, + ) + + return self.helper.batch_post_process( + self.executor, + blocks_list, + ) async def aio_stepping_two_step_extract( self,