|
| 1 | +# Add-ons for graphene-relay for subscriptions |
| 2 | + |
| 3 | +import re |
| 4 | + |
| 5 | +from graphene.utils.get_unbound_function import get_unbound_function |
| 6 | +from graphene.utils.props import props |
| 7 | +from graphene.types.field import Field |
| 8 | +from graphene.types.objecttype import ObjectType, ObjectTypeOptions |
| 9 | +from graphene.types.utils import yank_fields_from_attrs |
| 10 | +from graphene.types.interface import Interface |
| 11 | + |
| 12 | +from graphene.types import Field, InputObjectType, String |
| 13 | + |
| 14 | + |
| 15 | +class SubscriptionOptions(ObjectTypeOptions): |
| 16 | + arguments = None # type: Dict[str, Argument] |
| 17 | + output = None # type: Type[ObjectType] |
| 18 | + resolver = None # type: Callable |
| 19 | + interfaces = () # type: Iterable[Type[Interface]] |
| 20 | + |
| 21 | + |
| 22 | +class Subscription(ObjectType): |
| 23 | + |
| 24 | + @classmethod |
| 25 | + def __init_subclass_with_meta__( |
| 26 | + cls, |
| 27 | + interfaces=(), |
| 28 | + resolver=None, |
| 29 | + output=None, |
| 30 | + arguments=None, |
| 31 | + _meta=None, |
| 32 | + **options, |
| 33 | + ): |
| 34 | + if not _meta: |
| 35 | + _meta = SubscriptionOptions(cls) |
| 36 | + |
| 37 | + output = output or getattr(cls, "Output", None) |
| 38 | + fields = {} |
| 39 | + |
| 40 | + for interface in interfaces: |
| 41 | + assert issubclass( |
| 42 | + interface, Interface |
| 43 | + ), f'All interfaces of {cls.__name__} must be a subclass of Interface. Received "{interface}".' |
| 44 | + fields.update(interface._meta.fields) |
| 45 | + |
| 46 | + if not output: |
| 47 | + # If output is defined, we don't need to get the fields |
| 48 | + fields = {} |
| 49 | + for base in reversed(cls.__mro__): |
| 50 | + fields.update(yank_fields_from_attrs(base.__dict__, _as=Field)) |
| 51 | + output = cls |
| 52 | + |
| 53 | + if not arguments: |
| 54 | + input_class = getattr(cls, "Arguments", None) |
| 55 | + if not input_class: |
| 56 | + input_class = getattr(cls, "Input", None) |
| 57 | + |
| 58 | + if input_class: |
| 59 | + arguments = props(input_class) |
| 60 | + else: |
| 61 | + arguments = {} |
| 62 | + |
| 63 | + if not resolver: |
| 64 | + subscribe = getattr(cls, "subscribe", None) |
| 65 | + assert subscribe, "All subscriptions must define a subscribe method in it" |
| 66 | + resolver = get_unbound_function(subscribe) |
| 67 | + |
| 68 | + if _meta.fields: |
| 69 | + _meta.fields.update(fields) |
| 70 | + else: |
| 71 | + _meta.fields = fields |
| 72 | + |
| 73 | + _meta.interfaces = interfaces |
| 74 | + _meta.output = output |
| 75 | + _meta.resolver = resolver |
| 76 | + _meta.arguments = arguments |
| 77 | + |
| 78 | + super(Subscription, cls).__init_subclass_with_meta__(_meta=_meta, **options) |
| 79 | + |
| 80 | + @classmethod |
| 81 | + def Field( |
| 82 | + cls, name=None, description=None, deprecation_reason=None, required=False |
| 83 | + ): |
| 84 | + """ Mount instance of subscription Field. """ |
| 85 | + return Field( |
| 86 | + cls._meta.output, |
| 87 | + args=cls._meta.arguments, |
| 88 | + resolver=cls._meta.resolver, |
| 89 | + name=name, |
| 90 | + description=description or cls._meta.description, |
| 91 | + deprecation_reason=deprecation_reason, |
| 92 | + required=required, |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +class ClientIDSubscription(Subscription): |
| 97 | + class Meta: |
| 98 | + abstract = True |
| 99 | + |
| 100 | + @classmethod |
| 101 | + def __init_subclass_with_meta__( |
| 102 | + cls, output=None, input_fields=None, arguments=None, name=None, **options |
| 103 | + ): |
| 104 | + input_class = getattr(cls, "Input", None) |
| 105 | + base_name = re.sub("Payload$", "", name or cls.__name__) |
| 106 | + |
| 107 | + assert not output, "Can't specify any output" |
| 108 | + assert not arguments, "Can't specify any arguments" |
| 109 | + |
| 110 | + bases = (InputObjectType,) |
| 111 | + if input_class: |
| 112 | + bases += (input_class,) |
| 113 | + |
| 114 | + if not input_fields: |
| 115 | + input_fields = {} |
| 116 | + |
| 117 | + cls.Input = type( |
| 118 | + f"{base_name}Input", |
| 119 | + bases, |
| 120 | + dict(input_fields, client_subscription_id=String(name="clientSubscriptionId")), |
| 121 | + ) |
| 122 | + |
| 123 | + arguments = dict( |
| 124 | + input=cls.Input(required=True) |
| 125 | + # 'client_subscription_id': String(name='clientSubscriptionId') |
| 126 | + ) |
| 127 | + subscribe_and_get_payload = getattr(cls, "subscribe_and_get_payload", None) |
| 128 | + if cls.subscribe and cls.subscribe.__func__ == ClientIDSubscription.subscribe.__func__: |
| 129 | + assert subscribe_and_get_payload, ( |
| 130 | + f"{name or cls.__name__}.subscribe_and_get_payload method is required" |
| 131 | + " in a ClientIDSubscription." |
| 132 | + ) |
| 133 | + |
| 134 | + if not name: |
| 135 | + name = f"{base_name}Payload" |
| 136 | + |
| 137 | + super(ClientIDSubscription, cls).__init_subclass_with_meta__( |
| 138 | + output=None, arguments=arguments, name=name, **options |
| 139 | + ) |
| 140 | + cls._meta.fields["client_subscription_id"] = Field(String, name="clientSubscriptionId") |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def subscribe(cls, root, info, input): |
| 144 | + def on_resolve(payload): |
| 145 | + def set_client_subscription_id(item): |
| 146 | + try: |
| 147 | + item.client_subscription_id = input.get("client_subscription_id") |
| 148 | + except Exception: |
| 149 | + raise Exception( |
| 150 | + f"Cannot set client_subscription_id in the payload object {repr(payload)}" |
| 151 | + ) |
| 152 | + return item |
| 153 | + return payload.map(set_client_subscription_id) |
| 154 | + |
| 155 | + result = cls.subscribe_and_get_payload(root, info, **input) |
| 156 | + return on_resolve(result) |
0 commit comments