Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ab2964d
added scoring model
uhbrar Feb 23, 2026
6167713
update worker requirements
uhbrar Feb 23, 2026
7eae783
add classifier weights
uhbrar Feb 23, 2026
f1d0f34
Update pathfinder predicates
maximusunc Feb 21, 2026
6373d84
Update Gandalf and gc
maximusunc Feb 21, 2026
c5d23e7
Bump patch version
maximusunc Feb 21, 2026
46556c2
make scoring model work in docker
uhbrar Feb 24, 2026
092b5b1
use gpu for embedding
uhbrar Feb 24, 2026
16a095f
update predicates
uhbrar Feb 24, 2026
495cdb7
respect timeout
uhbrar Feb 24, 2026
469e70c
Merge branch 'main' into scoring_model
uhbrar Feb 24, 2026
4af90ea
handle no paths
uhbrar Feb 24, 2026
a44ccd3
add gandalf folder to gitignore
uhbrar Feb 24, 2026
d0bdbc4
make gpu optional
uhbrar Feb 25, 2026
2996ee5
fix vram persistence
uhbrar Feb 27, 2026
7e36023
address pr comments
uhbrar Mar 5, 2026
8bd48bc
remove extra logging
uhbrar Mar 5, 2026
a027551
styling
uhbrar Mar 5, 2026
fedb9c2
fix analysis indexing
uhbrar Mar 5, 2026
0e9cc63
use threadpoolexecutor with lock
uhbrar Mar 11, 2026
25b5313
add error handling
uhbrar Mar 11, 2026
56e4ee8
Merge branch 'main' into scoring_model
uhbrar Mar 12, 2026
c0af0fd
Update scorer to main updates and fix some linting
maximusunc Mar 12, 2026
2de5f7b
add checks for empty results
uhbrar Mar 12, 2026
dc4ed24
change task_limit
uhbrar Mar 12, 2026
09a0001
load model first
uhbrar Mar 12, 2026
c9728f4
black
uhbrar Mar 12, 2026
4b121ad
fix model loading
uhbrar Mar 13, 2026
38885c3
Make development on non-gpu machines more bearable
maximusunc Mar 13, 2026
d78c0ef
Bump minor version
maximusunc Mar 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 191 additions & 7 deletions workers/score_paths/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,131 @@
import logging
import time
import uuid
from random import shuffle
from shepherd_utils.db import get_message, save_message, get_query_state
from xgboost import XGBClassifier
from sentence_transformers import SentenceTransformer
from bmt import Toolkit
from shepherd_utils.db import get_message, save_message
from shepherd_utils.shared import get_tasks, wrap_up_task
from shepherd_utils.otel import setup_tracer


# Queue name
STREAM = "score_paths"
GROUP = "consumer"
CONSUMER = str(uuid.uuid4())[:8]
TASK_LIMIT = 100
tracer = setup_tracer(STREAM)
clf = XGBClassifier()
bmt = Toolkit()
model = SentenceTransformer("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")

def get_most_specific_category(categories, logger):
valid = []
for cat in categories:
element = bmt.get_element(cat)
if not element:
logger.error(f"Category {cat} doesn't exist.")
continue
valid.append(cat)

if not valid:
return None

most_specific = []
for cat in valid:
dominated = any(
cat in bmt.get_ancestors(other, reflexive=False)
for other in valid if other != cat
)
if not dominated:
most_specific.append(cat)

return bmt.get_element(most_specific[0])

def convert_path_to_sentence(source, target, path, knowledge_graph, logger):

path_node_list = [source]
while target not in path_node_list:
progress = False
for edge_id in path:
edge = knowledge_graph["edges"].get(edge_id)
if not edge:
logger.error(f"Edge {edge_id} not found in knowledge graph.")
continue
if edge["subject"] == path_node_list[-1]:
if edge["object"] not in path_node_list:
path_node_list.append(edge["object"])
progress = True
elif edge["object"] == path_node_list[-1]:
if edge["subject"] not in path_node_list:
path_node_list.append(edge["subject"])
progress = True
if not progress:
logger.error("Disconnected Path")
raise ValueError(f"Could not construct path from {source} to {target}")

current_node = source
path_predicate_list = []
for hop_num, next_node in enumerate(path_node_list[1:]):
path_predicate_list.append(set())
for edge_id in path:
edge = knowledge_graph["edges"].get(edge_id)
if not edge:
logger.error(f"Edge {edge_id} not found in knowledge graph.")
continue
pred = edge["predicate"]
pred_info = bmt.get_element(pred)
if not pred_info:
logger.error(f"Predicate {pred} doesn't exist in Biolink model.")
continue
pred = pred_info.name
if edge["subject"] == current_node and edge["object"] == next_node:
path_predicate_list[hop_num].add(pred)
elif edge["object"] == current_node and edge["subject"] == next_node:
if bmt.is_symmetric(pred):
path_predicate_list[hop_num].add(pred)
else:
inv = bmt.get_inverse_predicate(pred)
if not inv:
logger.error(f"No inverse found for predicate {pred}.")
else:
path_predicate_list[hop_num].add(inv)
current_node = next_node

