Skip to content

Commit

Permalink
feat(model): improve batch analysis logic and support npu
Browse files Browse the repository at this point in the history
- Add support for NPU (Neural Processing Unit) when available
- Implement batch analysis for GPU and NPU devices
- Optimize memory usage and improve performance
- Update logging and error handling
  • Loading branch information
myhloli committed Jan 15, 2025
1 parent 84f808f commit f350222
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 104 deletions.
172 changes: 87 additions & 85 deletions magic_pdf/model/batch_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
from PIL import Image

from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.operators.models import InferenceResult
# from magic_pdf.operators.models import InferenceResult

YOLO_LAYOUT_BASE_BATCH_SIZE = 4
MFD_BASE_BATCH_SIZE = 1
Expand Down Expand Up @@ -91,10 +91,12 @@ def __call__(self, images: list) -> list:
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
)

# 清理显存
Expand Down Expand Up @@ -195,81 +197,81 @@ def __call__(self, images: list) -> list:
return images_layout_res


def doc_batch_analyze(
dataset: Dataset,
ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
batch_ratio: int | None = None,
) -> InferenceResult:
"""Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""

if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')

lang = None if lang == '' else lang
# TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset)

model_manager = ModelSingleton()
custom_model: CustomPEKModel = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)

model_json = []

# batch analyze
images = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
analyze_result = batch_model(images)

for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []

page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)

# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')

return InferenceResult(model_json, dataset)
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
84 changes: 66 additions & 18 deletions magic_pdf/model/doc_analyze_by_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

# 关闭paddle的信号处理
import paddle
import torch
from loguru import logger

from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram

paddle.disable_signal_handler()

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
Expand Down Expand Up @@ -154,33 +158,77 @@ def doc_analyze(
table_enable=None,
) -> InferenceResult:

end_page_id = end_page_id if end_page_id else len(dataset)

model_manager = ModelSingleton()
custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)

batch_analyze = False
device = get_device()

npu_support = False
if str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
npu_support = True

if torch.cuda.is_available() and device != 'cpu' or npu_support:
gpu_memory = get_vram(device)
if gpu_memory is not None and gpu_memory >= 7:
batch_ratio = int((gpu_memory-3) // 1.5)
if batch_ratio >= 1:
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
batch_analyze = True

model_json = []
doc_analyze_start = time.time()

if end_page_id is None:
end_page_id = len(dataset)

for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
if batch_analyze:
# batch analyze
images = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
analyze_result = batch_model(images)

for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []

page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)

page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
else:
# single analyze

for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []

page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)

gc_start = time.time()
clean_memory(get_device())
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __call__(self, image):
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')

# 清理显存
clean_vram(self.device, vram_threshold=8)
clean_vram(self.device, vram_threshold=6)

# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
Expand Down

0 comments on commit f350222

Please sign in to comment.