Skip to content

Commit

Permalink
added if statement for backend
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Omenetti <[email protected]>
  • Loading branch information
Matteo-Omenetti committed Jan 23, 2025
1 parent d5b2c07 commit 6206687
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 44 deletions.
14 changes: 5 additions & 9 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import logging
import os
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Any, List, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from typing_extensions import deprecated
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -225,6 +219,8 @@ class PdfPipelineOptions(PipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_code_enrichment: bool = False # True: perform code OCR
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code

table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[
Expand Down
274 changes: 274 additions & 0 deletions docling/models/code_formula_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
import re
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple

from docling_core.types.doc import CodeItem, DoclingDocument, NodeItem, TextItem
from docling_core.types.doc.base import BoundingBox
from docling_core.types.doc.labels import CodeLanguageLabel, DocItemLabel
from PIL import Image
from pydantic import BaseModel

from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.utils.accelerator_utils import decide_device


class CodeFormulaModelOptions(BaseModel):
"""
Configuration options for the CodeFormulaModel.
Attributes
----------
kind : str
Type of the model. Fixed value "code_formula".
do_code_enrichment : bool
True if code enrichment is enabled, False otherwise.
do_formula_enrichment : bool
True if formula enrichment is enabled, False otherwise.
"""

kind: Literal["code_formula"] = "code_formula"
do_code_enrichment: bool = True
do_formula_enrichment: bool = True


class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
"""
Model for processing and enriching documents with code and formula predictions.
Attributes
----------
enabled : bool
True if the model is enabled, False otherwise.
options : CodeFormulaModelOptions
Configuration options for the CodeFormulaModel.
code_formula_model : CodeFormulaPredictor
The predictor model for code and formula processing.
Methods
-------
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
Initializes the CodeFormulaModel with the given configuration options.
is_processable(self, doc, element)
Determines if a given element in a document can be processed by the model.
__call__(self, doc, element_batch)
Processes the given batch of elements and enriches them with predictions.
"""

images_scale = 1.66 # = 120 dpi, aligned with training data resolution

def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: CodeFormulaModelOptions,
accelerator_options: AcceleratorOptions,
):
"""
Initializes the CodeFormulaModel with the given configuration.
Parameters
----------
enabled : bool
True if the model is enabled, False otherwise.
artifacts_path : Path
Path to the directory containing the model artifacts.
options : CodeFormulaModelOptions
Configuration options for the model.
accelerator_options : AcceleratorOptions
Options specifying the device and number of threads for acceleration.
"""
self.enabled = enabled
self.options = options

if self.enabled:
device = decide_device(accelerator_options.device)

from docling_ibm_models.code_formula_model.code_formula_predictor import (
CodeFormulaPredictor,
)

if artifacts_path is None:
artifacts_path = self.download_models_hf()

self.code_formula_model = CodeFormulaPredictor(
artifacts_path=artifacts_path,
device=device,
num_threads=accelerator_options.num_threads,
)

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars

disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.0",
)

return Path(download_path)

