Skip to content

Commit 600d217

Browse files
committed
Refactor docparse.py for better readability and organization
1 parent ea0536c commit 600d217

File tree

1 file changed

+103
-121
lines changed

1 file changed

+103
-121
lines changed

databend_aiserver/udfs/docparse.py

Lines changed: 103 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from collections import OrderedDict
1918
import mimetypes
2019
import os
2120
import tempfile
2221
from pathlib import Path
23-
from typing import Any, Dict, List, Optional, Protocol
22+
from time import perf_counter, perf_counter_ns
23+
from typing import Any, Dict, List, Optional, Protocol, Tuple
2424

2525
from databend_udf import StageLocation, udf
2626
from docling.document_converter import DocumentConverter, PdfFormatOption
@@ -33,8 +33,6 @@
3333
from docling.chunking import HybridChunker
3434
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
3535
from transformers import AutoTokenizer
36-
from opendal import exceptions as opendal_exceptions
37-
from time import perf_counter, perf_counter_ns
3836

3937
from databend_aiserver.runtime import DeviceRequest, choose_device, get_runtime
4038
from databend_aiserver.stages.operator import (
@@ -69,11 +67,9 @@ class _DoclingBackend:
6967
name = "docling"
7068

7169
def __init__(self) -> None:
72-
self.choice = self._choose_device()
73-
self.accel = self._build_accelerator(self.choice)
74-
self.ocr_provider = self._select_ocr_provider()
70+
self.accel = self._build_accelerator()
7571

76-
def _choose_device(self):
72+
def _build_accelerator(self):
7773
override = os.getenv("AISERVER_DOCLING_DEVICE")
7874
req = DeviceRequest(task="docling", allow_gpu=True, allow_mps=True, explicit=override)
7975
choice = choose_device(req)
@@ -84,9 +80,7 @@ def _choose_device(self):
8480
choice.reason,
8581
override,
8682
)
87-
return choice
8883

89-
def _build_accelerator(self, choice):
9084
if AcceleratorOptions is None or AcceleratorDevice is None:
9185
return None
9286
if choice.device.startswith("cuda"):
@@ -95,42 +89,26 @@ def _build_accelerator(self, choice):
9589
return AcceleratorOptions(device=AcceleratorDevice.MPS)
9690
return AcceleratorOptions(device=AcceleratorDevice.CPU)
9791

98-
def _select_ocr_provider(self) -> Optional[str]:
99-
runtime = get_runtime()
100-
providers = runtime.capabilities.onnx_providers
101-
if runtime.capabilities.device_kind == "cuda" and "CUDAExecutionProvider" in providers:
102-
choice = "CUDAExecutionProvider"
103-
else:
104-
choice = "CPUExecutionProvider"
105-
logger.info("Docling OCR provider: %s (available=%s)", choice, providers)
106-
return choice
107-
10892
def _build_converter(self):
109-
# Docling expects accelerator via pipeline options, not constructor kwargs.
11093
format_options: Dict[InputFormat, Any] = {}
11194
if self.accel is not None:
11295
pdf_opts = ThreadedPdfPipelineOptions()
11396
pdf_opts.accelerator_options = self.accel
114-
format_options[InputFormat.PDF] = PdfFormatOption(
115-
pipeline_options=pdf_opts
116-
)
97+
format_options[InputFormat.PDF] = PdfFormatOption(pipeline_options=pdf_opts)
11798

11899
try:
119-
return DocumentConverter(
120-
format_options=format_options if format_options else None
121-
)
100+
return DocumentConverter(format_options=format_options if format_options else None)
122101
except TypeError:
123-
# Extremely old docling builds may not accept format_options; fall back.
124-
logger.warning(
125-
"Installed docling version does not support format_options; using defaults"
126-
)
102+
logger.warning("Installed docling version does not support format_options; using defaults")
127103
return DocumentConverter()
128104

