Skip to content

Commit c63a15d

Browse files
authored
pluggable ocr dataset options (#148)
1 parent fff6808 commit c63a15d

File tree

3 files changed

+163
-32
lines changed

3 files changed

+163
-32
lines changed

indico/queries/datasets.py

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import json
4+
import jsons
45
import tempfile
56
from pathlib import Path
67
from typing import List
@@ -13,9 +14,9 @@
1314
HTTPRequest,
1415
RequestChain,
1516
)
16-
from indico.errors import IndicoNotFound
17+
from indico.errors import IndicoNotFound, IndicoInputError
1718
from indico.queries.storage import UploadBatched, UploadImages
18-
from indico.types.dataset import Dataset
19+
from indico.types.dataset import Dataset, OcrEngine, OmnipageOcrOptionsInput, ReadApiOcrOptionsInput, OcrInputLanguage
1920

2021

2122
class ListDatasets(GraphQLRequest):
@@ -184,14 +185,17 @@ class CreateDataset(RequestChain):
184185
previous = None
185186

186187
def __init__(
187-
self,
188-
name: str,
189-
files: List[str],
190-
wait: bool = True,
191-
dataset_type: str = "TEXT",
192-
from_local_images: bool = False,
193-
image_filename_col: str = "filename",
194-
batch_size: int = 20,
188+
self,
189+
name: str,
190+
files: List[str],
191+
wait: bool = True,
192+
dataset_type: str = "TEXT",
193+
from_local_images: bool = False,
194+
image_filename_col: str = "filename",
195+
batch_size: int = 20,
196+
ocr_engine: OcrEngine = None,
197+
omnipage_ocr_options: OmnipageOcrOptionsInput = None,
198+
read_api_ocr_options: ReadApiOcrOptionsInput = None
195199
):
196200
self.files = files
197201
self.name = name
@@ -200,6 +204,8 @@ def __init__(
200204
self.from_local_images = from_local_images
201205
self.image_filename_col = image_filename_col
202206
self.batch_size = batch_size
207+
if omnipage_ocr_options is not None and read_api_ocr_options is not None:
208+
raise IndicoInputError("Must supply either omnipage or readapi options but not both.")
203209
super().__init__()
204210

205211
def requests(self):
@@ -235,7 +241,7 @@ def requests(self):
235241
yield GetDatasetFileStatus(id=dataset_id)
236242
debouncer = Debouncer()
237243
while not all(
238-
f.status in ["DOWNLOADED", "FAILED"] for f in self.previous.files
244+
f.status in ["DOWNLOADED", "FAILED"] for f in self.previous.files
239245
):
240246
yield GetDatasetFileStatus(id=self.previous.id)
241247
debouncer.backoff()
@@ -250,7 +256,7 @@ def requests(self):
250256
debouncer = Debouncer()
251257
if self.wait is True:
252258
while not all(
253-
[f.status in ["PROCESSED", "FAILED"] for f in self.previous.files]
259+
[f.status in ["PROCESSED", "FAILED"] for f in self.previous.files]
254260
):
255261
yield GetDatasetFileStatus(id=dataset_id)
256262
debouncer.backoff()
@@ -295,20 +301,32 @@ def process_response(self, response):
295301

296302
class CreateEmptyDataset(GraphQLRequest):
297303
query = """
298-
mutation($name: String!, $datasetType: DatasetType) {
299-
createDataset(name: $name, datasetType: $datasetType) {
304+
mutation($name: String!, $datasetType: DatasetType, $config: DataConfigInput) {
305+
createDataset(name: $name, datasetType: $datasetType, config: $config ) {
300306
id
301307
name
302308
}
303309
}
304310
"""
305311

306-
def __init__(self, name: str, dataset_type: str = None):
312+
def __init__(self, name: str, dataset_type: str = None, ocr_engine: OcrEngine = None,
313+
omnipage_ocr_options: OmnipageOcrOptionsInput = None,
314+
readapi_ocr_options: ReadApiOcrOptionsInput = None):
307315
if not dataset_type:
308316
dataset_type = "TEXT"
309-
317+
config = None
318+
if ocr_engine is not None:
319+
config = {
320+
"ocrOptions": {
321+
"ocrEngine": ocr_engine.name,
322+
"omnipageOptions": omnipage_ocr_options,
323+
"readapiOptions": readapi_ocr_options
324+
}
325+
}
310326
super().__init__(
311-
self.query, variables={"name": name, "datasetType": dataset_type}
327+
self.query, variables={"name": name, "datasetType": dataset_type,
328+
"config": jsons.dump(config, key_transformer=jsons.KEY_TRANSFORMER_CAMELCASE,
329+
strip_nulls=True)}
312330
)
313331

314332
def process_response(self, response):
@@ -324,7 +342,6 @@ class _AddFiles(GraphQLRequest):
324342
}
325343
}
326344
"""
327-
328345

329346
def __init__(self, dataset_id: int, metadata: List[str]):
330347
super().__init__(
@@ -358,11 +375,11 @@ class AddFiles(RequestChain):
358375
previous = None
359376

360377
def __init__(
361-
self,
362-
dataset_id: int,
363-
files: List[str],
364-
wait: bool = True,
365-
batch_size: int = 20,
378+
self,
379+
dataset_id: int,
380+
files: List[str],
381+
wait: bool = True,
382+
batch_size: int = 20,
366383
):
367384
self.dataset_id = dataset_id
368385
self.files = files
@@ -380,8 +397,8 @@ def requests(self):
380397
yield GetDatasetFileStatus(id=self.dataset_id)
381398
debouncer = Debouncer()
382399
while not all(
383-
f.status in ["DOWNLOADED", "FAILED", "PROCESSED"]
384-
for f in self.previous.files
400+
f.status in ["DOWNLOADED", "FAILED", "PROCESSED"]
401+
for f in self.previous.files
385402
):
386403
yield GetDatasetFileStatus(id=self.previous.id)
387404
debouncer.backoff()
@@ -448,10 +465,10 @@ class ProcessFiles(RequestChain):
448465
"""
449466

