Skip to content

Commit

Permalink
Merge pull request #698 from myhloli/dev
Browse files Browse the repository at this point in the history
feat(layoutreader): support local model directory and improve model loading
  • Loading branch information
myhloli authored Oct 8, 2024
2 parents 3fb0494 + 0b2b0ce commit 8786d20
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ This project currently uses PyMuPDF to achieve advanced functionality. However,
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six)

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ TODO
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
- [fast-langdetect](https://github.com/LlmKira/fast-langdetect)
- [pdfminer.six](https://github.com/pdfminer/pdfminer.six)

Expand Down
1 change: 1 addition & 0 deletions docs/download_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# use modelscope sdk download models
from modelscope import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
print(f"model dir is: {model_dir}/models")
1 change: 1 addition & 0 deletions docs/download_models_hf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from huggingface_hub import snapshot_download
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit')
layoutreader_model_dir = snapshot_download('hantian/layoutreader')
print(f"model dir is: {model_dir}/models")
7 changes: 7 additions & 0 deletions docs/how_to_download_models_zh_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ python脚本执行完毕后,会输出模型下载目录
如此前通过 git lfs 下载过模型文件,可以进入到之前的下载目录中,通过`git pull`命令更新模型。

> 0.9.x及以后版本由于新增layout排序模型,且该模型和此前的模型不在同一仓库,不能通过`git pull`命令更新,需要单独下载。
>
>```
>from modelscope import snapshot_download
>snapshot_download('ppaanngggg/layoutreader')
>```
## 2. 通过 Hugging Face 或 Model Scope 下载过模型
如此前通过 HuggingFace 或 Model Scope 下载过模型,可以重复执行此前的模型下载python脚本,将会自动将模型目录更新到最新版本。
1 change: 1 addition & 0 deletions magic-pdf.template.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"bucket-name-2":["ak", "sk", "endpoint"]
},
"models-dir":"/tmp/models",
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"table-config": {
"model": "TableMaster",
Expand Down
12 changes: 12 additions & 0 deletions magic_pdf/libs/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def get_local_models_dir():
return models_dir


def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get("layoutreader-model-dir")
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser("~")
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, ".cache/modelscope/hub/ppaanngggg/layoutreader")
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir


def get_device():
config = read_config()
device = config.get("device-mode")
Expand Down
19 changes: 11 additions & 8 deletions magic_pdf/pdf_parse_union_core_v2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import statistics
import time

Expand All @@ -9,6 +10,7 @@

from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
Expand Down Expand Up @@ -95,7 +97,7 @@ def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans


def model_init(model_name: str, local_path=None):
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
if torch.cuda.is_available():
device = torch.device("cuda")
Expand All @@ -108,9 +110,13 @@ def model_init(model_name: str, local_path=None):
supports_bfloat16 = False

if model_name == "layoutreader":
if local_path:
model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir()
if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained(layoutreader_model_dir)
else:
logger.warning(
f"local layoutreader model not exists, use online model from huggingface")
model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
# 检查设备是否支持 bfloat16
if supports_bfloat16:
Expand All @@ -131,12 +137,9 @@ def __new__(cls, *args, **kwargs):
cls._instance = super().__new__(cls)
return cls._instance

def get_model(self, model_name: str, local_path=None):
def get_model(self, model_name: str):
if model_name not in self._models:
if local_path:
self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
else:
self._models[model_name] = model_init(model_name=model_name)
self._models[model_name] = model_init(model_name=model_name)
return self._models[model_name]


Expand Down

0 comments on commit 8786d20

Please sign in to comment.