diff --git a/src/retriever/config/utils.py b/src/retriever/config/utils.py index da241f15..e143531c 100644 --- a/src/retriever/config/utils.py +++ b/src/retriever/config/utils.py @@ -10,7 +10,7 @@ from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap -from retriever.types.general import JsonSerializable +from retriever.types.base import JsonSerializable yaml = YAML() CommentedSerializable = JsonSerializable | list[CommentedMap] | dict[str, CommentedMap] diff --git a/src/retriever/types/base.py b/src/retriever/types/base.py new file mode 100644 index 00000000..6e26969c --- /dev/null +++ b/src/retriever/types/base.py @@ -0,0 +1,6 @@ +# No types in this module may depend on anything from other modules. +JsonPrimitive = None | int | float | str | bool + +JsonSerializable = ( + JsonPrimitive | list["JsonSerializable"] | dict[str, "JsonSerializable"] +) diff --git a/src/retriever/types/general.py b/src/retriever/types/general.py index e9c5b979..aa2766f3 100644 --- a/src/retriever/types/general.py +++ b/src/retriever/types/general.py @@ -91,14 +91,5 @@ class LookupArtifacts(NamedTuple): # A pair of Qnode and CURIE, used to uniquely identify partial results QNodeCURIEPair = tuple[QNodeID, CURIE] -JsonSerializable = ( - None - | int - | float - | str - | bool - | list["JsonSerializable"] - | dict[str, "JsonSerializable"] -) EntityToEntityMapping = dict[CURIE, list[CURIE]] diff --git a/src/retriever/types/trapi.py b/src/retriever/types/trapi.py index 6ada6f4d..00894e04 100644 --- a/src/retriever/types/trapi.py +++ b/src/retriever/types/trapi.py @@ -15,7 +15,9 @@ from __future__ import annotations from enum import Enum -from typing import Any, NotRequired, TypedDict +from typing import NotRequired, TypedDict + +from retriever.types.base import JsonSerializable class SetInterpretationEnum(str, Enum): @@ -48,7 +50,7 @@ class QueryDict(TypedDict): message: MessageDict log_level: NotRequired[LogLevel | None] - workflow: NotRequired[list[dict[str, Any]] | None] + workflow: NotRequired[list[dict[str, JsonSerializable]] | None] submitter: NotRequired[str | None] bypass_cache: NotRequired[bool] parameters: NotRequired[ParametersDict | None] @@ -84,7 +86,7 @@ class ResponseDict(TypedDict): status: NotRequired[str | None] description: NotRequired[str | None] logs: NotRequired[list[LogEntryDict]] - workflow: NotRequired[list[dict[str, Any]] | None] + workflow: NotRequired[list[dict[str, JsonSerializable]] | None] parameters: NotRequired[ParametersDict | None] schema_version: NotRequired[str | None] biolink_version: NotRequired[str | None] @@ -249,7 +251,7 @@ class AttributeDict(TypedDict): attribute_type_id: str original_attribute_name: NotRequired[str | None] - value: Any + value: JsonSerializable value_type_id: NotRequired[str | None] attribute_source: NotRequired[str | None] value_url: NotRequired[URL | None] @@ -335,7 +337,7 @@ class MetaAttributeDict(TypedDict): "name": str, "not": NotRequired[bool], "operator": OperatorEnum, - "value": Any, + "value": JsonSerializable, "unit_id": NotRequired[str | None], "unit_name": NotRequired[str | None], }, diff --git a/src/retriever/utils/trapi.py b/src/retriever/utils/trapi.py index a0139295..4809d931 100644 --- a/src/retriever/utils/trapi.py +++ b/src/retriever/utils/trapi.py @@ -4,12 +4,13 @@ import uuid from collections import defaultdict from collections.abc import Sequence -from typing import Any, cast +from typing import cast from opentelemetry import trace from reasoner_pydantic import QueryGraph from reasoner_pydantic.utils import make_hashable +from retriever.types.base import JsonSerializable from retriever.types.trapi import ( CURIE, AnalysisDict, @@ -462,6 +463,40 @@ def meta_attributes_meet_constraints( return all(constr["id"] in attribute_type_ids for constr in constraints) +def _compare_values( + constr: JsonSerializable, attr: JsonSerializable, operator: OperatorEnum +) -> bool: + if operator == OperatorEnum.EQUAL: + return attr == constr + elif operator in (OperatorEnum.GT, OperatorEnum.LT): + if isinstance(attr, dict | list | None): + raise TypeError("Cannot use operators `>` or `<` with type `{type(a_val}`.") + elif isinstance(constr, dict | list | None): + raise TypeError("Cannot use operators `>` or `<` with type `{type(c_val}`.") + + # Ensure values to compare are same type (or are both numbers for int/float) + num_disagreement = isinstance(attr, int | float) != isinstance( + constr, int | float + ) + type_disagreement = type(attr) is not type(constr) + if (type_disagreement and not num_disagreement) or num_disagreement: + raise TypeError( + "Cannot compare unalike types (constraint: `{type(c_val)}`, attribute: `{type(a_val)}`)" + ) + # NOTE: Doing some bogus casts to make type check understand + # THEY HAVE BEEN CONFIRMED TO BE THE SAME TYPE (see above) + if operator == OperatorEnum.GT: + return cast(int, attr) > cast(int, constr) + else: + return cast(int, attr) < cast(int, constr) + else: # OperatorEnum.MATCH + if not isinstance(constr, str): + raise TypeError( + f"Cannot use constraint value of type `{type(constr)}` as regex pattern." + ) + return bool(re.search(constr, str(attr))) + + def attribute_meets_constraint( constraint: AttributeConstraintDict, attribute: AttributeDict ) -> bool: @@ -480,22 +515,20 @@ def attribute_meets_constraint( # Per attribute constraints, all other operators operate # On either the value itself, or list members if the value is a list # This way, we can do both at once - attr_values: list[Any] = ( # pyright:ignore[reportUnknownVariableType] + constr_values: list[JsonSerializable] = ( + constraint_value if isinstance(constraint_value, list) else [constraint_value] + ) + attr_values: list[JsonSerializable] = ( attribute["value"] if isinstance(attribute["value"], list) else [attribute["value"]] ) - success = False - for value in attr_values: - if ( - (operator == OperatorEnum.EQUAL and (value == constraint_value)) - or (operator == OperatorEnum.GT and (value > constraint_value)) - or (operator == OperatorEnum.LT and (value < constraint_value)) - or (operator == OperatorEnum.MATCH and (re.search(constraint_value, value))) - ): - success = True - break + success: bool = False + success = any( + any(_compare_values(c_val, a_val, operator) for a_val in attr_values) + for c_val in constr_values + ) if negated: success = not success diff --git a/tests/test_utils_trapi.py b/tests/test_utils_trapi.py index c4e2df52..c43699e5 100644 --- a/tests/test_utils_trapi.py +++ b/tests/test_utils_trapi.py @@ -637,6 +637,7 @@ def test_shortcuts(self): def test_equals( self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict ): + assert attribute_meets_constraint( AttributeConstraintDict( name="some_type is equal to 0", @@ -722,6 +723,27 @@ def test_equals( numeric_array_attribute, ) + # Test both array values + assert attribute_meets_constraint( + AttributeConstraintDict( + name="some_type has 0, 1, or 2", + id="some_type", + value=[0, 1, 2], + operator=OperatorEnum.EQUAL, + ), + numeric_array_attribute, + ) + + assert not attribute_meets_constraint( + AttributeConstraintDict( + name="some_type has 3, 4, or 5", + id="some_type", + value=[3, 4, 5], + operator=OperatorEnum.EQUAL, + ), + numeric_array_attribute, + ) + def test_strict_equals( self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict ): @@ -813,6 +835,19 @@ def test_strict_equals( def test_greater_than( self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict ): + + # Test fail case + with pytest.raises(TypeError): + assert attribute_meets_constraint( + AttributeConstraintDict( + name="A constraint that should error", + id="some_type", + value={"some": "dictionary"}, + operator=OperatorEnum.GT, + ), + numeric_attribute, + ) + assert attribute_meets_constraint( AttributeConstraintDict( name="some_type is greater than -1", @@ -922,6 +957,33 @@ def test_greater_than( def test_less_than( self, numeric_attribute: AttributeDict, numeric_array_attribute: AttributeDict ): + + # Test fail case + with pytest.raises(TypeError): + assert attribute_meets_constraint( + AttributeConstraintDict( + name="A constraint that should error", + id="some_type", + value="a", + operator=OperatorEnum.LT, + ), + AttributeDict( + attribute_type_id="some_type", value={"some": "dictionary"} + ), + ) + + # Test fail case + with pytest.raises(TypeError): + assert attribute_meets_constraint( + AttributeConstraintDict( + name="A constraint that should error", + id="some_type", + value="a", + operator=OperatorEnum.LT, + ), + AttributeDict(attribute_type_id="some_type", value=1), + ) + assert attribute_meets_constraint( AttributeConstraintDict( name="some_type is less than 1", @@ -1031,6 +1093,17 @@ def test_less_than( def test_matches( self, string_attribute: AttributeDict, string_array_attribute: AttributeDict ): + # Test fail case + with pytest.raises(TypeError): + assert attribute_meets_constraint( + AttributeConstraintDict( + name="A constraint that should error", + id="some_type", + value={"some": "dictionary"}, + operator=OperatorEnum.MATCH, + ), + string_attribute, + ) assert attribute_meets_constraint( AttributeConstraintDict( name="some_type ends in 'bc'", @@ -1084,7 +1157,6 @@ def test_matches( string_array_attribute, ) - assert not attribute_meets_constraint( AttributeConstraintDict( name="some_type has a value which ends in 'yz'",