450467
def __init__(
451-
self,
452-
dataset_id: int,
453-
datafile_ids: List[int],
454-
wait: bool = True,
468+
self,
469+
dataset_id: int,
470+
datafile_ids: List[int],
471+
wait: bool = True,
455472
):
456473
self.dataset_id = dataset_id
457474
self.datafile_ids = datafile_ids
@@ -463,7 +480,7 @@ def requests(self):
463480
yield GetDatasetFileStatus(id=self.dataset_id)
464481
if self.wait:
465482
while not all(
466-
f.status in ["PROCESSED", "FAILED"] for f in self.previous.files
483+
f.status in ["PROCESSED", "FAILED"] for f in self.previous.files
467484
):
468485
yield GetDatasetFileStatus(id=self.dataset_id)
469486
debouncer.backoff()
@@ -497,7 +514,36 @@ def requests(self):
497514
yield GetDatasetFileStatus(id=self.dataset_id)
498515
if self.wait:
499516
while not all(
500-
f.status in ["PROCESSED", "FAILED"] for f in self.previous.files
517+
f.status in ["PROCESSED", "FAILED"] for f in self.previous.files
501518
):
502519
yield GetDatasetFileStatus(id=self.dataset_id)
503520
debouncer.backoff()
521+
522+
523+
class GetOcrEngineLanguageCodes(GraphQLRequest):
524+
"""
525+
Fetches and lists the available languages by name and code for the given OCR Engine
526+
527+
Args:
528+
ocr_engine(OcrEngine): The engine to fetch for.
529+
"""
530+
query = """query{
531+
ocrOptions {
532+
engines{
533+
name
534+
languages {
535+
name
536+
code
537+
}
538+
}
539+
}
540+
}"""
541+
542+
def __init__(self, engine: OcrEngine):
543+
self.engine = engine
544+
super().__init__(self.query)
545+
546+
def process_response(self, response):
547+
data = super().process_response(response)["ocrOptions"]["engines"]
548+
engine_laguages = next(x["languages"] for x in data if x["name"] == self.engine.name)
549+
return [OcrInputLanguage(**option) for option in engine_laguages]

