Skip to content

Commit

Permalink
Merge pull request #672 from myhloli/add-layoutreader
Browse files Browse the repository at this point in the history
feat:add layoutreader to sort blocks
  • Loading branch information
myhloli authored Sep 30, 2024
2 parents 9c38c88 + fcf2424 commit bcbee13
Show file tree
Hide file tree
Showing 11 changed files with 742 additions and 42 deletions.
10 changes: 10 additions & 0 deletions magic_pdf/libs/clean_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Opendatalab. All rights reserved.
import torch
import gc


def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
124 changes: 88 additions & 36 deletions magic_pdf/libs/draw_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def draw_bbox_without_number(i, bbox_list, page, rgb_config, fill_config):
) # Draw the rectangle


def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config, draw_bbox=True):
new_rgb = []
for item in rgb_config:
item = float(item) / 255
Expand All @@ -42,31 +42,31 @@ def draw_bbox_with_number(i, bbox_list, page, rgb_config, fill_config):
for j, bbox in enumerate(page_data):
x0, y0, x1, y1 = bbox
rect_coords = fitz.Rect(x0, y0, x1, y1) # Define the rectangle
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
if draw_bbox:
if fill_config:
page.draw_rect(
rect_coords,
color=None,
fill=new_rgb,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=new_rgb,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
page.insert_text(
(x0, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
(x1+2, y0 + 10), str(j + 1), fontsize=10, color=new_rgb
) # Insert the index in the top left corner of the rectangle


def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []
dropped_bbox_list = []
tables_list, tables_body_list = [], []
tables_caption_list, tables_footnote_list = [], []
Expand All @@ -76,16 +76,14 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts_list = []
interequations_list = []
for page in pdf_info:
page_layout_list = []

page_dropped_list = []
tables, tables_body, tables_caption, tables_footnote = [], [], [], []
imgs, imgs_body, imgs_caption, imgs_footnote = [], [], [], []
titles = []
texts = []
interequations = []
for layout in page['layout_bboxes']:
page_layout_list.append(layout['layout_bbox'])
layout_bbox_list.append(page_layout_list)

for dropped_bbox in page['discarded_blocks']:
page_dropped_list.append(dropped_bbox['bbox'])
dropped_bbox_list.append(page_dropped_list)
Expand Down Expand Up @@ -129,9 +127,19 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
texts_list.append(texts)
interequations_list.append(interequations)

layout_bbox_list = []

for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)

pdf_docs = fitz.open('pdf', pdf_bytes)

for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)

draw_bbox_without_number(i, dropped_bbox_list, page, [158, 158, 158],
True)
draw_bbox_without_number(i, tables_list, page, [153, 153, 0],
Expand All @@ -146,13 +154,15 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, imgs_body_list, page, [153, 255, 51], True)
draw_bbox_without_number(i, imgs_caption_list, page, [102, 178, 255],
True)
draw_bbox_with_number(i, imgs_footnote_list, page, [255, 178, 102],
draw_bbox_without_number(i, imgs_footnote_list, page, [255, 178, 102],
True),
draw_bbox_without_number(i, titles_list, page, [102, 102, 255], True)
draw_bbox_without_number(i, texts_list, page, [153, 0, 76], True)
draw_bbox_without_number(i, interequations_list, page, [0, 255, 0],
True)

draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False, draw_bbox=False)

# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf')

Expand Down Expand Up @@ -211,9 +221,9 @@ def get_span_info(span):
# 构造其余useful_list
for block in page['para_blocks']:
if block['type'] in [
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
BlockType.Text,
BlockType.Title,
BlockType.InterlineEquation,
]:
for line in block['lines']:
for span in line['spans']:
Expand All @@ -232,10 +242,8 @@ def get_span_info(span):
for i, page in enumerate(pdf_docs):
# 获取当前页面的数据
draw_bbox_without_number(i, text_list, page, [255, 0, 0], False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0],
False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255],
False)
draw_bbox_without_number(i, inline_equation_list, page, [0, 255, 0], False)
draw_bbox_without_number(i, interline_equation_list, page, [0, 0, 255], False)
draw_bbox_without_number(i, image_list, page, [255, 204, 0], False)
draw_bbox_without_number(i, table_list, page, [204, 0, 255], False)
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
Expand All @@ -244,7 +252,7 @@ def get_span_info(span):
pdf_docs.save(f'{out_path}/{filename}_spans.pdf')


def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
Expand Down Expand Up @@ -279,7 +287,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
elif layout_det['category_id'] == CategoryId.ImageCaption:
imgs_caption.append(bbox)
elif layout_det[
'category_id'] == CategoryId.InterlineEquation_YOLO:
'category_id'] == CategoryId.InterlineEquation_YOLO:
interequations.append(bbox)
elif layout_det['category_id'] == CategoryId.Abandon:
page_dropped_list.append(bbox)
Expand Down Expand Up @@ -316,3 +324,47 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):

# Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf')


def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []

