Skip to content

Commit

Permalink
fix(LayoutModel): Set by default the device to "auto". Improve loggin…
Browse files Browse the repository at this point in the history
…g. Improve demo_layout_predictor

Signed-off-by: Nikos Livathinos <[email protected]>
  • Loading branch information
nikos-livathinos committed Nov 18, 2024
1 parent 68475aa commit 21426da
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 37 deletions.
17 changes: 12 additions & 5 deletions demo/demo_layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def save_predictions(prefix: str, viz_dir: str, img_fn: str, img, predictions: d
def demo(
logger: logging.Logger,
artifact_path: str,
device: str,
num_threads: int,
img_dir: str,
viz_dir: str,
Expand All @@ -68,8 +69,7 @@ def demo(
pdf_image = pyvips.Image.new_from_file("test_data/ADS.2007.page_123.pdf", page=0)
"""
# Create the layout predictor
lpredictor = LayoutPredictor(artifact_path, num_threads=num_threads)
logger.info("LayoutPredictor settings: {}".format(lpredictor.info()))
lpredictor = LayoutPredictor(artifact_path, device=device, num_threads=num_threads)

# Predict all test png images
t0 = time.perf_counter()
Expand All @@ -87,6 +87,9 @@ def demo(

# Save predictions
logger.info("Saving prediction visualization in: '%s'", viz_dir)

# TODO: Switch LayoutModel implementations
# save_predictions("JIT", viz_dir, img_fn, image, preds)
save_predictions("ST", viz_dir, img_fn, image, preds)
total_ms = 1000 * (time.perf_counter() - t0)
avg_ms = (total_ms / img_counter) if img_counter > 0 else 0
Expand All @@ -100,10 +103,12 @@ def demo(
def main(args):
r""" """
num_threads = int(args.num_threads) if args.num_threads is not None else None
device = args.device.lower()
img_dir = args.img_dir
viz_dir = args.viz_dir

# Initialize logger
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("LayoutPredictor")
logger.setLevel(logging.DEBUG)
if not logger.hasHandlers():
Expand All @@ -122,19 +127,21 @@ def main(args):
# download_path = snapshot_download(repo_id="ds4sd/docling-models")
# artifact_path = os.path.join(download_path, "model_artifacts/layout/beehive_v0.0.5_pt")

# os.environ["TORCH_DEVICE"] = "cpu"
# artifact_path = "/Users/nli/data/models/layout_model/online_docling_models/v2.0.1"
# artifact_path = "/Users/nli/model_weights/docling/layout_model/online_docling_models/v2.0.1"
artifact_path = "/Users/nli/model_weights/docling/layout_model/safe_tensors"

# Test the LayoutPredictor
demo(logger, artifact_path, num_threads, img_dir, viz_dir)
demo(logger, artifact_path, device, num_threads, img_dir, viz_dir)


if __name__ == "__main__":
r"""
python -m demo.demo_layout_predictor -i <images_dir>
"""
parser = argparse.ArgumentParser(description="Test the LayoutPredictor")
parser.add_argument(
"-d", "--device", required=False, default="cpu", help="One of [cpu, cuda, mps]"
)
parser.add_argument(
"-n", "--num_threads", required=False, default=None, help="Number of threads"
)
Expand Down
15 changes: 10 additions & 5 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import os
from collections.abc import Iterable
from typing import Union
Expand All @@ -12,22 +13,24 @@
from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

_log = logging.getLogger(__name__)


class LayoutPredictor:
"""
Document layout prediction using safe tensors
"""

def __init__(self, artifact_path: str, device: str = "cpu", num_threads: int = 4):
def __init__(self, artifact_path: str, device: str = "auto", num_threads: int = 4):
"""
Provide the artifact path that contains the LayoutModel file
Parameters
----------
artifact_path: Path for the model torch file.
device: (Optional) Device to run the inference.
It should be one of: ["cpu", "cuda", "mps"].
Otherwise the best available device is selected
device: (Optional) Device to run the inference. One of: ["cpu", "cuda", "mps", "auto"].
When it is "auto", the best available device is selected.
Default value is "auto"
num_threads: (Optional) Number of threads to run the inference when the device is "cpu".
Raises
Expand Down Expand Up @@ -64,7 +67,7 @@ def __init__(self, artifact_path: str, device: str = "cpu", num_threads: int = 4
self._image_size = 640
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)

# Set device based on env var or availability
# Set device based on init parameter or availability
device_name = device.lower()
if device_name in ["cuda", "mps", "cpu"]:
self._device = torch.device(device_name)
Expand Down Expand Up @@ -94,6 +97,8 @@ def __init__(self, artifact_path: str, device: str = "cpu", num_threads: int = 4
)
self._model.eval()

_log.debug("LayoutPredictor settings: {}".format(self.info()))

def info(self) -> dict:
"""
Get information about the configuration of LayoutPredictor
Expand Down
54 changes: 27 additions & 27 deletions docling_ibm_models/layoutmodel/layout_predictor_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import logging
import os
from collections.abc import Iterable
from typing import Union
Expand All @@ -11,6 +12,9 @@
import torchvision.transforms as T
from PIL import Image

_log = logging.getLogger(__name__)


MODEL_CHECKPOINT_FN = "model.pt"
DEFAULT_NUM_THREADS = 4

Expand All @@ -20,22 +24,20 @@ class LayoutPredictor:
Document layout prediction using torch
"""

def __init__(self, artifact_path: str, num_threads: int = None):
def __init__(
self, artifact_path: str, device: str = "auto", num_threads: int = None
):
"""
Provide the artifact path that contains the LayoutModel file
The number of threads is decided, in the following order, by:
1. The init method parameter `num_threads`, if it is set.
2. The envvar "OMP_NUM_THREADS", if it is set.
3. The default value DEFAULT_NUM_THREADS.
The execution device is decided by the env var "TORCH_DEVICE" with values:
'cpu', 'cuda', or 'mps'. If not set, automatically selects the best available device.
Parameters
----------
artifact_path: Path for the model torch file.
num_threads: (Optional) Number of threads to run the inference.
device: (Optional) Device to run the inference. One of: ["cpu", "cuda", "mps", "auto"].
When it is "auto", the best available device is selected.
Default value is "auto"
num_threads: (Optional) Number of threads to run the inference when the device is "cpu".
Raises
------
Expand Down Expand Up @@ -71,42 +73,40 @@ def __init__(self, artifact_path: str, num_threads: int = None):
self._image_size = 640
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)

# Set device based on env var or availability
device_name = os.environ.get("TORCH_DEVICE", "").lower()
# Set device based on init parameter or availability
device_name = device.lower()
if device_name in ["cuda", "mps", "cpu"]:
self.device = torch.device(device_name)
self._device = torch.device(device_name)
elif torch.cuda.is_available():
self.device = torch.device("cuda")
self._device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = torch.device("mps")
self._device = torch.device("mps")
else:
self.device = torch.device("cpu")
self._device = torch.device("cpu")

# Model file
self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
if not os.path.isfile(self._torch_fn):
raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))

# Set number of threads for CPU
if self.device.type == "cpu":
if num_threads is None:
num_threads = int(
os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS)
)
if self._device.type == "cpu":
self._num_threads = num_threads
torch.set_num_threads(self._num_threads)

# Load model and move to device
self.model = torch.jit.load(self._torch_fn, map_location=self.device)
self.model.eval()
self._model = torch.jit.load(self._torch_fn, map_location=self._device)
self._model.eval()

_log.debug("LayoutPredictor settings: {}".format(self.info()))

def info(self) -> dict:
"""
Get information about the configuration of LayoutPredictor
"""
info = {
"torch_file": self._torch_fn,
"device": str(self.device),
"device": str(self._device),
"image_size": self._image_size,
"threshold": self._threshold,
}
Expand Down Expand Up @@ -140,18 +140,18 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
raise TypeError("Not supported input image format")

w, h = page_img.size
orig_size = torch.tensor([w, h], device=self.device)[None]
orig_size = torch.tensor([w, h], device=self._device)[None]

transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
img = transforms(page_img)[None].to(self.device)
img = transforms(page_img)[None].to(self._device)

# Predict
labels, boxes, scores = self.model(img, orig_size)
labels, boxes, scores = self._model(img, orig_size)

# Yield output
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
Expand Down

0 comments on commit 21426da

Please sign in to comment.