Skip to content

Commit 6115194

Browse files
Update type anotations (#784)
* Update the functions to use the right type anotations * Use right import path and refactor types * Set the manageType to use union of the config * Update the type of the app * Add type and also remove return on run validator as it never returns * Use right variable for the _chunks and avoid using of comparism operator for type checks * Update the types and import path * Use the optional for backwards py versions * Add py.typed * Solve the nonetype bug * Solve python lower versions break * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b16846e commit 6115194

File tree

15 files changed

+548
-376
lines changed

15 files changed

+548
-376
lines changed

push_notifications/admin.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from django.contrib import admin, messages
33
from django.utils.encoding import force_str
44
from django.utils.translation import gettext_lazy as _
5+
from django.http import HttpRequest
6+
from django.db.models import QuerySet
57

68
from .exceptions import APNSServerError, GCMError, WebPushError
79
from .models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
@@ -20,9 +22,9 @@ class DeviceAdmin(admin.ModelAdmin):
2022
if hasattr(User, "USERNAME_FIELD"):
2123
search_fields = ("name", "device_id", "user__%s" % (User.USERNAME_FIELD))
2224
else:
23-
search_fields = ("name", "device_id")
25+
search_fields = ("name", "device_id", "")
2426

25-
def send_messages(self, request, queryset, bulk=False):
27+
def send_messages(self, request: HttpRequest, queryset: QuerySet, bulk: bool = False) -> None:
2628
"""
2729
Provides error handling for DeviceAdmin send_message and send_bulk_message methods.
2830
"""
@@ -105,22 +107,22 @@ def send_messages(self, request, queryset, bulk=False):
105107
msg = _("All messages were sent: %s" % (ret))
106108
self.message_user(request, msg)
107109

108-
def send_message(self, request, queryset):
110+
def send_message(self, request: HttpRequest, queryset: QuerySet) -> None:
109111
self.send_messages(request, queryset)
110112

111113
send_message.short_description = _("Send test message")
112114

113-
def send_bulk_message(self, request, queryset):
115+
def send_bulk_message(self, request: HttpRequest, queryset: QuerySet) -> None:
114116
self.send_messages(request, queryset, True)
115117

116118
send_bulk_message.short_description = _("Send test message in bulk")
117119

118-
def enable(self, request, queryset):
120+
def enable(self, request: HttpRequest, queryset: QuerySet) -> None:
119121
queryset.update(active=True)
120122

121123
enable.short_description = _("Enable selected devices")
122124

123-
def disable(self, request, queryset):
125+
def disable(self, request: HttpRequest, queryset: QuerySet) -> None:
124126
queryset.update(active=False)
125127

126128
disable.short_description = _("Disable selected devices")
@@ -132,7 +134,7 @@ class GCMDeviceAdmin(DeviceAdmin):
132134
)
133135
list_filter = ("active", "cloud_message_type")
134136

135-
def send_messages(self, request, queryset, bulk=False):
137+
def send_messages(self, request: HttpRequest, queryset: QuerySet, bulk: bool = False) -> None:
136138
"""
137139
Provides error handling for DeviceAdmin send_message and send_bulk_message methods.
138140
"""
@@ -171,7 +173,7 @@ class WebPushDeviceAdmin(DeviceAdmin):
171173
if hasattr(User, "USERNAME_FIELD"):
172174
search_fields = ("name", "registration_id", "user__%s" % (User.USERNAME_FIELD))
173175
else:
174-
search_fields = ("name", "registration_id")
176+
search_fields = ("name", "registration_id", "")
175177

176178

177179
admin.site.register(APNSDevice, DeviceAdmin)

push_notifications/api/rest_framework.py

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from rest_framework.serializers import ModelSerializer, Serializer, ValidationError
55
from rest_framework.viewsets import ModelViewSet
66

7-
from ..fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re
8-
from ..models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
9-
from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
7+
from push_notifications.fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re
8+
from push_notifications.models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
9+
from push_notifications.settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
10+
from typing import Any, Union, Dict, Optional
1011

1112

