diff --git a/README.md b/README.md
index beefc470..d0f755d7 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@
[![Downloads](https://static.pepy.tech/badge/magic-pdf)](https://pepy.tech/project/magic-pdf)
[![Downloads](https://static.pepy.tech/badge/magic-pdf/month)](https://pepy.tech/project/magic-pdf)
-[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github)
+[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb)
@@ -33,7 +33,7 @@
PDF-Extract-Kit: High-Quality PDF Extraction Toolkit🔥🔥🔥
-
+
Easier to use: Just grab MinerU Desktop. No coding, no login, just a simple interface and smooth interactions. Enjoy it without any fuss!🚀🚀🚀
- 👋 join us on Discord and WeChat + 👋 join us on Discord and WeChat
# Changelog +- 2025/01/22 1.1.0 released. In this version we have focused on improving parsing accuracy and efficiency: + - Model capability upgrade (requires re-executing the [model download process](docs/how_to_download_models_en.md) to obtain incremental updates of model files) + - The layout recognition model has been upgraded to the latest `doclayout_yolo(2501)` model, improving layout recognition accuracy. + - The formula parsing model has been upgraded to the latest `unimernet(2501)` model, improving formula recognition accuracy. + - Performance optimization + - On devices that meet certain configuration requirements (16GB+ VRAM), by optimizing resource usage and restructuring the processing pipeline, overall parsing speed has been increased by more than 50%. + - Parsing effect optimization + - Added a new heading classification feature (testing version, enabled by default) to the online demo([mineru.net](https://mineru.net/OpenSourceTools/Extractor)/[huggingface](https://huggingface.co/spaces/opendatalab/MinerU)/[modelscope](https://www.modelscope.cn/studios/OpenDataLab/MinerU)), which supports hierarchical classification of headings, thereby enhancing document structuring. - 2025/01/10 1.0.1 released. This is our first official release, where we have introduced a completely new API interface and enhanced compatibility through extensive refactoring, as well as a brand new automatic language identification feature: - New API Interface - For the data-side API, we have introduced the Dataset class, designed to provide a robust and flexible data processing framework. This framework currently supports a variety of document formats, including images (.jpg and .png), PDFs, Word documents (.doc and .docx), and PowerPoint presentations (.ppt and .pptx). It ensures effective support for data processing tasks ranging from simple to complex. @@ -226,7 +234,7 @@ There are three different ways to experience MinerU: ### Online Demo Stable Version (Stable version verified by QA): -[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github) +[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github) Test Version (Synced with dev branch updates, testing new features): [![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU) @@ -273,6 +281,7 @@ You can modify certain configurations in this file to enable or disable features }, "table-config": { "model": "rapid_table", // Default to using "rapid_table", can be switched to "tablemaster" or "struct_eqtable". + "sub_model": "slanet_plus", // When the model is "rapid_table", you can choose a sub_model. The options are "slanet_plus" and "unitable" "enable": true, // The table recognition feature is enabled by default. If you need to disable it, please change the value here to "false". "max_time": 400 } @@ -356,6 +365,7 @@ TODO - [x] Reading order based on the model - [x] Recognition of `index` and `list` in the main text - [x] Table recognition +- [x] Heading Classification - [ ] Code block recognition in the main text - [ ] [Chemical formula recognition](docs/chemical_knowledge_introduction/introduction.pdf) - [ ] Geometric shape recognition @@ -365,7 +375,6 @@ TODO - Reading order is determined by the model based on the spatial distribution of readable content, and may be out of order in some areas under extremely complex layouts. - Vertical text is not supported. - Tables of contents and lists are recognized through rules, and some uncommon list formats may not be recognized. -- Only one level of headings is supported; hierarchical headings are not currently supported. - Code blocks are not yet supported in the layout model. - Comic books, art albums, primary school textbooks, and exercises cannot be parsed well. - Table recognition may result in row/column recognition errors in complex tables. diff --git a/README_zh-CN.md b/README_zh-CN.md index 664fb470..6903b746 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -14,7 +14,7 @@ [![Downloads](https://static.pepy.tech/badge/magic-pdf)](https://pepy.tech/project/magic-pdf) [![Downloads](https://static.pepy.tech/badge/magic-pdf/month)](https://pepy.tech/project/magic-pdf) -[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github) +[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github) [![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU) [![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb) @@ -33,19 +33,27 @@ PDF-Extract-Kit: 高质量PDF解析工具箱🔥🔥🔥- 👋 join us on Discord and WeChat + 👋 join us on Discord and WeChat
# 更新记录 +- 2025/01/22 1.1.0 发布,在这个版本我们重点提升了解析的精度与效率: + - 模型能力升级(需重新执行[模型下载流程](docs/how_to_download_models_zh_cn.md)以获得模型文件的增量更新) + - 布局识别模型升级到最新的`doclayout_yolo(2501)`模型,提升了layout识别精度 + - 公式解析模型升级到最新的`unimernet(2501)`模型,提升了公式识别精度 + - 性能优化 + - 在配置满足一定条件(显存16GB+)的设备上,通过优化资源占用和重构处理流水线,整体解析速度提升50%以上 + - 解析效果优化 + - 在线demo([mineru.net](https://mineru.net/OpenSourceTools/Extractor)/[huggingface](https://huggingface.co/spaces/opendatalab/MinerU)/[modelscope](https://www.modelscope.cn/studios/OpenDataLab/MinerU))上新增标题分级功能(测试版本,默认开启),支持对标题进行分级,提升文档结构化程度 - 2025/01/10 1.0.1 发布,这是我们的第一个正式版本,在这个版本中,我们通过大量重构带来了全新的API接口和更广泛的兼容性,以及全新的自动语言识别功能: - 全新API接口 - 对于数据侧API,我们引入了Dataset类,旨在提供一个强大而灵活的数据处理框架。该框架当前支持包括图像(.jpg及.png)、PDF、Word(.doc及.docx)、以及PowerPoint(.ppt及.pptx)在内的多种文档格式,确保了从简单到复杂的数据处理任务都能得到有效的支持。 @@ -227,7 +235,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ### 在线体验 稳定版(经过QA验证的稳定版本): -[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github) +[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github) 测试版(同步dev分支更新,测试新特性): @@ -277,6 +285,7 @@ pip install -U "magic-pdf[full]" --extra-index-url https://wheels.myhloli.com -i }, "table-config": { "model": "rapid_table", // 默认使用"rapid_table",可以切换为"tablemaster"和"struct_eqtable" + "sub_model": "slanet_plus", // 当model为"rapid_table"时,可以自选sub_model,可选项为"slanet_plus"和"unitable" "enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false" "max_time": 400 } @@ -359,6 +368,7 @@ TODO - [x] 基于模型的阅读顺序 - [x] 正文中目录、列表识别 - [x] 表格识别 +- [x] 标题分级 - [ ] 正文中代码块识别 - [ ] [化学式识别](docs/chemical_knowledge_introduction/introduction.pdf) - [ ] 几何图形识别 @@ -368,7 +378,6 @@ TODO - 阅读顺序基于模型对可阅读内容在空间中的分布进行排序,在极端复杂的排版下可能会部分区域乱序 - 不支持竖排文字 - 目录和列表通过规则进行识别,少部分不常见的列表形式可能无法识别 -- 标题只有一级,目前不支持标题分级 - 代码块在layout模型里还没有支持 - 漫画书、艺术图册、小学教材、习题尚不能很好解析 - 表格识别在复杂表格上可能会出现行/列识别错误 diff --git a/docker/ascend_npu/requirements.txt b/docker/ascend_npu/requirements.txt index 3efefe99..dc757ba3 100644 --- a/docker/ascend_npu/requirements.txt +++ b/docker/ascend_npu/requirements.txt @@ -1,7 +1,7 @@ boto3>=1.28.43 Brotli>=1.1.0 click>=8.1.7 -PyMuPDF>=1.24.9 +PyMuPDF>=1.24.9,<=1.24.14 loguru>=0.6.0 numpy>=1.21.6,<2.0.0 fast-langdetect>=0.2.3,<0.3.0 @@ -17,10 +17,9 @@ paddlepaddle==3.0.0b1 struct-eqtable==0.3.2 einops accelerate -doclayout_yolo==0.0.2 rapidocr-paddle rapidocr-onnxruntime -rapid_table==0.3.0 -doclayout-yolo==0.0.2 +rapid-table>=1.0.3,<2.0.0 +doclayout-yolo==0.0.2b1 openai detectron2 diff --git a/docker/china/requirements.txt b/docker/china/requirements.txt index 39006c35..699e4848 100644 --- a/docker/china/requirements.txt +++ b/docker/china/requirements.txt @@ -1,7 +1,7 @@ boto3>=1.28.43 Brotli>=1.1.0 click>=8.1.7 -PyMuPDF>=1.24.9 +PyMuPDF>=1.24.9,<=1.24.14 loguru>=0.6.0 numpy>=1.21.6,<2.0.0 fast-langdetect>=0.2.3,<0.3.0 @@ -16,10 +16,9 @@ paddleocr==2.7.3 struct-eqtable==0.3.2 einops accelerate -doclayout_yolo==0.0.2 rapidocr-paddle rapidocr-onnxruntime -rapid_table==0.3.0 -doclayout-yolo==0.0.2 +rapid-table>=1.0.3,<2.0.0 +doclayout-yolo==0.0.2b1 openai detectron2 diff --git a/docker/global/requirements.txt b/docker/global/requirements.txt index 39006c35..699e4848 100644 --- a/docker/global/requirements.txt +++ b/docker/global/requirements.txt @@ -1,7 +1,7 @@ boto3>=1.28.43 Brotli>=1.1.0 click>=8.1.7 -PyMuPDF>=1.24.9 +PyMuPDF>=1.24.9,<=1.24.14 loguru>=0.6.0 numpy>=1.21.6,<2.0.0 fast-langdetect>=0.2.3,<0.3.0 @@ -16,10 +16,9 @@ paddleocr==2.7.3 struct-eqtable==0.3.2 einops accelerate -doclayout_yolo==0.0.2 rapidocr-paddle rapidocr-onnxruntime -rapid_table==0.3.0 -doclayout-yolo==0.0.2 +rapid-table>=1.0.3,<2.0.0 +doclayout-yolo==0.0.2b1 openai detectron2 diff --git a/magic-pdf.template.json b/magic-pdf.template.json index ba31d96d..478a9f17 100644 --- a/magic-pdf.template.json +++ b/magic-pdf.template.json @@ -16,6 +16,7 @@ }, "table-config": { "model": "rapid_table", + "sub_model": "slanet_plus", "enable": true, "max_time": 400 }, @@ -39,5 +40,5 @@ "enable": false } }, - "config_version": "1.1.0" + "config_version": "1.1.1" } \ No newline at end of file diff --git a/magic_pdf/libs/boxbase.py b/magic_pdf/libs/boxbase.py index 52779a22..2813121b 100644 --- a/magic_pdf/libs/boxbase.py +++ b/magic_pdf/libs/boxbase.py @@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2): bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + if any([bbox1_area == 0, bbox2_area == 0]): + return 0 + # Compute the intersection over union by taking the intersection area # and dividing it by the sum of both areas minus the intersection area - iou = intersection_area / float(bbox1_area + bbox2_area - - intersection_area) + iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area) + return iou diff --git a/magic_pdf/libs/draw_bbox.py b/magic_pdf/libs/draw_bbox.py index 6d70c913..c2ad21d0 100644 --- a/magic_pdf/libs/draw_bbox.py +++ b/magic_pdf/libs/draw_bbox.py @@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): for page in pdf_info: page_line_list = [] for block in page['preproc_blocks']: - if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]: + if block['type'] in [BlockType.Text]: for line in block['lines']: bbox = line['bbox'] index = line['index'] page_line_list.append({'index': index, 'bbox': bbox}) - if block['type'] in [BlockType.Image, BlockType.Table]: + elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]: + if 'virtual_lines' in block: + if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None: + for line in block['virtual_lines']: + bbox = line['bbox'] + index = line['index'] + page_line_list.append({'index': index, 'bbox': bbox}) + else: + for line in block['lines']: + bbox = line['bbox'] + index = line['index'] + page_line_list.append({'index': index, 'bbox': bbox}) + elif block['type'] in [BlockType.Image, BlockType.Table]: for sub_block in block['blocks']: if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]: if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None: diff --git a/magic_pdf/libs/language.py b/magic_pdf/libs/language.py index 76e2eac9..73d382b7 100644 --- a/magic_pdf/libs/language.py +++ b/magic_pdf/libs/language.py @@ -12,12 +12,20 @@ from fast_langdetect import detect_language +def remove_invalid_surrogates(text): + # 移除无效的 UTF-16 代理对 + return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF)) + + def detect_lang(text: str) -> str: if len(text) == 0: return "" text = text.replace("\n", "") + text = remove_invalid_surrogates(text) + + # print(text) try: lang_upper = detect_language(text) except: @@ -37,3 +45,4 @@ def detect_lang(text: str) -> str: print(detect_lang("This is a test")) print(detect_lang("这个是中文测试。")) print(detect_lang("这个是中文测试。")) + print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试")) \ No newline at end of file diff --git a/magic_pdf/model/batch_analyze.py b/magic_pdf/model/batch_analyze.py index f82a7ca3..c1e719dd 100644 --- a/magic_pdf/model/batch_analyze.py +++ b/magic_pdf/model/batch_analyze.py @@ -7,19 +7,19 @@ from PIL import Image from magic_pdf.config.constants import MODEL_NAME -from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE -from magic_pdf.data.dataset import Dataset -from magic_pdf.libs.clean_memory import clean_memory -from magic_pdf.libs.config_reader import get_device -from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton +# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE +# from magic_pdf.data.dataset import Dataset +# from magic_pdf.libs.clean_memory import clean_memory +# from magic_pdf.libs.config_reader import get_device +# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.sub_modules.model_utils import ( clean_vram, crop_img, get_res_list_from_layout_res) from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( get_adjusted_mfdetrec_res, get_ocr_result_list) -from magic_pdf.operators.models import InferenceResult +# from magic_pdf.operators.models import InferenceResult -YOLO_LAYOUT_BASE_BATCH_SIZE = 4 +YOLO_LAYOUT_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1 MFR_BASE_BATCH_SIZE = 16 @@ -44,19 +44,20 @@ def __call__(self, images: list) -> list: modified_images = [] for image_index, image in enumerate(images): pil_img = Image.fromarray(image) - width, height = pil_img.size - if height > width: - input_res = {'poly': [0, 0, width, 0, width, height, 0, height]} - new_image, useful_list = crop_img( - input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0 - ) - layout_images.append(new_image) - modified_images.append([image_index, useful_list]) - else: - layout_images.append(pil_img) + # width, height = pil_img.size + # if height > width: + # input_res = {'poly': [0, 0, width, 0, width, height, 0, height]} + # new_image, useful_list = crop_img( + # input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0 + # ) + # layout_images.append(new_image) + # modified_images.append([image_index, useful_list]) + # else: + layout_images.append(pil_img) images_layout_res += self.model.layout_model.batch_predict( - layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE + # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE + layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE ) for image_index, useful_list in modified_images: @@ -78,7 +79,8 @@ def __call__(self, images: list) -> list: # 公式检测 mfd_start_time = time.time() images_mfd_res = self.model.mfd_model.batch_predict( - images, self.batch_ratio * MFD_BASE_BATCH_SIZE + # images, self.batch_ratio * MFD_BASE_BATCH_SIZE + images, MFD_BASE_BATCH_SIZE ) logger.info( f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}' @@ -91,10 +93,12 @@ def __call__(self, images: list) -> list: images, batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE, ) + mfr_count = 0 for image_index in range(len(images)): images_layout_res[image_index] += images_formula_list[image_index] + mfr_count += len(images_formula_list[image_index]) logger.info( - f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}' + f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}' ) # 清理显存 @@ -159,7 +163,7 @@ def __call__(self, images: list) -> list: elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER: html_code = self.model.table_model.img2html(new_image) elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE: - html_code, table_cell_bboxes, elapse = ( + html_code, table_cell_bboxes, logic_points, elapse = ( self.model.table_model.predict(new_image) ) run_time = time.time() - single_table_start_time @@ -195,81 +199,81 @@ def __call__(self, images: list) -> list: return images_layout_res -def doc_batch_analyze( - dataset: Dataset, - ocr: bool = False, - show_log: bool = False, - start_page_id=0, - end_page_id=None, - lang=None, - layout_model=None, - formula_enable=None, - table_enable=None, - batch_ratio: int | None = None, -) -> InferenceResult: - """Perform batch analysis on a document dataset. - - Args: - dataset (Dataset): The dataset containing document pages to be analyzed. - ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False. - show_log (bool, optional): Flag to enable logging. Defaults to False. - start_page_id (int, optional): The starting page ID for analysis. Defaults to 0. - end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page. - lang (str, optional): Language for OCR. Defaults to None. - layout_model (optional): Layout model to be used for analysis. Defaults to None. - formula_enable (optional): Flag to enable formula detection. Defaults to None. - table_enable (optional): Flag to enable table detection. Defaults to None. - batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1. - - Raises: - CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode. - - Returns: - InferenceResult: The result of the batch analysis containing the analyzed data and the dataset. - """ - - if not torch.cuda.is_available(): - raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode') - - lang = None if lang == '' else lang - # TODO: auto detect batch size - batch_ratio = 1 if batch_ratio is None else batch_ratio - end_page_id = end_page_id if end_page_id else len(dataset) - - model_manager = ModelSingleton() - custom_model: CustomPEKModel = model_manager.get_model( - ocr, show_log, lang, layout_model, formula_enable, table_enable - ) - batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) - - model_json = [] - - # batch analyze - images = [] - for index in range(len(dataset)): - if start_page_id <= index <= end_page_id: - page_data = dataset.get_page(index) - img_dict = page_data.get_image() - images.append(img_dict['img']) - analyze_result = batch_model(images) - - for index in range(len(dataset)): - page_data = dataset.get_page(index) - img_dict = page_data.get_image() - page_width = img_dict['width'] - page_height = img_dict['height'] - if start_page_id <= index <= end_page_id: - result = analyze_result.pop(0) - else: - result = [] - - page_info = {'page_no': index, 'height': page_height, 'width': page_width} - page_dict = {'layout_dets': result, 'page_info': page_info} - model_json.append(page_dict) - - # TODO: clean memory when gpu memory is not enough - clean_memory_start_time = time.time() - clean_memory(get_device()) - logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}') - - return InferenceResult(model_json, dataset) +# def doc_batch_analyze( +# dataset: Dataset, +# ocr: bool = False, +# show_log: bool = False, +# start_page_id=0, +# end_page_id=None, +# lang=None, +# layout_model=None, +# formula_enable=None, +# table_enable=None, +# batch_ratio: int | None = None, +# ) -> InferenceResult: +# """Perform batch analysis on a document dataset. +# +# Args: +# dataset (Dataset): The dataset containing document pages to be analyzed. +# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False. +# show_log (bool, optional): Flag to enable logging. Defaults to False. +# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0. +# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page. +# lang (str, optional): Language for OCR. Defaults to None. +# layout_model (optional): Layout model to be used for analysis. Defaults to None. +# formula_enable (optional): Flag to enable formula detection. Defaults to None. +# table_enable (optional): Flag to enable table detection. Defaults to None. +# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1. +# +# Raises: +# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode. +# +# Returns: +# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset. +# """ +# +# if not torch.cuda.is_available(): +# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode') +# +# lang = None if lang == '' else lang +# # TODO: auto detect batch size +# batch_ratio = 1 if batch_ratio is None else batch_ratio +# end_page_id = end_page_id if end_page_id else len(dataset) +# +# model_manager = ModelSingleton() +# custom_model: CustomPEKModel = model_manager.get_model( +# ocr, show_log, lang, layout_model, formula_enable, table_enable +# ) +# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) +# +# model_json = [] +# +# # batch analyze +# images = [] +# for index in range(len(dataset)): +# if start_page_id <= index <= end_page_id: +# page_data = dataset.get_page(index) +# img_dict = page_data.get_image() +# images.append(img_dict['img']) +# analyze_result = batch_model(images) +# +# for index in range(len(dataset)): +# page_data = dataset.get_page(index) +# img_dict = page_data.get_image() +# page_width = img_dict['width'] +# page_height = img_dict['height'] +# if start_page_id <= index <= end_page_id: +# result = analyze_result.pop(0) +# else: +# result = [] +# +# page_info = {'page_no': index, 'height': page_height, 'width': page_width} +# page_dict = {'layout_dets': result, 'page_info': page_info} +# model_json.append(page_dict) +# +# # TODO: clean memory when gpu memory is not enough +# clean_memory_start_time = time.time() +# clean_memory(get_device()) +# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}') +# +# return InferenceResult(model_json, dataset) diff --git a/magic_pdf/model/doc_analyze_by_custom_model.py b/magic_pdf/model/doc_analyze_by_custom_model.py index 88a55c57..f1863b5c 100644 --- a/magic_pdf/model/doc_analyze_by_custom_model.py +++ b/magic_pdf/model/doc_analyze_by_custom_model.py @@ -3,8 +3,12 @@ # 关闭paddle的信号处理 import paddle +import torch from loguru import logger +from magic_pdf.model.batch_analyze import BatchAnalyze +from magic_pdf.model.sub_modules.model_utils import get_vram + paddle.disable_signal_handler() os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 @@ -154,33 +158,88 @@ def doc_analyze( table_enable=None, ) -> InferenceResult: + end_page_id = end_page_id if end_page_id else len(dataset) - 1 + model_manager = ModelSingleton() custom_model = model_manager.get_model( ocr, show_log, lang, layout_model, formula_enable, table_enable ) + batch_analyze = False + device = get_device() + + npu_support = False + if str(device).startswith("npu"): + import torch_npu + if torch_npu.npu.is_available(): + npu_support = True + + if torch.cuda.is_available() and device != 'cpu' or npu_support: + gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device)))) + if gpu_memory is not None and gpu_memory >= 8: + + if 8 <= gpu_memory < 10: + batch_ratio = 2 + elif 10 <= gpu_memory <= 12: + batch_ratio = 4 + elif 12 < gpu_memory <= 16: + batch_ratio = 8 + elif 16 < gpu_memory <= 24: + batch_ratio = 16 + else: + batch_ratio = 32 + + if batch_ratio >= 1: + logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') + batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) + batch_analyze = True + model_json = [] doc_analyze_start = time.time() - if end_page_id is None: - end_page_id = len(dataset) - - for index in range(len(dataset)): - page_data = dataset.get_page(index) - img_dict = page_data.get_image() - img = img_dict['img'] - page_width = img_dict['width'] - page_height = img_dict['height'] - if start_page_id <= index <= end_page_id: - page_start = time.time() - result = custom_model(img) - logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') - else: - result = [] + if batch_analyze: + # batch analyze + images = [] + for index in range(len(dataset)): + if start_page_id <= index <= end_page_id: + page_data = dataset.get_page(index) + img_dict = page_data.get_image() + images.append(img_dict['img']) + analyze_result = batch_model(images) + + for index in range(len(dataset)): + page_data = dataset.get_page(index) + img_dict = page_data.get_image() + page_width = img_dict['width'] + page_height = img_dict['height'] + if start_page_id <= index <= end_page_id: + result = analyze_result.pop(0) + else: + result = [] + + page_info = {'page_no': index, 'height': page_height, 'width': page_width} + page_dict = {'layout_dets': result, 'page_info': page_info} + model_json.append(page_dict) - page_info = {'page_no': index, 'height': page_height, 'width': page_width} - page_dict = {'layout_dets': result, 'page_info': page_info} - model_json.append(page_dict) + else: + # single analyze + + for index in range(len(dataset)): + page_data = dataset.get_page(index) + img_dict = page_data.get_image() + img = img_dict['img'] + page_width = img_dict['width'] + page_height = img_dict['height'] + if start_page_id <= index <= end_page_id: + page_start = time.time() + result = custom_model(img) + logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') + else: + result = [] + + page_info = {'page_no': index, 'height': page_height, 'width': page_width} + page_dict = {'layout_dets': result, 'page_info': page_info} + model_json.append(page_dict) gc_start = time.time() clean_memory(get_device()) diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py index 8b482d7f..4edfe59f 100644 --- a/magic_pdf/model/pdf_extract_kit.py +++ b/magic_pdf/model/pdf_extract_kit.py @@ -69,6 +69,7 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): self.apply_table = self.table_config.get('enable', False) self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE) self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE) + self.table_sub_model_name = self.table_config.get('sub_model', None) # ocr config self.apply_ocr = ocr @@ -144,7 +145,7 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml' ) ), - device=self.device, + device='cpu' if str(self.device).startswith("mps") else self.device, ) elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: self.layout_model = atom_model_manager.get_atom_model( @@ -174,6 +175,7 @@ def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs): table_max_time=self.table_max_time, device=self.device, ocr_engine=self.ocr_model, + table_sub_model_name=self.table_sub_model_name ) logger.info('DocAnalysis init done!') @@ -192,24 +194,24 @@ def __call__(self, image): layout_res = self.layout_model(image, ignore_catids=[]) elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: # doclayout_yolo - if height > width: - input_res = {"poly":[0,0,width,0,width,height,0,height]} - new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0) - paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list - layout_res = self.layout_model.predict(new_image) - for res in layout_res: - p1, p2, p3, p4, p5, p6, p7, p8 = res['poly'] - p1 = p1 - paste_x + xmin - p2 = p2 - paste_y + ymin - p3 = p3 - paste_x + xmin - p4 = p4 - paste_y + ymin - p5 = p5 - paste_x + xmin - p6 = p6 - paste_y + ymin - p7 = p7 - paste_x + xmin - p8 = p8 - paste_y + ymin - res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8] - else: - layout_res = self.layout_model.predict(image) + # if height > width: + # input_res = {"poly":[0,0,width,0,width,height,0,height]} + # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0) + # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list + # layout_res = self.layout_model.predict(new_image) + # for res in layout_res: + # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly'] + # p1 = p1 - paste_x + xmin + # p2 = p2 - paste_y + ymin + # p3 = p3 - paste_x + xmin + # p4 = p4 - paste_y + ymin + # p5 = p5 - paste_x + xmin + # p6 = p6 - paste_y + ymin + # p7 = p7 - paste_x + xmin + # p8 = p8 - paste_y + ymin + # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8] + # else: + layout_res = self.layout_model.predict(image) layout_cost = round(time.time() - layout_start, 2) logger.info(f'layout detection time: {layout_cost}') @@ -228,7 +230,7 @@ def __call__(self, image): logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}') # 清理显存 - clean_vram(self.device, vram_threshold=8) + clean_vram(self.device, vram_threshold=6) # 从layout_res中获取ocr区域、表格区域、公式区域 ocr_res_list, table_res_list, single_page_mfdetrec_res = ( @@ -276,7 +278,7 @@ def __call__(self, image): elif self.table_model_name == MODEL_NAME.TABLE_MASTER: html_code = self.table_model.img2html(new_image) elif self.table_model_name == MODEL_NAME.RAPID_TABLE: - html_code, table_cell_bboxes, elapse = self.table_model.predict( + html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict( new_image ) run_time = time.time() - single_table_start_time diff --git a/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py b/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py index d27f17fd..784f7af9 100644 --- a/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +++ b/magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py @@ -9,7 +9,11 @@ def __init__(self, weight, device): def predict(self, image): layout_res = [] doclayout_yolo_res = self.model.predict( - image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device + image, + imgsz=1280, + conf=0.10, + iou=0.45, + verbose=False, device=self.device )[0] for xyxy, conf, cla in zip( doclayout_yolo_res.boxes.xyxy.cpu(), @@ -32,8 +36,8 @@ def batch_predict(self, images: list, batch_size: int) -> list: image_res.cpu() for image_res in self.model.predict( images[index : index + batch_size], - imgsz=1024, - conf=0.25, + imgsz=1280, + conf=0.10, iou=0.45, verbose=False, device=self.device, diff --git a/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py b/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py index 54e46c56..9eff1ccd 100644 --- a/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +++ b/magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py @@ -89,7 +89,7 @@ def predict(self, mfd_res, image): mf_image_list.append(bbox_img) dataset = MathDataset(mf_image_list, transform=self.mfr_transform) - dataloader = DataLoader(dataset, batch_size=64, num_workers=0) + dataloader = DataLoader(dataset, batch_size=32, num_workers=0) mfr_res = [] for mf_img in dataloader: mf_img = mf_img.to(self.device) diff --git a/magic_pdf/model/sub_modules/model_init.py b/magic_pdf/model/sub_modules/model_init.py index 120d62e2..7e555744 100644 --- a/magic_pdf/model/sub_modules/model_init.py +++ b/magic_pdf/model/sub_modules/model_init.py @@ -21,7 +21,7 @@ TableMasterPaddleModel -def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None): +def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None): if table_model_type == MODEL_NAME.STRUCT_EQTABLE: table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) elif table_model_type == MODEL_NAME.TABLE_MASTER: @@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr } table_model = TableMasterPaddleModel(config) elif table_model_type == MODEL_NAME.RAPID_TABLE: - table_model = RapidTableModel(ocr_engine) + table_model = RapidTableModel(ocr_engine, table_sub_model_name) else: logger.error('table model type not allow') exit(1) @@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs): kwargs.get('table_model_path'), kwargs.get('table_max_time'), kwargs.get('device'), - kwargs.get('ocr_engine') + kwargs.get('ocr_engine'), + kwargs.get('table_sub_model_name') ) elif model_name == AtomicModel.LangDetect: if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect: diff --git a/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py b/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py index 157fa82f..90eb84a3 100644 --- a/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +++ b/magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py @@ -7,6 +7,8 @@ from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line +import importlib.resources +from paddleocr import PaddleOCR from ppocr.utils.utility import check_and_read @@ -327,30 +329,35 @@ def get_onnx_model(self, **kwargs): return self._models[key] def onnx_model_init(key): - - import importlib.resources - - resource_path = importlib.resources.path('rapidocr_onnxruntime.models','') - - onnx_model = None - additional_ocr_params = { - "use_onnx": True, - "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx', - "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx', - "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx', - "det_db_box_thresh": key[1], - "use_dilation": key[2], - "det_db_unclip_ratio": key[3], - } - # logger.info(f"additional_ocr_params: {additional_ocr_params}") - if key[0] is not None: - additional_ocr_params["lang"] = key[0] - - from paddleocr import PaddleOCR - onnx_model = PaddleOCR(**additional_ocr_params) - - if onnx_model is None: - logger.error('model init failed') + if len(key) < 4: + logger.error('Invalid key length, expected at least 4 elements') exit(1) - else: - return onnx_model \ No newline at end of file + + try: + with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path: + additional_ocr_params = { + "use_onnx": True, + "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx', + "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx', + "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx', + "det_db_box_thresh": key[1], + "use_dilation": key[2], + "det_db_unclip_ratio": key[3], + } + + if key[0] is not None: + additional_ocr_params["lang"] = key[0] + + # logger.info(f"additional_ocr_params: {additional_ocr_params}") + + onnx_model = PaddleOCR(**additional_ocr_params) + + if onnx_model is None: + logger.error('model init failed') + exit(1) + else: + return onnx_model + + except Exception as e: + logger.exception(f'Error initializing model: {e}') + exit(1) \ No newline at end of file diff --git a/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py b/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py index 534ae837..5a30c383 100644 --- a/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +++ b/magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py @@ -2,12 +2,27 @@ import numpy as np import torch from loguru import logger -from rapid_table import RapidTable +from rapid_table import RapidTable, RapidTableInput +from rapid_table.main import ModelType + +from magic_pdf.libs.config_reader import get_device class RapidTableModel(object): - def __init__(self, ocr_engine): - self.table_model = RapidTable() + def __init__(self, ocr_engine, table_sub_model_name): + sub_model_list = [model.value for model in ModelType] + if table_sub_model_name is None: + input_args = RapidTableInput() + elif table_sub_model_name in sub_model_list: + if torch.cuda.is_available() and table_sub_model_name == "unitable": + input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device()) + else: + input_args = RapidTableInput(model_type=table_sub_model_name) + else: + raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}") + + self.table_model = RapidTable(input_args) + # if ocr_engine is None: # self.ocr_model_name = "RapidOCR" # if torch.cuda.is_available(): @@ -45,7 +60,11 @@ def predict(self, image): ocr_result = None if ocr_result: - html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) - return html_code, table_cell_bboxes, elapse + table_results = self.table_model(np.asarray(image), ocr_result) + html_code = table_results.pred_html + table_cell_bboxes = table_results.cell_bboxes + logic_points = table_results.logic_points + elapse = table_results.elapse + return html_code, table_cell_bboxes, logic_points, elapse else: - return None, None, None + return None, None, None, None diff --git a/magic_pdf/pdf_parse_union_core_v2.py b/magic_pdf/pdf_parse_union_core_v2.py index 5408e6f6..999166f9 100644 --- a/magic_pdf/pdf_parse_union_core_v2.py +++ b/magic_pdf/pdf_parse_union_core_v2.py @@ -1,4 +1,5 @@ import copy +import math import os import re import statistics @@ -12,7 +13,7 @@ from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.ocr_content_type import BlockType, ContentType from magic_pdf.data.dataset import Dataset, PageableData -from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio +from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device from magic_pdf.libs.convert_utils import dict_to_list @@ -117,9 +118,10 @@ def fill_char_in_spans(spans, all_chars): for char in all_chars: # 跳过非法bbox的char - x1, y1, x2, y2 = char['bbox'] - if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01: - continue + # x1, y1, x2, y2 = char['bbox'] + # if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01: + # continue + for span in spans: if calculate_char_in_span(char['bbox'], span['bbox'], char['c']): span['chars'].append(char) @@ -173,12 +175,35 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33): return False +def remove_tilted_line(text_blocks): + for block in text_blocks: + remove_lines = [] + for line in block['lines']: + cosine, sine = line['dir'] + # 计算弧度值 + angle_radians = math.atan2(sine, cosine) + # 将弧度值转换为角度值 + angle_degrees = math.degrees(angle_radians) + if 2 < abs(angle_degrees) < 88: + remove_lines.append(line) + for line in remove_lines: + block['lines'].remove(line) + + def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang): # cid用0xfffd表示,连字符拆开 # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] # cid用0xfffd表示,连字符不拆开 - text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] + #text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] + + # 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用 + text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks'] + # text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks'] + + # 移除所有角度不为0或90的line + remove_tilted_line(text_blocks_raw) + all_pymu_chars = [] for block in text_blocks_raw: for line in block['lines']: @@ -365,10 +390,11 @@ def cal_block_index(fix_blocks, sorted_bboxes): block['index'] = median_value # 删除图表body block中的虚拟line信息, 并用real_lines信息回填 - if block['type'] in [BlockType.ImageBody, BlockType.TableBody]: - block['virtual_lines'] = copy.deepcopy(block['lines']) - block['lines'] = copy.deepcopy(block['real_lines']) - del block['real_lines'] + if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]: + if 'real_lines' in block: + block['virtual_lines'] = copy.deepcopy(block['lines']) + block['lines'] = copy.deepcopy(block['real_lines']) + del block['real_lines'] else: # 使用xycut排序 block_bboxes = [] @@ -417,7 +443,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): block_weight = x1 - x0 # 如果block高度小于n行正文,则直接返回block的bbox - if line_height * 3 < block_height: + if line_height * 2 < block_height: if ( block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25 ): # 可能是双列结构,可以切细点 @@ -425,16 +451,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): else: # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细) if block_weight > page_w * 0.4: - line_height = (y1 - y0) / 3 lines = 3 + line_height = (y1 - y0) / lines elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点) lines = int(block_height / line_height) + 1 else: # 判断长宽比 if block_height / block_weight > 1.2: # 细长的不分 return [[x0, y0, x1, y1]] else: # 不细长的还是分成两行 - line_height = (y1 - y0) / 2 lines = 2 + line_height = (y1 - y0) / lines # 确定从哪个y位置开始绘制线条 current_y = y0 @@ -453,30 +479,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): page_line_list = [] + + def add_lines_to_block(b): + line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h) + b['lines'] = [] + for line_bbox in line_bboxes: + b['lines'].append({'bbox': line_bbox, 'spans': []}) + page_line_list.extend(line_bboxes) + for block in fix_blocks: if block['type'] in [ - BlockType.Text, BlockType.Title, BlockType.InterlineEquation, + BlockType.Text, BlockType.Title, BlockType.ImageCaption, BlockType.ImageFootnote, BlockType.TableCaption, BlockType.TableFootnote ]: if len(block['lines']) == 0: - bbox = block['bbox'] - lines = insert_lines_into_block(bbox, line_height, page_w, page_h) - for line in lines: - block['lines'].append({'bbox': line, 'spans': []}) - page_line_list.extend(lines) + add_lines_to_block(block) + elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2: + block['real_lines'] = copy.deepcopy(block['lines']) + add_lines_to_block(block) else: for line in block['lines']: bbox = line['bbox'] page_line_list.append(bbox) - elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]: - bbox = block['bbox'] + elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]: block['real_lines'] = copy.deepcopy(block['lines']) - lines = insert_lines_into_block(bbox, line_height, page_w, page_h) - block['lines'] = [] - for line in lines: - block['lines'].append({'bbox': line, 'spans': []}) - page_line_list.extend(lines) + add_lines_to_block(block) if len(page_line_list) > 200: # layoutreader最高支持512line return None @@ -663,12 +691,77 @@ def parse_page_core( discarded_blocks = magic_model.get_discarded(page_id) text_blocks = magic_model.get_text_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id) - inline_equations, interline_equations, interline_equation_blocks = ( - magic_model.get_equations(page_id) - ) - + inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id) page_w, page_h = magic_model.get_page_size(page_id) + def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w): + def merge_two_bbox(b1, b2): + x_min = min(b1['bbox'][0], b2['bbox'][0]) + y_min = min(b1['bbox'][1], b2['bbox'][1]) + x_max = max(b1['bbox'][2], b2['bbox'][2]) + y_max = max(b1['bbox'][3], b2['bbox'][3]) + return x_min, y_min, x_max, y_max + + def merge_two_blocks(b1, b2): + # 合并两个标题块的边界框 + b1['bbox'] = merge_two_bbox(b1, b2) + + # 合并两个标题块的文本内容 + line1 = b1['lines'][0] + line2 = b2['lines'][0] + line1['bbox'] = merge_two_bbox(line1, line2) + line1['spans'].extend(line2['spans']) + + return b1, b2 + + # 按 y 轴重叠度聚集标题块 + y_overlapping_blocks = [] + title_bs = [b for b in blocks if b['type'] == BlockType.Title] + while title_bs: + block1 = title_bs.pop(0) + current_row = [block1] + to_remove = [] + for block2 in title_bs: + if ( + __is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9) + and len(block1['lines']) == 1 + and len(block2['lines']) == 1 + ): + current_row.append(block2) + to_remove.append(block2) + for b in to_remove: + title_bs.remove(b) + y_overlapping_blocks.append(current_row) + + # 按x轴坐标排序并合并标题块 + to_remove_blocks = [] + for row in y_overlapping_blocks: + if len(row) == 1: + continue + + # 按x轴坐标排序 + row.sort(key=lambda x: x['bbox'][0]) + + merged_block = row[0] + for i in range(1, len(row)): + left_block = merged_block + right_block = row[i] + + left_height = left_block['bbox'][3] - left_block['bbox'][1] + right_height = right_block['bbox'][3] - right_block['bbox'][1] + + if ( + right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold + and left_height * 0.95 < right_height < left_height * 1.05 + ): + merged_block, to_remove_block = merge_two_blocks(merged_block, right_block) + to_remove_blocks.append(to_remove_block) + else: + merged_block = right_block + + for b in to_remove_blocks: + blocks.remove(b) + """将所有区块的bbox整理到一起""" # interline_equation_blocks参数不够准,后面切换到interline_equations上 interline_equation_blocks = [] @@ -753,6 +846,9 @@ def parse_page_core( """对block进行fix操作""" fix_blocks = fix_block_spans_v2(block_with_spans) + """同一行被断开的titile合并""" + merge_title_blocks(fix_blocks) + """获取所有line并计算正文line的高度""" line_height = get_line_height(fix_blocks) @@ -861,17 +957,23 @@ def pdf_parse_union( formula_aided_config = llm_aided_config.get('formula_aided', None) if formula_aided_config is not None: if formula_aided_config.get('enable', False): + llm_aided_formula_start_time = time.time() llm_aided_formula(pdf_info_dict, formula_aided_config) + logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}') """文本优化""" text_aided_config = llm_aided_config.get('text_aided', None) if text_aided_config is not None: if text_aided_config.get('enable', False): + llm_aided_text_start_time = time.time() llm_aided_text(pdf_info_dict, text_aided_config) + logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}') """标题优化""" title_aided_config = llm_aided_config.get('title_aided', None) if title_aided_config is not None: if title_aided_config.get('enable', False): + llm_aided_title_start_time = time.time() llm_aided_title(pdf_info_dict, title_aided_config) + logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}') """dict转list""" pdf_info_list = dict_to_list(pdf_info_dict) diff --git a/magic_pdf/post_proc/llm_aided.py b/magic_pdf/post_proc/llm_aided.py index 90bab31b..5149cb07 100644 --- a/magic_pdf/post_proc/llm_aided.py +++ b/magic_pdf/post_proc/llm_aided.py @@ -83,26 +83,47 @@ def llm_aided_title(pdf_info_dict, title_aided_config): if block["type"] == "title": origin_title_list.append(block) title_text = merge_para_with_text(block) - title_dict[f"{i}"] = title_text + page_line_height_list = [] + for line in block['lines']: + bbox = line['bbox'] + page_line_height_list.append(int(bbox[3] - bbox[1])) + if len(page_line_height_list) > 0: + line_avg_height = sum(page_line_height_list) / len(page_line_height_list) + else: + line_avg_height = int(block['bbox'][3] - block['bbox'][1]) + title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1] i += 1 # logger.info(f"Title list: {title_dict}") title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构: -1. 保留原始内容: +1. 字典中每个value均为一个list,包含以下元素: + - 标题文本 + - 文本行高是标题所在块的平均行高 + - 标题所在的页码 + +2. 保留原始内容: - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素 - 请务必保证输出的字典中元素的数量和输入的数量一致 -2. 保持字典内key-value的对应关系不变 +3. 保持字典内key-value的对应关系不变 -3. 优化层次结构: +4. 优化层次结构: - 为每个标题元素添加适当的层次结构 - - 标题层级应具有连续性,不能跳过某一层级 + - 行高较大的标题一般是更高级别的标题 + - 标题从前至后的层级必须是连续的,不能跳过层级 - 标题层级最多为4级,不要添加过多的层级 - - 优化后的标题为一个整数,代表该标题的层级 - + - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息 + +5. 合理性检查与微调: + - 在完成初步分级后,仔细检查分级结果的合理性 + - 根据上下文关系和逻辑顺序,对不合理的分级进行微调 + - 确保最终的分级结果符合文档的实际结构和逻辑 + IMPORTANT: -请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。 +请直接返回优化过的由标题层级组成的json,格式如下: +{{"0":1,"1":2,"2":2,"3":3}} +返回的json不需要格式化。 Input title list: {title_dict} @@ -110,24 +131,36 @@ def llm_aided_title(pdf_info_dict, title_aided_config): Corrected title list: """ - completion = client.chat.completions.create( - model=title_aided_config["model"], - messages=[ - {'role': 'user', 'content': title_optimize_prompt}], - temperature=0.7, - ) - - json_completion = json.loads(completion.choices[0].message.content) - - # logger.info(f"Title completion: {json_completion}") + retry_count = 0 + max_retries = 3 + json_completion = None - # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}") - if len(json_completion) == len(title_dict): + while retry_count < max_retries: try: - for i, origin_title_block in enumerate(origin_title_list): - origin_title_block["level"] = int(json_completion[str(i)]) + completion = client.chat.completions.create( + model=title_aided_config["model"], + messages=[ + {'role': 'user', 'content': title_optimize_prompt}], + temperature=0.7, + ) + json_completion = json.loads(completion.choices[0].message.content) + + # logger.info(f"Title completion: {json_completion}") + # logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}") + + if len(json_completion) == len(title_dict): + for i, origin_title_block in enumerate(origin_title_list): + origin_title_block["level"] = int(json_completion[str(i)]) + break + else: + logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.") + retry_count += 1 except Exception as e: - logger.exception(e) - else: - logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.") - + if isinstance(e, json.decoder.JSONDecodeError): + logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}") + else: + logger.exception(e) + retry_count += 1 + + if json_completion is None: + logger.error("Failed to decode JSON after maximum retries.") diff --git a/magic_pdf/pre_proc/ocr_span_list_modify.py b/magic_pdf/pre_proc/ocr_span_list_modify.py index a4ada96e..4354cb35 100644 --- a/magic_pdf/pre_proc/ocr_span_list_modify.py +++ b/magic_pdf/pre_proc/ocr_span_list_modify.py @@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans): def check_chars_is_overlap_in_span(chars): for i in range(len(chars)): for j in range(i + 1, len(chars)): - if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.9: + if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.35: return True return False diff --git a/magic_pdf/resources/model_config/model_configs.yaml b/magic_pdf/resources/model_config/model_configs.yaml index a11f509f..b75883b5 100644 --- a/magic_pdf/resources/model_config/model_configs.yaml +++ b/magic_pdf/resources/model_config/model_configs.yaml @@ -1,8 +1,8 @@ weights: layoutlmv3: Layout/LayoutLMv3/model_final.pth - doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt + doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt - unimernet_small: MFR/unimernet_small + unimernet_small: MFR/unimernet_small_2501 struct_eqtable: TabRec/StructEqTable tablemaster: TabRec/TableMaster rapid_table: TabRec/RapidTable \ No newline at end of file diff --git a/next_docs/en/_static/image/logo.png b/next_docs/en/_static/image/logo.png index 99752da4..09ab46b2 100644 Binary files a/next_docs/en/_static/image/logo.png and b/next_docs/en/_static/image/logo.png differ diff --git a/projects/gradio_app/examples/complex_layout.pdf b/projects/gradio_app/examples/complex_layout.pdf index f9d09673..02f294d7 100644 Binary files a/projects/gradio_app/examples/complex_layout.pdf and b/projects/gradio_app/examples/complex_layout.pdf differ diff --git a/projects/gradio_app/header.html b/projects/gradio_app/header.html index af3907d5..21b9184a 100644 --- a/projects/gradio_app/header.html +++ b/projects/gradio_app/header.html @@ -102,7 +102,7 @@ - + @@ -112,7 +112,7 @@ - + diff --git a/requirements.txt b/requirements.txt index 86e03dc0..060bab2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ fast-langdetect>=0.2.3 loguru>=0.6.0 numpy>=1.21.6,<2.0.0 pydantic>=2.7.2 -PyMuPDF>=1.24.9 +PyMuPDF>=1.24.9,<=1.24.14 scikit-learn>=1.0.2 torch>=2.2.2 transformers diff --git a/scripts/download_models.py b/scripts/download_models.py index 2a8153ed..e1eb7b93 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -16,7 +16,7 @@ def download_and_modify_json(url, local_filename, modifications): if os.path.exists(local_filename): data = json.load(open(local_filename)) config_version = data.get('config_version', '0.0.0') - if config_version < '1.1.0': + if config_version < '1.1.1': data = download_json(url) else: data = download_json(url) @@ -35,7 +35,7 @@ def download_and_modify_json(url, local_filename, modifications): "models/Layout/LayoutLMv3/*", "models/Layout/YOLO/*", "models/MFD/YOLO/*", - "models/MFR/unimernet_small/*", + "models/MFR/unimernet_small_2501/*", "models/TabRec/TableMaster/*", "models/TabRec/StructEqTable/*", ] diff --git a/scripts/download_models_hf.py b/scripts/download_models_hf.py index c2b944a5..9a87af23 100644 --- a/scripts/download_models_hf.py +++ b/scripts/download_models_hf.py @@ -16,7 +16,7 @@ def download_and_modify_json(url, local_filename, modifications): if os.path.exists(local_filename): data = json.load(open(local_filename)) config_version = data.get('config_version', '0.0.0') - if config_version < '1.1.0': + if config_version < '1.1.1': data = download_json(url) else: data = download_json(url) @@ -36,7 +36,7 @@ def download_and_modify_json(url, local_filename, modifications): "models/Layout/LayoutLMv3/*", "models/Layout/YOLO/*", "models/MFD/YOLO/*", - "models/MFR/unimernet_small/*", + "models/MFR/unimernet_small_2501/*", "models/TabRec/TableMaster/*", "models/TabRec/StructEqTable/*", ] diff --git a/setup.py b/setup.py index e234903a..f3c834d0 100644 --- a/setup.py +++ b/setup.py @@ -48,10 +48,10 @@ def parse_requirements(filename): "struct-eqtable==0.3.2", # 表格解析 "einops", # struct-eqtable依赖 "accelerate", # struct-eqtable依赖 - "doclayout_yolo==0.0.2", # doclayout_yolo + "doclayout_yolo==0.0.2b1", # doclayout_yolo "rapidocr-paddle", # rapidocr-paddle "rapidocr_onnxruntime", - "rapid_table==0.3.0", # rapid_table + "rapid_table>=1.0.3,<2.0.0", # rapid_table "PyYAML", # yaml "openai", # openai SDK "detectron2"