for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in ['text', 'title', 'interline_equation']:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in ['table', 'image']:
bbox = block['bbox']
index = block['index']
page_line_list.append({'index': index, 'bbox': bbox})
# for line in block['lines']:
# bbox = line['bbox']
# index = line['index']
# page_line_list.append({'index': index, 'bbox': bbox})
sorted_bboxes = sorted(page_line_list, key=lambda x: x['index'])
layout_bbox_list.append(sorted_bbox['bbox'] for sorted_bbox in sorted_bboxes)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)

pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')


def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
layout_bbox_list = []

for page in pdf_info:
page_block_list = []
for block in page['para_blocks']:
bbox = block['bbox']
page_block_list.append(bbox)
layout_bbox_list.append(page_block_list)
pdf_docs = fitz.open('pdf', pdf_bytes)
for i, page in enumerate(pdf_docs):
draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)

pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')
3 changes: 3 additions & 0 deletions magic_pdf/model/pdf_extract_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

from magic_pdf.libs.Constants import *
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel

os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
Expand Down Expand Up @@ -330,6 +331,8 @@ def __call__(self, image):
elif int(res['category_id']) in [5]:
table_res_list.append(res)

clean_memory()

# ocr识别
if self.apply_ocr:
ocr_start = time.time()
Expand Down
Empty file added magic_pdf/model/v3/__init__.py
Empty file.
125 changes: 125 additions & 0 deletions magic_pdf/model/v3/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from collections import defaultdict
from typing import List, Dict

import torch
from transformers import LayoutLMv3ForTokenClassification

MAX_LEN = 510
CLS_TOKEN_ID = 0
UNK_TOKEN_ID = 3
EOS_TOKEN_ID = 2


class DataCollator:
def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
bbox = []
labels = []
input_ids = []
attention_mask = []

# clip bbox and labels to max length, build input_ids and attention_mask
for feature in features:
_bbox = feature["source_boxes"]
if len(_bbox) > MAX_LEN:
_bbox = _bbox[:MAX_LEN]
_labels = feature["target_index"]
if len(_labels) > MAX_LEN:
_labels = _labels[:MAX_LEN]
_input_ids = [UNK_TOKEN_ID] * len(_bbox)
_attention_mask = [1] * len(_bbox)
assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
bbox.append(_bbox)
labels.append(_labels)
input_ids.append(_input_ids)
attention_mask.append(_attention_mask)

# add CLS and EOS tokens
for i in range(len(bbox)):
bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
labels[i] = [-100] + labels[i] + [-100]
input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
attention_mask[i] = [1] + attention_mask[i] + [1]

# padding to max length
max_len = max(len(x) for x in bbox)
for i in range(len(bbox)):
bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
attention_mask[i] = attention_mask[i] + [0] * (
max_len - len(attention_mask[i])
)

ret = {
"bbox": torch.tensor(bbox),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
"input_ids": torch.tensor(input_ids),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret["labels"][ret["labels"] > MAX_LEN] = -100
# set label > 0 to label-1, because original labels are 1-indexed
ret["labels"][ret["labels"] > 0] -= 1
return ret


def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
attention_mask = [1] + [1] * len(boxes) + [1]
return {
"bbox": torch.tensor([bbox]),
"attention_mask": torch.tensor([attention_mask]),
"input_ids": torch.tensor([input_ids]),
}


def prepare_inputs(
inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
) -> Dict[str, torch.Tensor]:
ret = {}
for k, v in inputs.items():
v = v.to(model.device)
if torch.is_floating_point(v):
v = v.to(model.dtype)
ret[k] = v
return ret


def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits = logits[1 : length + 1, :length]
orders = logits.argsort(descending=False).tolist()
ret = [o.pop() for o in orders]
while True:
order_to_idxes = defaultdict(list)
for idx, order in enumerate(ret):
order_to_idxes[order].append(idx)
# filter idxes len > 1
order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
if not order_to_idxes:
break
# filter
for order, idxes in order_to_idxes.items():
# find original logits of idxes
idxes_to_logit = {}
for idx in idxes:
idxes_to_logit[idx] = logits[idx, order]
idxes_to_logit = sorted(
idxes_to_logit.items(), key=lambda x: x[1], reverse=True
)
# keep the highest logit as order, set others to next candidate
for idx, _ in idxes_to_logit[1:]:
ret[idx] = orders[idx].pop()

return ret


def check_duplicate(a: List[int]) -> bool:
return len(a) != len(set(a))
2 changes: 1 addition & 1 deletion magic_pdf/pdf_parse_by_ocr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from magic_pdf.pdf_parse_union_core import pdf_parse_union
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union


def parse_pdf_by_ocr(pdf_bytes,
Expand Down
2 changes: 1 addition & 1 deletion magic_pdf/pdf_parse_by_txt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from magic_pdf.pdf_parse_union_core import pdf_parse_union
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union


def parse_pdf_by_txt(
Expand Down
Loading

0 comments on commit bcbee13

Please sign in to comment.