From 3d5011a474907996dcd9a5fab2dcaa177734ce8f Mon Sep 17 00:00:00 2001 From: Florian Borchert Date: Thu, 5 Dec 2024 11:00:32 +0100 Subject: [PATCH] Model inside image --- Dockerfile | 5 ++++- download_models.py | 15 +++++++++++++++ run_snomed_german_recommender.py | 8 ++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 download_models.py diff --git a/Dockerfile b/Dockerfile index e7fa40f..99efd70 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,10 @@ RUN conda install -n xmen -c conda-forge nmslib cymem murmurhash -y # Install pip dependencies RUN pip install --no-cache-dir -r requirements.txt +# Download the pre-trained models so they are cached in the Docker image +RUN python download_models.py + EXPOSE 5000 # Define the command to run the server with parameters -CMD ["conda", "run", "-n", "xmen", "python3", "run_snomed_german_recommender.py", "--no-gpu", "--port", "5000", "index"] +CMD ["conda", "run", "-n", "xmen", "python3", "run_snomed_german_recommender.py", "--no-gpu", "--port", "5000", "index", "--num_recs", "10"] diff --git a/download_models.py b/download_models.py new file mode 100644 index 0000000..5db49fb --- /dev/null +++ b/download_models.py @@ -0,0 +1,15 @@ +from xmen.linkers.model_wrapper import Model_Wrapper +from xmen.linkers import SapBERTLinker +from transformers import logging as tf_logging +import logging + +logging.basicConfig(level=logging.INFO) +tf_logging.set_verbosity_info() + +def download_models(): + """ Downloads the Hugging Face models required for the project. """ + Model_Wrapper().load_model(SapBERTLinker.CROSS_LINGUAL, use_cuda=False) + +if __name__ == '__main__': + + download_models() \ No newline at end of file diff --git a/run_snomed_german_recommender.py b/run_snomed_german_recommender.py index 655c4c2..c933c00 100644 --- a/run_snomed_german_recommender.py +++ b/run_snomed_german_recommender.py @@ -14,6 +14,12 @@ from utils import handle_dates +from transformers import logging as tf_logging +import logging + +logging.basicConfig(level=logging.INFO) +tf_logging.set_verbosity_info() + class xMENSNOMEDLinker(Classifier): def __init__(self, linker: EntityLinker, top_k = 3): self.linker = linker @@ -48,8 +54,10 @@ def run(): # Suppress InconsistentVersionWarning from TF-IDF vectorizer warnings.filterwarnings("ignore", category=InconsistentVersionWarning) + print("Loading xMEN SNOMED Linker...", flush=True) linker = default_ensemble(args.index_base_path, cuda=args.gpu) + print("Starting Ariadne server...", flush=True) server = Server() server.add_classifier("xmen_snomed", xMENSNOMEDLinker(linker, top_k=args.num_recs))