1213
# Fields
@@ -15,25 +16,30 @@ class HexIntegerField(IntegerField):
1516
Store an integer represented as a hex string of form "0x01".
1617
"""
1718

18-
def to_internal_value(self, data):
19+
def to_internal_value(self, data: Union[str, int]) -> int:
1920
# validate hex string and convert it to the unsigned
2021
# integer representation for internal use
2122
try:
22-
data = int(data, 16) if type(data) != int else data
23+
data = int(data, 16) if not isinstance(data, int) else data
2324
except ValueError:
2425
raise ValidationError("Device ID is not a valid hex number")
2526
return super().to_internal_value(data)
2627

27-
def to_representation(self, value):
28+
def to_representation(self, value: int) -> int:
2829
return value
2930

3031

3132
# Serializers
3233
class DeviceSerializerMixin(ModelSerializer):
3334
class Meta:
3435
fields = (
35-
"id", "name", "application_id", "registration_id", "device_id",
36-
"active", "date_created"
36+
"id",
37+
"name",
38+
"application_id",
39+
"registration_id",
40+
"device_id",
41+
"active",
42+
"date_created",
3743
)
3844
read_only_fields = ("date_created",)
3945

@@ -45,8 +51,7 @@ class APNSDeviceSerializer(ModelSerializer):
4551
class Meta(DeviceSerializerMixin.Meta):
4652
model = APNSDevice
4753

48-
def validate_registration_id(self, value):
49-
54+
def validate_registration_id(self, value: str) -> str:
5055
# https://developer.apple.com/documentation/uikit/uiapplicationdelegate/1622958-application
5156
# As of 02/2023 APNS tokens (registration_id) "are of variable length. Do not hard-code their size."
5257
if hex_re.match(value) is None:
@@ -56,10 +61,10 @@ def validate_registration_id(self, value):
5661

5762

5863
class UniqueRegistrationSerializerMixin(Serializer):
59-
def validate(self, attrs):
60-
devices = None
61-
primary_key = None
62-
request_method = None
64+
def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
65+
devices: Optional[Any] = None
66+
primary_key: Optional[Any] = None
67+
request_method: Optional[str] = None
6368

6469
if self.initial_data.get("registration_id", None):
6570
if self.instance:
@@ -76,9 +81,10 @@ def validate(self, attrs):
7681

7782
Device = self.Meta.model
7883
if request_method == "update":
79-
reg_id = attrs.get("registration_id", self.instance.registration_id)
80-
devices = Device.objects.filter(registration_id=reg_id) \
81-
.exclude(id=primary_key)
84+
reg_id: str = attrs.get("registration_id", self.instance.registration_id)
85+
devices = Device.objects.filter(registration_id=reg_id).exclude(
86+
id=primary_key
87+
)
8288
elif request_method == "create":
8389
devices = Device.objects.filter(registration_id=attrs["registration_id"])
8490

@@ -92,20 +98,26 @@ class GCMDeviceSerializer(UniqueRegistrationSerializerMixin, ModelSerializer):
9298
help_text="ANDROID_ID / TelephonyManager.getDeviceId() (e.g: 0x01)",
9399
style={"input_type": "text"},
94100
required=False,
95-
allow_null=True
101+
allow_null=True,
96102
)
97103

98104
class Meta(DeviceSerializerMixin.Meta):
99105
model = GCMDevice
100106
fields = (
101-
"id", "name", "registration_id", "device_id", "active", "date_created",
102-
"cloud_message_type", "application_id",
107+
"id",
108+
"name",
109+
"registration_id",
110+
"device_id",
111+
"active",
112+
"date_created",
113+
"cloud_message_type",
114+
"application_id",
103115
)
104116
extra_kwargs = {"id": {"read_only": False, "required": False}}
105117

106-
def validate_device_id(self, value):
118+
def validate_device_id(self, value: Optional[int] = None) -> Optional[int]:
107119
# device ids are 64 bit unsigned values
108-
if value > UNSIGNED_64BIT_INT_MAX_VALUE:
120+
if value is not None and value > UNSIGNED_64BIT_INT_MAX_VALUE:
109121
raise ValidationError("Device ID is out of range")
110122
return value
111123

@@ -119,26 +131,36 @@ class WebPushDeviceSerializer(UniqueRegistrationSerializerMixin, ModelSerializer
119131
class Meta(DeviceSerializerMixin.Meta):
120132
model = WebPushDevice
121133
fields = (
122-
"id", "name", "registration_id", "active", "date_created",
123-
"p256dh", "auth", "browser", "application_id",
134+
"id",
135+
"name",
136+
"registration_id",
137+
"active",
138+
"date_created",
139+
"p256dh",
140+
"auth",
141+
"browser",
142+
"application_id",
124143
)
125144

126145

127146
# Permissions
128147
class IsOwner(permissions.BasePermission):
129-
def has_object_permission(self, request, view, obj):
148+
def has_object_permission(self, request: Any, view: Any, obj: Any) -> bool:
130149
# must be the owner to view the object
131150
return obj.user == request.user
132151

133152

134153
# Mixins
135154
class DeviceViewSetMixin:
136-
lookup_field = "registration_id"
137-
138-
def create(self, request, *args, **kwargs):
139-
serializer = None
140-
is_update = False
141-
if SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID") and self.lookup_field in request.data:
155+
lookup_field: str = "registration_id"
156+
157+
def create(self, request: Any, *args: Any, **kwargs: Any) -> Response:
158+
serializer: Optional[Any] = None
159+
is_update: bool = False
160+
if (
161+
SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID")
162+
and self.lookup_field in request.data
163+
):
142164
instance = self.queryset.model.objects.filter(
143165
registration_id=request.data[self.lookup_field]
144166
).first()
@@ -155,23 +177,25 @@ def create(self, request, *args, **kwargs):
155177
else:
156178
self.perform_create(serializer)
157179
headers = self.get_success_headers(serializer.data)
158-
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
180+
return Response(
181+
serializer.data, status=status.HTTP_201_CREATED, headers=headers
182+
)
159183

160-
def perform_create(self, serializer):
184+
def perform_create(self, serializer: Serializer) -> Any:
161185
if self.request.user.is_authenticated:
162186
serializer.save(user=self.request.user)
163187
return super().perform_create(serializer)
164188

165-
def perform_update(self, serializer):
189+
def perform_update(self, serializer: Serializer) -> Any:
166190
if self.request.user.is_authenticated:
167191
serializer.save(user=self.request.user)
168192
return super().perform_update(serializer)
169193

170194

171195
class AuthorizedMixin:
172-
permission_classes = (permissions.IsAuthenticated, IsOwner)
196+
permission_classes: tuple = (permissions.IsAuthenticated, IsOwner)
173197

174-
def get_queryset(self):
198+
def get_queryset(self) -> Any:
175199
# filter all devices to only those belonging to the current user
176200
return self.queryset.filter(user=self.request.user)
177201

@@ -207,7 +231,7 @@ class WNSDeviceAuthorizedViewSet(AuthorizedMixin, WNSDeviceViewSet):
207231
class WebPushDeviceViewSet(DeviceViewSetMixin, ModelViewSet):
208232
queryset = WebPushDevice.objects.all()
209233
serializer_class = WebPushDeviceSerializer
210-
lookup_value_regex = '.+'
234+
lookup_value_regex: str = ".+"
211235

212236

213237
class WebPushDeviceAuthorizedViewSet(AuthorizedMixin, WebPushDeviceViewSet):

push_notifications/apns.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@
55
"""
66

