Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,20 @@ services:
volumes:
- ./logs:/app/logs
- ./.env:/app/.env
arax_pathfinder:
container_name: arax_pathfinder
build:
context: .
dockerfile: workers/arax_pathfinder/Dockerfile
restart: unless-stopped
depends_on:
shepherd_db:
condition: service_healthy
shepherd_broker:
condition: service_healthy
volumes:
- ./logs:/app/logs
- ./.env:/app/.env

arax_rank:
container_name: arax_rank
Expand Down
16 changes: 16 additions & 0 deletions shepherd_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,23 @@ class Settings(BaseSettings):
sync_kg_retrieval_url: str = "https://strider.renci.org/query"
default_data_tier: int = 0
omnicorp_url: str = "https://aragorn-ranker.renci.org/omnicorp_overlay"

# ARAX configs
arax_url: str = "https://arax.ncats.io/shepherd/api/arax/v1.4/query"
plover_url: str = "https://kg2cplover3.rtx.ai:9990"
curie_ngd_addr: str = (
"mysql:arax-databases-mysql.rtx.ai:public_ro:curie_ngd_v1_0_kg2_10_2"
)
node_degree_addr: str = (
"mysql:arax-databases-mysql.rtx.ai:public_ro:kg2c_v1_0_kg2_10_2"
)
arax_biolink_version: str = "4.2.5"
arax_blocked_list_url: str = (
"https://raw.githubusercontent.com/RTXteam/RTX/master/"
"code/ARAX/KnowledgeSources/general_concepts.json"
)
# End of ARAX configs

node_norm: str = "https://biothings.ci.transltr.io/nodenorm/api/"

pathfinder_redis_host: str = "host.docker.internal"
Expand Down
61 changes: 41 additions & 20 deletions workers/arax/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,51 @@
tracer = setup_tracer(STREAM)


async def arax(task, logger: logging.Logger):
def is_pathfinder_query(message):
try:
query_id = task[1]["query_id"]
logger.info(f"Getting message from db for query id {query_id}")
message = await get_message(query_id, logger)
message["submitter"] = "Shepherd"
logger.info(f"Get the message from db {message}")

headers = {"Content-Type": "application/json"}
response = requests.post(settings.arax_url, json=message, headers=headers)

logger.info(f"Status Code from ARAX response: {response.status_code}")
result = response.json()
result = add_shepherd_arax_to_edge_sources(result)

except Exception as e:
logger.error(f"Error occurred in ARAX entry module: {e}")
result = {"status": "error", "error": str(e)}
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
except:
qedges = {}
try:
# this can still fail if the input looks like e.g.:
# "query_graph": None
qpaths = message.get("message", {}).get("query_graph", {}).get("paths", {})
except:
qpaths = {}
if len(qpaths) > 1:
raise Exception("Only a single path is supported", 400)
if (len(qpaths) > 0) and (len(qedges) > 0):
raise Exception("Mixed mode pathfinder queries are not supported", 400)
return len(qpaths) == 1

response_id = task[1]["response_id"]

await save_message(response_id, result, logger)
async def arax(task, logger: logging.Logger):
start = time.time()
query_id = task[1]["query_id"]
logger.info(f"Getting message from db for query id {query_id}")
message = await get_message(query_id, logger)
if is_pathfinder_query(message):
workflow = [{"id": "arax.pathfinder"}]
await wrap_up_task(STREAM, GROUP, task, workflow, logger)
else:
try:
message["submitter"] = "Shepherd"
logger.info(f"Get the message from db {message}")
headers = {"Content-Type": "application/json"}
response = requests.post(settings.arax_url, json=message, headers=headers)
logger.info(f"Status Code from ARAX response: {response.status_code}")
result = response.json()
result = add_shepherd_arax_to_edge_sources(result)
except Exception as e:
logger.error(f"Error occurred calling ARAX service: {e}")
result = {"status": "error", "error": str(e)}
response_id = task[1]["response_id"]
await save_message(response_id, result, logger)
task[1]["workflow"] = json.dumps([{"id": "arax"}])

task[1]["workflow"] = json.dumps([{"id": "arax"}])
logger.info(f"Finished task {task[0]} in {time.time() - start}")


async def process_task(task, parent_ctx, logger: logging.Logger, limiter):
Expand Down
34 changes: 34 additions & 0 deletions workers/arax_pathfinder/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Use RENCI python base image
FROM ghcr.io/translatorsri/renci-python-image:3.11.5

# Add image info
LABEL org.opencontainers.image.source https://github.com/BioPack-team/shepherd

ENV PYTHONHASHSEED=0

# set up requirements
WORKDIR /app

# make sure all is writeable for the nru USER later on
RUN chmod -R 777 .

# Install requirements
COPY ./shepherd_utils ./shepherd_utils
COPY ./pyproject.toml .
RUN pip install .

COPY ./workers/arax_pathfinder/requirements.txt .
RUN pip install -r requirements.txt

# switch to the non-root user (nru). defined in the base image
USER nru

# Copy in files
COPY ./workers/arax_pathfinder ./

# Set up base for command and any variables
# that shouldn't be modified
# ENTRYPOINT ["uvicorn", "shepherd_server.server:APP"]

