Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/retriever/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap

from retriever.types.general import JsonSerializable
from retriever.types.base import JsonSerializable

yaml = YAML()
CommentedSerializable = JsonSerializable | list[CommentedMap] | dict[str, CommentedMap]
Expand Down
6 changes: 6 additions & 0 deletions src/retriever/types/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# No types in this module may depend on anything from other modules.
JsonPrimitive = None | int | float | str | bool

JsonSerializable = (
JsonPrimitive | list["JsonSerializable"] | dict[str, "JsonSerializable"]
)
9 changes: 0 additions & 9 deletions src/retriever/types/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,5 @@ class LookupArtifacts(NamedTuple):
# A pair of Qnode and CURIE, used to uniquely identify partial results
QNodeCURIEPair = tuple[QNodeID, CURIE]

JsonSerializable = (
None
| int
| float
| str
| bool
| list["JsonSerializable"]
| dict[str, "JsonSerializable"]
)

EntityToEntityMapping = dict[CURIE, list[CURIE]]
12 changes: 7 additions & 5 deletions src/retriever/types/trapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from __future__ import annotations

from enum import Enum
from typing import Any, NotRequired, TypedDict
from typing import NotRequired, TypedDict

from retriever.types.base import JsonSerializable


class SetInterpretationEnum(str, Enum):
Expand Down Expand Up @@ -48,7 +50,7 @@ class QueryDict(TypedDict):

message: MessageDict
log_level: NotRequired[LogLevel | None]
workflow: NotRequired[list[dict[str, Any]] | None]
workflow: NotRequired[list[dict[str, JsonSerializable]] | None]
submitter: NotRequired[str | None]
bypass_cache: NotRequired[bool]
parameters: NotRequired[ParametersDict | None]
Expand Down Expand Up @@ -84,7 +86,7 @@ class ResponseDict(TypedDict):
status: NotRequired[str | None]
description: NotRequired[str | None]
logs: NotRequired[list[LogEntryDict]]
workflow: NotRequired[list[dict[str, Any]] | None]
workflow: NotRequired[list[dict[str, JsonSerializable]] | None]
parameters: NotRequired[ParametersDict | None]
schema_version: NotRequired[str | None]
biolink_version: NotRequired[str | None]
Expand Down Expand Up @@ -249,7 +251,7 @@ class AttributeDict(TypedDict):

attribute_type_id: str
original_attribute_name: NotRequired[str | None]
value: Any
value: JsonSerializable
value_type_id: NotRequired[str | None]
attribute_source: NotRequired[str | None]
value_url: NotRequired[URL | None]
Expand Down Expand Up @@ -335,7 +337,7 @@ class MetaAttributeDict(TypedDict):
"name": str,
"not": NotRequired[bool],
"operator": OperatorEnum,
"value": Any,
"value": JsonSerializable,
"unit_id": NotRequired[str | None],
"unit_name": NotRequired[str | None],
},
Expand Down
57 changes: 45 additions & 12 deletions src/retriever/utils/trapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import uuid
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, cast
from typing import cast

from opentelemetry import trace
from reasoner_pydantic import QueryGraph
from reasoner_pydantic.utils import make_hashable