def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if a given element in a document can be processed by the model.
Parameters
----------
doc : DoclingDocument
The document being processed.
element : NodeItem
The element within the document to check.
Returns
-------
bool
True if the element can be processed, False otherwise.
"""
return self.enabled and (
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
or (
isinstance(element, TextItem)
and element.label == DocItemLabel.FORMULA
and self.options.do_formula_enrichment
)
)

def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
"""Extracts a programming language from the beginning of a (possibly multi-line) string.
This function checks if the input string starts with a pattern of the form
``<_some_language_>``. If it does, it extracts the language string and returns
a tuple of (remainder, language). Otherwise, it returns the original string
and `None`.
Args:
input_string (str): The input string, which may start with ``<_language_>``.
Returns:
Tuple[str, Optional[str]]:
A tuple where:
- The first element is either:
- The remainder of the string (everything after ``<_language_>``),
if a match is found; or
- The original string, if no match is found.
- The second element is the extracted language if a match is found;
otherwise, `None`.
"""
# Explanation of the regex:
# ^<_([^>]+)> : match "<_something>" at the start, capturing "something" (Group 1)
# \s* : optional whitespace
# (.*) : capture everything after that in Group 2
#
# We also use re.DOTALL so that the (.*) part can include newlines.
pattern = r"^<_([^>]+)_>\s*(.*)"
match = re.match(pattern, input_string, flags=re.DOTALL)
if match:
language = str(match.group(1)) # the captured programming language
remainder = str(match.group(2)) # everything after the <_language_>
return remainder, language
else:
return input_string, None

def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
"""
Converts a string to a corresponding `CodeLanguageLabel` enum member.
If the provided string does not match any value in `CodeLanguageLabel`,
it defaults to `CodeLanguageLabel.UNKNOWN`.
Args:
value (Optional[str]): The string representation of the code language or None.
Returns:
CodeLanguageLabel: The corresponding enum member if the value is valid,
otherwise `CodeLanguageLabel.UNKNOWN`.
"""
if not isinstance(value, str):
return CodeLanguageLabel.UNKNOWN

try:
return CodeLanguageLabel(value)
except ValueError:
return CodeLanguageLabel.UNKNOWN

def prepare_element(
self, conv_res: ConversionResult, element: NodeItem
) -> Optional[ItemAndImageEnrichmentElement]:
if not self.is_processable(doc=conv_res.document, element=element):
return None

assert isinstance(element, TextItem)

element_prov = element.prov[0]

expansion_factor = 0.03 # Adjust the expansion percentage as needed
bbox = element_prov.bbox
width = bbox.r - bbox.l
height = bbox.t - bbox.b

# Create the expanded bounding box
expanded_bbox = BoundingBox(
l=bbox.l - width * expansion_factor, # Expand left
t=bbox.t + height * expansion_factor, # Expand top
r=bbox.r + width * expansion_factor, # Expand right
b=bbox.b - height * expansion_factor, # Expand bottom
coord_origin=bbox.coord_origin, # Preserve coordinate origin
)

page_ix = element_prov.page_no - 1
cropped_image = conv_res.pages[page_ix].get_image(
scale=self.images_scale, cropbox=expanded_bbox
)
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)

def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
"""
Processes the given batch of elements and enriches them with predictions.
Parameters
----------
doc : DoclingDocument
The document being processed.
element_batch : Iterable[NodeItem]
A batch of elements to be processed.
Returns
-------
Iterable[Any]
An iterable of enriched elements.
"""
if not self.enabled:
for element in element_batch:
yield element.item
return

labels: List[str] = []
images: List[Image.Image] = []
elements: List[TextItem] = []
for el in element_batch:
assert isinstance(el.item, TextItem)
elements.append(el.item)
labels.append(el.item.label)
images.append(el.image)

outputs = self.code_formula_model.predict(images, labels)

for item, output in zip(elements, outputs):
if isinstance(item, CodeItem):
output, code_language = self._extract_code_language(output)
item.code_language = self._get_code_language_enum(code_language)
item.text = output

yield item
8 changes: 4 additions & 4 deletions docling/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import traceback
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Any, Callable, Iterable, List

from docling_core.types.doc import DoclingDocument, NodeItem

Expand All @@ -18,7 +18,7 @@
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import PipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseEnrichmentModel
from docling.models.base_model import BaseEnrichmentModel, GenericEnrichmentModel
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify

Expand All @@ -30,7 +30,7 @@ def __init__(self, pipeline_options: PipelineOptions):
self.pipeline_options = pipeline_options
self.keep_images = False
self.build_pipe: List[Callable] = []
self.enrichment_pipe: List[BaseEnrichmentModel] = []
self.enrichment_pipe: List[GenericEnrichmentModel[Any]] = []

def execute(self, in_doc: InputDocument, raises_on_error: bool) -> ConversionResult:
conv_res = ConversionResult(input=in_doc)
Expand Down Expand Up @@ -66,7 +66,7 @@ def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:

def _prepare_elements(
conv_res: ConversionResult, model: BaseEnrichmentModel
conv_res: ConversionResult, model: GenericEnrichmentModel[Any]
) -> Iterable[NodeItem]:
for doc_element, _level in conv_res.document.iterate_items():
prepared_element = model.prepare_element(
Expand Down
21 changes: 19 additions & 2 deletions docling/pipeline/standard_pdf_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
from pathlib import Path
from typing import Iterable, Optional
from typing import Optional

from docling_core.types.doc import DocItem, ImageRef, PictureItem, TableItem

Expand All @@ -17,8 +17,8 @@
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.models.base_model import BasePageModel
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
Expand Down Expand Up @@ -93,8 +93,25 @@ def __init__(self, pipeline_options: PdfPipelineOptions):

self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
# Code Formula Enrichment Model
CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment,
artifacts_path=None,
options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment,
),
accelerator_options=pipeline_options.accelerator_options,
),
]

if (
self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment
):
self.keep_backend = True

@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
Expand Down
Loading

0 comments on commit 6206687

Please sign in to comment.