From a05f0aa6928d2273ea89adde205a72d32a27c000 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Wed, 18 Nov 2020 19:25:47 +0100 Subject: [PATCH 1/4] Skip resolving recursive references in `resolve_all_refs` --- src/hypothesis_jsonschema/_canonicalise.py | 58 +++++---- tests/test_canonicalise.py | 130 +++++++++++++++++++++ 2 files changed, 167 insertions(+), 21 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index a8d2565..5f15eab 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -18,6 +18,7 @@ import re from copy import deepcopy from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union +from urllib.parse import urljoin import jsonschema from hypothesis.errors import InvalidArgument @@ -576,16 +577,25 @@ def resolve_remote(self, uri: str) -> NoReturn: ) +def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool: + """Detect if the given reference is recursive.""" + # Special case: a reference to the schema's root is always recursive + if reference == "#": + return True + # During reference resolving the scope might go to external schemas. `hypothesis-jsonschema` does not support + # schemas behind remote references, but the underlying `jsonschema` library includes meta schemas for + # different JSON Schema drafts that are available transparently, and they count as external schemas in this context. + # For this reason we need to check the reference relatively to the base uri. + full_reference = urljoin(resolver.base_uri, reference) + # If a fully-qualified reference is in the resolution stack, then we encounter it for the second time. + # Therefore it is a recursive reference. + return full_reference in resolver._scopes_stack + + def resolve_all_refs( schema: Union[bool, Schema], *, resolver: LocalResolver = None ) -> Schema: - """ - Resolve all references in the given schema. - - This handles nested definitions, but not recursive definitions. - The latter require special handling to convert to strategies and are much - less common, so we just ignore them (and error out) for now. - """ + """Resolve all non-recursive references in the given schema.""" if isinstance(schema, bool): return canonicalish(schema) assert isinstance(schema, dict), schema @@ -597,27 +607,31 @@ def resolve_all_refs( ) if "$ref" in schema: - s = dict(schema) - ref = s.pop("$ref") - with resolver.resolving(ref) as got: - if s == {}: - return resolve_all_refs(got, resolver=resolver) - m = merged([s, got]) - if m is None: # pragma: no cover - msg = f"$ref:{ref!r} had incompatible base schema {s!r}" - raise HypothesisRefResolutionError(msg) - return resolve_all_refs(m, resolver=resolver) - assert "$ref" not in schema + # Recursive references are skipped to avoid infinite recursion. + if not is_recursive_reference(schema["$ref"], resolver): + s = dict(schema) + ref = s.pop("$ref") + with resolver.resolving(ref) as got: + if s == {}: + return resolve_all_refs(deepcopy(got), resolver=resolver) + m = merged([s, got]) + if m is None: # pragma: no cover + msg = f"$ref:{ref!r} had incompatible base schema {s!r}" + raise HypothesisRefResolutionError(msg) + # `deepcopy` is not needed, because, the schemas are copied inside the `merged` call above + return resolve_all_refs(m, resolver=resolver) for key in SCHEMA_KEYS: val = schema.get(key, False) if isinstance(val, list): schema[key] = [ - resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v + resolve_all_refs(deepcopy(v), resolver=resolver) + if isinstance(v, dict) + else v for v in val ] elif isinstance(val, dict): - schema[key] = resolve_all_refs(val, resolver=resolver) + schema[key] = resolve_all_refs(deepcopy(val), resolver=resolver) else: assert isinstance(val, bool) for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas @@ -625,7 +639,9 @@ def resolve_all_refs( subschema = schema[key] assert isinstance(subschema, dict) schema[key] = { - k: resolve_all_refs(v, resolver=resolver) if isinstance(v, dict) else v + k: resolve_all_refs(deepcopy(v), resolver=resolver) + if isinstance(v, dict) + else v for k, v in subschema.items() } assert isinstance(schema, dict) diff --git a/tests/test_canonicalise.py b/tests/test_canonicalise.py index 86fd4e0..f89cc01 100644 --- a/tests/test_canonicalise.py +++ b/tests/test_canonicalise.py @@ -558,3 +558,133 @@ def test_validators_use_proper_draft(): } cc = canonicalish(schema) jsonschema.validators.validator_for(cc).check_schema(cc) + + +# Reference to itself +ROOT_REFERENCE = {"$ref": "#"} +# One extra nesting level +NESTED = {"not": {"$ref": "#/not"}} +# The same as above, but includes "$id". +NESTED_WITH_ID = { + "not": {"$ref": "#/not"}, + "$id": "http://json-schema.org/draft-07/schema#", +} +SELF_REFERENTIAL = {"foo": {"$ref": "#foo"}, "not": {"$ref": "#foo"}} + + +@pytest.mark.parametrize( + "schema, expected", + ( + (ROOT_REFERENCE, ROOT_REFERENCE), + (NESTED, NESTED), + (NESTED_WITH_ID, NESTED_WITH_ID), + # "foo" content should be inlined as is, because "#" is recursive (special case) + ( + {"foo": {"$ref": "#"}, "not": {"$ref": "#foo"}}, + {"foo": {"$ref": "#"}, "not": {"$ref": "#"}}, + ), + # "foo" content should be inlined as is, because it points to itself + ( + SELF_REFERENTIAL, + SELF_REFERENTIAL, + ), + # The same as above, but with one extra nesting level + ( + {"foo": {"not": {"$ref": "#foo"}}, "not": {"$ref": "#foo"}}, + # 1. We start from resolving "$ref" in "not" + # 2. at this point we don't know this path is recursive, so we follow to "foo" + # 3. inside "foo" we found a reference to "foo", which means it is recursive + {"foo": {"not": {"$ref": "#foo"}}, "not": {"not": {"$ref": "#foo"}}}, + ), + # Circular reference between two schemas + ( + {"foo": {"$ref": "#bar"}, "bar": {"$ref": "#foo"}, "not": {"$ref": "#foo"}}, + # 1. We start in "not" and follow to "foo" + # 2. In "foo" we follow to "bar" + # 3. Here we see a reference to previously seen scope, which means it is a recursive path + # We take the schema where we stop and inline it to the starting point (therefore it is `{"$ref": "#foo"}`) + {"foo": {"$ref": "#bar"}, "bar": {"$ref": "#foo"}, "not": {"$ref": "#foo"}}, + ), + ), +) +def test_skip_recursive_references_simple_schemas(schema, expected): + # When there is a recursive reference, it should not be resolved + assert resolve_all_refs(schema) == expected + + +@pytest.mark.parametrize( + "schema, resolved", + ( + # NOTE. The `resolved` fixture does not include "definitions" to save visual space here, but it is extended + # with it in the test body. + # The reference target is behind two references, that share the same definition path. Not a recursive reference + ( + { + "definitions": { + "properties": { + "foo": {"type": "string"}, + "bar": {"$ref": "#/definitions/properties/foo"}, + }, + }, + "not": {"$ref": "#/definitions/properties/bar"}, + }, + { + "not": {"type": "string"}, + }, + ), + # Here we need to resolve multiple references while being on the same resolution scope: + # "#/definitions/foo" contains two references + ( + { + "definitions": { + "foo": { + "properties": { + "bar": {"$ref": "#/definitions/spam"}, + "baz": {"$ref": "#/definitions/spam"}, + } + }, + "spam": {"type": "string"}, + }, + "properties": {"foo": {"$ref": "#/definitions/foo"}}, + }, + { + "properties": { + "foo": { + "properties": { + "bar": {"type": "string"}, + "baz": {"type": "string"}, + } + } + }, + }, + ), + # Similar to the one above, but recursive + ( + { + "definitions": { + "foo": { + "properties": { + "bar": {"$ref": "#/definitions/spam"}, + "baz": {"$ref": "#/definitions/spam"}, + } + }, + "spam": {"$ref": "#/definitions/foo"}, + }, + "properties": {"foo": {"$ref": "#/definitions/foo"}}, + }, + { + "properties": { + "foo": { + "properties": { + "bar": {"$ref": "#/definitions/foo"}, + "baz": {"$ref": "#/definitions/foo"}, + } + } + }, + }, + ), + ), +) +def test_skip_recursive_references_complex_schemas(schema, resolved): + resolved["definitions"] = schema["definitions"] + assert resolve_all_refs(schema) == resolved From 4dc6e4c318ebcf54f0cbbc2c8c62284b85dfbc51 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Thu, 19 Nov 2020 15:50:21 +0100 Subject: [PATCH 2/4] Do not inline recursive references at all --- src/hypothesis_jsonschema/_canonicalise.py | 57 +++++++++++++++------- src/hypothesis_jsonschema/_from_schema.py | 2 +- tests/test_canonicalise.py | 22 +++------ 3 files changed, 48 insertions(+), 33 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index 5f15eab..cdb9594 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -594,10 +594,13 @@ def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool: def resolve_all_refs( schema: Union[bool, Schema], *, resolver: LocalResolver = None -) -> Schema: - """Resolve all non-recursive references in the given schema.""" +) -> Tuple[Schema, bool]: + """Resolve all non-recursive references in the given schema. + + When a recursive reference is detected, it stops traversing the currently resolving branch and leaves it as is. + """ if isinstance(schema, bool): - return canonicalish(schema) + return canonicalish(schema), False assert isinstance(schema, dict), schema if resolver is None: resolver = LocalResolver.from_schema(deepcopy(schema)) @@ -620,32 +623,52 @@ def resolve_all_refs( raise HypothesisRefResolutionError(msg) # `deepcopy` is not needed, because, the schemas are copied inside the `merged` call above return resolve_all_refs(m, resolver=resolver) + else: + return schema, True for key in SCHEMA_KEYS: val = schema.get(key, False) if isinstance(val, list): - schema[key] = [ - resolve_all_refs(deepcopy(v), resolver=resolver) - if isinstance(v, dict) - else v - for v in val - ] + value = [] + for v in val: + if isinstance(v, dict): + resolved, is_recursive = resolve_all_refs( + deepcopy(v), resolver=resolver + ) + if is_recursive: + return schema, True + else: + value.append(resolved) + else: + value.append(v) + schema[key] = value elif isinstance(val, dict): - schema[key] = resolve_all_refs(deepcopy(val), resolver=resolver) + resolved, is_recursive = resolve_all_refs(deepcopy(val), resolver=resolver) + if is_recursive: + return schema, True + else: + schema[key] = resolved else: assert isinstance(val, bool) for key in SCHEMA_OBJECT_KEYS: # values are keys-to-schema-dicts, not schemas if key in schema: subschema = schema[key] assert isinstance(subschema, dict) - schema[key] = { - k: resolve_all_refs(deepcopy(v), resolver=resolver) - if isinstance(v, dict) - else v - for k, v in subschema.items() - } + value = {} + for k, v in subschema.items(): + if isinstance(v, dict): + resolved, is_recursive = resolve_all_refs( + deepcopy(v), resolver=resolver + ) + if is_recursive: + return schema, True + else: + value[k] = resolved + else: + value[k] = v + schema[key] = value assert isinstance(schema, dict) - return schema + return schema, False def merged(schemas: List[Any]) -> Optional[Schema]: diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index 46cc3e1..5b017da 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -114,7 +114,7 @@ def __from_schema( custom_formats: Dict[str, st.SearchStrategy[str]] = None, ) -> st.SearchStrategy[JSONType]: try: - schema = resolve_all_refs(schema) + schema, _ = resolve_all_refs(schema) except RecursionError: raise HypothesisRefResolutionError( f"Could not resolve recursive references in schema={schema!r}" diff --git a/tests/test_canonicalise.py b/tests/test_canonicalise.py index f89cc01..41ad126 100644 --- a/tests/test_canonicalise.py +++ b/tests/test_canonicalise.py @@ -578,12 +578,12 @@ def test_validators_use_proper_draft(): (ROOT_REFERENCE, ROOT_REFERENCE), (NESTED, NESTED), (NESTED_WITH_ID, NESTED_WITH_ID), - # "foo" content should be inlined as is, because "#" is recursive (special case) + # "foo" content should not be inlined, because "#" is recursive (special case) ( {"foo": {"$ref": "#"}, "not": {"$ref": "#foo"}}, - {"foo": {"$ref": "#"}, "not": {"$ref": "#"}}, + {"foo": {"$ref": "#"}, "not": {"$ref": "#foo"}}, ), - # "foo" content should be inlined as is, because it points to itself + # "foo" content should not be inlined, because it points to itself ( SELF_REFERENTIAL, SELF_REFERENTIAL, @@ -594,7 +594,7 @@ def test_validators_use_proper_draft(): # 1. We start from resolving "$ref" in "not" # 2. at this point we don't know this path is recursive, so we follow to "foo" # 3. inside "foo" we found a reference to "foo", which means it is recursive - {"foo": {"not": {"$ref": "#foo"}}, "not": {"not": {"$ref": "#foo"}}}, + {"foo": {"not": {"$ref": "#foo"}}, "not": {"$ref": "#foo"}}, ), # Circular reference between two schemas ( @@ -602,14 +602,13 @@ def test_validators_use_proper_draft(): # 1. We start in "not" and follow to "foo" # 2. In "foo" we follow to "bar" # 3. Here we see a reference to previously seen scope, which means it is a recursive path - # We take the schema where we stop and inline it to the starting point (therefore it is `{"$ref": "#foo"}`) {"foo": {"$ref": "#bar"}, "bar": {"$ref": "#foo"}, "not": {"$ref": "#foo"}}, ), ), ) def test_skip_recursive_references_simple_schemas(schema, expected): # When there is a recursive reference, it should not be resolved - assert resolve_all_refs(schema) == expected + assert resolve_all_refs(schema)[0] == expected @pytest.mark.parametrize( @@ -673,18 +672,11 @@ def test_skip_recursive_references_simple_schemas(schema, expected): "properties": {"foo": {"$ref": "#/definitions/foo"}}, }, { - "properties": { - "foo": { - "properties": { - "bar": {"$ref": "#/definitions/foo"}, - "baz": {"$ref": "#/definitions/foo"}, - } - } - }, + "properties": {"foo": {"$ref": "#/definitions/foo"}}, }, ), ), ) def test_skip_recursive_references_complex_schemas(schema, resolved): resolved["definitions"] = schema["definitions"] - assert resolve_all_refs(schema) == resolved + assert resolve_all_refs(schema)[0] == resolved From e9e6d519afaaa008216c82754096e5a34e5a62a9 Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Wed, 2 Dec 2020 10:02:07 +0100 Subject: [PATCH 3/4] Do not expect failures from tests for schemas with recursive references --- src/hypothesis_jsonschema/_from_schema.py | 8 +------- tests/test_from_schema.py | 11 +---------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index 5b017da..fa52891 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -17,7 +17,6 @@ FALSEY, TRUTHY, TYPE_STRINGS, - HypothesisRefResolutionError, Schema, canonicalish, get_integer_bounds, @@ -113,12 +112,7 @@ def __from_schema( *, custom_formats: Dict[str, st.SearchStrategy[str]] = None, ) -> st.SearchStrategy[JSONType]: - try: - schema, _ = resolve_all_refs(schema) - except RecursionError: - raise HypothesisRefResolutionError( - f"Could not resolve recursive references in schema={schema!r}" - ) from None + schema, _ = resolve_all_refs(schema) # We check for _FORMATS_TOKEN to avoid re-validating known good data. if custom_formats is not None and _FORMATS_TOKEN not in custom_formats: assert isinstance(custom_formats, dict) diff --git a/tests/test_from_schema.py b/tests/test_from_schema.py index 4a6e246..d3e3032 100644 --- a/tests/test_from_schema.py +++ b/tests/test_from_schema.py @@ -21,7 +21,6 @@ from hypothesis.internal.reflection import proxies from hypothesis_jsonschema._canonicalise import ( - HypothesisRefResolutionError, canonicalish, resolve_all_refs, ) @@ -242,16 +241,8 @@ def inner(*args, **kwargs): assert isinstance(name, str) try: f(*args, **kwargs) - assert name not in RECURSIVE_REFS except jsonschema.exceptions.RefResolutionError as err: - if ( - isinstance(err, HypothesisRefResolutionError) - or isinstance(err._cause, HypothesisRefResolutionError) - ) and ( - "does not fetch remote references" in str(err) - or name in RECURSIVE_REFS - and "Could not resolve recursive references" in str(err) - ): + if "does not fetch remote references" in str(err): pytest.xfail() raise From 4af9e0b49a0c14a42fdb854ab47c19242c841bad Mon Sep 17 00:00:00 2001 From: Dmitry Dygalo Date: Wed, 2 Dec 2020 10:21:53 +0100 Subject: [PATCH 4/4] Pass resolver to validators --- src/hypothesis_jsonschema/_canonicalise.py | 101 ++++++++++------- src/hypothesis_jsonschema/_from_schema.py | 123 ++++++++++++++------- tests/test_canonicalise.py | 7 +- tests/test_from_schema.py | 5 +- 4 files changed, 148 insertions(+), 88 deletions(-) diff --git a/src/hypothesis_jsonschema/_canonicalise.py b/src/hypothesis_jsonschema/_canonicalise.py index cdb9594..9123ae5 100644 --- a/src/hypothesis_jsonschema/_canonicalise.py +++ b/src/hypothesis_jsonschema/_canonicalise.py @@ -79,9 +79,20 @@ def _get_validator_class(schema: Schema) -> JSONSchemaValidator: return validator -def make_validator(schema: Schema) -> JSONSchemaValidator: +class LocalResolver(jsonschema.RefResolver): + def resolve_remote(self, uri: str) -> NoReturn: + raise HypothesisRefResolutionError( + f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})" + ) + + +def make_validator( + schema: Schema, resolver: LocalResolver = None +) -> JSONSchemaValidator: + if resolver is None: + resolver = LocalResolver.from_schema(schema) validator = _get_validator_class(schema) - return validator(schema) + return validator(schema, resolver=resolver) class HypothesisRefResolutionError(jsonschema.exceptions.RefResolutionError): @@ -203,7 +214,7 @@ def get_integer_bounds(schema: Schema) -> Tuple[Optional[int], Optional[int]]: return lower, upper -def canonicalish(schema: JSONType) -> Dict[str, Any]: +def canonicalish(schema: JSONType, resolver: LocalResolver = None) -> Dict[str, Any]: """Convert a schema into a more-canonical form. This is obviously incomplete, but improves best-effort recognition of @@ -225,12 +236,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: "but expected a dict." ) + if resolver is None: + resolver = LocalResolver.from_schema(schema) + if "const" in schema: - if not make_validator(schema).is_valid(schema["const"]): + if not make_validator(schema, resolver=resolver).is_valid(schema["const"]): return FALSEY return {"const": schema["const"]} if "enum" in schema: - validator = make_validator(schema) + validator = make_validator(schema, resolver=resolver) enum_ = sorted( (v for v in schema["enum"] if validator.is_valid(v)), key=sort_key ) @@ -254,15 +268,15 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: # Recurse into the value of each keyword with a schema (or list of them) as a value for key in SCHEMA_KEYS: if isinstance(schema.get(key), list): - schema[key] = [canonicalish(v) for v in schema[key]] + schema[key] = [canonicalish(v, resolver=resolver) for v in schema[key]] elif isinstance(schema.get(key), (bool, dict)): - schema[key] = canonicalish(schema[key]) + schema[key] = canonicalish(schema[key], resolver=resolver) else: assert key not in schema, (key, schema[key]) for key in SCHEMA_OBJECT_KEYS: if key in schema: schema[key] = { - k: v if isinstance(v, list) else canonicalish(v) + k: v if isinstance(v, list) else canonicalish(v, resolver=resolver) for k, v in schema[key].items() } @@ -308,7 +322,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: if "array" in type_ and "contains" in schema: if isinstance(schema.get("items"), dict): - contains_items = merged([schema["contains"], schema["items"]]) + contains_items = merged( + [schema["contains"], schema["items"]], resolver=resolver + ) if contains_items is not None: schema["contains"] = contains_items @@ -462,9 +478,9 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: type_.remove(t) if t not in ("integer", "number"): not_["type"].remove(t) - not_ = canonicalish(not_) + not_ = canonicalish(not_, resolver=resolver) - m = merged([not_, {**schema, "type": type_}]) + m = merged([not_, {**schema, "type": type_}], resolver=resolver) if m is not None: not_ = m if not_ != FALSEY: @@ -543,7 +559,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: else: tmp = schema.copy() ao = tmp.pop("allOf") - out = merged([tmp] + ao) + out = merged([tmp] + ao, resolver=resolver) if isinstance(out, dict): # pragma: no branch schema = out # TODO: this assertion is soley because mypy 0.750 doesn't know @@ -555,7 +571,7 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: one_of = sorted(one_of, key=encode_canonical_json) one_of = [s for s in one_of if s != FALSEY] if len(one_of) == 1: - m = merged([schema, one_of[0]]) + m = merged([schema, one_of[0]], resolver=resolver) if m is not None: # pragma: no branch return m if (not one_of) or one_of.count(TRUTHY) > 1: @@ -570,13 +586,6 @@ def canonicalish(schema: JSONType) -> Dict[str, Any]: FALSEY = canonicalish(False) -class LocalResolver(jsonschema.RefResolver): - def resolve_remote(self, uri: str) -> NoReturn: - raise HypothesisRefResolutionError( - f"hypothesis-jsonschema does not fetch remote references (uri={uri!r})" - ) - - def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool: """Detect if the given reference is recursive.""" # Special case: a reference to the schema's root is always recursive @@ -593,7 +602,7 @@ def is_recursive_reference(reference: str, resolver: LocalResolver) -> bool: def resolve_all_refs( - schema: Union[bool, Schema], *, resolver: LocalResolver = None + schema: Union[bool, Schema], *, resolver: LocalResolver ) -> Tuple[Schema, bool]: """Resolve all non-recursive references in the given schema. @@ -602,8 +611,6 @@ def resolve_all_refs( if isinstance(schema, bool): return canonicalish(schema), False assert isinstance(schema, dict), schema - if resolver is None: - resolver = LocalResolver.from_schema(deepcopy(schema)) if not isinstance(resolver, jsonschema.RefResolver): raise InvalidArgument( f"resolver={resolver} (type {type(resolver).__name__}) is not a RefResolver" @@ -617,7 +624,7 @@ def resolve_all_refs( with resolver.resolving(ref) as got: if s == {}: return resolve_all_refs(deepcopy(got), resolver=resolver) - m = merged([s, got]) + m = merged([s, got], resolver=resolver) if m is None: # pragma: no cover msg = f"$ref:{ref!r} had incompatible base schema {s!r}" raise HypothesisRefResolutionError(msg) @@ -671,7 +678,7 @@ def resolve_all_refs( return schema, False -def merged(schemas: List[Any]) -> Optional[Schema]: +def merged(schemas: List[Any], resolver: LocalResolver = None) -> Optional[Schema]: """Merge *n* schemas into a single schema, or None if result is invalid. Takes the logical intersection, so any object that validates against the returned @@ -684,7 +691,9 @@ def merged(schemas: List[Any]) -> Optional[Schema]: It's currently also used for keys that could be merged but aren't yet. """ assert schemas, "internal error: must pass at least one schema to merge" - schemas = sorted((canonicalish(s) for s in schemas), key=upper_bound_instances) + schemas = sorted( + (canonicalish(s, resolver=resolver) for s in schemas), key=upper_bound_instances + ) if any(s == FALSEY for s in schemas): return FALSEY out = schemas[0] @@ -693,11 +702,11 @@ def merged(schemas: List[Any]) -> Optional[Schema]: continue # If we have a const or enum, this is fairly easy by filtering: if "const" in out: - if make_validator(s).is_valid(out["const"]): + if make_validator(s, resolver=resolver).is_valid(out["const"]): continue return FALSEY if "enum" in out: - validator = make_validator(s) + validator = make_validator(s, resolver=resolver) enum_ = [v for v in out["enum"] if validator.is_valid(v)] if not enum_: return FALSEY @@ -748,21 +757,23 @@ def merged(schemas: List[Any]) -> Optional[Schema]: else: out_combined = merged( [s for p, s in out_pat.items() if re.search(p, prop_name)] - or [out_add] + or [out_add], + resolver=resolver, ) if prop_name in s_props: s_combined = s_props[prop_name] else: s_combined = merged( [s for p, s in s_pat.items() if re.search(p, prop_name)] - or [s_add] + or [s_add], + resolver=resolver, ) if out_combined is None or s_combined is None: # pragma: no cover # Note that this can only be the case if we were actually going to # use the schema which we attempted to merge, i.e. prop_name was # not in the schema and there were unmergable pattern schemas. return None - m = merged([out_combined, s_combined]) + m = merged([out_combined, s_combined], resolver=resolver) if m is None: return None out_props[prop_name] = m @@ -770,14 +781,17 @@ def merged(schemas: List[Any]) -> Optional[Schema]: # simpler as we merge with either an identical pattern, or additionalProperties. if out_pat or s_pat: for pattern in set(out_pat) | set(s_pat): - m = merged([out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)]) + m = merged( + [out_pat.get(pattern, out_add), s_pat.get(pattern, s_add)], + resolver=resolver, + ) if m is None: # pragma: no cover return None out_pat[pattern] = m out["patternProperties"] = out_pat # Finally, we merge togther the additionalProperties schemas. if out_add or s_add: - m = merged([out_add, s_add]) + m = merged([out_add, s_add], resolver=resolver) if m is None: # pragma: no cover return None out["additionalProperties"] = m @@ -811,7 +825,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: return None if "contains" in out and "contains" in s and out["contains"] != s["contains"]: # If one `contains` schema is a subset of the other, we can discard it. - m = merged([out["contains"], s["contains"]]) + m = merged([out["contains"], s["contains"]], resolver=resolver) if m == out["contains"] or m == s["contains"]: out["contains"] = m s.pop("contains") @@ -841,7 +855,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: v = {"required": v} elif isinstance(sval, list): sval = {"required": sval} - m = merged([v, sval]) + m = merged([v, sval], resolver=resolver) if m is None: return None odeps[k] = m @@ -855,26 +869,27 @@ def merged(schemas: List[Any]) -> Optional[Schema]: [ out.get("additionalItems", TRUTHY), s.get("additionalItems", TRUTHY), - ] + ], + resolver=resolver, ) for a, b in itertools.zip_longest(oitems, sitems): if a is None: a = out.get("additionalItems", TRUTHY) elif b is None: b = s.get("additionalItems", TRUTHY) - out["items"].append(merged([a, b])) + out["items"].append(merged([a, b], resolver=resolver)) elif isinstance(oitems, list): - out["items"] = [merged([x, sitems]) for x in oitems] + out["items"] = [merged([x, sitems], resolver=resolver) for x in oitems] out["additionalItems"] = merged( - [out.get("additionalItems", TRUTHY), sitems] + [out.get("additionalItems", TRUTHY), sitems], resolver=resolver ) elif isinstance(sitems, list): - out["items"] = [merged([x, oitems]) for x in sitems] + out["items"] = [merged([x, oitems], resolver=resolver) for x in sitems] out["additionalItems"] = merged( - [s.get("additionalItems", TRUTHY), oitems] + [s.get("additionalItems", TRUTHY), oitems], resolver=resolver ) else: - out["items"] = merged([oitems, sitems]) + out["items"] = merged([oitems, sitems], resolver=resolver) if out["items"] is None: return None if isinstance(out["items"], list) and None in out["items"]: @@ -898,7 +913,7 @@ def merged(schemas: List[Any]) -> Optional[Schema]: # If non-validation keys like `title` or `description` don't match, # that doesn't really matter and we'll just go with first we saw. return None - out = canonicalish(out) + out = canonicalish(out, resolver=resolver) if out == FALSEY: return FALSEY assert isinstance(out, dict) diff --git a/src/hypothesis_jsonschema/_from_schema.py b/src/hypothesis_jsonschema/_from_schema.py index fa52891..b48dfc6 100644 --- a/src/hypothesis_jsonschema/_from_schema.py +++ b/src/hypothesis_jsonschema/_from_schema.py @@ -4,6 +4,7 @@ import math import operator import re +from copy import deepcopy from fractions import Fraction from functools import partial from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union @@ -17,6 +18,7 @@ FALSEY, TRUTHY, TYPE_STRINGS, + LocalResolver, Schema, canonicalish, get_integer_bounds, @@ -41,11 +43,15 @@ def merged_as_strategies( - schemas: List[Schema], custom_formats: Optional[Dict[str, st.SearchStrategy[str]]] + schemas: List[Schema], + custom_formats: Optional[Dict[str, st.SearchStrategy[str]]], + resolver: LocalResolver, ) -> st.SearchStrategy[JSONType]: assert schemas, "internal error: must pass at least one schema to merge" if len(schemas) == 1: - return from_schema(schemas[0], custom_formats=custom_formats) + return __from_schema( + schemas[0], custom_formats=custom_formats, resolver=resolver + ) # Try to merge combinations of strategies. strats = [] combined: Set[str] = set() @@ -55,13 +61,13 @@ def merged_as_strategies( ): if combined.issuperset(group): continue - s = merged([inputs[g] for g in group]) + s = merged([inputs[g] for g in group], resolver=resolver) if s is not None and s != FALSEY: - validators = [make_validator(s) for s in schemas] + validators = [make_validator(s, resolver=resolver) for s in schemas] strats.append( - from_schema(s, custom_formats=custom_formats).filter( - lambda obj: all(v.is_valid(obj) for v in validators) - ) + __from_schema( + s, custom_formats=custom_formats, resolver=resolver + ).filter(lambda obj: all(v.is_valid(obj) for v in validators)) ) combined.update(group) return st.one_of(strats) @@ -78,7 +84,8 @@ def from_schema( everything else in drafts 04, 05, and 07 is fully tested and working. """ try: - return __from_schema(schema, custom_formats=custom_formats) + resolver = LocalResolver.from_schema(deepcopy(schema)) + return __from_schema(schema, custom_formats=custom_formats, resolver=resolver) except Exception as err: error = err @@ -111,8 +118,9 @@ def __from_schema( schema: Union[bool, Schema], *, custom_formats: Dict[str, st.SearchStrategy[str]] = None, + resolver: LocalResolver, ) -> st.SearchStrategy[JSONType]: - schema, _ = resolve_all_refs(schema) + schema, _ = resolve_all_refs(schema, resolver=resolver) # We check for _FORMATS_TOKEN to avoid re-validating known good data. if custom_formats is not None and _FORMATS_TOKEN not in custom_formats: assert isinstance(custom_formats, dict) @@ -135,7 +143,7 @@ def __from_schema( } custom_formats[_FORMATS_TOKEN] = None # type: ignore - schema = canonicalish(schema) + schema = canonicalish(schema, resolver=resolver) # Boolean objects are special schemata; False rejects all and True accepts all. if schema == FALSEY: return st.nothing() @@ -153,32 +161,37 @@ def __from_schema( if "not" in schema: not_ = schema.pop("not") assert isinstance(not_, dict) - validator = make_validator(not_).is_valid - return from_schema(schema, custom_formats=custom_formats).filter( - lambda v: not validator(v) - ) + validator = make_validator(not_, resolver=resolver).is_valid + return __from_schema( + schema, custom_formats=custom_formats, resolver=resolver + ).filter(lambda v: not validator(v)) if "anyOf" in schema: tmp = schema.copy() ao = tmp.pop("anyOf") assert isinstance(ao, list) - return st.one_of([merged_as_strategies([tmp, s], custom_formats) for s in ao]) + return st.one_of( + [ + merged_as_strategies([tmp, s], custom_formats, resolver=resolver) + for s in ao + ] + ) if "allOf" in schema: tmp = schema.copy() ao = tmp.pop("allOf") assert isinstance(ao, list) - return merged_as_strategies([tmp] + ao, custom_formats) + return merged_as_strategies([tmp] + ao, custom_formats, resolver=resolver) if "oneOf" in schema: tmp = schema.copy() oo = tmp.pop("oneOf") assert isinstance(oo, list) - schemas = [merged([tmp, s]) for s in oo] + schemas = [merged([tmp, s], resolver=resolver) for s in oo] return st.one_of( [ - from_schema(s, custom_formats=custom_formats) + __from_schema(s, custom_formats=custom_formats, resolver=resolver) for s in schemas if s is not None ] - ).filter(make_validator(schema).is_valid) + ).filter(make_validator(schema, resolver=resolver).is_valid) # Simple special cases if "enum" in schema: assert schema["enum"], "Canonicalises to non-empty list or FALSEY" @@ -189,18 +202,21 @@ def __from_schema( map_: Dict[str, Callable[[Schema], st.SearchStrategy[JSONType]]] = { "null": lambda _: st.none(), "boolean": lambda _: st.booleans(), - "number": number_schema, - "integer": integer_schema, + "number": partial(number_schema, resolver=resolver), + "integer": partial(integer_schema, resolver=resolver), "string": partial(string_schema, custom_formats), - "array": partial(array_schema, custom_formats), - "object": partial(object_schema, custom_formats), + "array": partial(array_schema, custom_formats, resolver=resolver), + "object": partial(object_schema, custom_formats, resolver=resolver), } assert set(map_) == set(TYPE_STRINGS) return st.one_of([map_[t](schema) for t in get_type(schema)]) def _numeric_with_multiplier( - min_value: Optional[float], max_value: Optional[float], schema: Schema + min_value: Optional[float], + max_value: Optional[float], + schema: Schema, + resolver: LocalResolver, ) -> st.SearchStrategy[float]: """Handle numeric schemata containing the multipleOf key.""" multiple_of = schema["multipleOf"] @@ -218,23 +234,23 @@ def _numeric_with_multiplier( return ( st.integers(min_value, max_value) .map(lambda x: x * multiple_of) - .filter(make_validator(schema).is_valid) + .filter(make_validator(schema, resolver=resolver).is_valid) ) -def integer_schema(schema: dict) -> st.SearchStrategy[float]: +def integer_schema(schema: dict, resolver: LocalResolver) -> st.SearchStrategy[float]: """Handle integer schemata.""" min_value, max_value = get_integer_bounds(schema) if "multipleOf" in schema: - return _numeric_with_multiplier(min_value, max_value, schema) + return _numeric_with_multiplier(min_value, max_value, schema, resolver) return st.integers(min_value, max_value) -def number_schema(schema: dict) -> st.SearchStrategy[float]: +def number_schema(schema: dict, resolver: LocalResolver) -> st.SearchStrategy[float]: """Handle numeric schemata.""" min_value, max_value, exclude_min, exclude_max = get_number_bounds(schema) if "multipleOf" in schema: - return _numeric_with_multiplier(min_value, max_value, schema) + return _numeric_with_multiplier(min_value, max_value, schema, resolver) return st.floats( min_value=min_value, max_value=max_value, @@ -416,10 +432,14 @@ def string_schema( def array_schema( - custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict + custom_formats: Dict[str, st.SearchStrategy[str]], + schema: dict, + resolver: LocalResolver, ) -> st.SearchStrategy[List[JSONType]]: """Handle schemata for arrays.""" - _from_schema_ = partial(from_schema, custom_formats=custom_formats) + _from_schema_ = partial( + __from_schema, custom_formats=custom_formats, resolver=resolver + ) items = schema.get("items", {}) additional_items = schema.get("additionalItems", {}) min_size = schema.get("minItems", 0) @@ -437,10 +457,14 @@ def array_schema( # allowed to do so. We'll skip the None (unmergable / no contains) cases # below, and let Hypothesis ignore the FALSEY cases for us. if "contains" in schema: - for i, mrgd in enumerate(merged([schema["contains"], s]) for s in items): + for i, mrgd in enumerate( + merged([schema["contains"], s], resolver=resolver) for s in items + ): if mrgd is not None: items_strats[i] |= _from_schema_(mrgd) - contains_additional = merged([schema["contains"], additional_items]) + contains_additional = merged( + [schema["contains"], additional_items], resolver=resolver + ) if contains_additional is not None: additional_items_strat |= _from_schema_(contains_additional) @@ -477,12 +501,17 @@ def not_seen(elem: JSONType) -> bool: items_strat = _from_schema_(items) if "contains" in schema: contains_strat = _from_schema_(schema["contains"]) - if merged([items, schema["contains"]]) != schema["contains"]: + if ( + merged([items, schema["contains"]], resolver=resolver) + != schema["contains"] + ): # We only need this filter if we couldn't merge items in when # canonicalising. Note that for list-items, above, we just skip # the mixed generation in this case (because they tend to be # heterogeneous) and hope it works out anyway. - contains_strat = contains_strat.filter(make_validator(items).is_valid) + contains_strat = contains_strat.filter( + make_validator(items, resolver=resolver).is_valid + ) items_strat |= contains_strat strat = st.lists( @@ -493,12 +522,14 @@ def not_seen(elem: JSONType) -> bool: ) if "contains" not in schema: return strat - contains = make_validator(schema["contains"]).is_valid + contains = make_validator(schema["contains"], resolver=resolver).is_valid return strat.filter(lambda val: any(contains(x) for x in val)) def object_schema( - custom_formats: Dict[str, st.SearchStrategy[str]], schema: dict + custom_formats: Dict[str, st.SearchStrategy[str]], + schema: dict, + resolver: LocalResolver, ) -> st.SearchStrategy[Dict[str, JSONType]]: """Handle a manageable subset of possible schemata for objects.""" required = schema.get("required", []) # required keys @@ -527,13 +558,13 @@ def object_schema( st.sampled_from(sorted(dep_names) + sorted(dep_schemas) + sorted(properties)) if (dep_names or dep_schemas or properties) else st.nothing(), - from_schema(names, custom_formats=custom_formats) + __from_schema(names, custom_formats=custom_formats, resolver=resolver) if additional_allowed else st.nothing(), st.one_of([st.from_regex(p) for p in sorted(patterns)]), ) all_names_strategy = st.one_of([s for s in name_strats if not s.is_empty]).filter( - make_validator(names).is_valid + make_validator(names, resolver=resolver).is_valid ) @st.composite # type: ignore @@ -576,12 +607,20 @@ def from_object_schema(draw: Any) -> Any: pattern_schemas.insert(0, properties[key]) if pattern_schemas: - out[key] = draw(merged_as_strategies(pattern_schemas, custom_formats)) + out[key] = draw( + merged_as_strategies( + pattern_schemas, custom_formats, resolver=resolver + ) + ) else: - out[key] = draw(from_schema(additional, custom_formats=custom_formats)) + out[key] = draw( + __from_schema( + additional, custom_formats=custom_formats, resolver=resolver + ) + ) for k, v in dep_schemas.items(): - if k in out and not make_validator(v).is_valid(out): + if k in out and not make_validator(v, resolver=resolver).is_valid(out): out.pop(key) elements.reject() diff --git a/tests/test_canonicalise.py b/tests/test_canonicalise.py index 41ad126..30aa3f8 100644 --- a/tests/test_canonicalise.py +++ b/tests/test_canonicalise.py @@ -9,6 +9,7 @@ from hypothesis_jsonschema import from_schema from hypothesis_jsonschema._canonicalise import ( FALSEY, + LocalResolver, canonicalish, get_type, make_validator, @@ -608,7 +609,8 @@ def test_validators_use_proper_draft(): ) def test_skip_recursive_references_simple_schemas(schema, expected): # When there is a recursive reference, it should not be resolved - assert resolve_all_refs(schema)[0] == expected + resolver = LocalResolver.from_schema(schema) + assert resolve_all_refs(schema, resolver=resolver)[0] == expected @pytest.mark.parametrize( @@ -679,4 +681,5 @@ def test_skip_recursive_references_simple_schemas(schema, expected): ) def test_skip_recursive_references_complex_schemas(schema, resolved): resolved["definitions"] = schema["definitions"] - assert resolve_all_refs(schema)[0] == resolved + resolver = LocalResolver.from_schema(schema) + assert resolve_all_refs(schema, resolver=resolver)[0] == resolved diff --git a/tests/test_from_schema.py b/tests/test_from_schema.py index d3e3032..f531544 100644 --- a/tests/test_from_schema.py +++ b/tests/test_from_schema.py @@ -21,6 +21,7 @@ from hypothesis.internal.reflection import proxies from hypothesis_jsonschema._canonicalise import ( + LocalResolver, canonicalish, resolve_all_refs, ) @@ -179,7 +180,9 @@ def test_invalid_schemas_are_invalid(name): @pytest.mark.parametrize("name", sorted(NON_EXISTENT_REF_SCHEMAS)) def test_invalid_ref_schemas_are_invalid(name): with pytest.raises(Exception): - resolve_all_refs(catalog[name]) + schema = catalog[name] + resolver = LocalResolver.from_schema(schema) + resolve_all_refs(schema, resolver=resolver) RECURSIVE_REFS = {