Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: New API for models initialization with accelerators parameters. Use HF implementation for LayoutPredictor. Migrate models to safetensors format. #50

Merged
merged 40 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9ecbcb5
Update model codes to choose proper torch device
cau-git Nov 1, 2024
2cf5840
Fix mapping model to device
cau-git Nov 1, 2024
4f8d1c4
feat: Introduce a LayoutPredictor implementation based on safe tensors
nikos-livathinos Nov 14, 2024
2f4694a
feat: LayoutPredictor implementation based on safe tensors
nikos-livathinos Nov 14, 2024
6a9f85f
chore: Add transformers dependency
nikos-livathinos Nov 14, 2024
a75e12d
chore(LayoutPredictor): Rename the JIT version of LayoutPredictor as …
nikos-livathinos Nov 15, 2024
1081e2f
fix(LayoutPredictor): Fix info()
nikos-livathinos Nov 15, 2024
39e7d9d
feat(DemoLayoutPredictor): Refactor demo for LayoutPredictor to save …
nikos-livathinos Nov 15, 2024
7a76ba9
chore(tests): Name the test for the JIT LayoutPredictor as test_layou…
nikos-livathinos Nov 15, 2024
83976af
feat(test): WIP: Add tests for the safe tensors LayoutPredictor
nikos-livathinos Nov 15, 2024
087019c
fix(DemoLayoutPredictor): save_predictions()
nikos-livathinos Nov 15, 2024
6cd3708
chore(DemoLayoutPredictor): Switch to ST
nikos-livathinos Nov 15, 2024
551ab8d
fix(LayoutModel): Fix the device usage in the HF implementation of La…
nikos-livathinos Nov 15, 2024
a17cfd2
chore: Add accelerate in the dependencies
nikos-livathinos Nov 15, 2024
6cb0437
fix(LayoutPredictor-JIT): Fix the cropping of bbox
nikos-livathinos Nov 15, 2024
0e05fa5
chore: Code formatting
nikos-livathinos Nov 15, 2024
9b96d1f
fix(LayoutPredictor): Fix cropping of bbox
nikos-livathinos Nov 15, 2024
4f8c3d3
Code formatting
nikos-livathinos Nov 15, 2024
517b930
fix(demo_layout_predictor): Fix the saving of the predictions and tim…
nikos-livathinos Nov 15, 2024
9ef17fa
chore: Code formatting
nikos-livathinos Nov 15, 2024
68475aa
fix(LayoutPredictor): Introduce init parameters to set the device and…
nikos-livathinos Nov 18, 2024
21426da
fix(LayoutModel): Set by default the device to "auto". Improve loggin…
nikos-livathinos Nov 18, 2024
261e8e1
Merge branch 'main' into nli/performance
nikos-livathinos Dec 2, 2024
0575c18
feat: Introduce new constructor signature for both LayoutPredictor, T…
nikos-livathinos Dec 2, 2024
c4af734
chore: Code styling. Add black, isort in the pre-commit
nikos-livathinos Dec 3, 2024
7c153d5
fix: Refactor the predictor's signature to have the device as string.…
nikos-livathinos Dec 4, 2024
c73f914
Rebase to release_v3
cau-git Dec 4, 2024
94a5c98
feat: Change layout model base threshold and remove blacklist labels …
cau-git Dec 4, 2024
937ad3c
chore: Remove the jit implementation of layout predictor
nikos-livathinos Dec 4, 2024
a052bed
chore: Lower the min transformers version to 4.42.0
nikos-livathinos Dec 6, 2024
ff2d02a
fix: Remove the accelerate package dependency
nikos-livathinos Dec 8, 2024
4ea78fd
fix(demo): Add default number of CPU threads in the layout demo.
nikos-livathinos Dec 8, 2024
d0c856b
chore: Introduce a direct dependency for safetensors
nikos-livathinos Dec 9, 2024
fd37342
feat: Refactor TFPredictor to load the safetensor file. Update unit t…
nikos-livathinos Dec 9, 2024
213f124
chore: Change the safetensors version to 0.4.3
nikos-livathinos Dec 10, 2024
68f0cdf
chore: Tests: Update HF tag to PR/2
nikos-livathinos Dec 10, 2024
a83882c
Increase layout conf treshold
cau-git Dec 10, 2024
ea32d43
chore: Update the revision for HF models to v2.1.0
nikos-livathinos Dec 10, 2024
8d8c2f8
fix: Code styling
nikos-livathinos Dec 10, 2024
ae497d5
Rebase from main
cau-git Dec 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 6 additions & 18 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,9 @@ repos:
entry: poetry lock --check
pass_filenames: false
language: system

# Ready to be enabled soon
# - repo: local
# hooks:
# - id: system
# name: flake8
# entry: poetry run flake8 docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
# - repo: local
# hooks:
# - id: system
# name: MyPy
# entry: poetry run mypy docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
# - id: system
# name: MyPy
# entry: poetry run mypy docling_ibm_models
# pass_filenames: false
# language: system
# files: '\.py$'
107 changes: 67 additions & 40 deletions demo/demo_layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,52 @@
from pathlib import Path

