Skip to content

Commit 0b08cf3

Browse files
committed
fix: DeferStreamDirectiveOnValidOperations fragment tracking
Replicates graphql/graphql-js@c3e5513
1 parent 014a7d5 commit 0b08cf3

2 files changed

Lines changed: 405 additions & 47 deletions

File tree

src/graphql/validation/rules/defer_stream_directive_on_valid_operations_rule.py

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,81 +6,157 @@
66

77
from ...error import GraphQLError
88
from ...language import (
9+
ArgumentNode,
910
BooleanValueNode,
1011
DirectiveNode,
1112
FragmentDefinitionNode,
12-
Node,
13+
FragmentSpreadNode,
1314
OperationDefinitionNode,
1415
OperationType,
16+
SelectionSetNode,
1517
VariableNode,
1618
)
17-
from ...type import GraphQLDeferDirective, GraphQLStreamDirective
18-
from . import ASTValidationRule, ValidationContext
19+
from ...type import (
20+
GraphQLDeferDirective,
21+
GraphQLIncludeDirective,
22+
GraphQLSkipDirective,
23+
GraphQLStreamDirective,
24+
)
25+
from . import ASTValidationRule
1926

2027
__all__ = ["DeferStreamDirectiveOnValidOperationsRule"]
2128

2229

23-
def if_argument_can_be_false(node: DirectiveNode) -> bool:
30+
def get_directive(node: Any, name: str) -> DirectiveNode | None:
31+
for directive in node.directives or ():
32+
if directive.name.value == name:
33+
return directive
34+
return None
35+
36+
37+
def get_if_argument(node: DirectiveNode) -> ArgumentNode | None:
2438
for argument in node.arguments or ():
2539
if argument.name.value == "if":
26-
if isinstance(argument.value, BooleanValueNode):
27-
if argument.value.value:
28-
return False
29-
elif not isinstance(argument.value, VariableNode):
30-
return False
31-
return True
32-
return False
40+
return argument
41+
return None
42+
43+
44+
def if_argument_can_be_false(node: DirectiveNode) -> bool:
45+
# @defer(if: false) / @stream(if: false)
46+
# @defer(if: $shouldDefer) / @stream(if: $shouldStream)
47+
if_argument = get_if_argument(node)
48+
if not if_argument:
49+
return False
50+
if isinstance(if_argument.value, BooleanValueNode):
51+
if if_argument.value.value:
52+
return False
53+
elif not isinstance(if_argument.value, VariableNode):
54+
return False
55+
return True
56+
57+
58+
def can_be_skipped_via_skip_directive(node: DirectiveNode) -> bool:
59+
# @skip(if: true)
60+
# @skip(if: $shouldSkip)
61+
if_argument = get_if_argument(node)
62+
if not if_argument:
63+
# Missing `if` is reported by ProvidedRequiredArgumentsRule. For this rule,
64+
# treat malformed @skip as potentially skipped to avoid duplicate errors.
65+
return True
66+
# If argument is a static boolean, it is always skipped if true,
67+
# never skipped if false; otherwise it can be skipped via a variable.
68+
if isinstance(if_argument.value, BooleanValueNode):
69+
return if_argument.value.value
70+
return True
71+
72+
73+
def can_be_skipped_via_include_directive(node: DirectiveNode) -> bool:
74+
# @include(if: false)
75+
# @include(if: $shouldInclude)
76+
if_argument = get_if_argument(node)
77+
if not if_argument:
78+
# Missing `if` is reported by ProvidedRequiredArgumentsRule. For this rule,
79+
# treat malformed @include as not skippable.
80+
return False
81+
# If argument is a static boolean, it is always skipped if false,
82+
# never skipped if true; otherwise it can be skipped via a variable.
83+
if isinstance(if_argument.value, BooleanValueNode):
84+
return not if_argument.value.value
85+
return True
3386

3487

