11# -*- coding: utf-8 -*-
22
33import json
4+ import jsons
45import tempfile
56from pathlib import Path
67from typing import List
1314 HTTPRequest ,
1415 RequestChain ,
1516)
16- from indico .errors import IndicoNotFound
17+ from indico .errors import IndicoNotFound , IndicoInputError
1718from indico .queries .storage import UploadBatched , UploadImages
18- from indico .types .dataset import Dataset
19+ from indico .types .dataset import Dataset , OcrEngine , OmnipageOcrOptionsInput , ReadApiOcrOptionsInput , OcrInputLanguage
1920
2021
2122class ListDatasets (GraphQLRequest ):
@@ -184,14 +185,17 @@ class CreateDataset(RequestChain):
184185 previous = None
185186
186187 def __init__ (
187- self ,
188- name : str ,
189- files : List [str ],
190- wait : bool = True ,
191- dataset_type : str = "TEXT" ,
192- from_local_images : bool = False ,
193- image_filename_col : str = "filename" ,
194- batch_size : int = 20 ,
188+ self ,
189+ name : str ,
190+ files : List [str ],
191+ wait : bool = True ,
192+ dataset_type : str = "TEXT" ,
193+ from_local_images : bool = False ,
194+ image_filename_col : str = "filename" ,
195+ batch_size : int = 20 ,
196+ ocr_engine : OcrEngine = None ,
197+ omnipage_ocr_options : OmnipageOcrOptionsInput = None ,
198+ read_api_ocr_options : ReadApiOcrOptionsInput = None
195199 ):
196200 self .files = files
197201 self .name = name
@@ -200,6 +204,8 @@ def __init__(
200204 self .from_local_images = from_local_images
201205 self .image_filename_col = image_filename_col
202206 self .batch_size = batch_size
207+ if omnipage_ocr_options is not None and read_api_ocr_options is not None :
208+ raise IndicoInputError ("Must supply either omnipage or readapi options but not both." )
203209 super ().__init__ ()
204210
205211 def requests (self ):
@@ -235,7 +241,7 @@ def requests(self):
235241 yield GetDatasetFileStatus (id = dataset_id )
236242 debouncer = Debouncer ()
237243 while not all (
238- f .status in ["DOWNLOADED" , "FAILED" ] for f in self .previous .files
244+ f .status in ["DOWNLOADED" , "FAILED" ] for f in self .previous .files
239245 ):
240246 yield GetDatasetFileStatus (id = self .previous .id )
241247 debouncer .backoff ()
@@ -250,7 +256,7 @@ def requests(self):
250256 debouncer = Debouncer ()
251257 if self .wait is True :
252258 while not all (
253- [f .status in ["PROCESSED" , "FAILED" ] for f in self .previous .files ]
259+ [f .status in ["PROCESSED" , "FAILED" ] for f in self .previous .files ]
254260 ):
255261 yield GetDatasetFileStatus (id = dataset_id )
256262 debouncer .backoff ()
@@ -295,20 +301,32 @@ def process_response(self, response):
295301
296302class CreateEmptyDataset (GraphQLRequest ):
297303 query = """
298- mutation($name: String!, $datasetType: DatasetType) {
299- createDataset(name: $name, datasetType: $datasetType) {
304+ mutation($name: String!, $datasetType: DatasetType, $config: DataConfigInput ) {
305+ createDataset(name: $name, datasetType: $datasetType, config: $config ) {
300306 id
301307 name
302308 }
303309 }
304310 """
305311
306- def __init__ (self , name : str , dataset_type : str = None ):
312+ def __init__ (self , name : str , dataset_type : str = None , ocr_engine : OcrEngine = None ,
313+ omnipage_ocr_options : OmnipageOcrOptionsInput = None ,
314+ readapi_ocr_options : ReadApiOcrOptionsInput = None ):
307315 if not dataset_type :
308316 dataset_type = "TEXT"
309-
317+ config = None
318+ if ocr_engine is not None :
319+ config = {
320+ "ocrOptions" : {
321+ "ocrEngine" : ocr_engine .name ,
322+ "omnipageOptions" : omnipage_ocr_options ,
323+ "readapiOptions" : readapi_ocr_options
324+ }
325+ }
310326 super ().__init__ (
311- self .query , variables = {"name" : name , "datasetType" : dataset_type }
327+ self .query , variables = {"name" : name , "datasetType" : dataset_type ,
328+ "config" : jsons .dump (config , key_transformer = jsons .KEY_TRANSFORMER_CAMELCASE ,
329+ strip_nulls = True )}
312330 )
313331
314332 def process_response (self , response ):
@@ -324,7 +342,6 @@ class _AddFiles(GraphQLRequest):
324342 }
325343 }
326344 """
327-
328345
329346 def __init__ (self , dataset_id : int , metadata : List [str ]):
330347 super ().__init__ (
@@ -358,11 +375,11 @@ class AddFiles(RequestChain):
358375 previous = None
359376
360377 def __init__ (
361- self ,
362- dataset_id : int ,
363- files : List [str ],
364- wait : bool = True ,
365- batch_size : int = 20 ,
378+ self ,
379+ dataset_id : int ,
380+ files : List [str ],
381+ wait : bool = True ,
382+ batch_size : int = 20 ,
366383 ):
367384 self .dataset_id = dataset_id
368385 self .files = files
@@ -380,8 +397,8 @@ def requests(self):
380397 yield GetDatasetFileStatus (id = self .dataset_id )
381398 debouncer = Debouncer ()
382399 while not all (
383- f .status in ["DOWNLOADED" , "FAILED" , "PROCESSED" ]
384- for f in self .previous .files
400+ f .status in ["DOWNLOADED" , "FAILED" , "PROCESSED" ]
401+ for f in self .previous .files
385402 ):
386403 yield GetDatasetFileStatus (id = self .previous .id )
387404 debouncer .backoff ()
@@ -448,10 +465,10 @@ class ProcessFiles(RequestChain):
448465 """
449466
450467 def __init__ (
451- self ,
452- dataset_id : int ,
453- datafile_ids : List [int ],
454- wait : bool = True ,
468+ self ,
469+ dataset_id : int ,
470+ datafile_ids : List [int ],
471+ wait : bool = True ,
455472 ):
456473 self .dataset_id = dataset_id
457474 self .datafile_ids = datafile_ids
@@ -463,7 +480,7 @@ def requests(self):
463480 yield GetDatasetFileStatus (id = self .dataset_id )
464481 if self .wait :
465482 while not all (
466- f .status in ["PROCESSED" , "FAILED" ] for f in self .previous .files
483+ f .status in ["PROCESSED" , "FAILED" ] for f in self .previous .files
467484 ):
468485 yield GetDatasetFileStatus (id = self .dataset_id )
469486 debouncer .backoff ()
@@ -497,7 +514,36 @@ def requests(self):
497514 yield GetDatasetFileStatus (id = self .dataset_id )
498515 if self .wait :
499516 while not all (
500- f .status in ["PROCESSED" , "FAILED" ] for f in self .previous .files
517+ f .status in ["PROCESSED" , "FAILED" ] for f in self .previous .files
501518 ):
502519 yield GetDatasetFileStatus (id = self .dataset_id )
503520 debouncer .backoff ()
521+
522+
523+ class GetOcrEngineLanguageCodes (GraphQLRequest ):
524+ """
525+ Fetches and lists the available languages by name and code for the given OCR Engine
526+
527+ Args:
528+ ocr_engine(OcrEngine): The engine to fetch for.
529+ """
530+ query = """query{
531+ ocrOptions {
532+ engines{
533+ name
534+ languages {
535+ name
536+ code
537+ }
538+ }
539+ }
540+ }"""
541+
542+ def __init__ (self , engine : OcrEngine ):
543+ self .engine = engine
544+ super ().__init__ (self .query )
545+
546+ def process_response (self , response ):
547+ data = super ().process_response (response )["ocrOptions" ]["engines" ]
548+ engine_laguages = next (x ["languages" ] for x in data if x ["name" ] == self .engine .name )
549+ return [OcrInputLanguage (** option ) for option in engine_laguages ]
0 commit comments