diff --git a/README.md b/README.md index fb11ce66..0d9ef5b2 100644 --- a/README.md +++ b/README.md @@ -78,11 +78,26 @@ For either type, it can accept: ### gcm This sends messages via Google/Firebase Cloud Messaging (GCM/FCM) -and hence can be used to deliver notifications to Android apps. It -expects the 'api_key' parameter to contain the 'Server key', -which can be acquired from Firebase Console at: -`https://console.firebase.google.com/project//settings/cloudmessaging/` +and hence can be used to deliver notifications to Android apps. +The expected configuration depends on which version of the firebase api you +wish to use. + +For legacy API, it expects: + +- the `api_key` parameter to contain the `Server key`, + which can be acquired from Firebase Console at: + `https://console.firebase.google.com/project//settings/cloudmessaging/` + +For API v1, it expects: + +- the `api_version` parameter to contain `v1` +- the `project_id` parameter to contain the `Project ID`, + which can be acquired from Firebase Console at: + `https://console.cloud.google.com/project//settings/general/` +- the `service_account_file` parameter to contain the path to the service account file, + which can be acquired from Firebase Console at: + `https://console.firebase.google.com/project//settings/serviceaccounts/adminsdk` Using an HTTP Proxy for outbound traffic ---------------------------------------- diff --git a/changelog.d/361.feature b/changelog.d/361.feature new file mode 100644 index 00000000..a5751518 --- /dev/null +++ b/changelog.d/361.feature @@ -0,0 +1 @@ +Adds the ability to use the new FCM v1 API. diff --git a/docs/applications.md b/docs/applications.md index 53eeaf1c..2989a665 100644 --- a/docs/applications.md +++ b/docs/applications.md @@ -211,7 +211,14 @@ may be useful for reference. ### Firebase Cloud Messaging -The client will receive a message with an FCM `data` payload with this structure: +The client will receive a message with an FCM `data` payload with a structure depending on the api version used: + +Please note that fields may be truncated if they are large, so that they fit +within FCM's limit. +Please also note that some fields will be unavailable if you registered a pusher +with `event_id_only` format. + +#### Legacy API ```json { @@ -232,10 +239,24 @@ The client will receive a message with an FCM `data` payload with this structure } ``` -Please note that fields may be truncated if they are large, so that they fit -within FCM's limit. -Please also note that some fields will be unavailable if you registered a pusher -with `event_id_only` format. +#### API v1 + +```json +{ + "event_id": "$3957tyerfgewrf384", + "type": "m.room.message", + "sender": "@exampleuser:example.org", + "room_name": "Mission Control", + "room_alias": "#exampleroom:example.org", + "sender_display_name": "Major Tom", + "content_msgtype": "m.text", + "content_body": "I'm floating in a most peculiar way." + "room_id": "!slw48wfj34rtnrf:example.org", + "prio": "high", + "unread": "2", + "missed_calls": "1" +} +``` ### WebPush diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 7f22ec7f..c3168006 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -207,7 +207,7 @@ EOF ``` -## Example of an FCM request +## Example of an FCM request (Legacy API) HTTP data sent to `https://fcm.googleapis.com/fcm/send`: diff --git a/pyproject.toml b/pyproject.toml index 0caaf93f..a4351ffd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "attrs>=19.2.0", "cryptography>=2.6.1", "idna>=2.8", + "google-auth>=2.27.0", "jaeger-client>=4.0.0", "matrix-common==1.3.0", "opentracing>=2.2.0", diff --git a/stubs/google/__init__.pyi b/stubs/google/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/stubs/google/auth/__init__.pyi b/stubs/google/auth/__init__.pyi new file mode 100644 index 00000000..0d33517f --- /dev/null +++ b/stubs/google/auth/__init__.pyi @@ -0,0 +1,3 @@ +from google.auth._default import default + +__all__ = ["default"] diff --git a/stubs/google/auth/_default.pyi b/stubs/google/auth/_default.pyi new file mode 100644 index 00000000..4fcef899 --- /dev/null +++ b/stubs/google/auth/_default.pyi @@ -0,0 +1,15 @@ +from typing import Optional + +from google.auth.transport.requests import Request + +class Credentials: + token = "token" + + def refresh(self, request: Request) -> None: ... + +def default( + scopes: Optional[list[str]] = None, + request: Optional[str] = None, + quota_project_id: Optional[int] = None, + default_scopes: Optional[list[str]] = None, +) -> tuple[Credentials, Optional[str]]: ... diff --git a/stubs/google/auth/transport/__init__.pyi b/stubs/google/auth/transport/__init__.pyi new file mode 100644 index 00000000..a5ffe934 --- /dev/null +++ b/stubs/google/auth/transport/__init__.pyi @@ -0,0 +1,3 @@ +from google.auth.transport.requests import Request + +__all__ = ["Request"] diff --git a/stubs/google/auth/transport/requests.pyi b/stubs/google/auth/transport/requests.pyi new file mode 100644 index 00000000..2eda7fcb --- /dev/null +++ b/stubs/google/auth/transport/requests.pyi @@ -0,0 +1 @@ +class Request: ... diff --git a/stubs/google/oauth2/__init__.pyi b/stubs/google/oauth2/__init__.pyi new file mode 100644 index 00000000..2436b62b --- /dev/null +++ b/stubs/google/oauth2/__init__.pyi @@ -0,0 +1,3 @@ +from google.oauth2.service_account import Credentials + +__all__ = ["Credentials"] diff --git a/stubs/google/oauth2/service_account.pyi b/stubs/google/oauth2/service_account.pyi new file mode 100644 index 00000000..41542d0e --- /dev/null +++ b/stubs/google/oauth2/service_account.pyi @@ -0,0 +1,16 @@ +from typing import Optional + +from google.auth.transport.requests import Request + +class Credentials: + token = "token" + + def refresh(self, request: Request) -> None: ... + @staticmethod + def from_service_account_file( + service_account_file: str, + scopes: Optional[list[str]] = None, + request: Optional[str] = None, + quota_project_id: Optional[int] = None, + default_scopes: Optional[list[str]] = None, + ) -> Credentials: ... diff --git a/sygnal.yaml.sample b/sygnal.yaml.sample index 810821fc..dd117418 100644 --- a/sygnal.yaml.sample +++ b/sygnal.yaml.sample @@ -205,9 +205,12 @@ apps: # This is an example GCM/FCM push configuration. # - #com.example.myapp.android: + #im.vector.app: # type: gcm - # api_key: your_api_key_for_gcm + # #api_key: + # api_version: v1 + # project_id: project-id + # service_account_file: /path/to/service_account.json # # # This is the maximum number of connections to GCM servers at any one time # # the default is 20. diff --git a/sygnal/exceptions.py b/sygnal/exceptions.py index 62019589..41ab8d7c 100644 --- a/sygnal/exceptions.py +++ b/sygnal/exceptions.py @@ -39,6 +39,18 @@ def __init__(self, *args: object, custom_retry_delay: Optional[int] = None) -> N self.custom_retry_delay = custom_retry_delay +class NotificationQuotaDispatchException(Exception): + """ + To be used by pushkins for errors that are do to exceeding the quota + limits and are hopefully temporary, so the request should possibly be + retried soon. + """ + + def __init__(self, *args: object, custom_retry_delay: Optional[int] = None) -> None: + super().__init__(*args) + self.custom_retry_delay = custom_retry_delay + + class ProxyConnectError(ConnectError): """ Exception raised when we are unable to start a connection using a HTTP proxy diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py index 7ff63370..1177c0aa 100644 --- a/sygnal/gcmpushkin.py +++ b/sygnal/gcmpushkin.py @@ -17,9 +17,12 @@ import json import logging import time +from enum import Enum from io import BytesIO from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple +import google.auth.transport.requests +from google.oauth2 import service_account from opentracing import Span, logs, tags from prometheus_client import Counter, Gauge, Histogram from twisted.internet.defer import DeferredSemaphore @@ -29,6 +32,7 @@ from sygnal.exceptions import ( NotificationDispatchException, + NotificationQuotaDispatchException, PushkinSetupException, TemporaryNotificationDispatchException, ) @@ -70,10 +74,14 @@ logger = logging.getLogger(__name__) GCM_URL = b"https://fcm.googleapis.com/fcm/send" +GCM_URL_V1 = "https://fcm.googleapis.com/v1/projects/{ProjectID}/messages:send" MAX_TRIES = 3 RETRY_DELAY_BASE = 10 +RETRY_DELAY_BASE_QUOTA_EXCEEDED = 60 MAX_BYTES_PER_FIELD = 1024 +AUTH_SCOPES = ["https://www.googleapis.com/auth/firebase.messaging"] + # The error codes that mean a registration ID will never # succeed and we should reject it upstream. # We include NotRegistered here too for good measure, even @@ -95,6 +103,11 @@ DEFAULT_MAX_CONNECTIONS = 20 +class APIVersion(Enum): + Legacy = "legacy" + V1 = "v1" + + class GcmPushkin(ConcurrencyLimitedPushkin): """ Pushkin that relays notifications to Google/Firebase Cloud Messaging. @@ -103,8 +116,11 @@ class GcmPushkin(ConcurrencyLimitedPushkin): UNDERSTOOD_CONFIG_FIELDS = { "type", "api_key", + "api_version", "fcm_options", "max_connections", + "project_id", + "service_account_file", } | ConcurrencyLimitedPushkin.UNDERSTOOD_CONFIG_FIELDS def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None: @@ -137,9 +153,38 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None: proxy_url_str=proxy_url, ) - self.api_key = self.get_config("api_key", str) - if not self.api_key: - raise PushkinSetupException("No API key set in config") + self.api_version = APIVersion.Legacy + version_str = self.get_config("api_version", str) + if not version_str: + logger.warning( + "API version not set in config, defaulting to %s", + self.api_version.value, + ) + else: + try: + self.api_version = APIVersion(version_str) + except ValueError: + raise PushkinSetupException( + "Invalid API version set in config", + version_str, + ) + + if self.api_version is APIVersion.Legacy: + self.api_key = self.get_config("api_key", str) + if not self.api_key: + raise PushkinSetupException("No API key set in config") + + self.project_id = self.get_config("project_id", str) + if self.api_version is APIVersion.V1 and not self.project_id: + raise PushkinSetupException( + "Must configure `project_id` when using FCM api v1", + ) + + self.service_account_file = self.get_config("service_account_file", str) + if self.api_version is APIVersion.V1 and not self.service_account_file: + raise PushkinSetupException( + "Must configure `service_account_file` when using FCM api v1", + ) # Use the fcm_options config dictionary as a foundation for the body; # this lets the Sygnal admin choose custom FCM options @@ -186,12 +231,16 @@ async def _perform_http_request( with PENDING_REQUESTS_GAUGE.track_inprogress(): await self.connection_semaphore.acquire() + url = GCM_URL + if self.api_version is APIVersion.V1: + url = str.encode(GCM_URL_V1.format(ProjectID=self.project_id)) + try: with SEND_TIME_HISTOGRAM.time(): with ACTIVE_REQUESTS_GAUGE.track_inprogress(): response = await self.http_agent.request( b"POST", - GCM_URL, + url, headers=Headers(headers), bodyProducer=body_producer, ) @@ -215,8 +264,6 @@ async def _request_dispatch( ) -> Tuple[List[str], List[str]]: poke_start_time = time.time() - failed = [] - response, response_text = await self._perform_http_request(body, headers) RESPONSE_STATUS_CODES_COUNTER.labels( @@ -227,6 +274,39 @@ async def _request_dispatch( span.set_tag(tags.HTTP_STATUS_CODE, response.code) + if self.api_version is APIVersion.Legacy: + return self._handle_legacy_response( + n, + log, + response, + response_text, + pushkeys, + span, + ) + elif self.api_version is APIVersion.V1: + return self._handle_v1_response( + log, + response, + response_text, + pushkeys, + span, + ) + else: + log.warn( + "Processing response for unknown API version: %s", self.api_version + ) + return [], [] + + def _handle_legacy_response( + self, + n: Notification, + log: NotificationLoggerAdapter, + response: IResponse, + response_text: str, + pushkeys: List[str], + span: Span, + ) -> Tuple[List[str], List[str]]: + failed = [] if 500 <= response.code < 600: log.debug("%d from server, waiting to try again", response.code) @@ -319,25 +399,112 @@ async def _request_dispatch( f"Unknown GCM response code {response.code}" ) + def _handle_v1_response( + self, + log: NotificationLoggerAdapter, + response: IResponse, + response_text: str, + pushkeys: List[str], + span: Span, + ) -> Tuple[List[str], List[str]]: + if 500 <= response.code < 600: + log.debug("%d from server, waiting to try again", response.code) + + retry_after = None + + for header_value in response.headers.getRawHeaders( + b"retry-after", default=[] + ): + retry_after = int(header_value) + span.log_kv({"event": "gcm_retry_after", "retry_after": retry_after}) + + raise TemporaryNotificationDispatchException( + "GCM server error, hopefully temporary.", custom_retry_delay=retry_after + ) + elif response.code == 400: + log.error( + "%d from server, we have sent something invalid! Error: %r", + response.code, + response_text, + ) + # permanent failure: give up + raise NotificationDispatchException("Invalid request") + elif response.code == 401: + log.error( + "401 from server! Our API key is invalid? Error: %r", response_text + ) + # permanent failure: give up + raise NotificationDispatchException("Not authorised to push") + elif response.code == 403: + log.error("403 from server! Sender ID mismatch! Error: %r", response_text) + # permanent failure: give up + raise NotificationDispatchException("Sender ID mismatch") + elif response.code == 429: + log.debug("%d from server, waiting to try again", response.code) + + # Minimum 1 minute delay required + retry_after = None + + for header_value in response.headers.getRawHeaders( + b"retry-after", default=[] + ): + retry_after = int(header_value) + + span.log_kv({"event": "gcm_retry_after", "retry_after": retry_after}) + raise NotificationQuotaDispatchException( + "Message rate quota exceeded.", custom_retry_delay=retry_after + ) + elif response.code == 404: + log.info("Reg IDs %r get 404 response; assuming unregistered", pushkeys) + return pushkeys, [] + elif 200 <= response.code < 300: + return [], [] + else: + raise NotificationDispatchException( + f"Unknown GCM response code {response.code}" + ) + + def _get_access_token(self) -> str: + """Retrieve a valid access token that can be used to authorize requests. + + :return: Access token. + """ + # TODO: Should we use the environment variable approach instead? + # export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json + # credentials, project = google.auth.default(scopes=AUTH_SCOPES) + credentials = service_account.Credentials.from_service_account_file( + str(self.service_account_file), + scopes=AUTH_SCOPES, + ) + request = google.auth.transport.requests.Request() + credentials.refresh(request) + return credentials.token + async def _dispatch_notification_unlimited( self, n: Notification, device: Device, context: NotificationContext ) -> List[str]: log = NotificationLoggerAdapter(logger, {"request_id": context.request_id}) - # `_dispatch_notification_unlimited` gets called once for each device in the - # `Notification` with a matching app ID. We do something a little dirty and - # perform all of our dispatches the first time we get called for a - # `Notification` and do nothing for the rest of the times we get called. - pushkeys = [ - device.pushkey for device in n.devices if self.handles_appid(device.app_id) - ] - # `pushkeys` ought to never be empty here. At the very least it should contain - # `device`'s pushkey. - - if pushkeys[0] != device.pushkey: - # We've already been asked to dispatch for this `Notification` and have - # previously sent out the notification to all devices. - return [] + pushkeys: list[str] = [] + if self.api_version is APIVersion.Legacy: + # `_dispatch_notification_unlimited` gets called once for each device in the + # `Notification` with a matching app ID. We do something a little dirty and + # perform all of our dispatches the first time we get called for a + # `Notification` and do nothing for the rest of the times we get called. + pushkeys = [ + device.pushkey + for device in n.devices + if self.handles_appid(device.app_id) + ] + # `pushkeys` ought to never be empty here. At the very least it should contain + # `device`'s pushkey. + + if pushkeys[0] != device.pushkey: + # We've already been asked to dispatch for this `Notification` and have + # previously sent out the notification to all devices. + return [] + elif self.api_version is APIVersion.V1: + pushkeys = [device.pushkey] # The pushkey is kind of secret because you can use it to send push # to someone. @@ -350,7 +517,7 @@ async def _dispatch_notification_unlimited( # TODO: Implement collapse_key to queue only one message per room. failed: List[str] = [] - data = GcmPushkin._build_data(n, device) + data = GcmPushkin._build_data(n, device, self.api_version) # Reject pushkey(s) if default_payload is misconfigured if data is None: @@ -363,18 +530,32 @@ async def _dispatch_notification_unlimited( headers = { "User-Agent": ["sygnal"], "Content-Type": ["application/json"], - "Authorization": ["key=%s" % (self.api_key,)], } + if self.api_version == APIVersion.Legacy: + headers["Authorization"] = ["key=%s" % (self.api_key,)] + elif self.api_version is APIVersion.V1: + headers["Authorization"] = ["Bearer %s" % (self._get_access_token(),)] + body = self.base_request_body.copy() body["data"] = data - body["priority"] = "normal" if n.prio == "low" else "high" + if self.api_version is APIVersion.Legacy: + body["priority"] = "normal" if n.prio == "low" else "high" + elif self.api_version is APIVersion.V1: + priority = {"priority": "normal" if n.prio == "low" else "high"} + body["android"] = priority for retry_number in range(0, MAX_TRIES): - if len(pushkeys) == 1: - body["to"] = pushkeys[0] - else: - body["registration_ids"] = pushkeys + if self.api_version is APIVersion.Legacy: + if len(pushkeys) == 1: + body["to"] = pushkeys[0] + else: + body["registration_ids"] = pushkeys + elif self.api_version is APIVersion.V1: + body["token"] = device.pushkey + new_body = body + body = {} + body["message"] = new_body log.info( "Sending (attempt %i) => %r room:%s, event:%s", @@ -413,6 +594,24 @@ async def _dispatch_notification_unlimited( {"event": "temporary_fail", "retrying_in": retry_delay} ) + await twisted_sleep( + retry_delay, twisted_reactor=self.sygnal.reactor + ) + except NotificationQuotaDispatchException as exc: + retry_delay = RETRY_DELAY_BASE_QUOTA_EXCEEDED * (2**retry_number) + if exc.custom_retry_delay is not None: + retry_delay = exc.custom_retry_delay + + log.warning( + "Quota exceeded, will retry in %d seconds", + retry_delay, + exc_info=True, + ) + + span_parent.log_kv( + {"event": "temporary_fail", "retrying_in": retry_delay} + ) + await twisted_sleep( retry_delay, twisted_reactor=self.sygnal.reactor ) @@ -424,7 +623,11 @@ async def _dispatch_notification_unlimited( return failed @staticmethod - def _build_data(n: Notification, device: Device) -> Optional[Dict[str, Any]]: + def _build_data( + n: Notification, + device: Device, + api_version: APIVersion, + ) -> Optional[Dict[str, Any]]: """ Build the payload data to be sent. Args: @@ -465,12 +668,24 @@ def _build_data(n: Notification, device: Device) -> Optional[Dict[str, Any]]: if data[attr] is not None and len(data[attr]) > MAX_BYTES_PER_FIELD: data[attr] = data[attr][0:MAX_BYTES_PER_FIELD] + if api_version is APIVersion.V1: + if "content" in data: + for attr, value in data["content"].items(): + if not isinstance(value, str): + continue + data["content_" + attr] = value + del data["content"] + data["prio"] = "high" if n.prio == "low": data["prio"] = "normal" if getattr(n, "counts", None): - data["unread"] = n.counts.unread - data["missed_calls"] = n.counts.missed_calls + if api_version is APIVersion.Legacy: + data["unread"] = n.counts.unread + data["missed_calls"] = n.counts.missed_calls + elif api_version is APIVersion.V1: + data["unread"] = str(n.counts.unread) + data["missed_calls"] = str(n.counts.missed_calls) return data diff --git a/tests/test_gcm.py b/tests/test_gcm.py index 3a4f61f9..023fae59 100644 --- a/tests/test_gcm.py +++ b/tests/test_gcm.py @@ -14,6 +14,7 @@ # limitations under the License. import json from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple +from unittest.mock import MagicMock from sygnal.gcmpushkin import GcmPushkin @@ -25,6 +26,16 @@ DEVICE_EXAMPLE = {"app_id": "com.example.gcm", "pushkey": "spqr", "pushkey_ts": 42} DEVICE_EXAMPLE2 = {"app_id": "com.example.gcm", "pushkey": "spqr2", "pushkey_ts": 42} +DEVICE_EXAMPLE_APIV1 = { + "app_id": "com.example.gcm.apiv1", + "pushkey": "spqr", + "pushkey_ts": 42, +} +DEVICE_EXAMPLE2_APIV1 = { + "app_id": "com.example.gcm.apiv1", + "pushkey": "spqr2", + "pushkey_ts": 42, +} DEVICE_EXAMPLE_WITH_DEFAULT_PAYLOAD = { "app_id": "com.example.gcm", "pushkey": "spqr", @@ -38,6 +49,19 @@ } }, } +DEVICE_EXAMPLE_APIV1_WITH_DEFAULT_PAYLOAD = { + "app_id": "com.example.gcm.apiv1", + "pushkey": "spqr", + "pushkey_ts": 42, + "data": { + "default_payload": { + "aps": { + "mutable-content": 1, + "alert": {"loc-key": "SINGLE_UNREAD", "loc-args": []}, + } + } + }, +} DEVICE_EXAMPLE_WITH_BAD_DEFAULT_PAYLOAD = { "app_id": "com.example.gcm", @@ -86,18 +110,28 @@ async def _perform_http_request( # type: ignore[override] self.num_requests += 1 return self.preloaded_response, json.dumps(self.preloaded_response_payload) + def _get_access_token(self) -> str: + return "token" + class GcmTestCase(testutils.TestCase): def config_setup(self, config: Dict[str, Any]) -> None: config["apps"]["com.example.gcm"] = { "type": "tests.test_gcm.TestGcmPushkin", "api_key": "kii", + "api_version": "legacy", } config["apps"]["com.example.gcm.ios"] = { "type": "tests.test_gcm.TestGcmPushkin", "api_key": "kii", "fcm_options": {"content_available": True, "mutable_content": True}, } + config["apps"]["com.example.gcm.apiv1"] = { + "type": "tests.test_gcm.TestGcmPushkin", + "api_version": "v1", + "project_id": "example_project", + "service_account_file": "/path/to/file.json", + } def get_test_pushkin(self, name: str) -> TestGcmPushkin: pushkin = self.sygnal.pushkins[name] @@ -109,15 +143,101 @@ def test_expected(self) -> None: Tests the expected case: a good response from GCM leads to a good response from Sygnal. """ + self.apns_pushkin_snotif = MagicMock() gcm = self.get_test_pushkin("com.example.gcm") gcm.preload_with_response( 200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]} ) + # type safety: using ignore here due to mypy not handling monkeypatching, + # see https://github.com/python/mypy/issues/2427 + gcm._request_dispatch = self.apns_pushkin_snotif # type: ignore[assignment] # noqa: E501 + + method = self.apns_pushkin_snotif + method.side_effect = testutils.make_async_magic_mock(([], [])) + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE])) + self.assertEqual(1, method.call_count) + notification_req = method.call_args.args + + self.assertEqual( + { + "data": { + "event_id": "$qTOWWTEL48yPm3uT-gdNhFcoHxfKbZuqRVnnWWSkGBs", + "type": "m.room.message", + "sender": "@exampleuser:matrix.org", + "room_name": "Mission Control", + "room_alias": "#exampleroom:matrix.org", + "membership": None, + "sender_display_name": "Major Tom", + "content": { + "msgtype": "m.text", + "body": "I'm floating in a most peculiar way.", + "other": 1, + }, + "room_id": "!slw48wfj34rtnrf:example.com", + "prio": "high", + "unread": 2, + "missed_calls": 1, + }, + "priority": "high", + "to": "spqr", + }, + notification_req[2], + ) + + self.assertEqual(resp, {"rejected": []}) + + def test_expected_api_v1(self) -> None: + """ + Tests the expected case: a good response from GCM leads to a good + response from Sygnal. + """ + self.apns_pushkin_snotif = MagicMock() + gcm = self.get_test_pushkin("com.example.gcm.apiv1") + gcm.preload_with_response( + 200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]} + ) + + # type safety: using ignore here due to mypy not handling monkeypatching, + # see https://github.com/python/mypy/issues/2427 + gcm._request_dispatch = self.apns_pushkin_snotif # type: ignore[assignment] # noqa: E501 + + method = self.apns_pushkin_snotif + method.side_effect = testutils.make_async_magic_mock(([], [])) + + resp = self._request(self._make_dummy_notification([DEVICE_EXAMPLE_APIV1])) + + self.assertEqual(1, method.call_count) + notification_req = method.call_args.args + + self.assertEqual( + { + "message": { + "data": { + "event_id": "$qTOWWTEL48yPm3uT-gdNhFcoHxfKbZuqRVnnWWSkGBs", + "type": "m.room.message", + "sender": "@exampleuser:matrix.org", + "room_name": "Mission Control", + "room_alias": "#exampleroom:matrix.org", + "membership": None, + "sender_display_name": "Major Tom", + "content_msgtype": "m.text", + "content_body": "I'm floating in a most peculiar way.", + "room_id": "!slw48wfj34rtnrf:example.com", + "prio": "high", + "unread": "2", + "missed_calls": "1", + }, + "android": {"priority": "high"}, + "token": "spqr", + } + }, + notification_req[2], + ) + self.assertEqual(resp, {"rejected": []}) - self.assertEqual(gcm.num_requests, 1) def test_expected_with_default_payload(self) -> None: """ @@ -136,6 +256,23 @@ def test_expected_with_default_payload(self) -> None: self.assertEqual(resp, {"rejected": []}) self.assertEqual(gcm.num_requests, 1) + def test_expected_api_v1_with_default_payload(self) -> None: + """ + Tests the expected case: a good response from GCM leads to a good + response from Sygnal. + """ + gcm = self.get_test_pushkin("com.example.gcm.apiv1") + gcm.preload_with_response( + 200, {"results": [{"message_id": "msg42", "registration_id": "spqr"}]} + ) + + resp = self._request( + self._make_dummy_notification([DEVICE_EXAMPLE_APIV1_WITH_DEFAULT_PAYLOAD]) + ) + + self.assertEqual(resp, {"rejected": []}) + self.assertEqual(gcm.num_requests, 1) + def test_misformed_default_payload_rejected(self) -> None: """ Tests that a non-dict default_payload is rejected. @@ -192,6 +329,31 @@ def test_batching(self) -> None: self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) self.assertEqual(gcm.num_requests, 1) + def test_batching_api_v1(self) -> None: + """ + Tests that multiple GCM devices have their notification delivered to GCM + separately, instead of being delivered together. + """ + gcm = self.get_test_pushkin("com.example.gcm.apiv1") + gcm.preload_with_response( + 200, + { + "results": [ + {"registration_id": "spqr", "message_id": "msg42"}, + {"registration_id": "spqr2", "message_id": "msg42"}, + ] + }, + ) + + resp = self._request( + self._make_dummy_notification([DEVICE_EXAMPLE_APIV1, DEVICE_EXAMPLE2_APIV1]) + ) + + self.assertEqual(resp, {"rejected": []}) + assert gcm.last_request_body is not None + self.assertEqual(gcm.last_request_body["message"]["token"], "spqr2") + self.assertEqual(gcm.num_requests, 2) + def test_batching_individual_failure(self) -> None: """ Tests that multiple GCM devices have their notification delivered to GCM diff --git a/tests/testutils.py b/tests/testutils.py index 30f84583..29467be2 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -106,6 +106,7 @@ def _make_dummy_notification(self, devices): "content": { "msgtype": "m.text", "body": "I'm floating in a most peculiar way.", + "other": 1, }, "counts": {"unread": 2, "missed_calls": 1}, "devices": devices,