Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(parse_pipeline): Resolve post-processing exceptions caused by partial PDFs due to file corruption or non-standard format by forcing a re-print. #957

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions magic_pdf/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
import fitz
# from io import BytesIO
# from pypdf import PdfReader, PdfWriter


def prepare_env(output_dir, pdf_file_name, method):
Expand All @@ -26,6 +29,42 @@ def prepare_env(output_dir, pdf_file_name, method):
return local_image_dir, local_md_dir


# def convert_pdf_bytes_to_bytes_by_pypdf(pdf_bytes, start_page_id=0, end_page_id=None):
# # 将字节数据包装在 BytesIO 对象中
# pdf_file = BytesIO(pdf_bytes)
# # 读取 PDF 的字节数据
# reader = PdfReader(pdf_file)
# # 创建一个新的 PDF 写入器
# writer = PdfWriter()
# # 将所有页面添加到新的 PDF 写入器中
# end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(reader.pages) - 1
# if end_page_id > len(reader.pages) - 1:
# logger.warning("end_page_id is out of range, use pdf_docs length")
# end_page_id = len(reader.pages) - 1
# for i, page in enumerate(reader.pages):
# if start_page_id <= i <= end_page_id:
# writer.add_page(page)
# # 创建一个字节缓冲区来存储输出的 PDF 数据
# output_buffer = BytesIO()
# # 将 PDF 写入字节缓冲区
# writer.write(output_buffer)
# # 获取字节缓冲区的内容
# converted_pdf_bytes = output_buffer.getvalue()
# return converted_pdf_bytes


def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
document = fitz.open("pdf", pdf_bytes)
output_document = fitz.open()
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
if end_page_id > len(document) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(document) - 1
output_document.insert_pdf(document, from_page=start_page_id, to_page=end_page_id)
output_bytes = output_document.tobytes()
return output_bytes


def do_parse(
output_dir,
pdf_file_name,
Expand Down Expand Up @@ -55,6 +94,8 @@ def do_parse(
f_draw_model_bbox = True
f_draw_line_sort_bbox = True

pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id, end_page_id)

orig_model_list = copy.deepcopy(model_list)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
parse_method)
Expand All @@ -66,15 +107,18 @@ def do_parse(
if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
# start_page_id=start_page_id, end_page_id=end_page_id,
lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
# start_page_id=start_page_id, end_page_id=end_page_id,
lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id, lang=lang,
# start_page_id=start_page_id, end_page_id=end_page_id,
lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else:
logger.error('unknown parse method')
Expand Down
Loading