# Variables that can be overriden
CMD ["python", "worker.py"]
Empty file.
2 changes: 2 additions & 0 deletions workers/arax_pathfinder/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
catrax-pathfinder==1.2.2
biolink-helper-pkg==1.0.0
208 changes: 208 additions & 0 deletions workers/arax_pathfinder/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""Arax ARA Pathfinder module."""

import requests
import asyncio
import json
import logging
import time
import uuid
from pathlib import Path
from pathfinder.Pathfinder import Pathfinder
from biolink_helper_pkg import BiolinkHelper

from shepherd_utils.inject_shepherd_arax_provenance import add_shepherd_arax_to_edge_sources
from shepherd_utils.config import settings
from shepherd_utils.db import (
get_message,
save_message,
)
from shepherd_utils.otel import setup_tracer
from shepherd_utils.shared import (
get_tasks,
wrap_up_task,
)

# Queue name
STREAM = "arax.pathfinder"
# Consumer group, most likely you don't need to change this.
GROUP = "consumer"
CONSUMER = str(uuid.uuid4())[:8]
TASK_LIMIT = 100
tracer = setup_tracer(STREAM)

NUM_TOTAL_HOPS = 4
MAX_HOPS_TO_EXPLORE = 4
MAX_PATHFINDER_PATHS = 500
PRUNE_TOP_K = 200
NODE_DEGREE_THRESHOLD = 1000000

OUT_PATH = Path("general_concepts.json")


def download_file(url: str, out_path: Path, overwrite: bool = False) -> Path:
out_path = Path(out_path)

if out_path.exists() and not overwrite:
return out_path

out_path.parent.mkdir(parents=True, exist_ok=True)

r = requests.get(url, timeout=60)
r.raise_for_status()

out_path.write_bytes(r.content)
return out_path


def get_blocked_list():
download_file(settings.arax_blocked_list_url, OUT_PATH, False)

with open(OUT_PATH, "r") as file:
json_block_list = json.load(file)
synonyms = set(s.lower() for s in json_block_list["synonyms"])
return set(json_block_list["curies"]), synonyms


def execute_pathfinding_sync(pinned_node_ids, pinned_node_keys, intermediate_categories, logger):

blocked_curies, blocked_synonyms = get_blocked_list()

pathfinder_instance = Pathfinder(
"MLRepo",
settings.plover_url,
settings.curie_ngd_addr,
settings.node_degree_addr,
blocked_curies,
blocked_synonyms,
logger,
)

biolink_cache_dir = "/tmp/biolink"
Path(biolink_cache_dir).mkdir(parents=True, exist_ok=True)
biolink_helper = BiolinkHelper(settings.arax_biolink_version, biolink_cache_dir)
descendants = set(biolink_helper.get_descendants(intermediate_categories[0]))

start = time.perf_counter()
logger.info("Starting pathfinder.get_paths() in worker thread")

result, aux_graphs, knowledge_graph = pathfinder_instance.get_paths(
pinned_node_ids[0],
pinned_node_ids[1],
pinned_node_keys[0],
pinned_node_keys[1],
NUM_TOTAL_HOPS,
MAX_HOPS_TO_EXPLORE,
MAX_PATHFINDER_PATHS,
PRUNE_TOP_K,
NODE_DEGREE_THRESHOLD,
descendants,
)

elapsed = time.perf_counter() - start
logger.info(f"pathfinder.get_paths() finished in {elapsed:.3f} seconds")

return result, aux_graphs, knowledge_graph


async def pathfinder(task, logger: logging.Logger):
start = time.time()
query_id = task[1]["query_id"]
workflow = json.loads(task[1]["workflow"])
response_id = task[1]["response_id"]
message = await get_message(query_id, logger)
parameters = message.get("parameters") or {}
parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout)
parameters["tiers"] = parameters.get("tiers") or [0]
message["parameters"] = parameters

qgraph = message["message"]["query_graph"]
pinned_node_keys = []
pinned_node_ids = []
for node_key, node in qgraph["nodes"].items():
pinned_node_keys.append(node_key)
if node.get("ids", None) is not None:
pinned_node_ids.append(node["ids"][0])
if len(set(pinned_node_ids)) != 2:
logger.error("Pathfinder queries require two pinned nodes.")
return message, 500

intermediate_categories = []
path_key = next(iter(qgraph["paths"].keys()))
qpath = qgraph["paths"][path_key]
if qpath.get("constraints", None) is not None and len(qpath.get("constraints", [])) > 0:
constraints = qpath["constraints"]
if len(constraints) > 1:
logger.error("Pathfinder queries do not support multiple constraints.")
return message, 500
if len(constraints) > 0:
intermediate_categories = (
constraints[0].get("intermediate_categories", None) or []
)
if len(intermediate_categories) > 1:
logger.error(
"Pathfinder queries do not support multiple intermediate categories"
)
return message, 500
else:
intermediate_categories = ["biolink:NamedThing"]

try:
result, aux_graphs, knowledge_graph = await asyncio.to_thread(
execute_pathfinding_sync,
pinned_node_ids,
pinned_node_keys,
intermediate_categories,
logger
)

res = []
if result is not None:
res.append(
{
"id": result["id"],
"analyses": result["analyses"],
"node_bindings": result["node_bindings"],
"essence": "result",
}
)
if aux_graphs is None:
aux_graphs = {}
if knowledge_graph is None:
knowledge_graph = {}
message["message"]["knowledge_graph"] = knowledge_graph
message["message"]["auxiliary_graphs"] = aux_graphs
message["message"]["results"] = res

message = add_shepherd_arax_to_edge_sources(message)

await save_message(response_id, message, logger)
except Exception as e:
logger.error(
f"PathFinder failed to find paths between {pinned_node_keys[0]} and {pinned_node_keys[1]}. "
f"Error message is: {e}"
)
message = {"status": "error", "error": str(e)}
await save_message(response_id, message, logger)

await wrap_up_task(STREAM, GROUP, task, workflow, logger)
logger.info(f"Task took {time.time() - start}")


async def process_task(task, parent_ctx, logger, limiter):
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await pathfinder(task, logger)
finally:
span.end()
limiter.release()


async def poll_for_tasks():
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))


if __name__ == "__main__":
asyncio.run(poll_for_tasks())
Loading