77
import time
8-
8+
from typing import Optional, Dict, Any, List, Union
99
from apns2 import client as apns2_client
1010
from apns2 import credentials as apns2_credentials
1111
from apns2 import errors as apns2_errors
1212
from apns2 import payload as apns2_payload
1313

1414
from . import models
1515
from .conf import get_manager
16-
from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError
16+
from .exceptions import APNSUnsupportedPriority, APNSServerError
1717

1818

19-
def _apns_create_socket(creds=None, application_id=None):
19+
def _apns_create_socket(creds: Optional[apns2_credentials.Credentials] = None, application_id: Optional[str] = None) -> apns2_client.APNsClient:
2020
if creds is None:
2121
if not get_manager().has_auth_token_creds(application_id):
2222
cert = get_manager().get_apns_certificate(application_id)
@@ -39,31 +39,48 @@ def _apns_create_socket(creds=None, application_id=None):
3939

4040

4141
def _apns_prepare(
42-
token, alert, application_id=None, badge=None, sound=None, category=None,
43-
content_available=False, action_loc_key=None, loc_key=None, loc_args=[],
44-
extra={}, mutable_content=False, thread_id=None, url_args=None):
45-
if action_loc_key or loc_key or loc_args:
46-
apns2_alert = apns2_payload.PayloadAlert(
47-
body=alert if alert else {}, body_localized_key=loc_key,
48-
body_localized_args=loc_args, action_localized_key=action_loc_key)
49-
else:
50-
apns2_alert = alert
51-
52-
if callable(badge):
53-
badge = badge(token)
54-
55-
return apns2_payload.Payload(
56-
alert=apns2_alert, badge=badge, sound=sound, category=category,
57-
url_args=url_args, custom=extra, thread_id=thread_id,
58-
content_available=content_available, mutable_content=mutable_content)
42+
token: str,
43+
alert: Optional[str],
44+
application_id: Optional[str] = None,
45+
badge: Optional[int] = None,
46+
sound: Optional[str] = None,
47+
category: Optional[str] = None,
48+
content_available: bool = False,
49+
action_loc_key: Optional[str] = None,
50+
loc_key: Optional[str] = None,
51+
loc_args: List[Any] = [],
52+
extra: Dict[str, Any] = {},
53+
mutable_content: bool = False,
54+
thread_id: Optional[str] = None,
55+
url_args: Optional[list] = None
56+
) -> apns2_payload.Payload:
57+
if action_loc_key or loc_key or loc_args:
58+
apns2_alert = apns2_payload.PayloadAlert(
59+
body=alert if alert else {}, body_localized_key=loc_key,
60+
body_localized_args=loc_args, action_localized_key=action_loc_key)
61+
else:
62+
apns2_alert = alert
63+
64+
if callable(badge):
65+
badge = badge(token)
66+
67+
return apns2_payload.Payload(
68+
alert=apns2_alert, badge=badge, sound=sound, category=category,
69+
url_args=url_args, custom=extra, thread_id=thread_id,
70+
content_available=content_available, mutable_content=mutable_content)
5971

