diff --git a/.gitignore b/.gitignore index 49e14ccc8d..0b199c7e56 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ venv*/ .python-version build/ dist/ +.vscode/ diff --git a/docs/advanced.md b/docs/advanced.md index 1b0ecee7c5..d1ca5bd38d 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -1057,3 +1057,39 @@ Which we can use in the same way: >>> response.json() {"text": "Hello, world!"} ``` + + +## Custom JSON library support + +By default `httpx` uses built-in python library `json` to decode and encode +JSON data. +It is possible to use some other, higher performance library or add some custom +encoder options. + +You can make `dumps` return pretty-printed JSON by default. +```python +import httpx +import json + +def _my_dumps(obj, **kwargs): + return json.dumps(obj, + ensure_ascii=kwargs.pop('ensure_ascii', False), + indent=kwargs.pop('indent', 4), + sort_keys=kwargs.pop('sort_keys', True), + **kwargs + ) + +with httpx.Client(json_encoder=_my_dumps) as client: + response = client.post('http://httpbin.org/anything', json={'Hello': 'World!', '🙂': '👋'}) + print(response.json()) +``` + +Or use another library for JSON support. +```python +import httpx +import orjson + +with httpx.Client(json_encoder=orjson.dumps, json_decoder=orjson.loads) as client: + response = client.post('http://httpbin.org/anything', json={'Hello': 'World!', '🙂': '👋'}) + print(response.json()) +``` diff --git a/httpx/_client.py b/httpx/_client.py index d15c004530..6eb9909f2d 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -1,5 +1,6 @@ import datetime -import enum +import functools +import json import typing import warnings from types import TracebackType @@ -90,6 +91,8 @@ def __init__( event_hooks: typing.Dict[str, typing.List[typing.Callable]] = None, base_url: URLTypes = "", trust_env: bool = True, + json_encoder: typing.Callable = json.dumps, + json_decoder: typing.Callable = json.loads, ): event_hooks = {} if event_hooks is None else event_hooks @@ -108,6 +111,8 @@ def __init__( self._trust_env = trust_env self._netrc = NetRCInfo() self._state = ClientState.UNOPENED + self._json_encoder = json_encoder + self._json_decoder = json_decoder @property def is_closed(self) -> bool: @@ -317,6 +322,7 @@ def build_request( params=params, headers=headers, cookies=cookies, + json_encoder=self._json_encoder, ) def _merge_url(self, url: URLTypes) -> URL: @@ -411,7 +417,12 @@ def _build_redirect_request(self, request: Request, response: Response) -> Reque stream = self._redirect_stream(request, method) cookies = Cookies(self.cookies) return Request( - method=method, url=url, headers=headers, cookies=cookies, stream=stream + method=method, + url=url, + headers=headers, + cookies=cookies, + stream=stream, + json_encoder=self._json_encoder, ) def _redirect_method(self, request: Request, response: Response) -> str: @@ -548,6 +559,8 @@ class Client(BaseClient): rather than sending actual network requests. * **trust_env** - *(optional)* Enables or disables usage of environment variables for configuration. + * **json_encoder** - *(optional)* A custom function to encode request data to JSON. + * **json_decoder** - *(optional)* A custom function to decode response data from JSON. """ def __init__( @@ -570,6 +583,8 @@ def __init__( transport: httpcore.SyncHTTPTransport = None, app: typing.Callable = None, trust_env: bool = True, + json_encoder: typing.Callable = json.dumps, + json_decoder: typing.Callable = json.loads, ): super().__init__( auth=auth, @@ -581,6 +596,8 @@ def __init__( event_hooks=event_hooks, base_url=base_url, trust_env=trust_env, + json_encoder=json_encoder, + json_decoder=json_decoder, ) if http2: @@ -878,6 +895,8 @@ def on_close(response: Response) -> None: ext=ext, request=request, on_close=on_close, + json_encoder=self._json_encoder, + json_decoder=self._json_decoder, ) self.cookies.extract_cookies(response) @@ -1185,6 +1204,8 @@ class AsyncClient(BaseClient): rather than sending actual network requests. * **trust_env** - *(optional)* Enables or disables usage of environment variables for configuration. + * **json_encoder** - *(optional)* A custom function to encode request data to JSON. + * **json_decoder** - *(optional)* A custom function to decode response data from JSON. """ def __init__( @@ -1207,6 +1228,8 @@ def __init__( transport: httpcore.AsyncHTTPTransport = None, app: typing.Callable = None, trust_env: bool = True, + json_encoder: typing.Callable = json.dumps, + json_decoder: typing.Callable = json.loads, ): super().__init__( auth=auth, @@ -1218,6 +1241,8 @@ def __init__( event_hooks=event_hooks, base_url=base_url, trust_env=trust_env, + json_encoder=json_encoder, + json_decoder=json_decoder, ) if http2: @@ -1519,6 +1544,8 @@ async def on_close(response: Response) -> None: ext=ext, request=request, on_close=on_close, + json_encoder=self._json_encoder, + json_decoder=self._json_decoder, ) self.cookies.extract_cookies(response) diff --git a/httpx/_content.py b/httpx/_content.py index bf402c9e29..8bbd9f8ecb 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -1,5 +1,6 @@ import inspect -from json import dumps as json_dumps +import json +import typing from typing import ( Any, AsyncIterable, @@ -137,8 +138,12 @@ def encode_html(html: str) -> Tuple[Dict[str, str], ByteStream]: return headers, PlainByteStream(body) -def encode_json(json: Any) -> Tuple[Dict[str, str], ByteStream]: - body = json_dumps(json).encode("utf-8") +def encode_json( + json: Any, encoder: typing.Callable = json.dumps +) -> Tuple[Dict[str, str], ByteStream]: + body: Union[bytes, str] = encoder(json) + if isinstance(body, str): + body = body.encode("utf-8") content_length = str(len(body)) content_type = "application/json" headers = {"Content-Length": content_length, "Content-Type": content_type} @@ -151,6 +156,7 @@ def encode_request( files: RequestFiles = None, json: Any = None, boundary: bytes = None, + json_encoder: typing.Callable = json.dumps, ) -> Tuple[Dict[str, str], ByteStream]: """ Handles encoding the given `content`, `data`, `files`, and `json`, @@ -173,7 +179,7 @@ def encode_request( elif data: return encode_urlencoded_data(data) elif json is not None: - return encode_json(json) + return encode_json(json, encoder=json_encoder) return {}, PlainByteStream(b"") @@ -183,6 +189,7 @@ def encode_response( text: str = None, html: str = None, json: Any = None, + json_encoder: typing.Callable = json.dumps, ) -> Tuple[Dict[str, str], ByteStream]: """ Handles encoding the given `content`, returning a two-tuple of @@ -195,6 +202,6 @@ def encode_response( elif html is not None: return encode_html(html) elif json is not None: - return encode_json(json) + return encode_json(json, encoder=json_encoder) return {}, PlainByteStream(b"") diff --git a/httpx/_models.py b/httpx/_models.py index c981c740bf..56d3b760ed 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -2,7 +2,7 @@ import contextlib import datetime import email.message -import json as jsonlib +import json import typing import urllib.request from collections.abc import MutableMapping @@ -793,6 +793,7 @@ def __init__( files: RequestFiles = None, json: typing.Any = None, stream: ByteStream = None, + json_encoder: typing.Callable = json.dumps, ): if isinstance(method, bytes): self.method = method.decode("ascii").upper() @@ -820,7 +821,9 @@ def __init__( self.stream = stream self._prepare({}) else: - headers, stream = encode_request(content, data, files, json) + headers, stream = encode_request( + content, data, files, json, json_encoder=json_encoder + ) self._prepare(headers) self.stream = stream @@ -903,6 +906,8 @@ def __init__( ext: dict = None, history: typing.List["Response"] = None, on_close: typing.Callable = None, + json_encoder: typing.Callable = json.dumps, + json_decoder: typing.Callable = json.loads, ): self.status_code = status_code self.headers = Headers(headers) @@ -936,7 +941,9 @@ def __init__( # from the transport API. self.stream = stream else: - headers, stream = encode_response(content, text, html, json) + headers, stream = encode_response( + content, text, html, json, json_encoder=json_encoder + ) self._prepare(headers) self.stream = stream if content is None or isinstance(content, (bytes, str)): @@ -944,6 +951,7 @@ def __init__( self.read() self._num_bytes_downloaded = 0 + self._json_decoder = json_decoder def _prepare(self, default_headers: typing.Dict[str, str]) -> None: for key, value in default_headers.items(): @@ -1106,14 +1114,17 @@ def raise_for_status(self) -> None: raise HTTPStatusError(message, request=request, response=self) def json(self, **kwargs: typing.Any) -> typing.Any: + """ + For available `kwargs` see `json.loads` definition. + """ if self.charset_encoding is None and self.content and len(self.content) > 3: encoding = guess_json_utf(self.content) if encoding is not None: try: - return jsonlib.loads(self.content.decode(encoding), **kwargs) + return self._json_decoder(self.content.decode(encoding), **kwargs) except UnicodeDecodeError: pass - return jsonlib.loads(self.text, **kwargs) + return self._json_decoder(self.text, **kwargs) @property def cookies(self) -> "Cookies": diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 44ff90fe51..6eb4117c9b 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,4 +1,6 @@ +import json from datetime import timedelta +from decimal import Decimal import httpcore import pytest @@ -242,3 +244,20 @@ async def test_deleting_unclosed_async_client_causes_warning(): await client.get("http://example.com") with pytest.warns(UserWarning): del client + + +@pytest.mark.usefixtures("async_environment") +async def test_post_json_overriden_decoder(server): + def _my_loads(s, **kwargs): + return json.loads(s, parse_float=lambda v: Decimal(v), **kwargs) + + url = server.url + async with httpx.AsyncClient(json_decoder=_my_loads) as client: + response = await client.post( + url.copy_with(path="/echo_body"), + json={"text": "Hello, world!", "decimal": 0.12345}, + ) + assert response.status_code == 200 + data = response.json() + + assert isinstance(data["decimal"], Decimal) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index a41f4232fb..1f2b75a5ae 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,4 +1,6 @@ +import json from datetime import timedelta +from decimal import Decimal import httpcore import pytest @@ -300,3 +302,19 @@ def test_raw_client_header(): ["User-Agent", f"python-httpx/{httpx.__version__}"], ["Example-Header", "example-value"], ] + + +def test_post_json_overriden_decoder(server): + def _my_loads(s, **kwargs): + return json.loads(s, parse_float=lambda v: Decimal(v), **kwargs) + + url = server.url + with httpx.Client(json_decoder=_my_loads) as client: + response = client.post( + url.copy_with(path="/echo_body"), + json={"text": "Hello, world!", "decimal": 0.12345}, + ) + assert response.status_code == 200 + data = response.json() + + assert isinstance(data["decimal"], Decimal) diff --git a/tests/test_content.py b/tests/test_content.py index 384f9f2287..3c335435f3 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -1,4 +1,5 @@ import io +import json import typing import pytest @@ -357,3 +358,34 @@ async def hello_world(): def test_response_invalid_argument(): with pytest.raises(TypeError): encode_response(123) # type: ignore + + +@pytest.mark.asyncio +async def test_json_content_overriden_encoder(): + def _my_dumps(obj, **kwargs): + return json.dumps(obj, ensure_ascii=False, sort_keys=True, **kwargs).encode( + "utf-8" + ) + + data = { + "こんにちは": "世界", + "ওহে": "বিশ্ব!", + "Привет": "мир!", + "Hello": "world!", + } + headers, stream = encode_request(json=data, json_encoder=_my_dumps) + + assert isinstance(stream, typing.Iterable) + assert isinstance(stream, typing.AsyncIterable) + + sync_content = b"".join([part for part in stream]) + async_content = b"".join([part async for part in stream]) + + result = _my_dumps(data) + + assert headers == { + "Content-Length": str(len(result)), + "Content-Type": "application/json", + } + assert sync_content == result + assert async_content == result