Skip to content

Commit

Permalink
Pydantic field validator and comment restored.
Browse files Browse the repository at this point in the history
Signed-off-by: ahn <[email protected]>
  • Loading branch information
ahn committed Jan 31, 2025
1 parent ba57e44 commit 390987d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
11 changes: 10 additions & 1 deletion docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class AcceleratorOptions(BaseSettings):
num_threads: int = 4
device: str = "auto"

@validator("device")
@field_validator("device")
def validate_device(cls, value):
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
Expand All @@ -55,6 +55,15 @@ def validate_device(cls, value):
@model_validator(mode="before")
@classmethod
def check_alternative_envvars(cls, data: Any) -> Any:
r"""
Set num_threads from the "alternative" envvar OMP_NUM_THREADS.
The alternative envvar is used only if it is valid and the regular envvar is not set.
Notice: The standard pydantic settings mechanism with parameter "aliases" does not provide
the same functionality. In case the alias envvar is set and the user tries to override the
parameter in settings initialization, Pydantic treats the parameter provided in __init__()
as an extra input instead of simply overwriting the evvar value for that parameter.
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
if input_num_threads is None:
Expand Down
13 changes: 13 additions & 0 deletions docling/models/easyocr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
_log = logging.getLogger(__name__)


def unwrap_dataparallel(model):
if isinstance(model, torch.nn.DataParallel):
return model.module
return model


class EasyOcrModel(BaseOcrModel):
def __init__(
self,
Expand Down Expand Up @@ -51,6 +57,7 @@ def __init__(
for x in [
AcceleratorDevice.CUDA.value,
AcceleratorDevice.MPS.value,
"cuda:",
]
]
)
Expand All @@ -71,6 +78,12 @@ def __init__(
verbose=False,
)

self.reader.device = device
self.reader.detector = unwrap_dataparallel(self.reader.detector).to(device)
self.reader.recognizer = unwrap_dataparallel(self.reader.recognizer).to(
device
)

def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
Expand Down
14 changes: 6 additions & 8 deletions docs/examples/run_with_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def main():
# accelerator_options = AcceleratorOptions(
# num_threads=8, device=AcceleratorDevice.AUTO
# )
accelerator_options = AcceleratorOptions(
num_threads=8, device=AcceleratorDevice.CPU
)
# accelerator_options = AcceleratorOptions(
# num_threads=8, device=AcceleratorDevice.CPU
# )
# accelerator_options = AcceleratorOptions(
# num_threads=8, device=AcceleratorDevice.MPS
# )
Expand All @@ -31,9 +31,7 @@ def main():
# )

# easyocr doesnt support cuda:N allocation
# accelerator_options = AcceleratorOptions(
# num_threads=8, device="cuda:1"
# )
accelerator_options = AcceleratorOptions(num_threads=8, device="cuda:0")

pipeline_options = PdfPipelineOptions()
pipeline_options.accelerator_options = accelerator_options
Expand All @@ -59,8 +57,8 @@ def main():
# List with total time per document
doc_conversion_secs = conversion_result.timings["pipeline_total"].times

md = doc.export_to_markdown()
print(md)
# md = doc.export_to_markdown()
# print(md)
print(f"Conversion secs: {doc_conversion_secs}")


Expand Down

0 comments on commit 390987d

Please sign in to comment.