129105
def convert(self, stage_location: StageLocation, path: str) -> tuple[ConversionResult, int]:
130106
t_start = perf_counter()
131107
raw = load_stage_file(stage_location, path)
132108
suffix = stage_file_suffix(path)
133109
converter = self._build_converter()
110+
111+
# Try processing from memory stream first
134112
if DocumentStream is not None:
135113
try:
136114
stream = DocumentStream(
@@ -139,25 +117,19 @@ def convert(self, stage_location: StageLocation, path: str) -> tuple[ConversionR
139117
mime_type=mimetypes.guess_type(f"file{suffix}")[0] or "application/octet-stream",
140118
)
141119
result = converter.convert(stream)
142-
logger.info(
143-
"Docling convert path=%s stream=memory bytes=%s duration=%.3fs",
144-
path,
145-
len(raw),
146-
perf_counter() - t_start,
147-
)
120+
logger.info("Docling convert path=%s stream=memory bytes=%s duration=%.3fs",
121+
path, len(raw), perf_counter() - t_start)
148122
return result, len(raw)
149123
except Exception:
150124
pass
125+
126+
# Fallback to temp file
151127
with tempfile.TemporaryDirectory() as tmpdir:
152128
tmp_path = Path(tmpdir) / f"doc{suffix}"
153129
tmp_path.write_bytes(raw)
154130
result = converter.convert(tmp_path)
155-
logger.info(
156-
"Docling convert path=%s stream=tempfile bytes=%s duration=%.3fs",
157-
path,
158-
len(raw),
159-
perf_counter() - t_start,
160-
)
131+
logger.info("Docling convert path=%s stream=tempfile bytes=%s duration=%.3fs",
132+
path, len(raw), perf_counter() - t_start)
161133
return result, len(raw)
162134

163135

@@ -181,6 +153,76 @@ def _get_hf_tokenizer(model_name: str) -> HuggingFaceTokenizer:
181153
return _TOKENIZER_CACHE[model_name]
182154

183155

156+
def _resolve_full_path(stage_location: StageLocation, path: str) -> str:
157+
resolved_path = resolve_stage_subpath(stage_location, path)
158+
storage = stage_location.storage or {}
159+
storage_root = str(storage.get("root", "") or "")
160+
bucket = storage.get("bucket") or storage.get("name")
161+
162+
if storage_root.startswith("s3://"):
163+
base = storage_root.rstrip("/")
164+
return f"{base}/{resolved_path}"
165+
elif bucket:
166+
base = f"s3://{bucket}"
167+
if storage_root:
168+
base = f"{base}/{storage_root.strip('/')}"
169+
return f"{base}/{resolved_path}"
170+
171+
return resolved_path or path
172+
173+
174+
def _chunk_document(doc: Any) -> Tuple[List[Dict[str, Any]], bool]:
175+
"""Chunk the document and return pages/chunks and a fallback flag."""
176+
markdown = doc.export_to_markdown()
177+
tokenizer = _get_hf_tokenizer(DEFAULT_EMBED_MODEL)
178+
chunker = HybridChunker(tokenizer=tokenizer)
179+
180+
try:
181+
chunks = list(chunker.chunk(dl_doc=doc))
182+
if not chunks:
183+
return [{"index": 0, "content": markdown}], True
184+
185+
return [
186+
{"index": idx, "content": chunker.contextualize(chunk)}
187+
for idx, chunk in enumerate(chunks)
188+
], False
189+
except Exception:
190+
return [{"index": 0, "content": markdown}], True
191+
192+
193+
def _format_response(
194+
path: str,
195+
full_path: str,
196+
pages: List[Dict[str, Any]],
197+
file_size: int,
198+
timings: Dict[str, float],
199+
fallback: bool
200+
) -> Dict[str, Any]:
201+
duration_ms = timings["total"]
202+
payload: Dict[str, Any] = {
203+
"metadata": {
204+
"chunk_count": len(pages),
205+
"chunk_size": DEFAULT_CHUNK_SIZE,
206+
"duration_ms": duration_ms,
207+
"file_size": file_size,
208+
"filename": Path(path).name,
209+
"path": full_path,
210+
"timings_ms": timings,
211+
"version": 1,
212+
},
213+
"chunks": pages,
214+
}
215+
216+
if fallback:
217+
payload["error_information"] = [
218+
{
219+
"type": "ChunkingFallback",
220+
"message": "chunker failed or returned empty; returned full markdown instead",
221+
}
222+
]
223+
return payload
224+
225+
184226
@udf(
185227
name="ai_parse_document",
186228
stage_refs=["stage_location"],
@@ -189,13 +231,7 @@ def _get_hf_tokenizer(model_name: str) -> HuggingFaceTokenizer:
189231
io_threads=4,
190232
)
191233
def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any]:
192-
"""Parse a document and return Snowflake-compatible layout output.
193-
194-
Simplified semantics:
195-
- Always processes the full document.
196-
- Always returns Markdown layout in ``content``.
197-
- Includes ``pages`` array with per-page content when possible.
198-
"""
234+
"""Parse a document and return Snowflake-compatible layout output."""
199235
try:
200236
t_total_ns = perf_counter_ns()
201237
runtime = get_runtime()
@@ -205,91 +241,36 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any
205241
runtime.capabilities.preferred_device,
206242
runtime.capabilities.device_kind,
207243
)
244+
208245
backend = _get_doc_parser_backend()
209-
210246
t_convert_start_ns = perf_counter_ns()
211247
result, file_size = backend.convert(stage_location, path)
212248
t_convert_end_ns = perf_counter_ns()
213249

