Skip to content

Commit

Permalink
Merge pull request #898 from myhloli/fix-line-over-512
Browse files Browse the repository at this point in the history
feat(model): add xycut algorithm for block sorting
  • Loading branch information
myhloli authored Nov 7, 2024
2 parents 54844a5 + 7d5850e commit 2600d32
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 21 deletions.
14 changes: 10 additions & 4 deletions magic_pdf/libs/draw_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
if block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']:
bbox = line['bbox']
Expand Down
242 changes: 242 additions & 0 deletions magic_pdf/model/v3/xycut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from typing import List
import cv2
import numpy as np


def projection_by_bboxes(boxes: np.array, axis: int) -> np.ndarray:
"""
通过一组 bbox 获得投影直方图,最后以 per-pixel 形式输出
Args:
boxes: [N, 4]
axis: 0-x坐标向水平方向投影, 1-y坐标向垂直方向投影
Returns:
1D 投影直方图,长度为投影方向坐标的最大值(我们不需要图片的实际边长,因为只是要找文本框的间隔)
"""
assert axis in [0, 1]
length = np.max(boxes[:, axis::2])
res = np.zeros(length, dtype=int)
# TODO: how to remove for loop?
for start, end in boxes[:, axis::2]:
res[start:end] += 1
return res


# from: https://dothinking.github.io/2021-06-19-%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%E7%AE%97%E6%B3%95/#:~:text=%E9%80%92%E5%BD%92%E6%8A%95%E5%BD%B1%E5%88%86%E5%89%B2%EF%BC%88Recursive%20XY,%EF%BC%8C%E5%8F%AF%E4%BB%A5%E5%88%92%E5%88%86%E6%AE%B5%E8%90%BD%E3%80%81%E8%A1%8C%E3%80%82
def split_projection_profile(arr_values: np.array, min_value: float, min_gap: float):
"""Split projection profile:
```
┌──┐
arr_values │ │ ┌─┐───
┌──┐ │ │ │ │ |
│ │ │ │ ┌───┐ │ │min_value
│ │<- min_gap ->│ │ │ │ │ │ |
────┴──┴─────────────┴──┴─┴───┴─┴─┴─┴───
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
```
Args:
arr_values (np.array): 1-d array representing the projection profile.
min_value (float): Ignore the profile if `arr_value` is less than `min_value`.
min_gap (float): Ignore the gap if less than this value.
Returns:
tuple: Start indexes and end indexes of split groups.
"""
# all indexes with projection height exceeding the threshold
arr_index = np.where(arr_values > min_value)[0]
if not len(arr_index):
return

# find zero intervals between adjacent projections
# | | ||
# ||||<- zero-interval -> |||||
arr_diff = arr_index[1:] - arr_index[0:-1]
arr_diff_index = np.where(arr_diff > min_gap)[0]
arr_zero_intvl_start = arr_index[arr_diff_index]
arr_zero_intvl_end = arr_index[arr_diff_index + 1]

# convert to index of projection range:
# the start index of zero interval is the end index of projection
arr_start = np.insert(arr_zero_intvl_end, 0, arr_index[0])
arr_end = np.append(arr_zero_intvl_start, arr_index[-1])
arr_end += 1 # end index will be excluded as index slice

return arr_start, arr_end


def recursive_xy_cut(boxes: np.ndarray, indices: List[int], res: List[int]):
"""
Args:
boxes: (N, 4)
indices: 递归过程中始终表示 box 在原始数据中的索引
res: 保存输出结果
"""
# 向 y 轴投影
assert len(boxes) == len(indices)

_indices = boxes[:, 1].argsort()
y_sorted_boxes = boxes[_indices]
y_sorted_indices = indices[_indices]

# debug_vis(y_sorted_boxes, y_sorted_indices)

y_projection = projection_by_bboxes(boxes=y_sorted_boxes, axis=1)
pos_y = split_projection_profile(y_projection, 0, 1)
if not pos_y:
return

arr_y0, arr_y1 = pos_y
for r0, r1 in zip(arr_y0, arr_y1):
# [r0, r1] 表示按照水平切分,有 bbox 的区域,对这些区域会再进行垂直切分
_indices = (r0 <= y_sorted_boxes[:, 1]) & (y_sorted_boxes[:, 1] < r1)

y_sorted_boxes_chunk = y_sorted_boxes[_indices]
y_sorted_indices_chunk = y_sorted_indices[_indices]

_indices = y_sorted_boxes_chunk[:, 0].argsort()
x_sorted_boxes_chunk = y_sorted_boxes_chunk[_indices]
x_sorted_indices_chunk = y_sorted_indices_chunk[_indices]

# 往 x 方向投影
x_projection = projection_by_bboxes(boxes=x_sorted_boxes_chunk, axis=0)
pos_x = split_projection_profile(x_projection, 0, 1)
if not pos_x:
continue

arr_x0, arr_x1 = pos_x
if len(arr_x0) == 1:
# x 方向无法切分
res.extend(x_sorted_indices_chunk)
continue

