-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Matteo Omenetti <[email protected]>
- Loading branch information
1 parent
d5b2c07
commit 6206687
Showing
8 changed files
with
398 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
import re | ||
from pathlib import Path | ||
from typing import Iterable, List, Literal, Optional, Tuple | ||
|
||
from docling_core.types.doc import CodeItem, DoclingDocument, NodeItem, TextItem | ||
from docling_core.types.doc.base import BoundingBox | ||
from docling_core.types.doc.labels import CodeLanguageLabel, DocItemLabel | ||
from PIL import Image | ||
from pydantic import BaseModel | ||
|
||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement | ||
from docling.datamodel.document import ConversionResult | ||
from docling.datamodel.pipeline_options import AcceleratorOptions | ||
from docling.models.base_model import BaseItemAndImageEnrichmentModel | ||
from docling.utils.accelerator_utils import decide_device | ||
|
||
|
||
class CodeFormulaModelOptions(BaseModel): | ||
""" | ||
Configuration options for the CodeFormulaModel. | ||
Attributes | ||
---------- | ||
kind : str | ||
Type of the model. Fixed value "code_formula". | ||
do_code_enrichment : bool | ||
True if code enrichment is enabled, False otherwise. | ||
do_formula_enrichment : bool | ||
True if formula enrichment is enabled, False otherwise. | ||
""" | ||
|
||
kind: Literal["code_formula"] = "code_formula" | ||
do_code_enrichment: bool = True | ||
do_formula_enrichment: bool = True | ||
|
||
|
||
class CodeFormulaModel(BaseItemAndImageEnrichmentModel): | ||
""" | ||
Model for processing and enriching documents with code and formula predictions. | ||
Attributes | ||
---------- | ||
enabled : bool | ||
True if the model is enabled, False otherwise. | ||
options : CodeFormulaModelOptions | ||
Configuration options for the CodeFormulaModel. | ||
code_formula_model : CodeFormulaPredictor | ||
The predictor model for code and formula processing. | ||
Methods | ||
------- | ||
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options) | ||
Initializes the CodeFormulaModel with the given configuration options. | ||
is_processable(self, doc, element) | ||
Determines if a given element in a document can be processed by the model. | ||
__call__(self, doc, element_batch) | ||
Processes the given batch of elements and enriches them with predictions. | ||
""" | ||
|
||
images_scale = 1.66 # = 120 dpi, aligned with training data resolution | ||
|
||
def __init__( | ||
self, | ||
enabled: bool, | ||
artifacts_path: Optional[Path], | ||
options: CodeFormulaModelOptions, | ||
accelerator_options: AcceleratorOptions, | ||
): | ||
""" | ||
Initializes the CodeFormulaModel with the given configuration. | ||
Parameters | ||
---------- | ||
enabled : bool | ||
True if the model is enabled, False otherwise. | ||
artifacts_path : Path | ||
Path to the directory containing the model artifacts. | ||
options : CodeFormulaModelOptions | ||
Configuration options for the model. | ||
accelerator_options : AcceleratorOptions | ||
Options specifying the device and number of threads for acceleration. | ||
""" | ||
self.enabled = enabled | ||
self.options = options | ||
|
||
if self.enabled: | ||
device = decide_device(accelerator_options.device) | ||
|
||
from docling_ibm_models.code_formula_model.code_formula_predictor import ( | ||
CodeFormulaPredictor, | ||
) | ||
|
||
if artifacts_path is None: | ||
artifacts_path = self.download_models_hf() | ||
|
||
self.code_formula_model = CodeFormulaPredictor( | ||
artifacts_path=artifacts_path, | ||
device=device, | ||
num_threads=accelerator_options.num_threads, | ||
) | ||
|
||
@staticmethod | ||
def download_models_hf( | ||
local_dir: Optional[Path] = None, force: bool = False | ||
) -> Path: | ||
from huggingface_hub import snapshot_download | ||
from huggingface_hub.utils import disable_progress_bars | ||
|
||
disable_progress_bars() | ||
download_path = snapshot_download( | ||
repo_id="ds4sd/CodeFormula", | ||
force_download=force, | ||
local_dir=local_dir, | ||
revision="v1.0.0", | ||
) | ||
|
||
return Path(download_path) | ||
|
||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: | ||
""" | ||
Determines if a given element in a document can be processed by the model. | ||
Parameters | ||
---------- | ||
doc : DoclingDocument | ||
The document being processed. | ||
element : NodeItem | ||
The element within the document to check. | ||
Returns | ||
------- | ||
bool | ||
True if the element can be processed, False otherwise. | ||
""" | ||
return self.enabled and ( | ||
(isinstance(element, CodeItem) and self.options.do_code_enrichment) | ||
or ( | ||
isinstance(element, TextItem) | ||
and element.label == DocItemLabel.FORMULA | ||
and self.options.do_formula_enrichment | ||
) | ||
) | ||
|
||
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]: | ||
"""Extracts a programming language from the beginning of a (possibly multi-line) string. | ||
This function checks if the input string starts with a pattern of the form | ||
``<_some_language_>``. If it does, it extracts the language string and returns | ||
a tuple of (remainder, language). Otherwise, it returns the original string | ||
and `None`. | ||
Args: | ||
input_string (str): The input string, which may start with ``<_language_>``. | ||
Returns: | ||
Tuple[str, Optional[str]]: | ||
A tuple where: | ||
- The first element is either: | ||
- The remainder of the string (everything after ``<_language_>``), | ||
if a match is found; or | ||
- The original string, if no match is found. | ||
- The second element is the extracted language if a match is found; | ||
otherwise, `None`. | ||
""" | ||
# Explanation of the regex: | ||
# ^<_([^>]+)> : match "<_something>" at the start, capturing "something" (Group 1) | ||
# \s* : optional whitespace | ||
# (.*) : capture everything after that in Group 2 | ||
# | ||
# We also use re.DOTALL so that the (.*) part can include newlines. | ||
pattern = r"^<_([^>]+)_>\s*(.*)" | ||
match = re.match(pattern, input_string, flags=re.DOTALL) | ||
if match: | ||
language = str(match.group(1)) # the captured programming language | ||
remainder = str(match.group(2)) # everything after the <_language_> | ||
return remainder, language | ||
else: | ||
return input_string, None | ||
|
||
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel: | ||
""" | ||
Converts a string to a corresponding `CodeLanguageLabel` enum member. | ||
If the provided string does not match any value in `CodeLanguageLabel`, | ||
it defaults to `CodeLanguageLabel.UNKNOWN`. | ||
Args: | ||
value (Optional[str]): The string representation of the code language or None. | ||
Returns: | ||
CodeLanguageLabel: The corresponding enum member if the value is valid, | ||
otherwise `CodeLanguageLabel.UNKNOWN`. | ||
""" | ||
if not isinstance(value, str): | ||
return CodeLanguageLabel.UNKNOWN | ||
|
||
try: | ||
return CodeLanguageLabel(value) | ||
except ValueError: | ||
return CodeLanguageLabel.UNKNOWN | ||
|
||
def prepare_element( | ||
self, conv_res: ConversionResult, element: NodeItem | ||
) -> Optional[ItemAndImageEnrichmentElement]: | ||
if not self.is_processable(doc=conv_res.document, element=element): | ||
return None | ||
|
||
assert isinstance(element, TextItem) | ||
|
||
element_prov = element.prov[0] | ||
|
||
expansion_factor = 0.03 # Adjust the expansion percentage as needed | ||
bbox = element_prov.bbox | ||
width = bbox.r - bbox.l | ||
height = bbox.t - bbox.b | ||
|
||
# Create the expanded bounding box | ||
expanded_bbox = BoundingBox( | ||
l=bbox.l - width * expansion_factor, # Expand left | ||
t=bbox.t + height * expansion_factor, # Expand top | ||
r=bbox.r + width * expansion_factor, # Expand right | ||
b=bbox.b - height * expansion_factor, # Expand bottom | ||
coord_origin=bbox.coord_origin, # Preserve coordinate origin | ||
) | ||
|
||
page_ix = element_prov.page_no - 1 | ||
cropped_image = conv_res.pages[page_ix].get_image( | ||
scale=self.images_scale, cropbox=expanded_bbox | ||
) | ||
return ItemAndImageEnrichmentElement(item=element, image=cropped_image) | ||
|
||
def __call__( | ||
self, | ||
doc: DoclingDocument, | ||
element_batch: Iterable[ItemAndImageEnrichmentElement], | ||
) -> Iterable[NodeItem]: | ||
""" | ||
Processes the given batch of elements and enriches them with predictions. | ||
Parameters | ||
---------- | ||
doc : DoclingDocument | ||
The document being processed. | ||
element_batch : Iterable[NodeItem] | ||
A batch of elements to be processed. | ||
Returns | ||
------- | ||
Iterable[Any] | ||
An iterable of enriched elements. | ||
""" | ||
if not self.enabled: | ||
for element in element_batch: | ||
yield element.item | ||
return | ||
|
||
labels: List[str] = [] | ||
images: List[Image.Image] = [] | ||
elements: List[TextItem] = [] | ||
for el in element_batch: | ||
assert isinstance(el.item, TextItem) | ||
elements.append(el.item) | ||
labels.append(el.item.label) | ||
images.append(el.image) | ||
|
||
outputs = self.code_formula_model.predict(images, labels) | ||
|
||
for item, output in zip(elements, outputs): | ||
if isinstance(item, CodeItem): | ||
output, code_language = self._extract_code_language(output) | ||
item.code_language = self._get_code_language_enum(code_language) | ||
item.text = output | ||
|
||
yield item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.