Skip to content

Commit e1cba8a

Browse files
committed
draft for picture description models
Signed-off-by: Michele Dolfi <[email protected]>
1 parent 6c22cba commit e1cba8a

File tree

5 files changed

+279
-2
lines changed

5 files changed

+279
-2
lines changed

docling/datamodel/pipeline_options.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from enum import Enum
22
from pathlib import Path
3-
from typing import List, Literal, Optional, Union
3+
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
44

5-
from pydantic import BaseModel, ConfigDict, Field
5+
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
66

77

88
class TableFormerMode(str, Enum):
@@ -61,6 +61,46 @@ class TesseractOcrOptions(OcrOptions):
6161
)
6262

6363

64+
class PicDescBaseOptions(BaseModel):
65+
kind: str
66+
batch_size: int = 8
67+
scale: float = 2
68+
69+
bitmap_area_threshold: float = (
70+
0.2 # percentage of the area for a bitmap to processed with the models
71+
)
72+
73+
74+
class PicDescApiOptions(PicDescBaseOptions):
75+
kind: Literal["api"] = "api"
76+
77+
url: AnyUrl = AnyUrl("")
78+
headers: Dict[str, str] = {}
79+
params: Dict[str, Any] = {}
80+
timeout: float = 20
81+
82+
llm_prompt: str = ""
83+
provenance: str = ""
84+
85+
86+
class PicDescVllmOptions(PicDescBaseOptions):
87+
kind: Literal["vllm"] = "vllm"
88+
89+
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html
90+
91+
# Parameters for LLaVA-1.6/LLaVA-NeXT
92+
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
93+
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
94+
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
95+
96+
# Parameters for Phi-3-Vision
97+
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
98+
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
99+
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
100+
101+
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
102+
103+
64104
class PipelineOptions(BaseModel):
65105
create_legacy_output: bool = (
66106
True # This defautl will be set to False on a future version of docling
@@ -71,11 +111,15 @@ class PdfPipelineOptions(PipelineOptions):
71111
artifacts_path: Optional[Union[Path, str]] = None
72112
do_table_structure: bool = True # True: perform table structure extraction
73113
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
114+
do_picture_description: bool = False
74115

75116
table_structure_options: TableStructureOptions = TableStructureOptions()
76117
ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = (
77118
Field(EasyOcrOptions(), discriminator="kind")
78119
)
120+
picture_description_options: Annotated[
121+
Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind")
122+
] = PicDescApiOptions() # TODO: needs defaults or optional
79123

80124
images_scale: float = 1.0
81125
generate_page_images: bool = False
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import base64
2+
import io
3+
import logging
4+
from typing import List, Optional
5+
6+
import httpx
7+
from docling_core.types.doc import PictureItem
8+
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
9+
PictureDescriptionData,
10+
)
11+
from pydantic import BaseModel, ConfigDict
12+
13+
from docling.datamodel.pipeline_options import PicDescApiOptions
14+
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
15+
16+
_log = logging.getLogger(__name__)
17+
18+
19+
class ChatMessage(BaseModel):
20+
role: str
21+
content: str
22+
23+
24+
class ResponseChoice(BaseModel):
25+
index: int
26+
message: ChatMessage
27+
finish_reason: str
28+
29+
30+
class ResponseUsage(BaseModel):
31+
prompt_tokens: int
32+
completion_tokens: int
33+
total_tokens: int
34+
35+
36+
class ApiResponse(BaseModel):
37+
model_config = ConfigDict(
38+
protected_namespaces=(),
39+
)
40+
41+
id: str
42+
model: Optional[str] = None # returned bu openai
43+
choices: List[ResponseChoice]
44+
created: int
45+
usage: ResponseUsage
46+
47+
48+
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
49+
50+
def __init__(self, enabled: bool, options: PicDescApiOptions):
51+
super().__init__(enabled=enabled, options=options)
52+
self.options: PicDescApiOptions
53+
54+
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
55+
assert picture.image is not None
56+
57+
img_io = io.BytesIO()
58+
picture.image.pil_image.save(img_io, "PNG")
59+
60+
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
61+
62+
messages = [
63+
{
64+
"role": "user",
65+
"content": [
66+
{
67+
"type": "text",
68+
"text": self.options.llm_prompt,
69+
},
70+
{
71+
"type": "image_url",
72+
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
73+
},
74+
],
75+
}
76+
]
77+
78+
payload = {
79+
"messages": messages,
80+
**self.options.params,
81+
}
82+
83+
r = httpx.post(
84+
str(self.options.url),
85+
headers=self.options.headers,
86+
json=payload,
87+
timeout=self.options.timeout,
88+
)
89+
if not r.is_success:
90+
_log.error(f"Error calling the API. Reponse was {r.text}")
91+
r.raise_for_status()
92+
93+
api_resp = ApiResponse.model_validate_json(r.text)
94+
generated_text = api_resp.choices[0].message.content.strip()
95+
96+
return PictureDescriptionData(
97+
provenance=self.options.provenance,
98+
text=generated_text,
99+
)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any, Iterable
4+
5+
from docling_core.types.doc import (
6+
DoclingDocument,
7+
NodeItem,
8+
PictureClassificationClass,
9+
PictureItem,
10+
)
11+
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
12+
PictureDescriptionData,
13+
)
14+
15+
from docling.datamodel.pipeline_options import PicDescBaseOptions
16+
from docling.models.base_model import BaseEnrichmentModel
17+
18+
19+
class PictureDescriptionBaseModel(BaseEnrichmentModel):
20+
21+
def __init__(self, enabled: bool, options: PicDescBaseOptions):
22+
self.enabled = enabled
23+
self.options = options
24+
self.provenance = "TODO"
25+
26+
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
27+
# TODO: once the image classifier is active, we can differentiate among image types
28+
return self.enabled and isinstance(element, PictureItem)
29+
30+
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
31+
raise NotImplemented
32+
33+
def __call__(
34+
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
35+
) -> Iterable[Any]:
36+
if not self.enabled:
37+
return
38+
39+
for element in element_batch:
40+
assert isinstance(element, PictureItem)
41+
assert element.image is not None
42+
43+
annotation = self._annotate_image(element)
44+
element.annotations.append(annotation)
45+
46+
yield element
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
from typing import List
3+
4+
from docling_core.types.doc import PictureItem
5+
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
6+
PictureDescriptionData,
7+
)
8+
9+
from docling.datamodel.pipeline_options import PicDescVllmOptions
10+
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
11+
12+
13+
class PictureDescriptionVllmModel(PictureDescriptionBaseModel):
14+
15+
def __init__(self, enabled: bool, options: PicDescVllmOptions):
16+
super().__init__(enabled=enabled, options=options)
17+
self.options: PicDescVllmOptions
18+
19+
if self.enabled:
20+
raise NotImplemented
21+
22+
if self.enabled:
23+
try:
24+
from vllm import LLM, SamplingParams # type: ignore
25+
except ImportError:
26+
raise ImportError(
27+
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
28+
)
29+
30+
self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore
31+
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore
32+
33+
# Generate a stable hash from the extra parameters
34+
def create_hash(t):
35+
return ""
36+
37+
params_hash = create_hash(
38+
json.dumps(self.options.llm_extra, sort_keys=True)
39+
+ json.dumps(self.options.sampling_params, sort_keys=True)
40+
)
41+
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
42+
43+
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
44+
assert picture.image is not None
45+
46+
from vllm import RequestOutput
47+
48+
inputs = [
49+
{
50+
"prompt": self.options.llm_prompt,
51+
"multi_modal_data": {"image": picture.image.pil_image},
52+
}
53+
]
54+
outputs: List[RequestOutput] = self.llm.generate( # type: ignore
55+
inputs, sampling_params=self.sampling_params # type: ignore
56+
)
57+
58+
generated_text = outputs[0].outputs[0].text
59+
return PictureDescriptionData(provenance=self.provenance, text=generated_text)

