Skip to content

Commit

Permalink
Merge pull request #740 from myhloli/para-split-v3
Browse files Browse the repository at this point in the history
feat(list&index block): detect and merge list and index blocks
  • Loading branch information
myhloli authored Oct 14, 2024
2 parents c479245 + 1f1dd35 commit 702b6ac
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 20 deletions.
33 changes: 19 additions & 14 deletions magic_pdf/dict2md/ocr_mkcontent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
from magic_pdf.para.para_split_v3 import ListLineTag


def __is_hyphen_at_line_end(line):
Expand Down Expand Up @@ -124,7 +125,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type == BlockType.Text:
if para_type in [BlockType.Text, BlockType.List, BlockType.Index]:
para_text = merge_para_with_text(para_block, parse_type=parse_type, lang=lang)
elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block, parse_type=parse_type, lang=lang)}'
Expand Down Expand Up @@ -177,22 +178,26 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
return page_markdown


def merge_para_with_text(para_block, parse_type="auto", lang=None):

def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'unknown'
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'empty'
return 'unknown'
else:
return 'empty'


def merge_para_with_text(para_block, parse_type="auto", lang=None):
para_text = ''
for line in para_block['lines']:
for i, line in enumerate(para_block['lines']):

if i >= 1 and line.get(ListLineTag.IS_LIST_START_LINE, False):
para_text += ' \n'

line_text = ''
line_lang = ''
for span in line['spans']:
Expand Down
13 changes: 13 additions & 0 deletions magic_pdf/libs/draw_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list = []
texts_list = []
interequations_list = []
lists_list = []
indexs_list = []
for page in pdf_info:

page_dropped_list = []
Expand All @@ -83,6 +85,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles = []
texts = []
interequations = []
lists = []
indexs = []

for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
Expand Down Expand Up @@ -115,6 +119,11 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts.append(bbox)
elif block['type'] == BlockType.InterlineEquation:
interequations.append(bbox)
elif block['type'] == BlockType.List:
lists.append(bbox)
elif block['type'] == BlockType.Index:
indexs.append(bbox)

tables_list.append(tables)
tables_body_list.append(tables_body)
tables_caption_list.append(tables_caption)
Expand All @@ -126,6 +135,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
titles_list.append(titles)
texts_list.append(texts)
interequations_list.append(interequations)
lists_list.append(lists)
indexs_list.append(indexs)

layout_bbox_list = []

Expand Down Expand Up @@ -160,6 +171,8 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
True)
draw_bbox_without_number(i, lists_list, page, [40, 169, 92], True)
draw_bbox_without_number(i, indexs_list, page, [40, 169, 92], True)

draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)

Expand Down
2 changes: 2 additions & 0 deletions magic_pdf/libs/ocr_content_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class BlockType:
InterlineEquation = 'interline_equation'
Footnote = 'footnote'
Discarded = 'discarded'
List = 'list'
Index = 'index'


class CategoryId:
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def __call__(self, image):
if torch.cuda.is_available():
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 8:
if total_memory <= 10:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
Expand Down
162 changes: 160 additions & 2 deletions magic_pdf/para/para_split_v3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import copy

from loguru import logger

from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, ContentType

LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
LIST_END_FLAG = ('.', '。', ';', ';')


class ListLineTag:
IS_LIST_START_LINE = "is_list_start_line"
IS_LIST_END_LINE = "is_list_end_line"


def __process_blocks(blocks):
Expand Down Expand Up @@ -38,7 +47,127 @@ def __process_blocks(blocks):
return result


def __merge_2_blocks(block1, block2):
def __is_list_block(block):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
if len(block['lines']) >= 3:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]

left_close_num = 0
left_not_close_num = 0
right_not_close_num = 0
lines_text_list = []
for line in block['lines']:

line_text = ""

for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()

lines_text_list.append(line_text)

# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2:
left_close_num += 1
elif line['bbox'][0] - block['bbox_fs'][0] > line_height:
# logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
left_not_close_num += 1

# 计算右侧是否不顶格,拍脑袋用0.3block宽度做阈值
closed_area = 0.3 * block_weight
# closed_area = 5 * line_height
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1

# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag = False
if len(lines_text_list) > 0:
num_end_count = 0
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[-1] in LIST_END_FLAG:
num_end_count += 1

if num_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True

if left_close_num >= 2 and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2):
for line in block['lines']:
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
line[ListLineTag.IS_LIST_START_LINE] = True
if abs(block['bbox_fs'][2] - line['bbox'][2]) > line_height:
line[ListLineTag.IS_LIST_END_LINE] = True

return True
else:
return False
else:
return False


def __is_index_block(block):
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if len(block['lines']) >= 3:
first_line = block['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]

left_close_num = 0
right_close_num = 0

lines_text_list = []
for line in block['lines']:

# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height / 2:
left_close_num += 1

# 计算右侧是否不顶格
if abs(block['bbox_fs'][2] - line['bbox'][2]) < line_height / 2:
right_close_num += 1

line_text = ""

for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()

lines_text_list.append(line_text)

# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag = False
if len(lines_text_list) > 0:
num_start_count = 0
num_end_count = 0
for line_text in lines_text_list:
if len(line_text) > 0:
if line_text[0].isdigit():
num_start_count += 1
if line_text[-1].isdigit():
num_end_count += 1

if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
line_num_flag = True

if left_close_num >= 2 and right_close_num >= 2 and line_num_flag:
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True

return True
else:
return False
else:
return False


def __merge_2_text_blocks(block1, block2):
if len(block1['lines']) > 0:
first_line = block1['lines'][0]
line_height = first_line['bbox'][3] - first_line['bbox'][1]
Expand All @@ -59,17 +188,46 @@ def __merge_2_blocks(block1, block2):
return block1, block2


def __merge_2_list_blocks(block1, block2):

if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True

return block1, block2


def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:

if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list block
for block in text_blocks_group:
if __is_list_block(block):
block['type'] = BlockType.List
elif __is_index_block(block):
block['type'] = BlockType.Index

if len(text_blocks_group) > 1:
# 倒序遍历
for i in range(len(text_blocks_group)-1, -1, -1):
current_block = text_blocks_group[i]

# 检查是否有前一个块
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
__merge_2_blocks(current_block, prev_block)

if current_block['type'] == 'text' and prev_block['type'] == 'text':
__merge_2_text_blocks(current_block, prev_block)
if current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List:
__merge_2_list_blocks(current_block, prev_block)
if current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index:
__merge_2_list_blocks(current_block, prev_block)
else:
continue

Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/pre_proc/ocr_detect_all_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
'''将剩余的bbox做分离处理,防止后面分layout时出错'''
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)

return all_bboxes, all_discarded_blocks

Expand Down
3 changes: 1 addition & 2 deletions magic_pdf/pre_proc/ocr_dict_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue

# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'],
current_line[-1]['bbox']):
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], 0.6):
current_line.append(span)
else:
# 否则,开始新行
Expand Down

0 comments on commit 702b6ac

Please sign in to comment.