diff --git a/README.md b/README.md index 017e9d951..487e1365e 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,13 @@ content_type_overrides: application/zip: application/octet-stream ``` +### multiple_media_types + +OpenAPI documents may have more than one media type for a response. By default, `openapi-python-client` only generates a response body parser for the first one it encounters. +This config tells the generator to check the `Content-Type` header of the response and parse the response accordingly. + +For example, this might be useful if an OpenAPI document models a service that returns 503 with a JSON error description when a downstream service fails, but is behind a load balancer that returns 503 with plain text when overloaded. + ## Supported Extensions ### x-enum-varnames diff --git a/end_to_end_tests/multiple-media-types-golden-record/.gitignore b/end_to_end_tests/multiple-media-types-golden-record/.gitignore new file mode 100644 index 000000000..79a2c3d73 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/.gitignore @@ -0,0 +1,23 @@ +__pycache__/ +build/ +dist/ +*.egg-info/ +.pytest_cache/ + +# pyenv +.python-version + +# Environments +.env +.venv + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# JetBrains +.idea/ + +/coverage.xml +/.coverage diff --git a/end_to_end_tests/multiple-media-types-golden-record/README.md b/end_to_end_tests/multiple-media-types-golden-record/README.md new file mode 100644 index 000000000..bbe216a9c --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/README.md @@ -0,0 +1,124 @@ +# multiple-media-types-client +A client library for accessing Multiple media types + +## Usage +First, create a client: + +```python +from multiple_media_types_client import Client + +client = Client(base_url="https://api.example.com") +``` + +If the endpoints you're going to hit require authentication, use `AuthenticatedClient` instead: + +```python +from multiple_media_types_client import AuthenticatedClient + +client = AuthenticatedClient(base_url="https://api.example.com", token="SuperSecretToken") +``` + +Now call your endpoint and use your models: + +```python +from multiple_media_types_client.models import MyDataModel +from multiple_media_types_client.api.my_tag import get_my_data_model +from multiple_media_types_client.types import Response + +with client as client: + my_data: MyDataModel = get_my_data_model.sync(client=client) + # or if you need more info (e.g. status_code) + response: Response[MyDataModel] = get_my_data_model.sync_detailed(client=client) +``` + +Or do the same thing with an async version: + +```python +from multiple_media_types_client.models import MyDataModel +from multiple_media_types_client.api.my_tag import get_my_data_model +from multiple_media_types_client.types import Response + +async with client as client: + my_data: MyDataModel = await get_my_data_model.asyncio(client=client) + response: Response[MyDataModel] = await get_my_data_model.asyncio_detailed(client=client) +``` + +By default, when you're calling an HTTPS API it will attempt to verify that SSL is working correctly. Using certificate verification is highly recommended most of the time, but sometimes you may need to authenticate to a server (especially an internal server) using a custom certificate bundle. + +```python +client = AuthenticatedClient( + base_url="https://internal_api.example.com", + token="SuperSecretToken", + verify_ssl="/path/to/certificate_bundle.pem", +) +``` + +You can also disable certificate validation altogether, but beware that **this is a security risk**. + +```python +client = AuthenticatedClient( + base_url="https://internal_api.example.com", + token="SuperSecretToken", + verify_ssl=False +) +``` + +Things to know: +1. Every path/method combo becomes a Python module with four functions: + 1. `sync`: Blocking request that returns parsed data (if successful) or `None` + 1. `sync_detailed`: Blocking request that always returns a `Request`, optionally with `parsed` set if the request was successful. + 1. `asyncio`: Like `sync` but async instead of blocking + 1. `asyncio_detailed`: Like `sync_detailed` but async instead of blocking + +1. All path/query params, and bodies become method arguments. +1. If your endpoint had any tags on it, the first tag will be used as a module name for the function (my_tag above) +1. Any endpoint which did not have a tag will be in `multiple_media_types_client.api.default` + +## Advanced customizations + +There are more settings on the generated `Client` class which let you control more runtime behavior, check out the docstring on that class for more info. You can also customize the underlying `httpx.Client` or `httpx.AsyncClient` (depending on your use-case): + +```python +from multiple_media_types_client import Client + +def log_request(request): + print(f"Request event hook: {request.method} {request.url} - Waiting for response") + +def log_response(response): + request = response.request + print(f"Response event hook: {request.method} {request.url} - Status {response.status_code}") + +client = Client( + base_url="https://api.example.com", + httpx_args={"event_hooks": {"request": [log_request], "response": [log_response]}}, +) + +# Or get the underlying httpx client to modify directly with client.get_httpx_client() or client.get_async_httpx_client() +``` + +You can even set the httpx client directly, but beware that this will override any existing settings (e.g., base_url): + +```python +import httpx +from multiple_media_types_client import Client + +client = Client( + base_url="https://api.example.com", +) +# Note that base_url needs to be re-set, as would any shared cookies, headers, etc. +client.set_httpx_client(httpx.Client(base_url="https://api.example.com", proxies="http://localhost:8030")) +``` + +## Building / publishing this package +This project uses [Poetry](https://python-poetry.org/) to manage dependencies and packaging. Here are the basics: +1. Update the metadata in pyproject.toml (e.g. authors, version) +1. If you're using a private repository, configure it with Poetry + 1. `poetry config repositories. ` + 1. `poetry config http-basic. ` +1. Publish the client with `poetry publish --build -r ` or, if for public PyPI, just `poetry publish --build` + +If you want to install this client into another project without publishing it (e.g. for development) then: +1. If that project **is using Poetry**, you can simply do `poetry add ` from that project +1. If that project is not using Poetry: + 1. Build a wheel with `poetry build -f wheel` + 1. Install that wheel from the other project `pip install ` diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/__init__.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/__init__.py new file mode 100644 index 000000000..8342ab82c --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/__init__.py @@ -0,0 +1,8 @@ +"""A client library for accessing Multiple media types""" + +from .client import AuthenticatedClient, Client + +__all__ = ( + "AuthenticatedClient", + "Client", +) diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/__init__.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/__init__.py new file mode 100644 index 000000000..81f9fa241 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/__init__.py @@ -0,0 +1 @@ +"""Contains methods for accessing the API""" diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/multiple_media_types/__init__.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/multiple_media_types/__init__.py new file mode 100644 index 000000000..2d7c0b23d --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/multiple_media_types/__init__.py @@ -0,0 +1 @@ +"""Contains endpoint functions for accessing the API""" diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/multiple_media_types/post.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/multiple_media_types/post.py new file mode 100644 index 000000000..29e1121c5 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/api/multiple_media_types/post.py @@ -0,0 +1,144 @@ +from http import HTTPStatus +from typing import Any, Literal, Optional, Union, cast + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.error_response import ErrorResponse +from ...types import Response + + +def _get_kwargs() -> dict[str, Any]: + _kwargs: dict[str, Any] = { + "method": "post", + "url": "/", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[Any, ErrorResponse, Literal["Why have a fixed response? I dunno"]]]: + if response.status_code == 200: + response_200: Union[Literal["Why have a fixed response? I dunno"], Any] + if response.headers.get("content-type") == "application/json": + response_200 = cast(Literal["Why have a fixed response? I dunno"], response.json()) + if response_200 != "Why have a fixed response? I dunno": + raise ValueError( + f"response_200 must match const 'Why have a fixed response? I dunno', got '{response_200}'" + ) + return response_200 + if response.headers.get("content-type") == "application/octet-stream": + response_200 = cast(Any, response.content) + return response_200 + if response.status_code == 404: + response_404: Any + if response.headers.get("content-type") == "text/plain": + response_404 = cast(Any, response.text) + return response_404 + if response.status_code == 503: + response_503: Union[ErrorResponse, Any] + if response.headers.get("content-type") == "application/json": + response_503 = ErrorResponse.from_dict(response.json()) + + return response_503 + if response.headers.get("content-type") == "text/plain": + response_503 = cast(Any, response.text) + return response_503 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[Any, ErrorResponse, Literal["Why have a fixed response? I dunno"]]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[Any, ErrorResponse, Literal["Why have a fixed response? I dunno"]]]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Any, ErrorResponse, Literal['Why have a fixed response? I dunno']]] + """ + + kwargs = _get_kwargs() + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[Any, ErrorResponse, Literal["Why have a fixed response? I dunno"]]]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Any, ErrorResponse, Literal['Why have a fixed response? I dunno']] + """ + + return sync_detailed( + client=client, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[Any, ErrorResponse, Literal["Why have a fixed response? I dunno"]]]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Any, ErrorResponse, Literal['Why have a fixed response? I dunno']]] + """ + + kwargs = _get_kwargs() + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[Any, ErrorResponse, Literal["Why have a fixed response? I dunno"]]]: + """ + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Any, ErrorResponse, Literal['Why have a fixed response? I dunno']] + """ + + return ( + await asyncio_detailed( + client=client, + ) + ).parsed diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/client.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/client.py new file mode 100644 index 000000000..e80446f10 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/client.py @@ -0,0 +1,268 @@ +import ssl +from typing import Any, Optional, Union + +import httpx +from attrs import define, evolve, field + + +@define +class Client: + """A class for keeping track of data related to the API + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + """ + + raise_on_unexpected_status: bool = field(default=False, kw_only=True) + _base_url: str = field(alias="base_url") + _cookies: dict[str, str] = field(factory=dict, kw_only=True, alias="cookies") + _headers: dict[str, str] = field(factory=dict, kw_only=True, alias="headers") + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout") + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl") + _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects") + _httpx_args: dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args") + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) + + def with_headers(self, headers: dict[str, str]) -> "Client": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: dict[str, str]) -> "Client": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "Client": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) + + def set_httpx_client(self, client: httpx.Client) -> "Client": + """Manually set the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "Client": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "Client": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "Client": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) + + +@define +class AuthenticatedClient: + """A Client which has been authenticated for use on secured endpoints + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + token: The token to use for authentication + prefix: The prefix to use for the Authorization header + auth_header_name: The name of the Authorization header + """ + + raise_on_unexpected_status: bool = field(default=False, kw_only=True) + _base_url: str = field(alias="base_url") + _cookies: dict[str, str] = field(factory=dict, kw_only=True, alias="cookies") + _headers: dict[str, str] = field(factory=dict, kw_only=True, alias="headers") + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True, alias="timeout") + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True, alias="verify_ssl") + _follow_redirects: bool = field(default=False, kw_only=True, alias="follow_redirects") + _httpx_args: dict[str, Any] = field(factory=dict, kw_only=True, alias="httpx_args") + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) + + token: str + prefix: str = "Bearer" + auth_header_name: str = "Authorization" + + def with_headers(self, headers: dict[str, str]) -> "AuthenticatedClient": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: dict[str, str]) -> "AuthenticatedClient": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) + + def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient": + """Manually set the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "AuthenticatedClient": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "AuthenticatedClient": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "AuthenticatedClient": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/errors.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/errors.py new file mode 100644 index 000000000..5f92e76ac --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/errors.py @@ -0,0 +1,16 @@ +"""Contains shared errors types that can be raised from API functions""" + + +class UnexpectedStatus(Exception): + """Raised by api functions when the response status an undocumented status and Client.raise_on_unexpected_status is True""" + + def __init__(self, status_code: int, content: bytes): + self.status_code = status_code + self.content = content + + super().__init__( + f"Unexpected status code: {status_code}\n\nResponse content:\n{content.decode(errors='ignore')}" + ) + + +__all__ = ["UnexpectedStatus"] diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/__init__.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/__init__.py new file mode 100644 index 000000000..8617051d2 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/__init__.py @@ -0,0 +1,9 @@ +"""Contains all the data models used in inputs/outputs""" + +from .error_response import ErrorResponse +from .error_response_detail import ErrorResponseDetail + +__all__ = ( + "ErrorResponse", + "ErrorResponseDetail", +) diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/error_response.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/error_response.py new file mode 100644 index 000000000..b740f15c1 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/error_response.py @@ -0,0 +1,62 @@ +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, TypeVar + +from attrs import define as _attrs_define + +if TYPE_CHECKING: + from ..models.error_response_detail import ErrorResponseDetail + + +T = TypeVar("T", bound="ErrorResponse") + + +@_attrs_define +class ErrorResponse: + """ + Attributes: + code (str): Error category code + message (str): Human-readable error message + detail (ErrorResponseDetail): Error detail + """ + + code: str + message: str + detail: "ErrorResponseDetail" + + def to_dict(self) -> dict[str, Any]: + code = self.code + + message = self.message + + detail = self.detail.to_dict() + + field_dict: dict[str, Any] = {} + + field_dict.update( + { + "code": code, + "message": message, + "detail": detail, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + from ..models.error_response_detail import ErrorResponseDetail + + d = dict(src_dict) + code = d.pop("code") + + message = d.pop("message") + + detail = ErrorResponseDetail.from_dict(d.pop("detail")) + + error_response = cls( + code=code, + message=message, + detail=detail, + ) + + return error_response diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/error_response_detail.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/error_response_detail.py new file mode 100644 index 000000000..e07cad4da --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/models/error_response_detail.py @@ -0,0 +1,44 @@ +from collections.abc import Mapping +from typing import Any, TypeVar + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +T = TypeVar("T", bound="ErrorResponseDetail") + + +@_attrs_define +class ErrorResponseDetail: + """Error detail""" + + additional_properties: dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> dict[str, Any]: + field_dict: dict[str, Any] = {} + field_dict.update(self.additional_properties) + + return field_dict + + @classmethod + def from_dict(cls: type[T], src_dict: Mapping[str, Any]) -> T: + d = dict(src_dict) + error_response_detail = cls() + + error_response_detail.additional_properties = d + return error_response_detail + + @property + def additional_keys(self) -> list[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/py.typed b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/py.typed new file mode 100644 index 000000000..1aad32711 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 \ No newline at end of file diff --git a/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/types.py b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/types.py new file mode 100644 index 000000000..1b96ca408 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/multiple_media_types_client/types.py @@ -0,0 +1,54 @@ +"""Contains some shared types for properties""" + +from collections.abc import Mapping, MutableMapping +from http import HTTPStatus +from typing import IO, BinaryIO, Generic, Literal, Optional, TypeVar, Union + +from attrs import define + + +class Unset: + def __bool__(self) -> Literal[False]: + return False + + +UNSET: Unset = Unset() + +# The types that `httpx.Client(files=)` can accept, copied from that library. +FileContent = Union[IO[bytes], bytes, str] +FileTypes = Union[ + # (filename, file (or bytes), content_type) + tuple[Optional[str], FileContent, Optional[str]], + # (filename, file (or bytes), content_type, headers) + tuple[Optional[str], FileContent, Optional[str], Mapping[str, str]], +] +RequestFiles = list[tuple[str, FileTypes]] + + +@define +class File: + """Contains information for file uploads""" + + payload: BinaryIO + file_name: Optional[str] = None + mime_type: Optional[str] = None + + def to_tuple(self) -> FileTypes: + """Return a tuple representation that httpx will accept for multipart/form-data""" + return self.file_name, self.payload, self.mime_type + + +T = TypeVar("T") + + +@define +class Response(Generic[T]): + """A response from an endpoint""" + + status_code: HTTPStatus + content: bytes + headers: MutableMapping[str, str] + parsed: Optional[T] + + +__all__ = ["UNSET", "File", "FileTypes", "RequestFiles", "Response", "Unset"] diff --git a/end_to_end_tests/multiple-media-types-golden-record/pyproject.toml b/end_to_end_tests/multiple-media-types-golden-record/pyproject.toml new file mode 100644 index 000000000..9f08decf3 --- /dev/null +++ b/end_to_end_tests/multiple-media-types-golden-record/pyproject.toml @@ -0,0 +1,26 @@ +[tool.poetry] +name = "multiple-media-types-client" +version = "0.1.0" +description = "A client library for accessing Multiple media types" +authors = [] +readme = "README.md" +packages = [ + { include = "multiple_media_types_client" }, +] +include = ["CHANGELOG.md", "multiple_media_types_client/py.typed"] + +[tool.poetry.dependencies] +python = "^3.9" +httpx = ">=0.23.0,<0.29.0" +attrs = ">=22.2.0" +python-dateutil = "^2.8.0" + +[build-system] +requires = ["poetry-core>=2.0.0,<3.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["F", "I", "UP"] diff --git a/end_to_end_tests/multiple-media-types.config.yml b/end_to_end_tests/multiple-media-types.config.yml new file mode 100644 index 000000000..2c4a7e742 --- /dev/null +++ b/end_to_end_tests/multiple-media-types.config.yml @@ -0,0 +1,5 @@ +multiple_media_types: true +field_prefix: attr_ +content_type_overrides: + openapi/python/client: application/json +generate_all_tags: true diff --git a/end_to_end_tests/multiple-media-types.yml b/end_to_end_tests/multiple-media-types.yml new file mode 100644 index 000000000..13b9bc9c3 --- /dev/null +++ b/end_to_end_tests/multiple-media-types.yml @@ -0,0 +1,51 @@ +openapi: "3.1.0" +info: + title: "Multiple media types" + description: "Test multiple response media types" + version: "0.1.0" +paths: + "/": + post: + tags: [ "multiple_media_types" ] + responses: + "200": + description: "Successful Response" + content: + "application/json": + schema: + const: "Why have a fixed response? I dunno" + "application/octet-stream": {} + "404": + description: "Not Found" + content: + "text/plain": {} + "503": + description: "Server Not Available" + content: + "application/json": + schema: + $ref: "#/components/schemas/ErrorResponse" + "text/plain": {} +components: + schemas: + ErrorResponse: + properties: + code: + type: string + title: Code + description: Error category code + message: + type: string + title: Message + description: Human-readable error message + detail: + type: object + title: Detail + description: Error detail + additionalProperties: false + type: object + required: + - code + - message + - detail + title: ErrorResponse diff --git a/end_to_end_tests/regen_golden_record.py b/end_to_end_tests/regen_golden_record.py index ba608dfa3..d14360291 100644 --- a/end_to_end_tests/regen_golden_record.py +++ b/end_to_end_tests/regen_golden_record.py @@ -69,6 +69,15 @@ def regen_literal_enums_golden_record(): ) +def regen_multiple_media_types_golden_record(): + _regenerate( + spec_file_name="multiple-media-types.yml", + output_dir="multiple-media-types-client", + golden_record_dir="multiple-media-types-golden-record", + config_file_name="multiple-media-types.config.yml", + ) + + def regen_metadata_snapshots(): output_path = Path.cwd() / "test-3-1-features-client" snapshots_dir = Path(__file__).parent / "metadata_snapshots" @@ -147,3 +156,4 @@ def regen_custom_template_golden_record(): regen_docstrings_on_attributes_golden_record() regen_custom_template_golden_record() regen_literal_enums_golden_record() + regen_multiple_media_types_golden_record() diff --git a/end_to_end_tests/test_end_to_end.py b/end_to_end_tests/test_end_to_end.py index 347f72f7e..52538dc4f 100644 --- a/end_to_end_tests/test_end_to_end.py +++ b/end_to_end_tests/test_end_to_end.py @@ -137,6 +137,17 @@ def test_literal_enums_end_to_end(): ) +def test_multiple_media_types(): + config_path = Path(__file__).parent / "multiple-media-types.config.yml" + run_e2e_test( + "multiple-media-types.yml", + [f"--config={config_path}"], + {}, + "multiple-media-types-golden-record", + "multiple-media-types-client" + ) + + @pytest.mark.parametrize( "meta,generated_file,expected_file", ( diff --git a/openapi_python_client/config.py b/openapi_python_client/config.py index 21cb4d182..de961abb9 100644 --- a/openapi_python_client/config.py +++ b/openapi_python_client/config.py @@ -47,6 +47,7 @@ class ConfigFile(BaseModel): generate_all_tags: bool = False http_timeout: int = 5 literal_enums: bool = False + multiple_media_types: bool = False @staticmethod def load_from_path(path: Path) -> "ConfigFile": @@ -77,6 +78,7 @@ class Config: generate_all_tags: bool http_timeout: int literal_enums: bool + multiple_media_types: bool document_source: Union[Path, str] file_encoding: str content_type_overrides: dict[str, str] @@ -119,6 +121,7 @@ def from_sources( generate_all_tags=config_file.generate_all_tags, http_timeout=config_file.http_timeout, literal_enums=config_file.literal_enums, + multiple_media_types=config_file.multiple_media_types, document_source=document_source, file_encoding=file_encoding, overwrite=overwrite, diff --git a/openapi_python_client/parser/openapi.py b/openapi_python_client/parser/openapi.py index 0aab5a717..f1319d390 100644 --- a/openapi_python_client/parser/openapi.py +++ b/openapi_python_client/parser/openapi.py @@ -199,8 +199,11 @@ def _add_responses( continue # No reasons to use lazy imports in endpoints, so add lazy imports to relative here. - endpoint.relative_imports |= response.prop.get_lazy_imports(prefix=models_relative_prefix) - endpoint.relative_imports |= response.prop.get_imports(prefix=models_relative_prefix) + for media_type in response.content: + if not media_type.prop: # pragma: no cover + continue + endpoint.relative_imports |= media_type.prop.get_lazy_imports(prefix=models_relative_prefix) + endpoint.relative_imports |= media_type.prop.get_imports(prefix=models_relative_prefix) endpoint.responses.append(response) return endpoint, schemas @@ -476,11 +479,18 @@ def from_data( def response_type(self) -> str: """Get the Python type of any response from this endpoint""" - types = sorted({response.prop.get_type_string(quoted=False) for response in self.responses}) + types = sorted( + { + media_type.prop.get_type_string(quoted=False) + for response in self.responses + for media_type in response.content + if media_type.prop + } + ) if len(types) == 0: return "Any" if len(types) == 1: - return self.responses[0].prop.get_type_string(quoted=False) + return self.responses[0].content[0].prop.get_type_string(quoted=False) return f"Union[{', '.join(types)}]" def iter_all_parameters(self) -> Iterator[tuple[oai.ParameterLocation, Property]]: diff --git a/openapi_python_client/parser/properties/protocol.py b/openapi_python_client/parser/properties/protocol.py index 327ba0a5e..e40ef109e 100644 --- a/openapi_python_client/parser/properties/protocol.py +++ b/openapi_python_client/parser/properties/protocol.py @@ -62,6 +62,10 @@ class PropertyProtocol(Protocol): template: ClassVar[str] = "any_property.py.jinja" json_is_dict: ClassVar[bool] = False + @property + def type_string(self) -> str: + return self.get_type_string() + @abstractmethod def convert_value(self, value: Any) -> Value | None | PropertyError: """Convert a string value to a Value object""" diff --git a/openapi_python_client/parser/responses.py b/openapi_python_client/parser/responses.py index ec0f6136b..782fbb5dd 100644 --- a/openapi_python_client/parser/responses.py +++ b/openapi_python_client/parser/responses.py @@ -28,13 +28,24 @@ class _ResponseSource(TypedDict): NONE_SOURCE = _ResponseSource(attribute="None", return_type="None") +@define +class MediaType: + """Describes the response for a given content type""" + + content_type: Optional[str] + source: _ResponseSource + prop: Property + data: Union[ + oai.MediaType, oai.Reference, None + ] # Original data which created this response, useful for custom templates + + @define class Response: """Describes a single response for an endpoint""" status_code: HTTPStatus - prop: Property - source: _ResponseSource + content: list[MediaType] data: Union[oai.Response, oai.Reference] # Original data which created this response, useful for custom templates @@ -68,15 +79,21 @@ def empty_response( return Response( data=data, status_code=status_code, - prop=AnyProperty( - name=response_name, - default=None, - required=True, - python_name=PythonIdentifier(value=response_name, prefix=config.field_prefix), - description=data.description if isinstance(data, oai.Response) else None, - example=None, - ), - source=NONE_SOURCE, + content=[ + MediaType( + content_type=None, + prop=AnyProperty( + name=response_name, + default=None, + required=True, + python_name=PythonIdentifier(value=response_name, prefix=config.field_prefix), + description=data.description if isinstance(data, oai.Response) else None, + example=None, + ), + source=NONE_SOURCE, + data=None, + ) + ], ) @@ -116,39 +133,28 @@ def response_from_data( # noqa: PLR0911 ), schemas, ) - + content_types: list[MediaType] = [] for content_type, media_type in content.items(): source = _source_by_content_type(content_type, config) if source is not None: - schema_data = media_type.media_type_schema - break - else: + prop, schemas = property_from_data( + name=response_name, + required=True, + data=media_type.media_type_schema or oai.Schema(), + schemas=schemas, + parent_name=parent_name, + config=config, + ) + if isinstance(prop, PropertyError): + return prop, schemas + if prop.description is None and isinstance(data, oai.Response): + prop.description = data.description + + content_types.append(MediaType(content_type=content_type, source=source, prop=prop, data=media_type)) + if not content_types: return ( ParseError(data=data, detail=f"Unsupported content_type {content}"), schemas, ) - if schema_data is None: - return ( - empty_response( - status_code=status_code, - response_name=response_name, - config=config, - data=data, - ), - schemas, - ) - - prop, schemas = property_from_data( - name=response_name, - required=True, - data=schema_data, - schemas=schemas, - parent_name=parent_name, - config=config, - ) - - if isinstance(prop, PropertyError): - return prop, schemas - - return Response(status_code=status_code, prop=prop, source=source, data=data), schemas + return Response(status_code=status_code, content=content_types, data=data), schemas diff --git a/openapi_python_client/templates/endpoint_module.py.jinja b/openapi_python_client/templates/endpoint_module.py.jinja index 802fcc2ea..e1315e265 100644 --- a/openapi_python_client/templates/endpoint_module.py.jinja +++ b/openapi_python_client/templates/endpoint_module.py.jinja @@ -66,21 +66,48 @@ def _get_kwargs( def _parse_response(*, client: Union[AuthenticatedClient, Client], response: httpx.Response) -> Optional[{{ return_string }}]: - {% for response in endpoint.responses %} +{% for response in endpoint.responses %} if response.status_code == {{ response.status_code.value }}: - {% if parsed_responses %}{% import "property_templates/" + response.prop.template as prop_template %} + {% if not config.multiple_media_types %} + {% if parsed_responses %}{% import "property_templates/" + response.content[0].prop.template as prop_template %} {% if prop_template.construct %} - {{ prop_template.construct(response.prop, response.source.attribute) | indent(8) }} - {% elif response.source.return_type == response.prop.get_type_string() %} - {{ response.prop.python_name }} = {{ response.source.attribute }} + {{ prop_template.construct(response.content[0].prop, response.content[0].source.attribute) | indent(8) }} + {% elif response.content[0].source.return_type == response.content[0].prop.get_type_string() %} + {{ response.content[0].prop.python_name }} = {{ response.content[0].source.attribute }} {% else %} - {{ response.prop.python_name }} = cast({{ response.prop.get_type_string() }}, {{ response.source.attribute }}) + {{ response.content[0].prop.python_name }} = cast({{ response.content[0].prop.get_type_string() }}, {{ response.content[0].source.attribute }}) {% endif %} - return {{ response.prop.python_name }} + return {{ response.content[0].prop.python_name }} {% else %} return None {% endif %} - {% endfor %} + {% else %} + {% if response.content[1:] %} + {{response.content[0].prop.python_name}}: Union[{{ response.content | map(attribute='prop') | map(attribute='type_string') | join(', ') }}] + {% else %} + {{response.content[0].prop.python_name}}: {{response.content[0].prop.type_string}} + {% endif %} + {% for media_type in response.content %} + {% if media_type.content_type %} + if response.headers.get("content-type") == "{{ media_type.content_type }}": + {% else %} + if True: # Any Content-Type header value + {% endif %} + {% if parsed_responses %}{% import "property_templates/" + media_type.prop.template as prop_template %} + {% if prop_template.construct %} + {{ prop_template.construct(media_type.prop, media_type.source.attribute) | indent(12) }} + {% elif media_type.source.return_type == media_type.prop.get_type_string() %} + {{ media_type.prop.python_name }} = {{ media_type.source.attribute }} + {% else %} + {{ media_type.prop.python_name }} = cast({{ media_type.prop.get_type_string() }}, {{ media_type.source.attribute }}) + {% endif %} + return {{ media_type.prop.python_name }} + {% else %} + return None + {% endif %} + {% endfor %} + {%- endif %} +{% endfor %} if client.raise_on_unexpected_status: raise errors.UnexpectedStatus(response.status_code, response.content) else: diff --git a/tests/test_parser/test_openapi.py b/tests/test_parser/test_openapi.py index 3d1391ae2..afc936897 100644 --- a/tests/test_parser/test_openapi.py +++ b/tests/test_parser/test_openapi.py @@ -684,7 +684,8 @@ def test_response_type(self, response_types, expected): endpoint = self.make_endpoint() for response_type in response_types: mock_response = MagicMock() - mock_response.prop.get_type_string.return_value = response_type + mock_response.content = [MagicMock()] + mock_response.content[0].prop.get_type_string.return_value = response_type endpoint.responses.append(mock_response) assert endpoint.response_type() == expected diff --git a/tests/test_parser/test_responses.py b/tests/test_parser/test_responses.py index 8fb04d720..57effe9bd 100644 --- a/tests/test_parser/test_responses.py +++ b/tests/test_parser/test_responses.py @@ -6,7 +6,14 @@ from openapi_python_client.parser import responses from openapi_python_client.parser.errors import ParseError, PropertyError from openapi_python_client.parser.properties import Schemas -from openapi_python_client.parser.responses import JSON_SOURCE, NONE_SOURCE, Response, response_from_data +from openapi_python_client.parser.responses import ( + BYTES_SOURCE, + JSON_SOURCE, + NONE_SOURCE, + MediaType, + Response, + response_from_data, +) MODULE_NAME = "openapi_python_client.parser.responses" @@ -25,13 +32,19 @@ def test_response_from_data_no_content(any_property_factory): assert response == Response( status_code=200, - prop=any_property_factory( - name="response_200", - default=None, - required=True, - description="", - ), - source=NONE_SOURCE, + content=[ + MediaType( + content_type=None, + prop=any_property_factory( + name="response_200", + default=None, + required=True, + description="", + ), + source=NONE_SOURCE, + data=None, + ) + ], data=data, ) @@ -54,7 +67,7 @@ def test_response_from_data_unsupported_content_type(): def test_response_from_data_no_content_schema(any_property_factory): data = oai.Response.model_construct( - description="", + description="Description", content={"application/vnd.api+json; version=2.2": oai.MediaType.model_construct()}, ) config = MagicMock() @@ -70,13 +83,19 @@ def test_response_from_data_no_content_schema(any_property_factory): assert response == Response( status_code=200, - prop=any_property_factory( - name="response_200", - default=None, - required=True, - description=data.description, - ), - source=NONE_SOURCE, + content=[ + MediaType( + content_type="application/vnd.api+json; version=2.2", + prop=any_property_factory( + name="response_200", + default=None, + required=True, + description=data.description, + ), + source=JSON_SOURCE, + data=data.content["application/vnd.api+json; version=2.2"], + ) + ], data=data, ) @@ -131,8 +150,14 @@ def test_response_from_data_property(mocker, any_property_factory): assert response == responses.Response( status_code=400, - prop=prop, - source=JSON_SOURCE, + content=[ + MediaType( + content_type="application/json", + prop=prop, + source=JSON_SOURCE, + data=data.content["application/json"], + ) + ], data=data, ) property_from_data.assert_called_once_with( @@ -166,8 +191,14 @@ def test_response_from_data_reference(mocker, any_property_factory): assert response == responses.Response( status_code=400, - prop=prop, - source=JSON_SOURCE, + content=[ + MediaType( + content_type="application/json", + prop=prop, + source=JSON_SOURCE, + data=predefined_response_data.content["application/json"], + ) + ], data=predefined_response_data, ) @@ -250,12 +281,18 @@ def test_response_from_data_content_type_overrides(any_property_factory): assert response == Response( status_code=200, - prop=any_property_factory( - name="response_200", - default=None, - required=True, - description=data.description, - ), - source=NONE_SOURCE, + content=[ + MediaType( + content_type="application/zip", + prop=any_property_factory( + name="response_200", + default=None, + required=True, + description=data.description, + ), + source=BYTES_SOURCE, + data=data.content["application/zip"], + ) + ], data=data, )