Skip to content

Commit

Permalink
proper model download
Browse files Browse the repository at this point in the history
  • Loading branch information
Matteo Omenetti [email protected] committed Jan 14, 2025
1 parent 66700ab commit 8e838f8
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions demo/demo_code_formula_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def main(args):

# ! TODO: change this
# Download models from HF
# download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0")
# artifact_path = os.path.join(download_path, "model_artifacts/layout")
artifact_path = "/dccstor/doc_fig_class/DocFM-Vision-Pretrainer/Vary-master/checkpoints_new_code_equation_model/checkpoint-7000/"
download_path = snapshot_download(repo_id="ds4sd/CodeFormula")
artifact_path = os.path.join(download_path, "")
# artifact_path = "/dccstor/doc_fig_class/DocFM-Vision-Pretrainer/Vary-master/checkpoints_new_code_equation_model/checkpoint-7000/"

# Test the LayoutPredictor
demo(logger, artifact_path, device, num_threads, image_dir, viz_dir)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_code_formula_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
# Copyright IBM Corp. 2024 - 2024
# SPDX-License-Identifier: MIT
#
import os
import numpy as np
import pytest
from PIL import Image

from docling_ibm_models.code_formula_model.code_formula_predictor import CodeFormulaPredictor

from huggingface_hub import snapshot_download

@pytest.fixture(scope="module")
def init() -> dict:
r"""
Expand All @@ -28,19 +31,16 @@ def init() -> dict:
},
],
"info": {
"device": "cpu",
"device": "auto",
"temperature": 0,
},
}

# Download models from HF
# TODO: do this
# download_path = snapshot_download(repo_id="ds4sd/docling-models", revision="v2.1.0")
# artifact_path = os.path.join(download_path, "model_artifacts/layout")
download_path = snapshot_download(repo_id="ds4sd/CodeFormula")
artifact_path = os.path.join(download_path, "")

init["artifact_path"] = (
"/dccstor/doc_fig_class/DocFM-Vision-Pretrainer/Vary-master/checkpoints_new_code_equation_model/checkpoint-7000/"
)
init["artifact_path"] = artifact_path

return init

Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/code_formula/gt/code.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ int main(){
}

return 0;
}
}

0 comments on commit 8e838f8

Please sign in to comment.