Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ services:
volumes:
- ./logs:/app/logs
- ./.env:/app/.env
arax_pathfinder:
container_name: arax_pathfinder
build:
context: .
dockerfile: workers/arax_pathfinder/Dockerfile
restart: unless-stopped
depends_on:
shepherd_db:
condition: service_healthy
shepherd_broker:
condition: service_healthy
volumes:
- ./logs:/app/logs
- ./.env:/app/.env

######### BTE
bte:
Expand Down
12 changes: 12 additions & 0 deletions shepherd_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,19 @@ class Settings(BaseSettings):
kg_retrieval_url: str = "https://strider.renci.org/asyncquery"
sync_kg_retrieval_url: str = "https://strider.renci.org/query"
omnicorp_url: str = "https://aragorn-ranker.renci.org/omnicorp_overlay"

# ARAX configs
arax_url: str = "https://arax.ncats.io/shepherd/api/arax/v1.4/query"
plover_url: str = "https://kg2cploverdb.test.transltr.io"
curie_ngd_addr: str = "mysql:arax-databases-mysql.rtx.ai:public_ro:curie_ngd_v1_0_kg2_10_2"
node_degree_addr: str = "mysql:arax-databases-mysql.rtx.ai:public_ro:kg2c_v1_0_kg2_10_2"
arax_biolink_version: str = "4.2.5"
arax_blocked_list_url: str = (
"https://raw.githubusercontent.com/RTXteam/RTX/master/"
"code/ARAX/KnowledgeSources/general_concepts.json"
)
# End of ARAX configs

node_norm: str = "https://biothings.ci.transltr.io/nodenorm/api/"

pathfinder_redis_host: str = "host.docker.internal"
Expand Down
62 changes: 39 additions & 23 deletions workers/arax/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,49 @@
tracer = setup_tracer(STREAM)


async def arax(task, logger: logging.Logger):
def is_pathfinder_query(message):
try:
start = time.time()
query_id = task[1]["query_id"]
logger.info(f"Getting message from db for query id {query_id}")
message = await get_message(query_id, logger)
message["submitter"] = "Shepherd"
logger.info(f"Get the message from db {message}")

headers = {"Content-Type": "application/json"}
response = requests.post(settings.arax_url, json=message, headers=headers)

logger.info(f"Status Code from ARAX response: {response.status_code}")
result = response.json()

except Exception as e:
logger.error(f"Error occurred in ARAX entry module: {e}")
result = {"status": "error", "error": str(e)}

response_id = task[1]["response_id"]
# this can still fail if the input looks like e.g.:
# "query_graph": None
qedges = message.get("message", {}).get("query_graph", {}).get("edges", {})
except:
qedges = {}
try:
# this can still fail if the input looks like e.g.:
# "query_graph": None
qpaths = message.get("message", {}).get("query_graph", {}).get("paths", {})
except:
qpaths = {}
if len(qpaths) > 1:
raise Exception("Only a single path is supported", 400)
if (len(qpaths) > 0) and (len(qedges) > 0):
raise Exception("Mixed mode pathfinder queries are not supported", 400)
return len(qpaths) == 1

await save_message(response_id, result, logger)

workflow = [{"id": "arax"}]
async def arax(task, logger: logging.Logger):
start = time.time()
query_id = task[1]["query_id"]
logger.info(f"Getting message from db for query id {query_id}")
message = await get_message(query_id, logger)
if is_pathfinder_query(message):
workflow = [{"id": "arax.pathfinder"}]
else:
try:
workflow = [{"id": "arax"}]
message["submitter"] = "Shepherd"
logger.info(f"Get the message from db {message}")
headers = {"Content-Type": "application/json"}
response = requests.post(settings.arax_url, json=message, headers=headers)
logger.info(f"Status Code from ARAX response: {response.status_code}")
result = response.json()
except Exception as e:
logger.error(f"Error occurred calling ARAX service: {e}")
result = {"status": "error", "error": str(e)}
response_id = task[1]["response_id"]
await save_message(response_id, result, logger)

await wrap_up_task(STREAM, GROUP, task, workflow, logger)

logger.info(f"Finished task {task[0]} in {time.time() - start}")


Expand All @@ -61,7 +77,7 @@ async def process_task(task, parent_ctx, logger, limiter):

async def poll_for_tasks():
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))

Expand Down
34 changes: 34 additions & 0 deletions workers/arax_pathfinder/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Use RENCI python base image
FROM ghcr.io/translatorsri/renci-python-image:3.11.5