docling/pipeline/standard_pdf_pipeline.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from docling.datamodel.pipeline_options import (
1212
EasyOcrOptions,
1313
PdfPipelineOptions,
14+
PicDescApiOptions,
15+
PicDescVllmOptions,
1416
TesseractCliOcrOptions,
1517
TesseractOcrOptions,
1618
)
@@ -23,6 +25,9 @@
2325
PagePreprocessingModel,
2426
PagePreprocessingOptions,
2527
)
28+
from docling.models.pic_description_api_model import PictureDescriptionApiModel
29+
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
30+
from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel
2631
from docling.models.table_structure_model import TableStructureModel
2732
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
2833
from docling.models.tesseract_ocr_model import TesseractOcrModel
@@ -83,8 +88,15 @@ def __init__(self, pipeline_options: PdfPipelineOptions):
8388
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
8489
]
8590

91+
# Picture description model
92+
if (pic_desc_model := self.get_pic_description_model()) is None:
93+
raise RuntimeError(
94+
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
95+
)
96+
8697
self.enrichment_pipe = [
8798
# Other models working on `NodeItem` elements in the DoclingDocument
99+
pic_desc_model,
88100
]
89101

90102
@staticmethod
@@ -120,6 +132,23 @@ def get_ocr_model(self) -> Optional[BaseOcrModel]:
120132
)
121133
return None
122134

135+
def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]:
136+
if isinstance(
137+
self.pipeline_options.picture_description_options, PicDescApiOptions
138+
):
139+
return PictureDescriptionApiModel(
140+
enabled=self.pipeline_options.do_picture_description,
141+
options=self.pipeline_options.picture_description_options,
142+
)
143+
elif isinstance(
144+
self.pipeline_options.picture_description_options, PicDescVllmOptions
145+
):
146+
return PictureDescriptionVllmModel(
147+
enabled=self.pipeline_options.do_picture_description,
148+
options=self.pipeline_options.picture_description_options,
149+
)
150+
return None
151+
123152
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
124153
with TimeRecorder(conv_res, "page_init"):
125154
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore

0 commit comments

Comments
 (0)