source_cat = get_most_specific_category(
knowledge_graph["nodes"][source]["categories"], logger
)
if not source_cat:
raise ValueError(f"Could not determine category for source node {source}.")

path_sentence = f"{knowledge_graph['nodes'][source]['name']} (a {source_cat.name}) "
first_hop = True
for path_node, hop_predicates in zip(path_node_list[1:], path_predicate_list):
hop_preds = list(hop_predicates)
if not first_hop:
path_sentence += ", which "
else:
first_hop = False
if len(hop_preds) == 0:
raise ValueError(f"No predicates found for hop to node {path_node}.")
elif len(hop_preds) == 1:
path_sentence += f"{hop_preds[0]}"
else:
path_sentence += "either ["
for hop_pred in hop_preds[:-1]:
path_sentence += f"{hop_pred} or "
path_sentence += f"{hop_preds[-1]}]"
node_cat = get_most_specific_category(
knowledge_graph["nodes"][path_node]["categories"], logger
)
if not node_cat:
raise ValueError(f"Could not determine category for node {path_node}.")
path_sentence += (
f" {knowledge_graph['nodes'][path_node]['name']} (a {node_cat.name})"
)

logger.debug(f"Generated sentence: {path_sentence}")
return path_sentence

async def score_paths(task, logger: logging.Logger):
start = time.time()
Expand All @@ -25,12 +138,82 @@ async def score_paths(task, logger: logging.Logger):
workflow = json.loads(task[1]["workflow"])
message = await get_message(response_id, logger)
current_op = workflow[0]

try:
for ind, result in enumerate(message["message"]["results"]):
message["message"]["results"][ind]["analyses"] = [
{**analysis, "score": analysis.get("score", 0) or 0}
for analysis in result["analyses"]
]
for qpath_id, qpath in message["message"]["query_graph"]["paths"].items():
source_qnode = qpath["subject"]
target_qnode = qpath["object"]

for ind, result in enumerate(message["message"]["results"]):
# Fatal at the result level: if we can't find source/target
# node bindings the result is malformed, skip it
try:
source = result["node_bindings"][source_qnode][0]["id"]
target = result["node_bindings"][target_qnode][0]["id"]
except KeyError as e:
logger.error(
f"Result {ind} missing expected node binding {e}, skipping."
)
continue

# Pass 1: collect all sentences for this result
tasks = []
for ana_ind, analysis in enumerate(result["analyses"]):
try:
path_id = analysis["path_bindings"][qpath_id][0]["id"]
sentence = convert_path_to_sentence(
source,
target,
message["message"]["auxiliary_graphs"][path_id]["edges"],
message["message"]["knowledge_graph"],
logger,
)
tasks.append((ana_ind, sentence))
except KeyError as e:
logger.error(
f"Result {ind}, analysis {ana_ind}: missing key {e}, skipping analysis."
)
continue
except ValueError as e:
logger.error(
f"Result {ind}, analysis {ana_ind}: could not build sentence."
)
continue

if not tasks:
logger.warning(
f"Result {ind}: no valid analyses to score, skipping."
)
continue

all_sentences = [t[1] for t in tasks]
try:
all_embeddings = model.encode(
all_sentences, batch_size=32, show_progress_bar=False
)
except Exception as e:
logger.error(
f"Result {ind}: embedding failed due to {e}."
)
message["message"]["results"][ind]["analyses"][ana_ind][
"score"
] = 0.0
continue

for (ana_ind, _), embedding in zip(tasks, all_embeddings):
try:
probs = clf.predict_proba(embedding.reshape(1, -1))[:, 1]
message["message"]["results"][ind]["analyses"][ana_ind][
"score"
] = float(probs[0])
except Exception as e:
logger.error(
f"Result {ind}, analysis {ana_ind}: scoring failed due to {e}."
)
message["message"]["results"][ind]["analyses"][ana_ind][
"score"
] = 0.0
continue
except KeyError as e:
# can't find the right structure of message
err = f"Error scoring paths: {e}"
Expand Down Expand Up @@ -62,4 +245,5 @@ async def poll_for_tasks():


if __name__ == "__main__":
clf.load_model("workers/score_paths/model_weights/sapbert_classifier_weights.json")
asyncio.run(poll_for_tasks())
Loading