import numpy as np
from PIL import Image, ImageDraw
import torch
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont

from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor


def save_predictions(prefix: str, viz_dir: str, img_fn: str, img, predictions: dict):
img_path = Path(img_fn)

image = img.copy()
draw = ImageDraw.Draw(image)

predictions_filename = f"{prefix}_{img_path.stem}.txt"
predictions_fn = os.path.join(viz_dir, predictions_filename)
with open(predictions_fn, "w") as fd:
for pred in predictions:
bbox = [
round(pred["l"], 2),
round(pred["t"], 2),
round(pred["r"], 2),
round(pred["b"], 2),
]
label = pred["label"]
confidence = round(pred["confidence"], 3)

# Save the predictions in txt file
pred_txt = f"{prefix} {img_fn}: {label} - {bbox} - {confidence}\n"
fd.write(pred_txt)

# Draw the bbox and label
draw.rectangle(bbox, outline="orange")
txt = f"{label}: {confidence}"
draw.text(
(bbox[0], bbox[1]), text=txt, font=ImageFont.load_default(), fill="blue"
)

draw_filename = f"{prefix}_{img_path.name}"
draw_fn = os.path.join(viz_dir, draw_filename)
image.save(draw_fn)


def demo(
logger: logging.Logger,
artifact_path: str,
device: str,
num_threads: int,
img_dir: str,
viz_dir: str,
Expand All @@ -30,58 +67,43 @@ def demo(
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
"""
# Create the layout predictor
lpredictor = LayoutPredictor(artifact_path, num_threads=num_threads)
logger.info("LayoutPredictor settings: {}".format(lpredictor.info()))
lpredictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads)

# Predict all test png images
t0 = time.perf_counter()
img_counter = 0
for img_fn in Path(img_dir).rglob("*.png"):
img_counter += 1
logger.info("Predicting '%s'...", img_fn)
start_t = time.time()

with Image.open(img_fn) as image:
# Predict layout
img_t0 = time.perf_counter()
preds = list(lpredictor.predict(image))
dt_ms = 1000 * (time.time() - start_t)
logger.debug("Time elapsed for prediction(ms): %s", dt_ms)

# Draw predictions
out_img = image.copy()
draw = ImageDraw.Draw(out_img)

for i, pred in enumerate(preds):
score = pred["confidence"]
label = pred["label"]
box = [
round(pred["l"]),
round(pred["t"]),
round(pred["r"]),
round(pred["b"]),
]

# Draw bbox and label
draw.rectangle(
box,
outline="red",
)
draw.text(
(box[0], box[1]),
text=str(label),
fill="blue",
)
logger.info("%s: [label|score|bbox] = ['%s' | %s | %s]", i, label, score, box)

save_fn = os.path.join(viz_dir, os.path.basename(img_fn))
out_img.save(save_fn)
logger.info("Saving prediction visualization in: '%s'", save_fn)
img_ms = 1000 * (time.perf_counter() - img_t0)
logger.debug("Prediction(ms): {:.2f}".format(img_ms))

# Save predictions
logger.info("Saving prediction visualization in: '%s'", viz_dir)
save_predictions("ST", viz_dir, img_fn, image, preds)
total_ms = 1000 * (time.perf_counter() - t0)
avg_ms = (total_ms / img_counter) if img_counter > 0 else 0
logger.info(
"For {} images(ms): [total|avg] = [{:.1f}|{:.1f}]".format(
img_counter, total_ms, avg_ms
)
)


def main(args):
r""" """
num_threads = int(args.num_threads) if args.num_threads is not None else None
device = args.device.lower()
img_dir = args.img_dir
viz_dir = args.viz_dir

# Initialize logger
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("LayoutPredictor")
logger.setLevel(logging.DEBUG)
if not logger.hasHandlers():
Expand All @@ -96,11 +118,13 @@ def main(args):
Path(viz_dir).mkdir(parents=True, exist_ok=True)

# Download models from HF
download_path = snapshot_download(repo_id="ds4sd/docling-models")
artifact_path = os.path.join(download_path, "model_artifacts/layout/beehive_v0.0.5_pt")
download_path = snapshot_download(
repo_id="ds4sd/docling-models", revision="v2.1.0"
)
artifact_path = os.path.join(download_path, "model_artifacts/layout")

# Test the LayoutPredictor
demo(logger, artifact_path, num_threads, img_dir, viz_dir)
demo(logger, artifact_path, device, num_threads, img_dir, viz_dir)


if __name__ == "__main__":
Expand All @@ -109,7 +133,10 @@ def main(args):
"""
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
parser.add_argument(
"-n", "--num_threads", required=False, default=None, help="Number of threads"
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
)
parser.add_argument(
"-n", "--num_threads", required=False, default=4, help="Number of threads"
)
parser.add_argument(
"-i",
Expand Down
140 changes: 74 additions & 66 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import os
from collections.abc import Iterable
from typing import Union
Expand All @@ -10,38 +11,30 @@
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

MODEL_CHECKPOINT_FN = "model.pt"
DEFAULT_NUM_THREADS = 4
_log = logging.getLogger(__name__)


class LayoutPredictor:
r"""
Document layout prediction using torch
"""
Document layout prediction using safe tensors
"""

def __init__(
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
self,
artifact_path: str,
device: str = "cpu",
num_threads: int = 4,
):
r"""
"""
Provide the artifact path that contains the LayoutModel file

The number of threads is decided, in the following order, by:
1. The init method parameter `num_threads`, if it is set.
2. The envvar "OMP_NUM_THREADS", if it is set.
3. The default value DEFAULT_NUM_THREADS.

The execution provided is decided, in the following order:
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
it uses the "CPUExecutionProvider".
3. Otherwise if the "CUDAExecutionProvider" is present, use:
["CUDAExecutionProvider", "CPUExecutionProvider"]:

Parameters
----------
artifact_path: Path for the model torch file.
num_threads: (Optional) Number of threads to run the inference.
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
device: (Optional) device to run the inference.
num_threads: (Optional) Number of threads to run the inference if device = 'cpu'

Raises
------
Expand Down Expand Up @@ -70,40 +63,51 @@ def __init__(
}

# Blacklisted classes
self._black_classes = set(["Form", "Key-Value Region"])
self._black_classes = set() # ["Form", "Key-Value Region"])

# Set basic params
self._threshold = 0.6 # Score threshold
self._threshold = 0.3 # Score threshold
self._image_size = 640
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)

# Model file
self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
if not os.path.isfile(self._torch_fn):
raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))

# Get env vars
if num_threads is None:
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
# Set number of threads for CPU
self._device = torch.device(device)
self._num_threads = num_threads
if device == "cpu":
torch.set_num_threads(self._num_threads)

# Model file and configurations
self._st_fn = os.path.join(artifact_path, "model.safetensors")
if not os.path.isfile(self._st_fn):
raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn))

self.model = torch.jit.load(self._torch_fn)
# Load model and move to device
processor_config = os.path.join(artifact_path, "preprocessor_config.json")
model_config = os.path.join(artifact_path, "config.json")
self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
self._model = RTDetrForObjectDetection.from_pretrained(
artifact_path, config=model_config
).to(self._device)
self._model.eval()

_log.debug("LayoutPredictor settings: {}".format(self.info()))

def info(self) -> dict:
r"""
"""
Get information about the configuration of LayoutPredictor
"""
info = {
"torch_file": self._torch_fn,
"use_cpu_only": self._use_cpu_only,
"safe_tensors_file": self._st_fn,
"device": self._device.type,
"num_threads": self._num_threads,
"image_size": self._image_size,
"threshold": self._threshold,
}
return info

@torch.inference_mode()
def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
r"""
"""
Predict bounding boxes for a given image.
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
[left, top, right, bottom]
Expand All @@ -128,40 +132,44 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
else:
raise TypeError("Not supported input image format")

resize = {"height": self._image_size, "width": self._image_size}
inputs = self._image_processor(
images=page_img,
return_tensors="pt",
size=resize,
).to(self._device)
outputs = self._model(**inputs)
results = self._image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([page_img.size[::-1]]),
threshold=self._threshold,
)

w, h = page_img.size
orig_size = torch.tensor([w, h])[None]

transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
img = transforms(page_img)[None]
# Predict
with torch.no_grad():
labels, boxes, scores = self.model(img, orig_size)
result = results[0]
for score, label_id, box in zip(
result["scores"], result["labels"], result["boxes"]
):
score = float(score.item())

label_id = int(label_id.item()) + 1 # Advance the label_id
label_str = self._classes_map[label_id]

# Yield output
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
# Filter out blacklisted classes
label_idx = int(label_idx.item())
score = float(score.item())
label = self._classes_map[label_idx + 1]
if label in self._black_classes:
if label_str in self._black_classes:
continue

# Check against threshold
if score > self._threshold:
l = min(w, max(0, box[0]))
t = min(h, max(0, box[1]))
r = min(w, max(0, box[2]))
b = min(h, max(0, box[3]))
yield {
"l": l,
"t": t,
"r": r,
"b": b,
"label": label,
"confidence": score,
}
bbox_float = [float(b.item()) for b in box]
l = min(w, max(0, bbox_float[0]))
t = min(h, max(0, bbox_float[1]))
r = min(w, max(0, bbox_float[2]))
b = min(h, max(0, bbox_float[3]))
yield {
"l": l,
"t": t,
"r": r,
"b": b,
"label": label_str,
"confidence": score,
}
Loading
Loading