Skip to content

Commit

Permalink
Improve partitioning quality by adding special case to `_supplement_t…
Browse files Browse the repository at this point in the history
…ext` for `List-item`s (#1214)
  • Loading branch information
MarkLindblad authored Mar 7, 2025
1 parent 4a9e52e commit 07c7ac3
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions lib/sycamore/sycamore/transforms/detr_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,18 @@ def _init_model(self):

@staticmethod
def _supplement_text(inferred: list[Element], text: list[Element], threshold: float = 0.5) -> list[Element]:
# We first check IOU between inferred object and pdf miner text object, we also check if a detected object
# fully contains a pdf miner text object. After that, we combined all texts belonging a detected object and
# update its text representation. We allow multiple detected objects contain the same text, we hold on solving
# this.
"""
Associates extracted text with inferred objects. Meant to be called pagewise. Uses complete containment (the
text's bbox is fully within the inferred object's bbox), IOU (intersection over union), and IOB (intersection
over bounding box) to determine if a text object is associated with an inferred object. We allow multiple
detected objects to contain the same text, we are holding on solving this.
Once all text that can be associated has been, the text representation of the inferred object is updated to
incorporate its associated text.
In order to handle list items properly, we treat them as a special case.
"""
logger.info("running _supplement_text")

unmatched = text.copy()
for index_i, i in enumerate(inferred):
Expand All @@ -122,10 +130,16 @@ def _supplement_text(inferred: list[Element], text: list[Element], threshold: fl
matches = []
full_text = []
font_sizes = []
for m in matched:
is_list_item = i.type == "List-item"
num_matched = len(matched)
for m_index, m in enumerate(matched):
matches.append(m)
if m.text_representation:
full_text.append(m.text_representation)
if text_to_add := m.text_representation:
if (
is_list_item and m_index + 1 < num_matched and text_to_add[-1] == "\n"
): # special case for list items
text_to_add = text_to_add[:-1]
full_text.append(text_to_add)
if font_size := m.properties.get("font_size"):
font_sizes.append(font_size)
if isinstance(i, TableElement):
Expand Down

0 comments on commit 07c7ac3

Please sign in to comment.