indico/types/dataset.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import List
23

34
from indico.types.base import BaseType
@@ -58,3 +59,69 @@ def labelset_by_name(self, name: str) -> LabelSet:
5859

5960
def datacolumn_by_name(self, name: str) -> DataColumn:
6061
return next(l for l in self.datacolumns if l.name == name)
62+
63+
64+
class TableReadOrder(Enum):
65+
ROW = 0
66+
COLUMN = 1
67+
68+
class OcrEngine(Enum):
69+
"""
70+
Enum representing available OCR engines.
71+
"""
72+
OMNIPAGE = 0
73+
READAPI = 1
74+
pass
75+
76+
class OmnipageOcrOptionsInput(BaseType):
77+
"""
78+
Omnipage specific OCR options for dataset creation.
79+
80+
Args:
81+
auto_rotate(bool): auto rotate.
82+
single_colum(bool): Read table as a single column.
83+
upscale_images(bool): Scale up low-resolution images.
84+
languages(List[OmnipageLanguageCode]): List of languages to use in ocr.
85+
cells(bool): Return table information for post-processing rules
86+
force_render(bool): Force rednering.
87+
native_layout(bool): Native layout.
88+
native_pdf(bool): Native pdf.
89+
table_read_order(TableReadOrder): Read table by row or column.
90+
91+
"""
92+
auto_rotate: bool
93+
single_column: bool
94+
upscale_images: bool
95+
languages: List[str]
96+
cells: bool
97+
force_render: bool
98+
native_layout: bool
99+
native_pdf: bool
100+
table_read_order: TableReadOrder
101+
102+
class ReadApiOcrOptionsInput(BaseType):
103+
"""
104+
Read API OCR options.
105+
106+
Args:
107+
auto_rotate(bool): Auto rotate
108+
single_column(bool): Read table as a single column.
109+
upscale_images(bool): Scale up low resolution images.
110+
languages(List[str]): List of languages to use.
111+
"""
112+
auto_rotate: bool
113+
single_column: bool
114+
upscale_images: bool
115+
languages: List[str]
116+
117+
class OcrInputLanguage(BaseType):
118+
name: str
119+
code: str
120+
121+
class OcrOptionsInput():
122+
"""
123+
Input options for OCR engine.
124+
"""
125+
ocr_engine: OcrEngine
126+
omnipage_options: OmnipageOcrOptionsInput
127+
readapi_options: ReadApiOcrOptionsInput

tests/integration/queries/test_dataset.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ProcessCSV,
1818
)
1919
from indico.queries.export import CreateExport, DownloadExport
20-
from indico.types.dataset import Dataset
20+
from indico.types.dataset import Dataset, OmnipageOcrOptionsInput, TableReadOrder, OcrEngine
2121
from indico.errors import IndicoRequestError
2222
from tests.integration.data.datasets import airlines_dataset
2323

@@ -175,6 +175,24 @@ def _dataset_complete(dataset):
175175
assert dataset.status == "COMPLETE"
176176

177177

178+
def test_create_with_options(indico):
179+
client = IndicoClient()
180+
config: OmnipageOcrOptionsInput = {
181+
"auto_rotate": True,
182+
"single_column": True,
183+
"upscale_images": True,
184+
"languages": ["ENG", "FIN"],
185+
"force_render": False,
186+
"native_layout": False,
187+
"native_pdf": False,
188+
"table_read_order": TableReadOrder.ROW
189+
}
190+
dataset = client.call(CreateEmptyDataset(name=f"dataset-{int(time.time())}", ocr_engine=OcrEngine.OMNIPAGE,
191+
omnipage_ocr_options=config))
192+
193+
194+
195+
178196
def test_create_from_files_document(indico):
179197
client = IndicoClient()
180198
dataset = client.call(CreateEmptyDataset(name=f"dataset-{int(time.time())}"))

0 commit comments

Comments
 (0)