diff --git a/docling_ibm_models/layoutmodel/layout_predictor.py b/docling_ibm_models/layoutmodel/layout_predictor.py index df7851e..f9e0301 100644 --- a/docling_ibm_models/layoutmodel/layout_predictor.py +++ b/docling_ibm_models/layoutmodel/layout_predictor.py @@ -5,7 +5,7 @@ import logging import os from collections.abc import Iterable -from typing import Union +from typing import Set, Union import numpy as np import torch @@ -26,6 +26,8 @@ def __init__( artifact_path: str, device: str = "cpu", num_threads: int = 4, + base_treshold: float = 0.3, + blacklist_classes: Set[str] = set(), ): """ Provide the artifact path that contains the LayoutModel file @@ -63,10 +65,10 @@ def __init__( } # Blacklisted classes - self._black_classes = set() # ["Form", "Key-Value Region"]) + self._black_classes = blacklist_classes # set(["Form", "Key-Value Region"]) # Set basic params - self._threshold = 0.3 # Score threshold + self._threshold = base_treshold # Score threshold self._image_size = 640 self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)