Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/retriever/data_tiers/tier_1/elasticsearch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from retriever.data_tiers.tier_1.elasticsearch.types import (
ESEdge,
ESNode,
ESPayload,
)
from retriever.data_tiers.utils import (
Expand Down Expand Up @@ -172,6 +173,31 @@ async def run(

return results

async def fetch_single_node(self, _curie: str) -> ESNode | None:
"""Fetch a single canonical node from the Elasticsearch backend."""
index_name = "ubergraph_nodes"

if self.es_connection is None:
raise RuntimeError(
"Must use ElasticSearchDriver.connect() before fetching node metadata."
)

response = await self.es_connection.search(
index=index_name,
size=1,
query={"term": {"id": _curie}},
)
hits = response["hits"]["hits"]
if len(hits) == 0:
return None
total_hits = response["hits"]["total"]["value"]
if total_hits > 1:
log.warning(
f"Found {total_hits} canonical node hits for {_curie} in `ubergraph_nodes`; using the first match."
)

return ESNode.from_dict(hits[0]["_source"])

@override
@tracer.start_as_current_span("elasticsearch_query")
async def run_query(
Expand Down
33 changes: 23 additions & 10 deletions src/retriever/data_tiers/tier_1/elasticsearch/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,27 @@ def build_attributes(

return attributes

def build_single_node(
self, node: ESNode, attributes: list[AttributeDict] | None = None
) -> NodeDict:
"""Build a single TRAPI node from the given knowledge."""
_attributes = [] if attributes is None else attributes

if attributes is None:
# Cases that require additional formatting to be TRAPI-compliant
special_cases: SpecialCaseDict = {}
_attributes = self.build_attributes(node, special_cases)

trapi_node = NodeDict(
name=node.name,
categories=[
BiolinkEntity(biolink.ensure_prefix(cat)) for cat in node.category
],
attributes=_attributes,
)

return trapi_node

def build_nodes(
self, edges: list[ESEdge], query_subject: QNodeDict, query_object: QNodeDict
) -> dict[CURIE, NodeDict]:
Expand All @@ -288,8 +309,6 @@ def build_nodes(
node_ids[node_pos] = node_id
if node_id in nodes:
continue
attributes: list[AttributeDict] = []

# Cases that require additional formatting to be TRAPI-compliant
special_cases: SpecialCaseDict = {}

Expand All @@ -298,17 +317,11 @@ def build_nodes(
constraints = (
query_subject if node_pos == "subject" else query_object
).get("constraints", []) or []

if not attributes_meet_contraints(constraints, attributes):
continue

trapi_node = NodeDict(
name=node.name,
categories=[
BiolinkEntity(biolink.ensure_prefix(cat))
for cat in node.category
],
attributes=attributes,
)
trapi_node = self.build_single_node(node, attributes)

nodes[node_id] = trapi_node

Expand Down
45 changes: 45 additions & 0 deletions src/retriever/lookup/qgx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from opentelemetry import trace

from retriever.config.general import CONFIG
from retriever.data_tiers import tier_manager
from retriever.data_tiers.tier_1.elasticsearch.driver import ElasticSearchDriver
from retriever.data_tiers.tier_1.elasticsearch.transpiler import ElasticsearchTranspiler
from retriever.lookup.branch import (
Branch,
BranchID,
Expand Down Expand Up @@ -199,6 +202,8 @@ async def execute(self) -> LookupArtifacts:
await asyncio.sleep(0)
results.append(part.as_result(self.k_agraph))

await self.hydrate_missing_nodes()

timeout_task.cancel()
end_time = time.time()
duration_ms = math.ceil((end_time - self.start_time) * 1000)
Expand Down Expand Up @@ -234,6 +239,46 @@ async def execute(self) -> LookupArtifacts:
[], self.kgraph, self.aux_graphs, self.job_log.get_logs(), error=True
)

async def hydrate_missing_nodes(self) -> None:
"""Hydrate skeletal KG nodes using tier-1 canonical node metadata."""
incomplete_nodes = [
curie
for curie, node in self.kgraph["nodes"].items()
if len(node.get("categories", [])) == 0 or not node.get("name")
]
if len(incomplete_nodes) == 0:
return

driver: ElasticSearchDriver = cast(
ElasticSearchDriver, tier_manager.get_driver(1)
)
transpiler: ElasticsearchTranspiler = cast(
ElasticsearchTranspiler, tier_manager.get_transpiler(1)
)

hydrated_count = 0

for curie in incomplete_nodes:
try:
fetched = await driver.fetch_single_node(str(curie))
if fetched is None:
self.job_log.warning(
f"Unable to hydrate node metadata for {curie}: no canonical tier-1 node was found."
)
continue

trapi_node = transpiler.build_single_node(fetched)
update_kgraph(
self.kgraph,
KnowledgeGraphDict(nodes={curie: trapi_node}, edges={}),
)
hydrated_count += 1
except Exception:
self.job_log.exception(f"Failed to hydrate node metadata for {curie}.")

if hydrated_count > 0:
self.job_log.debug(f"Hydrated {hydrated_count} skeletal KG nodes.")

async def expand_initial_subclasses(self) -> None:
"""Check if any pinned nodes have subclasses and expand them accordingly."""
for qnode_id, node in self.qgraph["nodes"].items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,28 @@ def generate_qgraph_with_qualifier_constraints(qualifier_constraints: list[Quali
}
})

HYDRATION_QGRAPH = qg({
"nodes": {
"on": {
"categories": ["biolink:Gene", "biolink:Protein"],
"ids": ["NCBIGene:4314"],
"is_set": False,
},
"sn": {
"categories": ["biolink:ChemicalEntity"],
"ids": ["CHEBI:48927"],
"is_set": False,
},
},
"edges": {
"e00": {
"object": "on",
"predicates": ["biolink:affects"],
"subject": "sn",
}
},
})


def generate_qgraph_with_attribute_constraints(constraints: list[AttributeConstraintDict]):
"""Generate a QGraph with attribute constraints."""
Expand Down
35 changes: 29 additions & 6 deletions tests/data_tiers/tier_1/elasticsearch_tests/test_tier1_driver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import importlib
from typing import Iterator, cast, Any
from collections.abc import Iterator
from typing import Any, cast

import pytest
import retriever.config.general as general_mod
import retriever.data_tiers.tier_1.elasticsearch.driver as driver_mod
from retriever.data_tiers.tier_1.elasticsearch.meta import extract_metadata_entries_from_blob, get_t1_indices
from retriever.data_tiers.tier_1.elasticsearch.transpiler import ElasticsearchTranspiler
from retriever.data_tiers.tier_1.elasticsearch.types import ESPayload, ESEdge
from retriever.data_tiers.tier_1.elasticsearch.types import ESPayload, ESEdge, ESNode
from payload.trapi_qgraphs import DINGO_QGRAPH, VALID_REGEX_QGRAPHS, INVALID_REGEX_QGRAPHS, ID_BYPASS_PAYLOAD
from retriever.utils.redis import RedisClient
from test_tier1_transpiler import _convert_triple, _convert_batch_triple
Expand Down Expand Up @@ -93,11 +94,11 @@ def mock_elasticsearch_config(monkeypatch: pytest.MonkeyPatch) -> Iterator[None]
"payload, expected",
[
(PAYLOAD_0, 0),
(PAYLOAD_1, 4),
(PAYLOAD_1, 2),
(PAYLOAD_2, 32),
(
[PAYLOAD_0, PAYLOAD_1, PAYLOAD_2],
[0, 4, 32]
[0, 2, 32]
)
],
ids=[
Expand Down Expand Up @@ -203,13 +204,35 @@ async def test_metadata_retrieval():
await driver.close()


@pytest.mark.usefixtures("mock_elasticsearch_config")
@pytest.mark.asyncio
async def test_fetch_single_node():
driver: driver_mod.ElasticSearchDriver = driver_mod.ElasticSearchDriver()

try:
await driver.connect()
assert driver.es_connection is not None
except Exception:
pytest.skip("skipping fetch_single_node test: cannot connect")

node = await driver.fetch_single_node("CHEBI:48927")

assert isinstance(node, ESNode)
assert node is not None
assert node.id == "CHEBI:48927"
assert len(node.category) > 0
assert node.name == "N-acyl-L-alpha-amino acid"

await driver.close()


@pytest.mark.usefixtures("mock_elasticsearch_config")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"qgraph, expected_hits",
[
(DINGO_QGRAPH, 8),
(ID_BYPASS_PAYLOAD, 6181), # <-- adjust to the real number
(ID_BYPASS_PAYLOAD, 4176), # <-- adjust to the real number
],
)
async def test_end_to_end(qgraph, expected_hits):
Expand Down Expand Up @@ -381,6 +404,6 @@ async def test_ubergraph_info_retrieval():
# print(k, v)

# assert "mapping" in info
assert len(info) == 122707
assert len(info) == 122176

await driver.close()
47 changes: 47 additions & 0 deletions tests/test_trapi_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,53 @@ async def test_subclass_case2_id_to_cat(tier: int) -> None:
assert "UMLS:C3273258" in n0_ids, "UMLS:C3273258 must appear in n0 result bindings"


@pytest.mark.live
@pytest.mark.asyncio
async def test_tier1_query_hydrates_empty_query_nodes() -> None:
"""
Tier-1 query should hydrate the pinned ancestor node metadata rather than
leaving an empty shell node in the final knowledge graph.
"""
query_graph = {
"nodes": {
"on": {
"categories": ["biolink:Gene", "biolink:Protein"],
"ids": ["NCBIGene:4314"],
"is_set": False,
},
"sn": {
"categories": ["biolink:ChemicalEntity"],
"ids": ["CHEBI:48927"],
"is_set": False,
},
},
"edges": {
"e00": {
"object": "on",
"predicates": ["biolink:affects"],
"subject": "sn",
}
},
}

async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(QUERY_ENDPOINT, json=_request(1, query_graph))

msg = _assert_ok(response)
assert msg["results"], "Expected at least one tier-1 result for CHEBI:48927 affects NCBIGene:4314"

kg_nodes = msg["knowledge_graph"]["nodes"]
assert len(kg_nodes) == 3, "Expected hydrated ancestor, descendant support node, and target gene node"
assert "CHEBI:48927" in kg_nodes, "Pinned CHEBI:48927 must appear in the knowledge graph"
assert "CHEBI:71223" in kg_nodes, "Subclass support node CHEBI:71223 must appear in the knowledge graph"
assert "NCBIGene:4314" in kg_nodes, "Pinned NCBIGene:4314 must appear in the knowledge graph"
assert kg_nodes["CHEBI:48927"]["categories"], "CHEBI:48927 should be hydrated with categories"
assert kg_nodes["CHEBI:48927"].get("name") == "N-acyl-L-alpha-amino acid"
assert all(node.get("categories") for node in kg_nodes.values()), (
"No returned knowledge graph node should have empty categories after hydration"
)


# ---------------------------------------------------------------------------
# Multi-hop queries
#
Expand Down
Loading