# x 方向上能分开,继续递归调用
for c0, c1 in zip(arr_x0, arr_x1):
_indices = (c0 <= x_sorted_boxes_chunk[:, 0]) & (
x_sorted_boxes_chunk[:, 0] < c1
)
recursive_xy_cut(
x_sorted_boxes_chunk[_indices], x_sorted_indices_chunk[_indices], res
)


def points_to_bbox(points):
assert len(points) == 8

# [x1,y1,x2,y2,x3,y3,x4,y4]
left = min(points[::2])
right = max(points[::2])
top = min(points[1::2])
bottom = max(points[1::2])

left = max(left, 0)
top = max(top, 0)
right = max(right, 0)
bottom = max(bottom, 0)
return [left, top, right, bottom]


def bbox2points(bbox):
left, top, right, bottom = bbox
return [left, top, right, top, right, bottom, left, bottom]


def vis_polygon(img, points, thickness=2, color=None):
br2bl_color = color
tl2tr_color = color
tr2br_color = color
bl2tl_color = color
cv2.line(
img,
(points[0][0], points[0][1]),
(points[1][0], points[1][1]),
color=tl2tr_color,
thickness=thickness,
)

cv2.line(
img,
(points[1][0], points[1][1]),
(points[2][0], points[2][1]),
color=tr2br_color,
thickness=thickness,
)

cv2.line(
img,
(points[2][0], points[2][1]),
(points[3][0], points[3][1]),
color=br2bl_color,
thickness=thickness,
)

cv2.line(
img,
(points[3][0], points[3][1]),
(points[0][0], points[0][1]),
color=bl2tl_color,
thickness=thickness,
)
return img


def vis_points(
img: np.ndarray, points, texts: List[str] = None, color=(0, 200, 0)
) -> np.ndarray:
"""
Args:
img:
points: [N, 8] 8: x1,y1,x2,y2,x3,y3,x3,y4
texts:
color:
Returns:
"""
points = np.array(points)
if texts is not None:
assert len(texts) == points.shape[0]

for i, _points in enumerate(points):
vis_polygon(img, _points.reshape(-1, 2), thickness=2, color=color)
bbox = points_to_bbox(_points)
left, top, right, bottom = bbox
cx = (left + right) // 2
cy = (top + bottom) // 2

txt = texts[i]
font = cv2.FONT_HERSHEY_SIMPLEX
cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]

img = cv2.rectangle(
img,
(cx - 5 * len(txt), cy - cat_size[1] - 5),
(cx - 5 * len(txt) + cat_size[0], cy - 5),
color,
-1,
)

img = cv2.putText(
img,
txt,
(cx - 5 * len(txt), cy - 5),
font,
0.5,
(255, 255, 255),
thickness=1,
lineType=cv2.LINE_AA,
)

return img


def vis_polygons_with_index(image, points):
texts = [str(i) for i in range(len(points))]
res_img = vis_points(image.copy(), points, texts)
return res_img
71 changes: 54 additions & 17 deletions magic_pdf/pdf_parse_union_core_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
fix_block_spans,
fix_discarded_block, fix_block_spans_v2)
fix_discarded_block,
fix_block_spans_v2)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
Expand Down Expand Up @@ -174,23 +174,57 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:


def cal_block_index(fix_blocks, sorted_bboxes):
for block in fix_blocks:

line_index_list = []
if len(block['lines']) == 0:
block['index'] = sorted_bboxes.index(block['bbox'])
else:
if sorted_bboxes is not None:
# 使用layoutreader排序
for block in fix_blocks:
line_index_list = []
if len(block['lines']) == 0:
block['index'] = sorted_bboxes.index(block['bbox'])
else:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
line_index_list.append(line['index'])
median_value = statistics.median(line_index_list)
block['index'] = median_value

# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else:
# 使用xycut排序
block_bboxes = []
for block in fix_blocks:
block_bboxes.append(block['bbox'])

# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']

import numpy as np
from magic_pdf.model.v3.xycut import recursive_xy_cut

random_boxes = np.array(block_bboxes)
np.random.shuffle(random_boxes)
res = []
recursive_xy_cut(np.asarray(random_boxes).astype(int), np.arange(len(block_bboxes)), res)
assert len(res) == len(block_bboxes)
sorted_boxes = random_boxes[np.array(res)].tolist()

for i, block in enumerate(fix_blocks):
block['index'] = sorted_boxes.index(block['bbox'])

# 生成line index
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
line_inedx = 1
for block in sorted_blocks:
for line in block['lines']:
line['index'] = sorted_bboxes.index(line['bbox'])
line_index_list.append(line['index'])
median_value = statistics.median(line_index_list)
block['index'] = median_value

# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
line['index'] = line_inedx
line_inedx += 1

return fix_blocks

Expand Down Expand Up @@ -264,6 +298,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)

if len(page_line_list) > 512: # layoutreader最高支持512line
return None

# 使用layoutreader排序
x_scale = 1000.0 / page_w
y_scale = 1000.0 / page_h
Expand Down

0 comments on commit 2600d32

Please sign in to comment.