Skip to content

Commit

Permalink
Merge pull request #1567 from myhloli/dev
Browse files Browse the repository at this point in the history
refactor(table): add device configuration for Unitable model
  • Loading branch information
myhloli authored Jan 17, 2025
2 parents fd5427a + e64d4fe commit af3ec55
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType

from magic_pdf.libs.config_reader import get_device


class RapidTableModel(object):
def __init__(self, ocr_engine, table_sub_model_name):
Expand All @@ -13,7 +15,7 @@ def __init__(self, ocr_engine, table_sub_model_name):
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
Expand Down

0 comments on commit af3ec55

Please sign in to comment.