Skip to content

Commit 3c9fe76

Browse files
cau-gitMaksym Lysak
and
Maksym Lysak
authored
feat: [Experimental] Introduce VLM pipeline using HF AutoModelForVision2Seq, featuring SmolDocling model (#1054)
* Skeleton for SmolDocling model and VLM Pipeline Signed-off-by: Christoph Auer <[email protected]> Signed-off-by: Maksym Lysak <[email protected]> * wip smolDocling inference and vlm pipeline Signed-off-by: Maksym Lysak <[email protected]> * WIP, first working code for inference of SmolDocling, and vlm pipeline assembly code, example included. Signed-off-by: Maksym Lysak <[email protected]> * Fixes to preserve page image and demo export to html Signed-off-by: Maksym Lysak <[email protected]> * Enabled figure support in vlm_pipeline Signed-off-by: Maksym Lysak <[email protected]> * Fix for table span compute in vlm_pipeline Signed-off-by: Maksym Lysak <[email protected]> * Properly propagating image data per page, together with predicted tags in VLM pipeline. This enables correct figure extraction and page numbers in provenances Signed-off-by: Maksym Lysak <[email protected]> * Cleaned up logs, added pages to vlm_pipeline, basic timing per page measurement in smol_docling models Signed-off-by: Maksym Lysak <[email protected]> * Replaced hardcoded otsl tokens with the ones from docling-core tokens.py enum Signed-off-by: Maksym Lysak <[email protected]> * Added tokens/sec measurement, improved example Signed-off-by: Maksym Lysak <[email protected]> * Added capability for vlm_pipeline to grab text from preconfigured backend Signed-off-by: Maksym Lysak <[email protected]> * Exposed "force_backend_text" as pipeline parameter Signed-off-by: Maksym Lysak <[email protected]> * Flipped keep_backend to True for vlm_pipeline assembly to work Signed-off-by: Maksym Lysak <[email protected]> * Updated vlm pipeline assembly and smol docling model code to support updated doctags Signed-off-by: Maksym Lysak <[email protected]> * Fixing doctags starting tag, that broke elements on first line during assembly Signed-off-by: Maksym Lysak <[email protected]> * Introduced SmolDoclingOptions to configure model parameters (such as query and artifacts path) via client code, see example in minimal_smol_docling. Provisioning for other potential vlm all-in-one models. Signed-off-by: Maksym Lysak <[email protected]> * Moved artifacts_path for SmolDocling into vlm_options instead of global pipeline option Signed-off-by: Maksym Lysak <[email protected]> * New assembly code for latest model revision, updated prompt and parsing of doctags, updated logging Signed-off-by: Maksym Lysak <[email protected]> * Updated example of Smol Docling usage Signed-off-by: Maksym Lysak <[email protected]> * Added captions for the images for SmolDocling assembly code, improved provenance definition for all elements Signed-off-by: Maksym Lysak <[email protected]> * Update minimal smoldocling example Signed-off-by: Christoph Auer <[email protected]> * Fix repo id Signed-off-by: Christoph Auer <[email protected]> * Cleaned up unnecessary logging Signed-off-by: Maksym Lysak <[email protected]> * More elegant solution in removing the input prompt Signed-off-by: Maksym Lysak <[email protected]> * removed minimal_smol_docling example from CI checks Signed-off-by: Maksym Lysak <[email protected]> * Removed special html code wrapping when exporting to docling document, cleaned up comments Signed-off-by: Maksym Lysak <[email protected]> * Addressing PR comments, added enabled property to SmolDocling, and related VLM pipeline option, few other minor things Signed-off-by: Maksym Lysak <[email protected]> * Moved keep_backend = True to vlm pipeline Signed-off-by: Maksym Lysak <[email protected]> * removed pipeline_options.generate_table_images from vlm_pipeline (deprecated in the pipelines) Signed-off-by: Maksym Lysak <[email protected]> * Added example on how to get original predicted doctags in minimal_smol_docling Signed-off-by: Maksym Lysak <[email protected]> * removing changes from base_pipeline Signed-off-by: Maksym Lysak <[email protected]> * Replaced remaining strings to appropriate enums Signed-off-by: Maksym Lysak <[email protected]> * Updated poetry.lock Signed-off-by: Maksym Lysak <[email protected]> * re-built poetry.lock Signed-off-by: Maksym Lysak <[email protected]> * Generalize and refactor VLM pipeline and models Signed-off-by: Christoph Auer <[email protected]> * Rename example Signed-off-by: Christoph Auer <[email protected]> * Move imports Signed-off-by: Christoph Auer <[email protected]> * Expose control over using flash_attention_2 Signed-off-by: Christoph Auer <[email protected]> * Fix VLM example exclusion in CI Signed-off-by: Christoph Auer <[email protected]> * Add back device_map and accelerate Signed-off-by: Christoph Auer <[email protected]> * Make drawing code resilient against bad bboxes Signed-off-by: Christoph Auer <[email protected]> * chore: clean up code and comments Signed-off-by: Christoph Auer <[email protected]> * chore: more cleanup Signed-off-by: Christoph Auer <[email protected]> * chore: fix leftover .to(device) Signed-off-by: Christoph Auer <[email protected]> * fix: add proper table provenance Signed-off-by: Christoph Auer <[email protected]> --------- Signed-off-by: Christoph Auer <[email protected]> Signed-off-by: Maksym Lysak <[email protected]> Co-authored-by: Maksym Lysak <[email protected]>
1 parent ab683e4 commit 3c9fe76

File tree

9 files changed

+1248
-316
lines changed

9 files changed

+1248
-316
lines changed

.github/workflows/checks.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
run: |
2929
for file in docs/examples/*.py; do
3030
# Skip batch_convert.py
31-
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
31+
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
3232
echo "Skipping $file"
3333
continue
3434
fi

docling/datamodel/base_models.py

+5
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ class LayoutPrediction(BaseModel):
154154
clusters: List[Cluster] = []
155155

156156

157+
class VlmPrediction(BaseModel):
158+
text: str = ""
159+
160+
157161
class ContainerElement(
158162
BasePageElement
159163
): # Used for Form and Key-Value-Regions, only for typing.
@@ -197,6 +201,7 @@ class PagePredictions(BaseModel):
197201
tablestructure: Optional[TableStructurePrediction] = None
198202
figures_classification: Optional[FigureClassificationPrediction] = None
199203
equations_prediction: Optional[EquationPrediction] = None
204+
vlm_response: Optional[VlmPrediction] = None
200205

201206

202207
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]

docling/datamodel/pipeline_options.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class AcceleratorOptions(BaseSettings):
4141

4242
num_threads: int = 4
4343
device: Union[str, AcceleratorDevice] = "auto"
44+
cuda_use_flash_attention2: bool = False
4445

4546
@field_validator("device")
4647
def validate_device(cls, value):
@@ -254,6 +255,45 @@ def repo_cache_folder(self) -> str:
254255
)
255256

256257

258+
class BaseVlmOptions(BaseModel):
259+
kind: str
260+
prompt: str
261+
262+
263+
class ResponseFormat(str, Enum):
264+
DOCTAGS = "doctags"
265+
MARKDOWN = "markdown"
266+
267+
268+
class HuggingFaceVlmOptions(BaseVlmOptions):
269+
kind: Literal["hf_model_options"] = "hf_model_options"
270+
271+
repo_id: str
272+
load_in_8bit: bool = True
273+
llm_int8_threshold: float = 6.0
274+
quantized: bool = False
275+
276+
response_format: ResponseFormat
277+
278+
@property
279+
def repo_cache_folder(self) -> str:
280+
return self.repo_id.replace("/", "--")
281+
282+
283+
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
284+
repo_id="ds4sd/SmolDocling-256M-preview",
285+
prompt="Convert this page to docling.",
286+
response_format=ResponseFormat.DOCTAGS,
287+
)
288+
289+
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
290+
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
291+
# prompt="OCR the full page to markdown.",
292+
prompt="OCR this image.",
293+
response_format=ResponseFormat.MARKDOWN,
294+
)
295+
296+
257297
# Define an enum for the backend options
258298
class PdfBackend(str, Enum):
259299
"""Enum of valid PDF backends."""
@@ -285,7 +325,24 @@ class PipelineOptions(BaseModel):
285325
enable_remote_services: bool = False
286326

287327

288-
class PdfPipelineOptions(PipelineOptions):
328+
class PaginatedPipelineOptions(PipelineOptions):
329+
images_scale: float = 1.0
330+
generate_page_images: bool = False
331+
generate_picture_images: bool = False
332+
333+
334+
class VlmPipelineOptions(PaginatedPipelineOptions):
335+
artifacts_path: Optional[Union[Path, str]] = None
336+
337+
generate_page_images: bool = True
338+
force_backend_text: bool = (
339+
False # (To be used with vlms, or other generative models)
340+
)
341+
# If True, text from backend will be used instead of generated text
342+
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options
343+
344+
345+
class PdfPipelineOptions(PaginatedPipelineOptions):
289346
"""Options for the PDF pipeline."""
290347

291348
artifacts_path: Optional[Union[Path, str]] = None
@@ -295,6 +352,10 @@ class PdfPipelineOptions(PipelineOptions):
295352
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
296353
do_picture_classification: bool = False # True: classify pictures in documents
297354
do_picture_description: bool = False # True: run describe pictures in documents
355+
force_backend_text: bool = (
356+
False # (To be used with vlms, or other generative models)
357+
)
358+
# If True, text from backend will be used instead of generated text
298359

299360
table_structure_options: TableStructureOptions = TableStructureOptions()
300361
ocr_options: Union[

docling/models/hf_vlm_model.py

+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import logging
2+
import time
3+
from pathlib import Path
4+
from typing import Iterable, List, Optional
5+
6+
from docling.datamodel.base_models import Page, VlmPrediction
7+
from docling.datamodel.document import ConversionResult
8+
from docling.datamodel.pipeline_options import (
9+
AcceleratorDevice,
10+
AcceleratorOptions,
11+
HuggingFaceVlmOptions,
12+
)
13+
from docling.datamodel.settings import settings
14+
from docling.models.base_model import BasePageModel
15+
from docling.utils.accelerator_utils import decide_device
16+
from docling.utils.profiling import TimeRecorder
17+
18+
_log = logging.getLogger(__name__)
19+
20+
21+
class HuggingFaceVlmModel(BasePageModel):
22+
23+
def __init__(
24+
self,
25+
enabled: bool,
26+
artifacts_path: Optional[Path],
27+
accelerator_options: AcceleratorOptions,
28+
vlm_options: HuggingFaceVlmOptions,
29+
):
30+
self.enabled = enabled
31+
32+
self.vlm_options = vlm_options
33+
34+
if self.enabled:
35+
import torch
36+
from transformers import ( # type: ignore
37+
AutoModelForVision2Seq,
38+
AutoProcessor,
39+
BitsAndBytesConfig,
40+
)
41+
42+
device = decide_device(accelerator_options.device)
43+
self.device = device
44+
45+
_log.debug("Available device for HuggingFace VLM: {}".format(device))
46+
47+
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
48+
49+
# PARAMETERS:
50+
if artifacts_path is None:
51+
artifacts_path = self.download_models(self.vlm_options.repo_id)
52+
elif (artifacts_path / repo_cache_folder).exists():
53+
artifacts_path = artifacts_path / repo_cache_folder
54+
55+
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
56+
self.param_quantization_config = BitsAndBytesConfig(
57+
load_in_8bit=vlm_options.load_in_8bit, # True,
58+
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
59+
)
60+
self.param_quantized = vlm_options.quantized # False
61+
62+
self.processor = AutoProcessor.from_pretrained(artifacts_path)
63+
if not self.param_quantized:
64+
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
65+
artifacts_path,
66+
device_map=device,
67+
torch_dtype=torch.bfloat16,
68+
_attn_implementation=(
69+
"flash_attention_2"
70+
if self.device.startswith("cuda")
71+
and accelerator_options.cuda_use_flash_attention2
72+
else "eager"
73+
),
74+
) # .to(self.device)
75+
76+
else:
77+
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
78+
artifacts_path,
79+
device_map=device,
80+
torch_dtype="auto",
81+
quantization_config=self.param_quantization_config,
82+
_attn_implementation=(
83+
"flash_attention_2"
84+
if self.device.startswith("cuda")
85+
and accelerator_options.cuda_use_flash_attention2
86+
else "eager"
87+
),
88+
) # .to(self.device)
89+
90+
@staticmethod
91+
def download_models(
92+
repo_id: str,
93+
local_dir: Optional[Path] = None,
94+
force: bool = False,
95+
progress: bool = False,
96+
) -> Path:
97+
from huggingface_hub import snapshot_download
98+
from huggingface_hub.utils import disable_progress_bars
99+
100+
if not progress:
101+
disable_progress_bars()
102+
download_path = snapshot_download(
103+
repo_id=repo_id,
104+
force_download=force,
105+
local_dir=local_dir,
106+
# revision="v0.0.1",
107+
)
108+
109+
return Path(download_path)
110+
111+
def __call__(
112+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
113+
) -> Iterable[Page]:
114+
for page in page_batch:
115+
assert page._backend is not None
116+
if not page._backend.is_valid():
117+
yield page
118+
else:
119+
with TimeRecorder(conv_res, "vlm"):
120+
assert page.size is not None
121+
122+
hi_res_image = page.get_image(scale=2.0) # 144dpi
123+
# hi_res_image = page.get_image(scale=1.0) # 72dpi
124+
125+
if hi_res_image is not None:
126+
im_width, im_height = hi_res_image.size
127+
128+
# populate page_tags with predicted doc tags
129+
page_tags = ""
130+
131+
if hi_res_image:
132+
if hi_res_image.mode != "RGB":
133+
hi_res_image = hi_res_image.convert("RGB")
134+
135+
messages = [
136+
{
137+
"role": "user",
138+
"content": [
139+
{
140+
"type": "text",
141+
"text": "This is a page from a document.",
142+
},
143+
{"type": "image"},
144+
{"type": "text", "text": self.param_question},
145+
],
146+
}
147+
]
148+
prompt = self.processor.apply_chat_template(
149+
messages, add_generation_prompt=False
150+
)
151+
inputs = self.processor(
152+
text=prompt, images=[hi_res_image], return_tensors="pt"
153+
)
154+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
155+
156+
start_time = time.time()
157+
# Call model to generate:
158+
generated_ids = self.vlm_model.generate(
159+
**inputs, max_new_tokens=4096, use_cache=True
160+
)
161+
162+
generation_time = time.time() - start_time
163+
generated_texts = self.processor.batch_decode(
164+
generated_ids[:, inputs["input_ids"].shape[1] :],
165+
skip_special_tokens=False,
166+
)[0]
167+
168+
num_tokens = len(generated_ids[0])
169+
page_tags = generated_texts
170+
171+
# inference_time = time.time() - start_time
172+
# tokens_per_second = num_tokens / generation_time
173+
# print("")
174+
# print(f"Page Inference Time: {inference_time:.2f} seconds")
175+
# print(f"Total tokens on page: {num_tokens:.2f}")
176+
# print(f"Tokens/sec: {tokens_per_second:.2f}")
177+
# print("")
178+
page.predictions.vlm_response = VlmPrediction(text=page_tags)
179+
180+
yield page

0 commit comments

Comments
 (0)