from retriever.types.base import JsonSerializable
from retriever.types.trapi import (
CURIE,
AnalysisDict,
Expand Down Expand Up @@ -462,6 +463,40 @@ def meta_attributes_meet_constraints(
return all(constr["id"] in attribute_type_ids for constr in constraints)


def _compare_values(
constr: JsonSerializable, attr: JsonSerializable, operator: OperatorEnum
) -> bool:
if operator == OperatorEnum.EQUAL:
return attr == constr
elif operator in (OperatorEnum.GT, OperatorEnum.LT):
if isinstance(attr, dict | list | None):
raise TypeError("Cannot use operators `>` or `<` with type `{type(a_val}`.")
elif isinstance(constr, dict | list | None):
raise TypeError("Cannot use operators `>` or `<` with type `{type(c_val}`.")

# Ensure values to compare are same type (or are both numbers for int/float)
num_disagreement = isinstance(attr, int | float) != isinstance(
constr, int | float
)
type_disagreement = type(attr) is not type(constr)
if (type_disagreement and not num_disagreement) or num_disagreement:
raise TypeError(
"Cannot compare unalike types (constraint: `{type(c_val)}`, attribute: `{type(a_val)}`)"
)
# NOTE: Doing some bogus casts to make type check understand
# THEY HAVE BEEN CONFIRMED TO BE THE SAME TYPE (see above)
if operator == OperatorEnum.GT:
return cast(int, attr) > cast(int, constr)
else:
return cast(int, attr) < cast(int, constr)
else: # OperatorEnum.MATCH
if not isinstance(constr, str):
raise TypeError(
f"Cannot use constraint value of type `{type(constr)}` as regex pattern."
)
return bool(re.search(constr, str(attr)))


def attribute_meets_constraint(
constraint: AttributeConstraintDict, attribute: AttributeDict
) -> bool:
Expand All @@ -480,22 +515,20 @@ def attribute_meets_constraint(
# Per attribute constraints, all other operators operate
# On either the value itself, or list members if the value is a list
# This way, we can do both at once
attr_values: list[Any] = ( # pyright:ignore[reportUnknownVariableType]
constr_values: list[JsonSerializable] = (
constraint_value if isinstance(constraint_value, list) else [constraint_value]
)
attr_values: list[JsonSerializable] = (
attribute["value"]
if isinstance(attribute["value"], list)
else [attribute["value"]]
)

success = False
for value in attr_values:
if (
(operator == OperatorEnum.EQUAL and (value == constraint_value))
or (operator == OperatorEnum.GT and (value > constraint_value))
or (operator == OperatorEnum.LT and (value < constraint_value))
or (operator == OperatorEnum.MATCH and (re.search(constraint_value, value)))
):
success = True
break
success: bool = False
success = any(
any(_compare_values(c_val, a_val, operator) for a_val in attr_values)
for c_val in constr_values
)

if negated:
success = not success
Expand Down
74 changes: 73 additions & 1 deletion tests/test_utils_trapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def test_shortcuts(self):
def test_equals(
self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict
):

assert attribute_meets_constraint(
AttributeConstraintDict(
name="some_type is equal to 0",
Expand Down Expand Up @@ -722,6 +723,27 @@ def test_equals(
numeric_array_attribute,
)

# Test both array values
assert attribute_meets_constraint(
AttributeConstraintDict(
name="some_type has 0, 1, or 2",
id="some_type",
value=[0, 1, 2],
operator=OperatorEnum.EQUAL,
),
numeric_array_attribute,
)

assert not attribute_meets_constraint(
AttributeConstraintDict(
name="some_type has 3, 4, or 5",
id="some_type",
value=[3, 4, 5],
operator=OperatorEnum.EQUAL,
),
numeric_array_attribute,
)

def test_strict_equals(
self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict
):
Expand Down Expand Up @@ -813,6 +835,19 @@ def test_strict_equals(
def test_greater_than(
self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict
):

# Test fail case
with pytest.raises(TypeError):
assert attribute_meets_constraint(
AttributeConstraintDict(
name="A constraint that should error",
id="some_type",
value={"some": "dictionary"},
operator=OperatorEnum.GT,
),
numeric_attribute,
)

assert attribute_meets_constraint(
AttributeConstraintDict(
name="some_type is greater than -1",
Expand Down Expand Up @@ -922,6 +957,33 @@ def test_greater_than(
def test_less_than(
self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict
):

# Test fail case
with pytest.raises(TypeError):
assert attribute_meets_constraint(
AttributeConstraintDict(
name="A constraint that should error",
id="some_type",
value="a",
operator=OperatorEnum.LT,
),
AttributeDict(
attribute_type_id="some_type", value={"some": "dictionary"}
),
)

# Test fail case
with pytest.raises(TypeError):
assert attribute_meets_constraint(
AttributeConstraintDict(
name="A constraint that should error",
id="some_type",
value="a",
operator=OperatorEnum.LT,
),
AttributeDict(attribute_type_id="some_type", value=1),
)

assert attribute_meets_constraint(
AttributeConstraintDict(
name="some_type is less than 1",
Expand Down Expand Up @@ -1031,6 +1093,17 @@ def test_less_than(
def test_matches(
self, string_attribute: AttributeDict, string_array_attribute: AttributeDict
):
# Test fail case
with pytest.raises(TypeError):
assert attribute_meets_constraint(
AttributeConstraintDict(
name="A constraint that should error",
id="some_type",
value={"some": "dictionary"},
operator=OperatorEnum.MATCH,
),
string_attribute,
)
assert attribute_meets_constraint(
AttributeConstraintDict(
name="some_type ends in 'bc'",
Expand Down Expand Up @@ -1084,7 +1157,6 @@ def test_matches(
string_array_attribute,
)


assert not attribute_meets_constraint(
AttributeConstraintDict(
name="some_type has a value which ends in 'yz'",
Expand Down
Loading