# Add image info
LABEL org.opencontainers.image.source https://github.com/BioPack-team/shepherd

ENV PYTHONHASHSEED=0

# set up requirements
WORKDIR /app

# make sure all is writeable for the nru USER later on
RUN chmod -R 777 .

# Install requirements
COPY ./shepherd_utils ./shepherd_utils
COPY ./pyproject.toml .
RUN pip install .

COPY ./workers/arax_pathfinder/requirements.txt .
RUN pip install -r requirements.txt

# switch to the non-root user (nru). defined in the base image
USER nru

# Copy in files
COPY ./workers/arax_pathfinder ./

# Set up base for command and any variables
# that shouldn't be modified
# ENTRYPOINT ["uvicorn", "shepherd_server.server:APP"]

# Variables that can be overriden
CMD ["python", "worker.py"]
Empty file.
2 changes: 2 additions & 0 deletions workers/arax_pathfinder/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
catrax-pathfinder==1.0.2
biolink-helper-pkg==1.0.0
175 changes: 175 additions & 0 deletions workers/arax_pathfinder/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Arax ARA Pathfinder module."""

import requests
import asyncio
import json
import logging
import time
import uuid
from pathlib import Path
from pathfinder.Pathfinder import Pathfinder
from biolink_helper_pkg import BiolinkHelper

from shepherd_utils.config import settings
from shepherd_utils.db import (
get_message,
save_message,
)
from shepherd_utils.otel import setup_tracer
from shepherd_utils.shared import (
get_tasks,
wrap_up_task,
)

# Queue name
STREAM = "arax.pathfinder"
# Consumer group, most likely you don't need to change this.
GROUP = "consumer"
CONSUMER = str(uuid.uuid4())[:8]
TASK_LIMIT = 100
tracer = setup_tracer(STREAM)

NUM_TOTAL_HOPS = 4
MAX_PATHFINDER_PATHS = 500



OUT_PATH = Path("general_concepts.json")

def download_file(url: str, out_path: Path, overwrite: bool = False) -> Path:
out_path = Path(out_path)

if out_path.exists() and not overwrite:
return out_path

out_path.parent.mkdir(parents=True, exist_ok=True)

r = requests.get(url, timeout=60)
r.raise_for_status()

out_path.write_bytes(r.content)
return out_path


def get_blocked_list():
download_file(settings.arax_blocked_list_url, OUT_PATH, False)

with open(OUT_PATH, 'r') as file:
json_block_list = json.load(file)
synonyms = set(s.lower() for s in json_block_list['synonyms'])
return set(json_block_list['curies']), synonyms


async def pathfinder(task, logger: logging.Logger):
start = time.time()
query_id = task[1]["query_id"]
workflow = json.loads(task[1]["workflow"])
response_id = task[1]["response_id"]
message = await get_message(query_id, logger)
parameters = message.get("parameters") or {}
parameters["timeout"] = parameters.get("timeout", settings.lookup_timeout)
parameters["tiers"] = parameters.get("tiers") or [0]
message["parameters"] = parameters

qgraph = message["message"]["query_graph"]
pinned_node_keys = []
pinned_node_ids = []
for node_key, node in qgraph["nodes"].items():
pinned_node_keys.append(node_key)
if node.get("ids", None) is not None:
pinned_node_ids.append(node["ids"][0])
if len(set(pinned_node_ids)) != 2:
logger.error("Pathfinder queries require two pinned nodes.")
return message, 500

intermediate_categories = []
path_key = next(iter(qgraph["paths"].keys()))
qpath = qgraph["paths"][path_key]
if qpath.get("constraints", None) is not None:
constraints = qpath["constraints"]
if len(constraints) > 1:
logger.error("Pathfinder queries do not support multiple constraints.")
return message, 500
if len(constraints) > 0:
intermediate_categories = (
constraints[0].get("intermediate_categories", None) or []
)
if len(intermediate_categories) > 1:
logger.error(
"Pathfinder queries do not support multiple intermediate categories"
)
return message, 500
else:
intermediate_categories = ["biolink:NamedThing"]

blocked_curies, blocked_synonyms = get_blocked_list()
pathfinder = Pathfinder(
"MLRepo",
settings.plover_url,
settings.curie_ngd_addr,
settings.node_degree_addr,
blocked_curies,
blocked_synonyms,
logger
)

biolink_cache_dir = "/tmp/biolink"
Path(biolink_cache_dir).mkdir(parents=True, exist_ok=True)
biolink_helper = BiolinkHelper(settings.arax_biolink_version, biolink_cache_dir)
descendants = set(biolink_helper.get_descendants(intermediate_categories[0]))

try:
result, aux_graphs, knowledge_graph = pathfinder.get_paths(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if your pathfinder code is asynchronous or not, but this call is blocking and so your pathfinder implementation can only handle one query at a time. Is this intended?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @maximusunc

Could you please provide me your json query that you sent and got 0 paths?

Copy link
Copy Markdown
Collaborator Author

@mohsenht mohsenht Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my query and I got result for this one.

{
    "message": {
        "query_graph": {
            "nodes": {
                "n0": {
                    "ids": [
                        "CHEBI:31690"
                    ]
                },
                "n1": {
                    "ids": [
                        "MONDO:0004979"
                    ]
                }
            },
            "paths": {
                "p0": {
                    "subject": "n0",
                    "object": "n1",
                    "predicates": [
                        "biolink:related_to"
                    ],
                    "constraints": []
                }
            }
        }
    }
}

Copy link
Copy Markdown
Collaborator Author

@mohsenht mohsenht Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is now can handle multiple queries.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now we have the opposite effect here. It looks like now we're trying to handle every query that comes in, and running this locally, the RAM and CPU usage of this worker shot way up and hit my docker limits. I think we need to tune this so that CPU and RAM stay reasonable. What do we think is reasonable @mohsenht @dkoslicki ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also a little concerned with the query that you sent that does return results. That Chebi curie is Imatinib, but it is not normalized. Are you doing some normalization in your pathfinder code somewhere?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The worker pool capped at 4, each query takes under 1GB of RAM.
Since requests already take 2 to 4 minutes, I'm worried that reducing the workers to save resources will make end users wait way too long.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. What worker pool are you talking about? And how Shepherd is set up, it is made to horizontally scale, so if one arax_pathfinder worker can't handle all the requests coming in, Shepherd can just spin up another one. So when we build these individual workers, we should make sure that we understand where the bottlenecks are and pick a reasonable threshold for what it should be able to handle as a single worker, and then we can duplicate them at the kubernetes level to handle more load.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense regarding Shepherd's horizontal scaling. The local resource spike is coming from num_cores = min(multiprocessing.cpu_count(), 4) inside the pathfinder library.

I added that multiprocessing specifically to speed up path expansion and keep query times down to 2-4 minutes. If we restrict each worker to 1 core to reduce its footprint, individual queries will take significantly longer.

We have a trade-off: Are we okay with longer end-user wait times so Shepherd can scale smaller workers? Or should we set a higher baseline per worker (e.g., 4 cores) to keep queries fast, and let Shepherd scale those larger pods?

pinned_node_ids[0],
pinned_node_ids[1],
pinned_node_keys[0],
pinned_node_keys[1],
NUM_TOTAL_HOPS,
NUM_TOTAL_HOPS,
MAX_PATHFINDER_PATHS,
descendants,
)
res = []
if result is not None:
res.append({
"id": result["id"],
"analyses": result['analyses'],
"node_bindings": result['node_bindings'],
"essence": "result"
})
if aux_graphs is None:
aux_graphs = {}
if knowledge_graph is None:
knowledge_graph = {}
message["message"]["knowledge_graph"] = knowledge_graph
message["message"]["auxiliary_graphs"] = aux_graphs
message["message"]["results"] = res
await save_message(response_id, message, logger)
except Exception as e:
logger.error(f"PathFinder failed to find paths between {pinned_node_keys[0]} and {pinned_node_keys[1]}. "
f"Error message is: {e}")
message = {"status": "error", "error": str(e)}
await save_message(response_id, message, logger)

await wrap_up_task(STREAM, GROUP, task, workflow, logger)
logger.info(f"Task took {time.time() - start}")


async def process_task(task, parent_ctx, logger, limiter):
span = tracer.start_span(STREAM, context=parent_ctx)
try:
await pathfinder(task, logger)
finally:
span.end()
limiter.release()


async def poll_for_tasks():
async for task, parent_ctx, logger, limiter in get_tasks(
STREAM, GROUP, CONSUMER, TASK_LIMIT
):
asyncio.create_task(process_task(task, parent_ctx, logger, limiter))


if __name__ == "__main__":
asyncio.run(poll_for_tasks())
Loading