Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion aioresponses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
from .core import CallbackResult, aioresponses

__version__ = '0.7.9'
Expand Down
33 changes: 5 additions & 28 deletions aioresponses/compat.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
# -*- coding: utf-8 -*-
import asyncio # noqa: F401
import sys
from typing import Dict, Optional, Union # noqa
import asyncio
from re import Pattern
from typing import Optional, Union
from urllib.parse import parse_qsl, urlencode

from aiohttp import __version__ as aiohttp_version, StreamReader
from aiohttp import RequestInfo, StreamReader
from aiohttp.client_proto import ResponseHandler
from multidict import MultiDict
from packaging.version import Version
from yarl import URL

if sys.version_info < (3, 7):
from re import _pattern_type as Pattern
else:
from re import Pattern

AIOHTTP_VERSION = Version(aiohttp_version)


def stream_reader_factory( # noqa
Expand All @@ -27,7 +19,7 @@ def stream_reader_factory( # noqa

def merge_params(
url: 'Union[URL, str]',
params: Optional[Dict] = None
params: dict | None = None
) -> 'URL':
url = URL(url)
if params:
Expand All @@ -43,25 +35,10 @@ def normalize_url(url: 'Union[URL, str]') -> 'URL':
return url.with_query(urlencode(sorted(parse_qsl(url.query_string))))


try:
from aiohttp import RequestInfo
except ImportError:
class RequestInfo(object):
__slots__ = ('url', 'method', 'headers', 'real_url')

def __init__(
self, url: URL, method: str, headers: Dict, real_url: str
):
self.url = url
self.method = method
self.headers = headers
self.real_url = real_url

__all__ = [
'URL',
'Pattern',
'RequestInfo',
'AIOHTTP_VERSION',
'merge_params',
'stream_reader_factory',
'normalize_url',
Expand Down
98 changes: 46 additions & 52 deletions aioresponses/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import asyncio
import copy
import inspect
Expand All @@ -7,16 +6,15 @@
from functools import wraps
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from collections.abc import Callable
from unittest.mock import Mock, patch
from uuid import uuid4

Expand All @@ -29,15 +27,14 @@
)
from aiohttp.helpers import TimerNoop
from multidict import CIMultiDict, CIMultiDictProxy
from packaging.version import Version

from .compat import (
URL,
Pattern,
stream_reader_factory,
merge_params,
normalize_url,
RequestInfo, AIOHTTP_VERSION,
RequestInfo,
)

_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
Expand All @@ -47,12 +44,12 @@ class CallbackResult:

def __init__(self, method: str = hdrs.METH_GET,
status: int = 200,
body: Union[str, bytes] = '',
body: str | bytes = '',
content_type: str = 'application/json',
payload: Optional[Dict] = None,
headers: Optional[Dict] = None,
response_class: Optional[Type[ClientResponse]] = None,
reason: Optional[str] = None):
payload: dict | None = None,
headers: dict | None = None,
response_class: type[ClientResponse] | None = None,
reason: str | None = None):
self.method = method
self.status = status
self.body = body
Expand All @@ -63,22 +60,22 @@ def __init__(self, method: str = hdrs.METH_GET,
self.reason = reason


class RequestMatch(object):
class RequestMatch:
url_or_pattern = None # type: Union[URL, Pattern]

def __init__(self, url: Union[URL, str, Pattern],
def __init__(self, url: URL | str | Pattern,
method: str = hdrs.METH_GET,
status: int = 200,
body: Union[str, bytes] = '',
payload: Optional[Dict] = None,
exception: Optional[Exception] = None,
headers: Optional[Dict] = None,
body: str | bytes = '',
payload: dict | None = None,
exception: Exception | None = None,
headers: dict | None = None,
content_type: str = 'application/json',
response_class: Optional[Type[ClientResponse]] = None,
response_class: type[ClientResponse] | None = None,
timeout: bool = False,
repeat: Union[bool, int] = False,
reason: Optional[str] = None,
callback: Optional[Callable] = None):
repeat: bool | int = False,
reason: str | None = None,
callback: Callable | None = None):
if isinstance(url, Pattern):
self.url_or_pattern = url
self.match_func = self.match_regexp
Expand Down Expand Up @@ -118,7 +115,7 @@ def match(self, method: str, url: URL) -> bool:
return False
return self.match_func(url)

def _build_raw_headers(self, headers: Dict) -> Tuple:
def _build_raw_headers(self, headers: dict) -> tuple:
"""
Convert a dict of headers to a tuple of tuples

Expand All @@ -131,14 +128,14 @@ def _build_raw_headers(self, headers: Dict) -> Tuple:

def _build_response(self, url: 'Union[URL, str]',
method: str = hdrs.METH_GET,
request_headers: Optional[Dict] = None,
request_headers: dict | None = None,
status: int = 200,
body: Union[str, bytes] = '',
body: str | bytes = '',
content_type: str = 'application/json',
payload: Optional[Dict] = None,
headers: Optional[Dict] = None,
response_class: Optional[Type[ClientResponse]] = None,
reason: Optional[str] = None) -> ClientResponse:
payload: dict | None = None,
headers: dict | None = None,
response_class: type[ClientResponse] | None = None,
reason: str | None = None) -> ClientResponse:
if response_class is None:
response_class = ClientResponse
if payload is not None:
Expand Down Expand Up @@ -220,10 +217,10 @@ def __repr__(self) -> str:
RequestCall = namedtuple('RequestCall', ['args', 'kwargs'])


class aioresponses(object):
class aioresponses:
"""Mock aiohttp requests made by ClientSession."""
_matches = None # type: Dict[str, RequestMatch]
_responses: List[ClientResponse] = None
_responses: list[ClientResponse] = None
requests = None # type: Dict

def __init__(self, **kwargs: Any):
Expand All @@ -243,7 +240,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.stop()

def __call__(self, f: _FuncT) -> _FuncT:
def _pack_arguments(ctx, *args, **kwargs) -> Tuple[Tuple, Dict]:
def _pack_arguments(ctx, *args, **kwargs) -> tuple[tuple, dict]:
if self._param:
kwargs[self._param] = ctx
else:
Expand Down Expand Up @@ -303,16 +300,16 @@ def options(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:

def add(self, url: 'Union[URL, str, Pattern]', method: str = hdrs.METH_GET,
status: int = 200,
body: Union[str, bytes] = '',
exception: Optional[Exception] = None,
body: str | bytes = '',
exception: Exception | None = None,
content_type: str = 'application/json',
payload: Optional[Dict] = None,
headers: Optional[Dict] = None,
response_class: Optional[Type[ClientResponse]] = None,
repeat: Union[bool, int] = False,
payload: dict | None = None,
headers: dict | None = None,
response_class: type[ClientResponse] | None = None,
repeat: bool | int = False,
timeout: bool = False,
reason: Optional[str] = None,
callback: Optional[Callable] = None) -> None:
reason: str | None = None,
callback: Callable | None = None) -> None:

self._matches[str(uuid4())] = (RequestMatch(
url,
Expand All @@ -335,7 +332,7 @@ def _format_call_signature(self, *args, **kwargs) -> str:
formatted_args = ''
args_string = ', '.join([repr(arg) for arg in args])
kwargs_string = ', '.join([
'%s=%r' % (key, value) for key, value in kwargs.items()
f'{key}={value!r}' for key, value in kwargs.items()
])
if args_string:
formatted_args = args_string
Expand Down Expand Up @@ -405,7 +402,7 @@ def assert_called_with(self, url: 'Union[URL, str, Pattern]',
actual
)
raise AssertionError(
'%s != %s' % (expected_string, actual_string)
f'{expected_string} != {actual_string}'
)

def assert_any_call(self, url: 'Union[URL, str, Pattern]',
Expand Down Expand Up @@ -438,7 +435,7 @@ def assert_called_once_with(self, *args: Any, **kwargs: Any):
self.assert_called_with(*args, **kwargs)

@staticmethod
def is_exception(resp_or_exc: Union[ClientResponse, Exception]) -> bool:
def is_exception(resp_or_exc: ClientResponse | Exception) -> bool:
if inspect.isclass(resp_or_exc):
parent_classes = set(inspect.getmro(resp_or_exc))
if {Exception, BaseException} & parent_classes:
Expand Down Expand Up @@ -499,21 +496,18 @@ async def match(

async def _request_mock(self, orig_self: ClientSession,
method: str, url: 'Union[URL, str]',
*args: Tuple,
*args: tuple,
**kwargs: Any) -> 'ClientResponse':
"""Return mocked response object or raise connection error."""
if orig_self.closed:
raise RuntimeError('Session is closed')

if AIOHTTP_VERSION >= Version('3.8.0'):
# Join url with ClientSession._base_url
url = orig_self._build_url(url)
url_origin = str(url)
# Combine ClientSession headers with passed headers
if orig_self.headers:
kwargs["headers"] = orig_self._prepare_headers(kwargs.get("headers"))
else:
url_origin = url
# Join url with ClientSession._base_url
url = orig_self._build_url(url)
url_origin = str(url)
# Combine ClientSession headers with passed headers
if orig_self.headers:
kwargs["headers"] = orig_self._prepare_headers(kwargs.get("headers"))

url = normalize_url(merge_params(url, kwargs.get('params')))
url_str = str(url)
Expand All @@ -536,7 +530,7 @@ async def _request_mock(self, orig_self: ClientSession,
orig_self, method, url_origin, *args, **kwargs
))
raise ClientConnectionError(
'Connection refused: {} {}'.format(method, url)
f'Connection refused: {method} {url}'
)
self._responses.append(response)

Expand Down
17 changes: 8 additions & 9 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
-r requirements.txt
pip
wheel
flake8==5.0.4
tox==3.19.0
coverage==5.2.1
Sphinx==1.5.6
pytest==7.1.3
pytest-cov==2.10.1
pytest-html==2.1.1
tox==4.46.3
coverage==7.13.4
Sphinx==7.1.2
pytest==8.4.0
pytest-cov==6.0.0
pytest-html==4.0.0
ddt==1.4.1
typing
asynctest==0.13.0
yarl==1.9.4
yarl==1.22.0
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
packaging>=22.0
aiohttp>=3.3.0,<4.0.0
aiohttp>=3.8.0,<4.0.0
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ classifier =
License :: OSI Approved :: MIT License
Natural Language :: English
Programming Language :: Python :: 3
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3.13
Programming Language :: Python :: 3.14

[files]
packages =
Expand Down
Loading