diff --git a/Dockerfile b/Dockerfile index e7fa40f..99efd70 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,7 +23,10 @@ RUN conda install -n xmen -c conda-forge nmslib cymem murmurhash -y # Install pip dependencies RUN pip install --no-cache-dir -r requirements.txt +# Download the pre-trained models so they are cached in the Docker image +RUN python download_models.py + EXPOSE 5000 # Define the command to run the server with parameters -CMD ["conda", "run", "-n", "xmen", "python3", "run_snomed_german_recommender.py", "--no-gpu", "--port", "5000", "index"] +CMD ["conda", "run", "-n", "xmen", "python3", "run_snomed_german_recommender.py", "--no-gpu", "--port", "5000", "index", "--num_recs", "10"] diff --git a/download_models.py b/download_models.py new file mode 100644 index 0000000..5db49fb --- /dev/null +++ b/download_models.py @@ -0,0 +1,15 @@ +from xmen.linkers.model_wrapper import Model_Wrapper +from xmen.linkers import SapBERTLinker +from transformers import logging as tf_logging +import logging + +logging.basicConfig(level=logging.INFO) +tf_logging.set_verbosity_info() + +def download_models(): + """ Downloads the Hugging Face models required for the project. """ + Model_Wrapper().load_model(SapBERTLinker.CROSS_LINGUAL, use_cuda=False) + +if __name__ == '__main__': + + download_models() \ No newline at end of file diff --git a/run_snomed_german_recommender.py b/run_snomed_german_recommender.py index 655c4c2..c933c00 100644 --- a/run_snomed_german_recommender.py +++ b/run_snomed_german_recommender.py @@ -14,6 +14,12 @@ from utils import handle_dates +from transformers import logging as tf_logging +import logging + +logging.basicConfig(level=logging.INFO) +tf_logging.set_verbosity_info() + class xMENSNOMEDLinker(Classifier): def __init__(self, linker: EntityLinker, top_k = 3): self.linker = linker @@ -48,8 +54,10 @@ def run(): # Suppress InconsistentVersionWarning from TF-IDF vectorizer warnings.filterwarnings("ignore", category=InconsistentVersionWarning) + print("Loading xMEN SNOMED Linker...", flush=True) linker = default_ensemble(args.index_base_path, cuda=args.gpu) + print("Starting Ariadne server...", flush=True) server = Server() server.add_classifier("xmen_snomed", xMENSNOMEDLinker(linker, top_k=args.num_recs))