3588
class DeferStreamDirectiveOnValidOperationsRule(ASTValidationRule):
36-
"""Defer and stream directives are used on valid root field
89+
"""Defer and stream directives are used on valid operations
3790
38-
A GraphQL document is only valid if defer directives are not used on root
39-
mutation or subscription types.
91+
A GraphQL document is only valid if defer and stream directives are not used
92+
on root mutation or subscription types.
4093
"""
4194

42-
def __init__(self, context: ValidationContext) -> None:
43-
super().__init__(context)
44-
self.fragments_used_on_subscriptions: set[str] = set()
45-
4695
def enter_operation_definition(
4796
self, operation: OperationDefinitionNode, *_args: Any
4897
) -> None:
49-
if operation.operation == OperationType.SUBSCRIPTION:
50-
fragments = self.context.get_recursively_referenced_fragments(operation)
51-
for fragment in fragments:
52-
self.fragments_used_on_subscriptions.add(fragment.name.value)
98+
if operation.operation != OperationType.SUBSCRIPTION:
99+
return
100+
fragments: dict[str, FragmentDefinitionNode] = {}
101+
for definition in self.context.document.definitions:
102+
if isinstance(definition, FragmentDefinitionNode):
103+
fragments[definition.name.value] = definition
104+
self.forbid_unconditional_defer_stream(
105+
fragments, operation.selection_set, [], set()
106+
)
53107

54-
def enter_directive(
108+
def forbid_unconditional_defer_stream(
55109
self,
56-
node: DirectiveNode,
57-
_key: Any,
58-
_parent: Any,
59-
_path: Any,
60-
ancestors: list[Node],
110+
fragments: dict[str, FragmentDefinitionNode],
111+
selection_set: SelectionSetNode,
112+
parent_nodes: list[FragmentSpreadNode],
113+
visited_fragments: set[str],
61114
) -> None:
62-
try:
63-
definition_node = ancestors[2]
64-
except IndexError: # pragma: no cover
65-
return
66-
if (
67-
isinstance(definition_node, FragmentDefinitionNode)
68-
and definition_node.name.value in self.fragments_used_on_subscriptions
69-
) or (
70-
isinstance(definition_node, OperationDefinitionNode)
71-
and definition_node.operation == OperationType.SUBSCRIPTION
72-
):
73-
if node.name.value == GraphQLDeferDirective.name:
74-
if not if_argument_can_be_false(node):
115+
for selection in selection_set.selections:
116+
skip = get_directive(selection, GraphQLSkipDirective.name)
117+
if skip and can_be_skipped_via_skip_directive(skip):
118+
continue
119+
include = get_directive(selection, GraphQLIncludeDirective.name)
120+
if include and can_be_skipped_via_include_directive(include):
121+
continue
122+
for directive in selection.directives or ():
123+
name = directive.name.value
124+
if name == GraphQLDeferDirective.name:
125+
if if_argument_can_be_false(directive):
126+
continue
75127
msg = (
76128
"Defer directive not supported on subscription operations."
77129
" Disable `@defer` by setting the `if` argument to `false`."
78130
)
79-
self.report_error(GraphQLError(msg, node))
80-
elif node.name.value == GraphQLStreamDirective.name: # noqa: SIM102
81-
if not if_argument_can_be_false(node):
131+
elif name == GraphQLStreamDirective.name:
132+
if if_argument_can_be_false(directive):
133+
continue
82134
msg = (
83135
"Stream directive not supported on subscription operations."
84136
" Disable `@stream` by setting the `if` argument to `false`."
85137
)
86-
self.report_error(GraphQLError(msg, node))
138+
else:
139+
continue
140+
self.report_error(GraphQLError(msg, [directive, *parent_nodes]))
141+
if isinstance(selection, FragmentSpreadNode):
142+
fragment_name = selection.name.value
143+
if fragment_name in visited_fragments:
144+
continue
145+
visited_fragments.add(fragment_name)
146+
fragment = fragments.get(fragment_name)
147+
if fragment:
148+
self.forbid_unconditional_defer_stream(
149+
fragments,
150+
fragment.selection_set,
151+
[selection, *parent_nodes],
152+
visited_fragments,
153+
)
154+
else:
155+
child_selection_set = getattr(selection, "selection_set", None)
156+
if child_selection_set:
157+
self.forbid_unconditional_defer_stream(
158+
fragments,
159+
child_selection_set,
160+
parent_nodes,
161+
visited_fragments,
162+
)

0 commit comments

Comments
 (0)