|
6 | 6 |
|
7 | 7 | from ...error import GraphQLError |
8 | 8 | from ...language import ( |
| 9 | + ArgumentNode, |
9 | 10 | BooleanValueNode, |
10 | 11 | DirectiveNode, |
11 | 12 | FragmentDefinitionNode, |
12 | | - Node, |
| 13 | + FragmentSpreadNode, |
13 | 14 | OperationDefinitionNode, |
14 | 15 | OperationType, |
| 16 | + SelectionSetNode, |
15 | 17 | VariableNode, |
16 | 18 | ) |
17 | | -from ...type import GraphQLDeferDirective, GraphQLStreamDirective |
18 | | -from . import ASTValidationRule, ValidationContext |
| 19 | +from ...type import ( |
| 20 | + GraphQLDeferDirective, |
| 21 | + GraphQLIncludeDirective, |
| 22 | + GraphQLSkipDirective, |
| 23 | + GraphQLStreamDirective, |
| 24 | +) |
| 25 | +from . import ASTValidationRule |
19 | 26 |
|
20 | 27 | __all__ = ["DeferStreamDirectiveOnValidOperationsRule"] |
21 | 28 |
|
22 | 29 |
|
23 | | -def if_argument_can_be_false(node: DirectiveNode) -> bool: |
| 30 | +def get_directive(node: Any, name: str) -> DirectiveNode | None: |
| 31 | + for directive in node.directives or (): |
| 32 | + if directive.name.value == name: |
| 33 | + return directive |
| 34 | + return None |
| 35 | + |
| 36 | + |
| 37 | +def get_if_argument(node: DirectiveNode) -> ArgumentNode | None: |
24 | 38 | for argument in node.arguments or (): |
25 | 39 | if argument.name.value == "if": |
26 | | - if isinstance(argument.value, BooleanValueNode): |
27 | | - if argument.value.value: |
28 | | - return False |
29 | | - elif not isinstance(argument.value, VariableNode): |
30 | | - return False |
31 | | - return True |
32 | | - return False |
| 40 | + return argument |
| 41 | + return None |
| 42 | + |
| 43 | + |
| 44 | +def if_argument_can_be_false(node: DirectiveNode) -> bool: |
| 45 | + # @defer(if: false) / @stream(if: false) |
| 46 | + # @defer(if: $shouldDefer) / @stream(if: $shouldStream) |
| 47 | + if_argument = get_if_argument(node) |
| 48 | + if not if_argument: |
| 49 | + return False |
| 50 | + if isinstance(if_argument.value, BooleanValueNode): |
| 51 | + if if_argument.value.value: |
| 52 | + return False |
| 53 | + elif not isinstance(if_argument.value, VariableNode): |
| 54 | + return False |
| 55 | + return True |
| 56 | + |
| 57 | + |
| 58 | +def can_be_skipped_via_skip_directive(node: DirectiveNode) -> bool: |
| 59 | + # @skip(if: true) |
| 60 | + # @skip(if: $shouldSkip) |
| 61 | + if_argument = get_if_argument(node) |
| 62 | + if not if_argument: |
| 63 | + # Missing `if` is reported by ProvidedRequiredArgumentsRule. For this rule, |
| 64 | + # treat malformed @skip as potentially skipped to avoid duplicate errors. |
| 65 | + return True |
| 66 | + # If argument is a static boolean, it is always skipped if true, |
| 67 | + # never skipped if false; otherwise it can be skipped via a variable. |
| 68 | + if isinstance(if_argument.value, BooleanValueNode): |
| 69 | + return if_argument.value.value |
| 70 | + return True |
| 71 | + |
| 72 | + |
| 73 | +def can_be_skipped_via_include_directive(node: DirectiveNode) -> bool: |
| 74 | + # @include(if: false) |
| 75 | + # @include(if: $shouldInclude) |
| 76 | + if_argument = get_if_argument(node) |
| 77 | + if not if_argument: |
| 78 | + # Missing `if` is reported by ProvidedRequiredArgumentsRule. For this rule, |
| 79 | + # treat malformed @include as not skippable. |
| 80 | + return False |
| 81 | + # If argument is a static boolean, it is always skipped if false, |
| 82 | + # never skipped if true; otherwise it can be skipped via a variable. |
| 83 | + if isinstance(if_argument.value, BooleanValueNode): |
| 84 | + return not if_argument.value.value |
| 85 | + return True |
33 | 86 |
|
34 | 87 |
|
35 | 88 | class DeferStreamDirectiveOnValidOperationsRule(ASTValidationRule): |
36 | | - """Defer and stream directives are used on valid root field |
| 89 | + """Defer and stream directives are used on valid operations |
37 | 90 |
|
38 | | - A GraphQL document is only valid if defer directives are not used on root |
39 | | - mutation or subscription types. |
| 91 | + A GraphQL document is only valid if defer and stream directives are not used |
| 92 | + on root mutation or subscription types. |
40 | 93 | """ |
41 | 94 |
|
42 | | - def __init__(self, context: ValidationContext) -> None: |
43 | | - super().__init__(context) |
44 | | - self.fragments_used_on_subscriptions: set[str] = set() |
45 | | - |
46 | 95 | def enter_operation_definition( |
47 | 96 | self, operation: OperationDefinitionNode, *_args: Any |
48 | 97 | ) -> None: |
49 | | - if operation.operation == OperationType.SUBSCRIPTION: |
50 | | - fragments = self.context.get_recursively_referenced_fragments(operation) |
51 | | - for fragment in fragments: |
52 | | - self.fragments_used_on_subscriptions.add(fragment.name.value) |
| 98 | + if operation.operation != OperationType.SUBSCRIPTION: |
| 99 | + return |
| 100 | + fragments: dict[str, FragmentDefinitionNode] = {} |
| 101 | + for definition in self.context.document.definitions: |
| 102 | + if isinstance(definition, FragmentDefinitionNode): |
| 103 | + fragments[definition.name.value] = definition |
| 104 | + self.forbid_unconditional_defer_stream( |
| 105 | + fragments, operation.selection_set, [], set() |
| 106 | + ) |
53 | 107 |
|
54 | | - def enter_directive( |
| 108 | + def forbid_unconditional_defer_stream( |
55 | 109 | self, |
56 | | - node: DirectiveNode, |
57 | | - _key: Any, |
58 | | - _parent: Any, |
59 | | - _path: Any, |
60 | | - ancestors: list[Node], |
| 110 | + fragments: dict[str, FragmentDefinitionNode], |
| 111 | + selection_set: SelectionSetNode, |
| 112 | + parent_nodes: list[FragmentSpreadNode], |
| 113 | + visited_fragments: set[str], |
61 | 114 | ) -> None: |
62 | | - try: |
63 | | - definition_node = ancestors[2] |
64 | | - except IndexError: # pragma: no cover |
65 | | - return |
66 | | - if ( |
67 | | - isinstance(definition_node, FragmentDefinitionNode) |
68 | | - and definition_node.name.value in self.fragments_used_on_subscriptions |
69 | | - ) or ( |
70 | | - isinstance(definition_node, OperationDefinitionNode) |
71 | | - and definition_node.operation == OperationType.SUBSCRIPTION |
72 | | - ): |
73 | | - if node.name.value == GraphQLDeferDirective.name: |
74 | | - if not if_argument_can_be_false(node): |
| 115 | + for selection in selection_set.selections: |
| 116 | + skip = get_directive(selection, GraphQLSkipDirective.name) |
| 117 | + if skip and can_be_skipped_via_skip_directive(skip): |
| 118 | + continue |
| 119 | + include = get_directive(selection, GraphQLIncludeDirective.name) |
| 120 | + if include and can_be_skipped_via_include_directive(include): |
| 121 | + continue |
| 122 | + for directive in selection.directives or (): |
| 123 | + name = directive.name.value |
| 124 | + if name == GraphQLDeferDirective.name: |
| 125 | + if if_argument_can_be_false(directive): |
| 126 | + continue |
75 | 127 | msg = ( |
76 | 128 | "Defer directive not supported on subscription operations." |
77 | 129 | " Disable `@defer` by setting the `if` argument to `false`." |
78 | 130 | ) |
79 | | - self.report_error(GraphQLError(msg, node)) |
80 | | - elif node.name.value == GraphQLStreamDirective.name: # noqa: SIM102 |
81 | | - if not if_argument_can_be_false(node): |
| 131 | + elif name == GraphQLStreamDirective.name: |
| 132 | + if if_argument_can_be_false(directive): |
| 133 | + continue |
82 | 134 | msg = ( |
83 | 135 | "Stream directive not supported on subscription operations." |
84 | 136 | " Disable `@stream` by setting the `if` argument to `false`." |
85 | 137 | ) |
86 | | - self.report_error(GraphQLError(msg, node)) |
| 138 | + else: |
| 139 | + continue |
| 140 | + self.report_error(GraphQLError(msg, [directive, *parent_nodes])) |
| 141 | + if isinstance(selection, FragmentSpreadNode): |
| 142 | + fragment_name = selection.name.value |
| 143 | + if fragment_name in visited_fragments: |
| 144 | + continue |
| 145 | + visited_fragments.add(fragment_name) |
| 146 | + fragment = fragments.get(fragment_name) |
| 147 | + if fragment: |
| 148 | + self.forbid_unconditional_defer_stream( |
| 149 | + fragments, |
| 150 | + fragment.selection_set, |
| 151 | + [selection, *parent_nodes], |
| 152 | + visited_fragments, |
| 153 | + ) |
| 154 | + else: |
| 155 | + child_selection_set = getattr(selection, "selection_set", None) |
| 156 | + if child_selection_set: |
| 157 | + self.forbid_unconditional_defer_stream( |
| 158 | + fragments, |
| 159 | + child_selection_set, |
| 160 | + parent_nodes, |
| 161 | + visited_fragments, |
| 162 | + ) |
0 commit comments