6072

6173
def _apns_send(
62-
registration_id, alert, batch=False, application_id=None, creds=None, **kwargs
63-
):
74+
registration_id: Union[str, List[str]],
75+
alert: Optional[str] = None,
76+
batch: bool = False,
77+
application_id: Optional[str] = None,
78+
creds: Optional[apns2_credentials.Credentials] = None,
79+
**kwargs: Any
80+
) -> Optional[Dict[str, str]]:
6481
client = _apns_create_socket(creds=creds, application_id=application_id)
6582

66-
notification_kwargs = {}
83+
notification_kwargs: Dict[str, Any] = {}
6784

6885
# if expiration isn"t specified use 1 month from now
6986
notification_kwargs["expiration"] = kwargs.pop("expiration", None)
@@ -97,7 +114,13 @@ def _apns_send(
97114
)
98115

99116

100-
def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs):
117+
def apns_send_message(
118+
registration_id: str,
119+
alert: Optional[str] = None,
120+
application_id: Optional[str] = None,
121+
creds: Optional[apns2_credentials.Credentials] = None,
122+
**kwargs: Any
123+
) -> None:
101124
"""
102125
Sends an APNS notification to a single registration_id.
103126
This will send the notification as form data.
@@ -122,8 +145,12 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, *
122145

123146

124147
def apns_send_bulk_message(
125-
registration_ids, alert, application_id=None, creds=None, **kwargs
126-
):
148+
registration_ids: List[str],
149+
alert: Optional[str] = None,
150+
application_id: Optional[str] = None,
151+
creds: Optional[apns2_credentials.Credentials] = None,
152+
**kwargs: Any
153+
) -> Optional[Dict[str, str]]:
127154
"""
128155
Sends an APNS notification to one or more registration_ids.
129156
The registration_ids argument needs to be a list.

0 commit comments

Comments
 (0)