214-
doc = result.document
215-
markdown = doc.export_to_markdown()
216-
217-
# Docling chunking: tokenizer aligned with embedding model.
218-
tokenizer = _get_hf_tokenizer(DEFAULT_EMBED_MODEL)
219-
chunker = HybridChunker(tokenizer=tokenizer)
220-
221-
fallback = False
222-
try:
223-
chunks = list(chunker.chunk(dl_doc=doc))
224-
pages: List[Dict[str, Any]] = [
225-
{"index": idx, "content": chunker.contextualize(chunk)}
226-
for idx, chunk in enumerate(chunks)
227-
]
228-
except Exception:
229-
pages = [{"index": 0, "content": markdown}]
230-
fallback = True
231-
if not pages:
232-
pages = [{"index": 0, "content": markdown}]
233-
fallback = True
234-
235-
chunk_count = len(pages)
236-
250+
pages, fallback = _chunk_document(result.document)
237251
t_chunk_end_ns = perf_counter_ns()
238-
duration_ms = (t_chunk_end_ns - t_total_ns) / 1_000_000.0
239-
240-
# Output shape:
241-
# { "chunks": [...], "metadata": {...}, "error_information": [...] }
242-
resolved_path = resolve_stage_subpath(stage_location, path)
243-
storage = stage_location.storage or {}
244-
storage_root = str(storage.get("root", "") or "")
245-
bucket = storage.get("bucket") or storage.get("name")
246-
247-
if storage_root.startswith("s3://"):
248-
base = storage_root.rstrip("/")
249-
full_path = f"{base}/{resolved_path}"
250-
elif bucket:
251-
base = f"s3://{bucket}"
252-
if storage_root:
253-
base = f"{base}/{storage_root.strip('/')}"
254-
full_path = f"{base}/{resolved_path}"
255-
else:
256-
full_path = resolved_path or path
257-
258-
# Keep metadata first for predictable JSON ordering.
259-
payload: Dict[str, Any] = {
260-
"metadata": {
261-
"chunk_count": chunk_count,
262-
"chunk_size": DEFAULT_CHUNK_SIZE,
263-
"duration_ms": duration_ms,
264-
"file_size": file_size if file_size is not None else 0,
265-
"filename": Path(path).name,
266-
"path": full_path or path,
267-
"timings_ms": {
268-
"convert": (t_convert_end_ns - t_convert_start_ns) / 1_000_000.0,
269-
"chunk": (t_chunk_end_ns - t_convert_end_ns) / 1_000_000.0,
270-
"total": duration_ms,
271-
},
272-
"version": 1,
273-
},
274-
"chunks": pages,
252+
253+
full_path = _resolve_full_path(stage_location, path)
254+
255+
timings = {
256+
"convert": (t_convert_end_ns - t_convert_start_ns) / 1_000_000.0,
257+
"chunk": (t_chunk_end_ns - t_convert_end_ns) / 1_000_000.0,
258+
"total": (t_chunk_end_ns - t_total_ns) / 1_000_000.0,
275259
}
276-
if fallback:
277-
payload["error_information"] = [
278-
{
279-
"type": "ChunkingFallback",
280-
"message": "chunker failed or returned empty; returned full markdown instead",
281-
}
282-
]
260+
261+
payload = _format_response(path, full_path, pages, file_size, timings, fallback)
262+
283263
logger.info(
284264
"ai_parse_document path=%s backend=%s chunks=%s fallback=%s duration_ms=%.1f",
285265
path,
286266
getattr(backend, "name", "unknown"),
287-
chunk_count,
267+
len(pages),
288268
fallback,
289-
duration_ms,
269+
timings["total"],
290270
)
291271
return payload
292-
except Exception as exc: # pragma: no cover - defensive for unexpected docling errors
272+
273+
except Exception as exc: # pragma: no cover
293274
return {
294275
"metadata": {
295276
"path": path,
@@ -298,3 +279,4 @@ def ai_parse_document(stage_location: StageLocation, path: str) -> Dict[str, Any
298279
"chunks": [],
299280
"error_information": [{"message": str(exc), "type": exc.__class__.__name__}],
300281
}
282+

0 commit comments

Comments
 (0)