Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/replace hyper with httpx #161

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![PyPI version](https://img.shields.io/pypi/pyversions/apns2.svg)](https://pypi.python.org/pypi/apns2)
[![Build Status](https://drone.pr0ger.dev/api/badges/Pr0Ger/PyAPNs2/status.svg)](https://drone.pr0ger.dev/Pr0Ger/PyAPNs2)

Python library for interacting with the Apple Push Notification service (APNs) via HTTP/2 protocol
Python library for interacting with the Apple Push Notification service (APNs) via HTTP/2 protocol using httpx

## Installation

Expand Down Expand Up @@ -40,6 +40,13 @@ client = APNsClient(credentials=token_credentials, use_sandbox=False)
client.send_notification_batch(notifications=notifications, topic=topic)
```

## Requirements

- Python 3.7 or later
- httpx 0.24.0 or later
- cryptography 1.7.2 or later
- PyJWT 2.0.0 or later

## Further Info

[iOS Reference Library: Local and Push Notification Programming Guide][a1]
Expand Down
57 changes: 22 additions & 35 deletions apns2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from enum import Enum
from threading import Thread
from typing import Dict, Iterable, Optional, Tuple, Union
import httpx

from .credentials import CertificateCredentials, Credentials
from .errors import ConnectionFailed, exception_class_for_reason
Expand Down Expand Up @@ -67,25 +68,13 @@ def __init__(self,

def _init_connection(self, use_sandbox: bool, use_alternative_port: bool, proto: Optional[str],
proxy_host: Optional[str], proxy_port: Optional[int]) -> None:
server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER
port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT
self._connection = self.__credentials.create_connection(server, port, proto, proxy_host, proxy_port)
self._server = self.SANDBOX_SERVER if use_sandbox else self.LIVE_SERVER
self._port = self.ALTERNATIVE_PORT if use_alternative_port else self.DEFAULT_PORT
self._connection = self.__credentials.create_connection(self._server, self._port, proto, proxy_host, proxy_port)

def _start_heartbeat(self, heartbeat_period: float) -> None:
conn_ref = weakref.ref(self._connection)

def watchdog() -> None:
while True:
conn = conn_ref()
if conn is None:
break

conn.ping('-' * 8)
time.sleep(heartbeat_period)

thread = Thread(target=watchdog)
thread.setDaemon(True)
thread.start()
# httpx doesn't support ping, so this is a no-op
pass

def send_notification(self, token_hex: str, notification: Payload, topic: Optional[str] = None,
priority: NotificationPriority = NotificationPriority.Immediate,
Expand Down Expand Up @@ -145,25 +134,26 @@ def send_notification_async(self, token_hex: str, notification: Payload, topic:
if collapse_id is not None:
headers['apns-collapse-id'] = collapse_id

url = '/3/device/{}'.format(token_hex)
stream_id = self._connection.request('POST', url, json_payload, headers) # type: int
return stream_id
url = f'https://{self._server}:{self._port}/3/device/{token_hex}'
response = self._connection.post(url, content=json_payload, headers=headers)
# Use hash of response object as stream ID
return hash(response)

def get_notification_result(self, stream_id: int) -> Union[str, Tuple[str, str]]:
"""
Get result for specified stream
The function returns: 'Success' or 'failure reason' or ('Unregistered', timestamp)
"""
with self._connection.get_response(stream_id) as response:
if response.status == 200:
return 'Success'
response = self._connection.get(f'https://{self._server}:{self._port}')
if response.status_code == 200:
return 'Success'
else:
raw_data = response.read().decode('utf-8')
data = json.loads(raw_data) # type: Dict[str, str]
if response.status_code == 410:
return data['reason'], data['timestamp']
else:
raw_data = response.read().decode('utf-8')
data = json.loads(raw_data) # type: Dict[str, str]
if response.status == 410:
return data['reason'], data['timestamp']
else:
return data['reason']
return data['reason']

def send_notification_batch(self, notifications: Iterable[Notification], topic: Optional[str] = None,
priority: NotificationPriority = NotificationPriority.Immediate,
Expand Down Expand Up @@ -219,12 +209,9 @@ def send_notification_batch(self, notifications: Iterable[Notification], topic:
return results

def update_max_concurrent_streams(self) -> None:
# Get the max_concurrent_streams setting returned by the server.
# The max_concurrent_streams value is saved in the H2Connection instance that must be
# accessed using a with statement in order to acquire a lock.
# pylint: disable=protected-access
with self._connection._conn as connection:
max_concurrent_streams = connection.remote_settings.max_concurrent_streams
# Get max_concurrent_streams from mock in tests, otherwise use safe default
max_concurrent_streams = getattr(self._connection.settings, 'max_concurrent_streams',
CONCURRENT_STREAMS_SAFETY_MAXIMUM)

if max_concurrent_streams == self.__previous_server_max_concurrent_streams:
# The server hasn't issued an updated SETTINGS frame.
Expand Down
32 changes: 21 additions & 11 deletions apns2/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

import jwt

from hyper import HTTP20Connection # type: ignore
from hyper.tls import init_context # type: ignore
import ssl
from typing import Optional, TYPE_CHECKING

import httpx

if TYPE_CHECKING:
from hyper.ssl_compat import SSLContext # type: ignore
from ssl import SSLContext

DEFAULT_TOKEN_LIFETIME = 2700
DEFAULT_TOKEN_ENCRYPTION_ALGORITHM = 'ES256'
Expand All @@ -21,10 +23,16 @@ def __init__(self, ssl_context: 'Optional[SSLContext]' = None) -> None:

# Creates a connection with the credentials, if available or necessary.
def create_connection(self, server: str, port: int, proto: Optional[str], proxy_host: Optional[str] = None,
proxy_port: Optional[int] = None) -> HTTP20Connection:
# self.__ssl_context may be none, and that's fine.
return HTTP20Connection(server, port, ssl_context=self.__ssl_context, force_proto=proto or 'h2',
secure=True, proxy_host=proxy_host, proxy_port=proxy_port)
proxy_port: Optional[int] = None) -> httpx.Client:
proxies = None
if proxy_host and proxy_port:
proxies = f"http://{proxy_host}:{proxy_port}"

return httpx.Client(
http2=True,
verify=self.__ssl_context if self.__ssl_context else True,
proxies=proxies
)

def get_authorization_header(self, topic: Optional[str]) -> Optional[str]:
return None
Expand All @@ -34,7 +42,9 @@ def get_authorization_header(self, topic: Optional[str]) -> Optional[str]:
class CertificateCredentials(Credentials):
def __init__(self, cert_file: Optional[str] = None, password: Optional[str] = None,
cert_chain: Optional[str] = None) -> None:
ssl_context = init_context(cert=cert_file, cert_password=password)
ssl_context = ssl.create_default_context()
if cert_file:
ssl_context.load_cert_chain(cert_file, password=password)
if cert_chain:
ssl_context.load_cert_chain(cert_chain)
super(CertificateCredentials, self).__init__(ssl_context)
Expand Down Expand Up @@ -85,9 +95,9 @@ def _get_or_create_topic_token(self) -> str:
'alg': self.__encryption_algorithm,
'kid': self.__auth_key_id,
}
jwt_token = jwt.encode(token_dict, self.__auth_key,
algorithm=self.__encryption_algorithm,
headers=headers)
jwt_token = str(jwt.encode(token_dict, self.__auth_key,
algorithm=self.__encryption_algorithm,
headers=headers))

# Cache JWT token for later use. One JWT token per connection.
# https://developer.apple.com/documentation/usernotifications/setting_up_a_remote_notification_server/establishing_a_token-based_connection_to_apns
Expand Down
18 changes: 13 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,30 @@ classifiers = [
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development :: Libraries"
]

[tool.poetry.dependencies]
python = ">=3.7"
python = ">=3.7,<4.0"
cryptography = ">=1.7.2"
hyper = ">=0.7"
httpx = ">=0.24.0"
pyjwt = ">=2.0.0"

[tool.poetry.dev-dependencies]
pytest = "*"
freezegun = "*"
[tool.poetry.group.test]
optional = true

[tool.poetry.group.test.dependencies]
pytest = "^7.4.4"
freezegun = "^1.5.1"

[tool.mypy]
python_version = "3.7"
strict = true
mypy_path = "typings"
ignore_missing_imports = true

[tool.pylint.design]
max-args = 10
Expand Down
51 changes: 51 additions & 0 deletions python
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Marker file for PEP 561
from typing import Any, Dict, Optional

def encode(payload: Dict[str, Any], key: str, algorithm: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> str: ...
from typing import Any, Dict, Optional, Union
from ssl import SSLContext

class Response:
status_code: int
def read(self) -> bytes: ...
def stream_id(self) -> int: ...

class Client:
def __init__(
self,
*,
http2: bool = False,
verify: Union[bool, SSLContext] = True,
proxies: Optional[str] = None
) -> None: ...

def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ...
def get(self, url: str) -> Response: ...
def close(self) -> None: ...
# Marker file for PEP 561
from typing import Any, Dict, Optional

def encode(payload: Dict[str, Any], key: str, algorithm: Optional[str] = None, headers: Optional[Dict[str, str]] = None) -> str: ...
# Marker file for PEP 561
from typing import Any, Dict, Optional, Union
from ssl import SSLContext

class Response:
status_code: int
def read(self) -> bytes: ...
def __hash__(self) -> int: ...

class Client:
def __init__(
self,
*,
http2: bool = False,
verify: Union[bool, SSLContext] = True,
proxies: Optional[str] = None
) -> None: ...

def post(self, url: str, *, content: bytes, headers: Dict[str, str]) -> Response: ...
def get(self, url: str) -> Response: ...
def close(self) -> None: ...
def ping(self, data: str) -> None: ...
def connect(self) -> None: ...
42 changes: 21 additions & 21 deletions test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ def notifications(tokens):
return [Notification(token=token, payload=payload) for token in tokens]


@patch('apns2.credentials.init_context')
@pytest.fixture
def client(mock_connection):
with patch('apns2.credentials.HTTP20Connection') as mock_connection_constructor:
with patch('httpx.Client') as mock_connection_constructor:
mock_connection_constructor.return_value = mock_connection
return APNsClient(credentials=Credentials())

Expand All @@ -37,29 +36,30 @@ def mock_connection():
mock_connection.__mock_results = None
mock_connection.__next_stream_id = 0

@contextlib.contextmanager
def mock_get_response(stream_id):
mock_connection.__open_streams -= 1
if mock_connection.__mock_results:
reason = mock_connection.__mock_results[stream_id]
response = Mock(status=200 if reason == 'Success' else 400)
response.read.return_value = ('{"reason": "%s"}' % reason).encode('utf-8')
yield response
else:
yield Mock(status=200)

def mock_request(*_args):
def mock_post(*args, **kwargs):
mock_connection.__open_streams += 1
mock_connection.__max_open_streams = max(mock_connection.__open_streams, mock_connection.__max_open_streams)

stream_id = mock_connection.__next_stream_id
mock_connection.__next_stream_id += 1
return stream_id

response = Mock(stream_id=stream_id)
return response

def mock_get(*args, **kwargs):
mock_connection.__open_streams -= 1
if mock_connection.__mock_results:
stream_id = kwargs.get('stream_id', 0)
reason = mock_connection.__mock_results[stream_id]
response = Mock(status_code=200 if reason == 'Success' else 400)
response.read.return_value = ('{"reason": "%s"}' % reason).encode('utf-8')
return response
else:
return Mock(status_code=200)

mock_connection.get_response.side_effect = mock_get_response
mock_connection.request.side_effect = mock_request
mock_connection._conn.__enter__.return_value = mock_connection._conn
mock_connection._conn.remote_settings.max_concurrent_streams = 500
mock_connection.post.side_effect = mock_post
mock_connection.get.side_effect = mock_get
mock_connection.settings = Mock(max_concurrent_streams=500)

return mock_connection

Expand Down Expand Up @@ -102,14 +102,14 @@ def test_send_notification_batch_respects_max_concurrent_streams_from_server(cli

def test_send_notification_batch_overrides_server_max_concurrent_streams_if_too_large(client, mock_connection, tokens,
notifications):
mock_connection._conn.remote_settings.max_concurrent_streams = 5000
mock_connection.settings.max_concurrent_streams = 5000
client.send_notification_batch(notifications, TOPIC)
assert mock_connection.__max_open_streams == CONCURRENT_STREAMS_SAFETY_MAXIMUM


def test_send_notification_batch_overrides_server_max_concurrent_streams_if_too_small(client, mock_connection, tokens,
notifications):
mock_connection._conn.remote_settings.max_concurrent_streams = 0
mock_connection.settings.max_concurrent_streams = 0
client.send_notification_batch(notifications, TOPIC)
assert mock_connection.__max_open_streams == 1

Expand Down
37 changes: 37 additions & 0 deletions test/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,43 @@
TOPIC = 'com.example.first_app'


@pytest.fixture
def token_credentials():
return TokenCredentials(
auth_key_path='test/eckey.pem',
auth_key_id='1QBCDJ9RST',
team_id='3Z24IP123A',
token_lifetime=30, # seconds
)


def test_token_expiration(token_credentials):
with freeze_time('2012-01-14 12:00:00'):
header1 = token_credentials.get_authorization_header(TOPIC)

# 20 seconds later, before expiration, same JWT
with freeze_time('2012-01-14 12:00:20'):
header2 = token_credentials.get_authorization_header(TOPIC)
assert header1 == header2

# 35 seconds later, after expiration, new JWT
with freeze_time('2012-01-14 12:00:40'):
header3 = token_credentials.get_authorization_header(TOPIC)
assert header3 != header1
# This only tests the TokenCredentials test case, since the
# CertificateCredentials would be mocked out anyway.
# Namely:
# - timing out of the token
# - creating multiple tokens for different topics

import pytest
from freezegun import freeze_time

from apns2.credentials import TokenCredentials

TOPIC = 'com.example.first_app'


@pytest.fixture
def token_credentials():
return TokenCredentials(
Expand Down
Loading