Skip to content

Commit

Permalink
增加ocr版本解析功能
Browse files Browse the repository at this point in the history
  • Loading branch information
myhloli committed Mar 6, 2024
1 parent 2e487ca commit 701f384
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 1 deletion.
29 changes: 29 additions & 0 deletions demo/ocr_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

from loguru import logger

from magic_pdf.dict2md.ocr_mkcontent import mk_nlp_markdown
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr


def save_markdown(markdown_text, input_filepath):
# 获取输入文件的目录
directory = os.path.dirname(input_filepath)
# 获取输入文件的文件名(不带扩展名)
base_name = os.path.basename(input_filepath)
file_name_without_ext = os.path.splitext(base_name)[0]
# 定义输出文件的路径
output_filepath = os.path.join(directory, f"{file_name_without_ext}.md")

# 将Markdown文本写入.md文件
with open(output_filepath, 'w', encoding='utf-8') as file:
file.write(markdown_text)


if __name__ == '__main__':
ocr_json_file_path = r"D:\project\20231108code-clean\ocr\new\demo_4\ocr_0.json"
pdf_info_dict = parse_pdf_by_ocr(ocr_json_file_path)
markdown_text = mk_nlp_markdown(pdf_info_dict)
logger.info(markdown_text)
save_markdown(markdown_text, ocr_json_file_path)

21 changes: 21 additions & 0 deletions magic_pdf/dict2md/ocr_mkcontent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def mk_nlp_markdown(pdf_info_dict: dict):

markdown = []

for _, page_info in pdf_info_dict.items():
blocks = page_info.get("preproc_blocks")
if not blocks:
continue
for block in blocks:
for line in block['lines']:
line_text = ''
for span in line['spans']:
content = span['content'].replace('$', '\$') # 转义$
if span['type'] == 'inline_equation':
content = f"${content}$"
elif span['type'] == 'displayed_equation':
content = f"$$\n{content}\n$$"
line_text += content + ' '
# 在行末添加两个空格以强制换行
markdown.append(line_text.strip() + ' ')
return '\n'.join(markdown)
34 changes: 33 additions & 1 deletion magic_pdf/libs/boxbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def __overlap_y(Ay1, Ay2, By1, By2):
return x0_1<=x0_2<=x1_1 and vertical_overlap_cond


def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2

overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
height1, height2 = y1_1 - y0_1, y1_2 - y0_2
max_height = max(height1, height2)
min_height = min(height1, height2)

return (overlap / min_height) > overlap_ratio_threshold



def calculate_iou(bbox1, bbox2):
# Determine the coordinates of the intersection rectangle
x_left = max(bbox1[0], bbox2[0])
Expand Down Expand Up @@ -163,7 +177,25 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
else:
return intersection_area / min_box_area



def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
"""
通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox,
否则返回None
"""
x1_min, y1_min, x1_max, y1_max = bbox1
x2_min, y2_min, x2_max, y2_max = bbox2
area1 = (x1_max - x1_min) * (y1_max - y1_min)
area2 = (x2_max - x2_min) * (y2_max - y2_min)
overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
if overlap_ratio > ratio and area1 < area2:
return bbox1
elif overlap_ratio > ratio and area2 < area1:
return bbox2
else:
return None

def get_bbox_in_boundry(bboxes:list, boundry:tuple)-> list:
x0, y0, x1, y1 = boundry
new_boxes = [box for box in bboxes if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1]
Expand Down
46 changes: 46 additions & 0 deletions magic_pdf/libs/ocr_dict_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold


def merge_spans(spans):
# 按照y0坐标排序
spans.sort(key=lambda span: span['bbox'][1])

lines = []
current_line = [spans[0]]
for span in spans[1:]:
# 如果当前的span类型为"displayed_equation" 或者 当前行中已经有"displayed_equation"
if span['type'] == "displayed_equation" or any(s['type'] == "displayed_equation" for s in current_line):
# 则开始新行
lines.append(current_line)
current_line = [span]
continue

# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']):
current_line.append(span)
else:
# 否则,开始新行
lines.append(current_line)
current_line = [span]

# 添加最后一行
if current_line:
lines.append(current_line)

# 计算每行的边界框,并对每行中的span按照x0进行排序
line_objects = []
for line in lines:
# 按照x0坐标排序
line.sort(key=lambda span: span['bbox'][0])
line_bbox = [
min(span['bbox'][0] for span in line), # x0
min(span['bbox'][1] for span in line), # y0
max(span['bbox'][2] for span in line), # x1
max(span['bbox'][3] for span in line), # y1
]
line_objects.append({
"bbox": line_bbox,
"spans": line,
})

return line_objects
85 changes: 85 additions & 0 deletions magic_pdf/pdf_parse_by_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import json

from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
from magic_pdf.libs.ocr_dict_merge import merge_spans


def read_json_file(file_path):
with open(file_path, 'r') as f:
data = json.load(f)
return data


def construct_page_component(page_id, text_blocks_preproc):
return_dict = {
'preproc_blocks': text_blocks_preproc,
'page_idx': page_id
}
return return_dict


def parse_pdf_by_ocr(
ocr_json_file_path,
start_page_id=0,
end_page_id=None,
):
ocr_pdf_info = read_json_file(ocr_json_file_path)
pdf_info_dict = {}
end_page_id = end_page_id if end_page_id else len(ocr_pdf_info) - 1
for page_id in range(start_page_id, end_page_id + 1):
ocr_page_info = ocr_pdf_info[page_id]
layout_dets = ocr_page_info['layout_dets']
spans = []
for layout_det in layout_dets:
category_id = layout_det['category_id']
allow_category_id_list = [13, 14, 15]
if category_id in allow_category_id_list:
x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
bbox = [int(x0), int(y0), int(x1), int(y1)]
# 13: 'embedding', # 嵌入公式
# 14: 'isolated', # 单行公式
# 15: 'ocr_text', # ocr识别文本
span = {
'bbox': bbox,
}
if category_id == 13:
span['content'] = layout_det['latex']
span['type'] = 'inline_equation'
elif category_id == 14:
span['content'] = layout_det['latex']
span['type'] = 'displayed_equation'
elif category_id == 15:
span['content'] = layout_det['text']
span['type'] = 'text'
# print(span)
spans.append(span)
else:
continue

# 合并重叠的spans
for span1 in spans.copy():
for span2 in spans.copy():
if span1 != span2:
overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.8)
if overlap_box is not None:
bbox_to_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
if bbox_to_remove is not None:
spans.remove(bbox_to_remove)

# 将spans合并成line
lines = merge_spans(spans)

# 目前不做block拼接,先做个结构,每个block中只有一个line,block的bbox就是line的bbox
blocks = []
for line in lines:
blocks.append({
"bbox": line['bbox'],
"lines": [line],
})

# 构造pdf_info_dict
page_info = construct_page_component(page_id, blocks)
pdf_info_dict[f"page_{page_id}"] = page_info

return pdf_info_dict

0 comments on commit 701f384

Please sign in to comment.