diff --git a/.github/workflows/test_and_publish.yml b/.github/workflows/test_and_publish.yml new file mode 100644 index 0000000..06a7853 --- /dev/null +++ b/.github/workflows/test_and_publish.yml @@ -0,0 +1,176 @@ +name: Tests + +on: + push: + branches: + - main + - '*.*.*' + + pull_request: + branches: + - main + - '*.*.*' + + release: + types: [ published ] + +jobs: + + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: '3.13' + + - name: Install packages + run: pip install -r requirements_dev.txt + + - name: Black + run: | + black --check -l 120 simple_oauth2/ tests/ + + isort: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: '3.13' + + - name: Install packages + run: pip install -r requirements_dev.txt + + - name: Isort + run: | + isort --check simple_oauth2/ tests/ + + pycodestyle: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: '3.13' + + - name: Install packages + run: pip install -r requirements_dev.txt + + - name: Pycodestyle + run: | + pycodestyle simple_oauth2/ tests/ + + pydocstyle: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: '3.13' + + - name: Install packages + run: pip install -r requirements_dev.txt + + - name: Pydocstyle + run: | + pydocstyle --count simple_oauth2/ + + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: '3.13' + + - name: Install packages + run: pip install -r requirements_dev.txt + + - name: Mypy + run: | + mypy simple_oauth2 --disallow-untyped-def + + bandit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: '3.13' + + - name: Install packages + run: pip install -r requirements_dev.txt + + - name: Bandit + run: | + bandit --ini=setup.cfg -ll 2> /dev/null + + + test: + needs: [black, isort, pycodestyle, pydocstyle, bandit] + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ '3.10', '3.11', '3.12', '3.13'] + django-version: [ 42, 51, 52 ] + exclude: + - python-version: 3.13 + django-version: 42 + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@master + with: + python-version: ${{ matrix.python-version }} + - name: Install Tox and any other packages + run: | + pip install tox + - name: Python ${{ matrix.python-version }}, Django ${{ matrix.django-version }}, + run: tox -e py-django${{ matrix.django-version }} + + - name: Upload coverage to Codecov + if: matrix.python-version == 3.13 && matrix.django-version == 52 + uses: codecov/codecov-action@v5 + with: + file: ./coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + + publish: + needs: test + if: github.event_name == 'release' && github.event.action == 'published' + runs-on: ubuntu-latest + continue-on-error: true + + steps: + - uses: actions/checkout@master + + - name: Set up Python 3.13 + uses: actions/setup-python@v4 + with: + python-version: '3.13' + + - name: Creating Built Distributions + run: | + pip install setuptools + python setup.py sdist + + - name: Publish distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.pypi_password }} + skip_existing: true diff --git a/.gitignore b/.gitignore index b7faf40..cc1fe83 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,10 @@ +# OS generated files +.DS_Store + +# IDEs and editors +.vscode/ +.idea/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[codz] @@ -50,6 +57,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +junit.xml # Translations *.mo @@ -182,9 +190,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000..541d95a --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,85 @@ +image: docker:latest + +services: + - docker:dind + +stages: + - checks + - test + - build + - deploy + +variables: + BUILD_TOOLS_PATH: /home/gitlab-runner/codoc-tools/docker_build + DOCKERFILE_DIR_PATH: $GITLAB_PROJECT_DIR + +################################################################################ +################################# TEMPLATES #################################### +################################################################################ + +.checks: &checks + stage: checks + only: + - merge_requests + before_script: + - python3.11 -m venv venv + - source venv/bin/activate + - python3.11 -m pip install -U pip setuptools wheel + +.test: &test + stage: test + only: + - merge_requests + before_script: + - python3.11 -m venv venv + - source venv/bin/activate + - python3.11 -m pip install -U pip setuptools wheel + - python3.11 -m pip install -r requirements_dev.txt + +################################################################################ +################################### CHECKS ##################################### +################################################################################ + +job:flake8: + <<: *checks + script: + - python3.11 -m pip install flake8 flake8-junit-report + - flake8 . --output-file=flake8.txt + after_script: + - flake8_junit flake8.txt flake8_junit.xml + artifacts: + reports: + junit: flake8_junit.xml + +job:isort: + <<: *checks + script: + - python3.11 -m pip install "isort==5.10.1" + - isort --check django_test/ {{ app_name }}/ + +job:black: + <<: *checks + script: + - python3.11 -m pip install "black==22.3.0" + - black --check --exclude="^.*\b((migrations))\b.*$" -l 120 test_project/ {{ app_name }}/ + +job:pydocstyle: + <<: *checks + script: + - python3.11 -m pip install pydocstyle + - ./bin/pydocstyle.sh + + +################################################################################ +################################### TESTS ###################################### +################################################################################ + +job:unittest: + <<: *test + script: + - python3.11 -m pytest --create-db -vvv -s --color=yes --durations=0 --durations-min=1.0 --cov=. --cov-report term + artifacts: + reports: + junit: junit.xml + coverage: '/TOTAL.*\s+(\d+%)$/' + resource_group: {{ app_name }}_unittest diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..8c900a6 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,4 @@ +Changelog +========= + +## 0.1.0 {% now "SHORT_DATE_FORMAT" %} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..9919af5 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,108 @@ +# Contributing + +Contributions are welcome, and they are greatly appreciated! Every little bit helps, and credit will always be given. + +You can contribute in many ways: + +## Types of Contributions + +### Report Bugs + +Report bugs at https://github.com/Codoc-os/drf-simple-oauth2. + +If you are reporting a bug, please include: + +* Your operating system name and version. +* Any details about your local setup that might be helpful in troubleshooting. +* Detailed steps to reproduce the bug. + +### Fix Bugs + +Look through the GitHub issues for bugs. Anything tagged with "bug" +is open to whoever wants to implement it. + +### Implement Features + +Look through the GitHub issues for features. Anything tagged with "feature" +is open to whoever wants to implement it. + +### Write Documentation + +Django Opensearch DSL could always use more documentation, whether as part of the official Django Opensearch DSL docs, +in docstrings, or even on the web in blog posts, articles, and such. + +### Submit Feedback + +The best way to send feedback is to file an issue at https://github.com/Codoc-os/drf-simple-oauth2/issues. + +If you are proposing a feature: + +* Explain in detail how it would work. +* Keep the scope as narrow as possible, to make it easier to implement. +* Remember that this is a volunteer-driven project, and that contributions are welcome :) + +--- + +## Setting up local environment + +Ready to contribute? Here's how to set up `drf-simple-oauth2` for local development. + +1. Fork the `drf-simple-oauth2` repo on GitHub. + +2. Clone your fork locally: + + * `git clone git@github.com:/drf-simple-oauth2.git` + +3. Install your local copy into a virtualenv. + + ```bash + python3 -m venv venv + source venv/bin/activate + pip3 install -r requirements.txt + pip3 install -r requirements_dev.txt + ``` + +4. Create a branch for local development: + + * `git checkout -b name-of-your-bugfix-or-feature` + + Now you can make your changes locally. + +## Testing your changes + +Tests must be written inside the `tests/` Django's project. This project contains three directory. + +You can interact with this project using the root-level `manage.py`. + +If you need to manually tests some of your feature, you can create a `sqlite3` +database with `python3 manage.py migrate`. + +## Submitting your changes + +1. Ensure your code is correctly formatted and documented: + +```sh +./bin/pre_commit.sh +``` + +2. Commit your changes and push your branch to GitHub: + +```sh +git add . +git commit -m "Your detailed description of your changes." +git push origin name-of-your-bugfix-or-feature +``` + +3. Submit a pull request through the GitHub website. + +## Pull Request Guidelines + +Before you submit a pull request, check that it meets these guidelines: + +1. The pull request should include tests + +2. If the pull request adds functionality, the documentation should be updated. + +3. The pull request should pass all checks and tests. Check + https://github.com/qcoumes/drf-simple-oauth2/actions + and make sure that the tests pass for all supported Python versions. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..5c3c7da --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +include CONTRIBUTING.md +include LICENSE +include README.md +include CHANGELOG.md +recursive-include simple_oauth2 *py diff --git a/README.md b/README.md index fe0fdcf..f66f244 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,150 @@ -# drf-simple-oauth2 -Simple OAuth2 client package allowing to define OAuth2 / OpenID providers through settings. +# DRF Simple OAuth2 + +[![PyPI Version](https://badge.fury.io/py/drf-simple-oauth2.svg)](https://badge.fury.io/py/drf-simple-oauth2) +[![Documentation Status](https://readthedocs.org/projects/drf-simple-oauth2/badge/?version=latest)](https://drf-simple-oauth2.readthedocs.io/en/latest/?badge=latest) +![Tests](https://github.com/Codoc-os/drf-simple-oauth2/workflows/Tests/badge.svg) +[![Python 3.10+](https://img.shields.io/badge/Python-3.10+-brightgreen.svg)](#) +[![Django 4.2+](https://img.shields.io/badge/Django-4.2+-brightgreen.svg)](#) +[![License MIT](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://github.com/Codoc-os/drf-simple-oauth2/blob/master/LICENSE) +[![codecov](https://codecov.io/gh/Codoc-os/drf-simple-oauth2/branch/master/graph/badge.svg)](https://codecov.io/gh/Codoc-os/drf-simple-oauth2) +[![CodeFactor](https://www.codefactor.io/repository/github/Codoc-os/drf-simple-oauth2/badge)](https://www.codefactor.io/repository/github/Codoc-os/drf-simple-oauth2) + +**DRF Simple OAuth2** is an OAuth2/OpenID Connect client for Django REST Framework. It lets you define one or many providers entirely via settings. + +You can view the full documentation at . + +## Features + +- Provides endpoints for the OAuth2 Authorization Code flow. +- Supports multiple OAuth2/OpenID providers at once. +- Supports PKCE (Proof Key for Code Exchange). +- Customize the user creation/update logic using the information retrieved from the provider. + +## Requirements + +`drf-simple-oauth2` supports the officially supported versions of its dependencies (mainstream & LTS): + +- **Python** — see the [supported versions](https://devguide.python.org/versions/). +- **Django** — see the [supported versions](https://www.djangoproject.com/download/#supported-versions). + +## Installation + +Install via `pip`: + +```bash +pip install drf-simple-oauth2 +``` + +Add the app: + +```python +INSTALLED_APPS = [ + ... + "simple_oauth2", + ... +] +``` + +Include URLs in your project’s `urls.py`: + +```python +from django.urls import include, path + +urlpatterns = [ + ... + path("", include("simple_oauth2.urls", namespace="simple_oauth2")), + ... +] +``` + +## Configuration + +Define `SIMPLE_OAUTH2` in your Django settings: + +```python +SIMPLE_OAUTH2 = { + "auth0": { + "CLIENT_ID": "", + "CLIENT_SECRET": "", + "BASE_URL": "..auth0.com", + "REDIRECT_URI": "http://localhost:8080/app/auth0/callback", + "POST_LOGOUT_REDIRECT_URI": "http://localhost:8080/app", + }, + "google": { + "CLIENT_ID": "", + "CLIENT_SECRET": "", + "BASE_URL": "accounts.google.com", + "REDIRECT_URI": "http://localhost:8080/app/google/callback", + "POST_LOGOUT_REDIRECT_URI": "http://localhost:8080/app", + }, +} +``` + +See the [documentation](settings.md#available-settings) for all available settings. + +## Usage + +*The following assumes you mounted `simple_oauth2` URLs at the empty path (`""`) as shown above.* + +The flow below describes interactions between a frontend (**App**), a backend (**API**), and an OAuth2/OpenID provider (**Provider**). +You can find a more detailed explanation in the [documentation](flow.md). + +### 1) Redirect the user to the provider + +Request the provider-specific authorization URL from your API, then redirect the browser to it: + +```http +GET http://localhost:8000/oauth2/url/?provider=auth0 +``` +```json +{ + "url": "https://example.com/authorize?response_type=code&client_id=client&scope=openid+profile+email&nonce=085c979c02ecb914a4c6210ad1902037825c18fe8d9b0a1ca0daae113b7747035170e9400c6ec5c7439e1caa3249cc20d52975b34777778c2949f63a14accfb0&state=9143617326b20fa6b3f436001096f5365e1ccb2689becc75091399fb3b3b4f834333f4dada0c44b2d167326d6ddc279698a0b05a0720c45620b8696e944101c4&redirect_uri=https%3A%2F%2Fexample.com%2Fcallback&code_challenge=vo8kwt0Nrf.jfMj8HmMGKJeGJH8SFY8bVhKidrQkg7q2IeW~nfRrdlM4QosTTgjMnMmyzVAC3i5n.lOPx0NJvgB1G7~FSaDVwhTFM-UehPrp6~~lht6jbLVs-9Tlxsld&code_challenge_method=plain" +} +``` + +After consent, the provider redirects back to your App at `REDIRECT_URI` with `code` and `state` parameters. + +### 2) Exchange the code at your API + +POST the `code` and `state` to your API: + +```http +POST http://localhost:8000/oauth2/token/ +{ + "provider": "auth0", + "code": "", + "state": "" +} +``` +```json +{ + "api": { + "access": "", + "refresh": "" + }, + "provider": { + "access_token": "", + "id_token": "", + "refresh_token": "", + "logout_url": "https://example.com/v2/logout?..." + } +} +``` + +> Note: `refresh_token` may be absent depending on the provider configuration. + +The response payload is produced by `TOKEN_PAYLOAD_HANDLER`, which issues JWTs via `djangorestframework-simplejwt` and returns the provider’s tokens. +The `api` object contains tokens for authenticating against **your API**; the `provider` object contains tokens for the **provider**. + +Calling `/oauth2/token/` will also ensure a user exists in your database. This is handled by `TOKEN_USERINFO_HANDLER` (defaults to `simple_oauth2.utils.get_user`). It first tries to match a user via the `sub` claim from the ID Token; otherwise, it uses claims/UserInfo to retrieve or create a user. + +You can customize both behaviors per provider via the [`TOKEN_PAYLOAD_HANDLER`](settings.md#token_payload_handler) and [`TOKEN_USERINFO_HANDLER`](settings.md#token_userinfo_handler)` settings. + +### 3) Logout + +To log out from both your API and the provider: + +1. Log the user out from your API. +2. Redirect the user to the `logout_url` returned in the `provider` object from `/oauth2/token/`. + +The provider will redirect back to your App using `POST_LOGOUT_REDIRECT_URI`. diff --git a/bin/colors.sh b/bin/colors.sh new file mode 100644 index 0000000..e3ee2cd --- /dev/null +++ b/bin/colors.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Reset +Color_Off=$'\e[0m' # Text Reset + +# Regular Colors +Red=$'\e[0;31m' # Red +Green=$'\e[0;32m' # Green +Yellow=$'\e[0;33m' # Yellow +Purple=$'\e[0;35m' # Purple +Cyan=$'\e[0;36m' # Cyan diff --git a/bin/pre_commit.sh b/bin/pre_commit.sh new file mode 100755 index 0000000..889084a --- /dev/null +++ b/bin/pre_commit.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +BASE_PATH="$(dirname "$0")" +source "$BASE_PATH/colors.sh" +EXIT_CODE=0 + + +################################################################################ +# ISORT # +################################################################################ +echo -n "${Cyan}Formatting import with isort... $Color_Off" +out=$(isort simple_oauth2/ tests/) +if [ ! -z "$out" ] ; then + echo "" + echo -e "$out" +fi +echo "${Green}Ok ✅ $Color_Off" +echo "" + +################################################################################ +# BLACK # +################################################################################ +echo "${Cyan}Formatting code with black...$Color_Off" +black -l 120 simple_oauth2/ tests/ +echo "" + + +################################################################################ +# PYCODESTYLE # +################################################################################ +echo -n "${Cyan}Running pycodestyle... $Color_Off" +out=$(pycodestyle simple_oauth2 tests) +if [ "$?" -ne 0 ] ; then + echo "${Red}Error !$Color_Off" + echo -e "$out" + EXIT_CODE=1 +else + echo "${Green}Ok ✅ $Color_Off" +fi +echo "" + + +################################################################################ +# PYDOCSTYLE # +################################################################################ +echo -n "${Cyan}Running pydocstyle... $Color_Off" +out=$(pydocstyle --count simple_oauth2/) +if [ "$?" -ne 0 ] ; then + echo "${Red}Error !$Color_Off" + echo -e "$out" + EXIT_CODE=1 +else + echo "${Green}Ok ✅ $Color_Off" +fi +echo "" + + +################################################################################ +# MYPY # +################################################################################ +echo -n "${Cyan}Running mypy... $Color_Off" +out=$(mypy simple_oauth2 --disallow-untyped-def) +if [ "$?" -ne 0 ] ; then + echo "${Red}Error !$Color_Off" + echo -e "$out" + EXIT_CODE=1 +else + echo "${Green}Ok ✅ $Color_Off" +fi +echo "" + + +################################################################################ +# BANDIT # +################################################################################ +echo -n "${Cyan}Running bandit... $Color_Off" +out=$(bandit --ini=setup.cfg -ll 2> /dev/null) +if [ "$?" -ne 0 ] ; then + echo "${Red}Error !$Color_Off" + echo -e "$out" + EXIT_CODE=1 +else + echo "${Green}Ok ✅ $Color_Off" +fi +echo "" + + + +################################################################################ + + +if [ $EXIT_CODE = 1 ] ; then + echo "${Red}⚠ You must fix the errors before committing ⚠$Color_Off" + exit $EXIT_CODE +fi +echo "${Purple}✨ You can commit without any worry ✨$Color_Off" diff --git a/docs/AUTHORS.md b/docs/AUTHORS.md new file mode 100644 index 0000000..a561296 --- /dev/null +++ b/docs/AUTHORS.md @@ -0,0 +1,7 @@ +# Credits + +### Contributors + +* [Quentin Coumes (Codoc)](https://github.com/qcoumes) +* [Andrea Chávez Herrejón (Codoc)](https://github.com/andreach2713) + diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md new file mode 120000 index 0000000..04c99a5 --- /dev/null +++ b/docs/CHANGELOG.md @@ -0,0 +1 @@ +../CHANGELOG.md \ No newline at end of file diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md new file mode 120000 index 0000000..44fcc63 --- /dev/null +++ b/docs/CONTRIBUTING.md @@ -0,0 +1 @@ +../CONTRIBUTING.md \ No newline at end of file diff --git a/docs/LICENSE.md b/docs/LICENSE.md new file mode 120000 index 0000000..ea5b606 --- /dev/null +++ b/docs/LICENSE.md @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/docs/css/misc.css b/docs/css/misc.css new file mode 100644 index 0000000..a2fd9a8 --- /dev/null +++ b/docs/css/misc.css @@ -0,0 +1,8 @@ +td, th { + vertical-align: middle; + padding: 5px; +} + +.wy-nav-content { + max-width: 1000px; +} diff --git a/docs/flow.md b/docs/flow.md new file mode 100644 index 0000000..b615cdd --- /dev/null +++ b/docs/flow.md @@ -0,0 +1,69 @@ +# OAuth 2.0 Authorization Code Flow + +![OAuth 2.0 Authorization Code Flow](img/flow.png) + +The Authorization Code Flow (defined in [OAuth 2.0 RFC 6749, §4.1](https://datatracker.ietf.org/doc/html/rfc6749#section-4.1)) +exchanges a short-lived authorization code for tokens. + +## 1) Retrieve the authorization URL from the API + +The client application (**App**) requests the authorization URL from the backend (**API**) +using `GET /oauth2/url/?provider=`. + +## 2) The API generates the authorization URL + +The API generates a unique `state` and builds the authorization URL using the +provider’s `AUTHORIZATION_PATH` and configured settings (`CLIENT_ID`, `REDIRECT_URI`, etc.), +then returns it to the App. + +## 3) Redirect the user to the provider + +The App redirects the browser to the provider’s authorization endpoint using that URL. + +## 4) The user is redirected back to the App + +After the user authenticates and authorizes the application, the provider redirects +back to the App’s `REDIRECT_URI` with an authorization `code` and the original `state`. + +## 5) The App sends the code to the API + +The App sends `code` and `state` to the API via `POST /oauth2/token/`. + +## 6) The API exchanges the code for tokens + +The API verifies the `state` (and `code_verifier` if PKCE is enabled) and exchanges the +`code` for tokens using the provider’s `TOKEN_PATH` and configured settings. + +## 7) The provider returns tokens to the API + +The provider validates the code and returns an ID Token, Access Token, and optionally a +Refresh Token to the API. + +## 8) The API retrieves user info + +The API uses the Access Token to fetch user information from the provider’s `USERINFO_PATH`. + +## 9) The provider returns user info to the API + +The provider returns user information (e.g., `sub`, `email`, `preferred_username`) +according to the scopes requested in the authorization URL. + +## 10) The API creates or updates the user + +The API creates or updates a user in its database using the ID Token and UserInfo via +the `TOKEN_USERINFO_HANDLER()` callable. + +## 11) The API responds to the App + +The API returns the payload produced by `TOKEN_PAYLOAD_HANDLER()`. This typically +includes tokens for authenticating against the API (e.g., JWTs), the provider’s tokens, +and a `logout_url`. + +## 12) The App redirects the user to the provider’s logout URL + +When the user wants to log out, the App logs the user out of the API and then redirects +to the `logout_url` returned in the previous step. + +## 13) The provider redirects the user back to the App + +The provider logs the user out and redirects back to the App’s `POST_LOGOUT_REDIRECT_URI`. diff --git a/docs/img/flow.png b/docs/img/flow.png new file mode 100644 index 0000000..bdd8ce6 Binary files /dev/null and b/docs/img/flow.png differ diff --git a/docs/index.md b/docs/index.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/docs/index.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000..b7b3591 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,50 @@ +# Models + +## Overview + +Internally, `drf-simple-oauth2` uses two models to manage OAuth2 authentication: +`simple_oauth2.models.Session` and `simple_oauth2.models.Sub`. + +## Session + +The `Session` model represents an OAuth2 authorization session. It lets +`drf-simple-oauth2` track the state of an ongoing OAuth2 flow, from the moment +the authorization URL is generated until the user is authenticated and tokens +are issued. + +When a session is started (via `Session.start()`), a new random `state` is +generated, which is expected to be unique for a given provider. The `state` +prevents CSRF during the OAuth2 authorization flow and helps correlate the +callback with the original request. + +Sessions have a limited lifetime (configurable via +[`AUTHORIZATION_SESSION_LIFETIME`](settings.md#authorization_session_lifetime)). + +You may want to set up a periodic cleanup task to delete expired or otherwise +old sessions. You can filter on the `status` field to remove only sessions that +are failed, expired, or completed. + +`status` can have the following values (see `simple_oauth2.enums.Status`): + +- `pending`: The URL has been generated, but the call to `/oauth2/token/` has not yet + been made. +- `completed`: The call to `/oauth2/token/` was successful, and tokens have been issued. +- `token_failed`: The provider denied the authorization request, or an error occurred + while exchanging the authorization code for tokens. +- `userinfo_failed`: An error occurred while fetching the UserInfo from the provider. +- `expired`: The session has expired. + +> **Note**: The status of a session is not automatically set to `expired` at the +> moment it expires. It is marked as such only when `has_expired()` is called. In a +> cleanup job, you can proactively check for expiration using `created_at` or by +> invoking `has_expired()` before deciding to delete a session. + +## Sub + +The `Sub` model represents a unique user identifier (`sub`) from a given provider. +It is used to quickly find a user who has previously authenticated with a specific +provider, even if fields commonly used for identification (such as `username` or +`email`) have changed at the provider. + +A given `sub` is unique **per provider** (i.e., the `(provider, sub)` pair uniquely +identifies an external account). diff --git a/docs/settings.md b/docs/settings.md new file mode 100644 index 0000000..3d4a8f7 --- /dev/null +++ b/docs/settings.md @@ -0,0 +1,247 @@ +# Settings + +## Defining settings + +All settings live in a single `SIMPLE_OAUTH2` dictionary. +Each top-level key is the provider name; each value is a dict of that provider’s settings. + +The minimal required settings per provider are: `CLIENT_ID`, `CLIENT_SECRET`, `REDIRECT_URI`, +`POST_LOGOUT_REDIRECT_URI`, and `BASE_URL`. + +See below for an example configuration: + +```python +SIMPLE_OAUTH2 = { + "auth0": { + "CLIENT_ID": "", + "CLIENT_SECRET": "", + "BASE_URL": "..auth0.com", + "REDIRECT_URI": "http://localhost:8080/app/auth0/callback", + "POST_LOGOUT_REDIRECT_URI": "http://localhost:8080/app", + }, + "google": { + "CLIENT_ID": "", + "CLIENT_SECRET": "", + "BASE_URL": "accounts.google.com", + "REDIRECT_URI": "http://localhost:8080/app/google/callback", + "POST_LOGOUT_REDIRECT_URI": "http://localhost:8080/app", + }, +} +``` + +This assumes the provider exposes the standard OpenID Connect discovery document at +`/.well-known/openid-configuration`. When present, it is used to automatically +populate the remaining required settings: `AUTHORIZATION_PATH`, `TOKEN_PATH`, +`USERINFO_PATH`, `JWKS_PATH`, `LOGOUT_PATH`, and `SIGNING_ALGORITHMS`. + +If the provider serves its discovery document at a non-standard path, set +`OPENID_CONFIGURATION_PATH`. If no discovery document is available, you **must** +manually specify all required settings. + +Optional settings are listed below. + +## Accessing settings + +Use `simple_oauth2.settings.oauth2_settings`, which maps provider names to their +resolved settings: + +```python +from simple_oauth2.settings import oauth2_settings + +auth0_client_id = oauth2_settings["auth0"].CLIENT_ID +``` + +## Available settings + +The following settings are available for each provider. + +### `CLIENT_ID` + +*Required* + +Your OAuth2 client ID. + +### `CLIENT_SECRET` + +*Required* + +Your OAuth2 client secret. + +### `BASE_URL` + +*Required* + +The provider’s base domain. +Examples: Auth0 — `..auth0.com`; Google — `accounts.google.com`. + +### `REDIRECT_URI` + +*Required* + +The URI the provider redirects to after authorization. +Must match the value registered with the provider. + +### `POST_LOGOUT_REDIRECT_URI` + +*Required* + +The URI the provider redirects to after logout. +Must match the value registered with the provider. + +### `AUTHORIZATION_PATH` + +*Required; usually discovered via the provider’s OpenID configuration* + +Path to the authorization endpoint, typically `/authorize`. + +### `TOKEN_PATH` + +*Required; usually discovered via the provider’s OpenID configuration* + +Path to the token endpoint, typically `/oauth/token` or `/token`. + +### `USERINFO_PATH` + +*Required; usually discovered via the provider’s OpenID configuration* + +Path to the UserInfo endpoint, typically `/userinfo`. + +### `JWKS_PATH` + +*Required; usually discovered via the provider’s OpenID configuration* + +Path to the JWKS document, typically `/.well-known/jwks.json`. + +### `LOGOUT_PATH` + +*Required; usually discovered via the provider’s OpenID configuration* + +Path to the logout endpoint, typically `/logout`. + +### `SIGNING_ALGORITHMS` + +*Required; usually discovered via the provider’s OpenID configuration* + +List of algorithms used to sign ID tokens, e.g. `["RS256"]`. + +### `OPENID_CONFIGURATION_PATH` + +*Optional* + +Path to the provider’s OpenID configuration document. +Defaults to `/.well-known/openid-configuration`. Set this if the provider uses a non-standard path. + +### `SCOPES` + +*Optional* + +Scopes requested during authorization. +Defaults to `["openid", "profile", "email"]`. + +### `USE_PKCE` + +*Optional* + +Whether to use PKCE (Proof Key for Code Exchange). +Defaults to `True`. + +### `CODE_CHALLENGE_METHOD` + +*Optional* + +PKCE code challenge method. +Defaults to `S256`. Some providers only support `plain`. + +### `AUTHORIZATION_SESSION_LIFETIME` + +*Optional* + +Lifetime (in seconds) of an authorization session (from URL generation to the call to `/oauth2/token/`). +Defaults to `300` (5 minutes). + +### `AUTHORIZATION_EXTRA_PARAMETERS` + +*Optional* + +Extra query parameters to include in the authorization URL. +Some providers require additional parameters. +Defaults to `{}`. + +### `TOKEN_USERINFO_HANDLER` + +*Optional* + +Callable that creates/updates and returns the authenticated user using the ID token and UserInfo response. +Defaults to `simple_oauth2.utils.get_user`. + +Signature: + +```python +def custom_get_user( + provider, + userinfo: dict, + **kwargs: Any, +) -> models.Model: + ... +``` + +- `provider`: the provider settings used for authentication. +- `userinfo`: the UserInfo response (dict). +- `kwargs`: may contain additional items such as encoded `id_token` and `access_token`. Do **not** assume these are + always present. Tokens may include extra claims and can be decoded via `simple_oauth2.utils.decode_token`. + +Return a model instance representing the authenticated user. + +The contents of `userinfo` and the decoded ID token depend on the provider and requested scopes. At minimum, one of +them should include a `sub` claim that uniquely identifies the user at the provider. + +### `TOKEN_PAYLOAD_HANDLER` + +*Optional* + +Callable that builds the JSON payload returned by the `/oauth2/token/` endpoint. +Defaults to `simple_oauth2.utils.simple_jwt_authenticate`, which issues JWTs via `djangorestframework-simplejwt` and +returns the provider’s tokens. + +Signature: + +```python +def custom_token_payload( + provider, + oauth2_tokens: dict[str, str], + user: models.Model, +) -> dict: + ... +``` + +- `provider`: the provider settings used for authentication. +- `oauth2_tokens`: tokens returned by the provider (usually `access_token`, `id_token`, and optionally `refresh_token`). +- `user`: the user returned by `TOKEN_USERINFO_HANDLER`. + +Return a dict to be serialized as the `/oauth2/token/` response. + +It should at least include a valid logout URI (see `simple_oauth2.utils.simple_jwt_authenticate` for an example of +logout URI generation). + +### `VERIFY_SSL` + +*Optional* + +Whether the CA certificate must be verified when using `urllib.request` or +`requests`. +Default to `True`. + +### `TIMEOUT` + +*Optional* + +The number of seconds waiting for a response before issuing a timeout when using +`urllib.request` or `requests`. +Default to `5`. + +### `ALLOW_REDIRECTS` + +*Optional* + +Enable / disable redirection when using `urllib.request` or `requests`. +Default to `True`. diff --git a/manage.py b/manage.py new file mode 100755 index 0000000..2de3a7d --- /dev/null +++ b/manage.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "tests/test_project/")) + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault( + "DJANGO_SETTINGS_MODULE", "tests.test_project.test_project_conf.settings" + ) + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == "__main__": + main() diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..085972f --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,26 @@ +site_name: DRF Simple OAuth2 +site_author: qcoumes +repo_url: https://github.com/Codoc-os/drf-simple-oauth2 +docs_dir: docs/ +use_directory_urls: true + +nav: + - "Gettings Started": + - Introduction: index.md + + - "Documentation": + - Settings: settings.md + - Models: models.md + - Authorization Flow: flow.md + + - "Other": + - Contributing: CONTRIBUTING.md + - Changelog: CHANGELOG.md + - Authors: AUTHORS.md + - License: LICENSE.md + - Github: https://github.com/Codoc-os/drf-simple-oauth2 + +extra_css: + - css/misc.css + +theme: readthedocs diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f2e3b2b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +django>=4.2.0,<6.0.0 +djangorestframework>=3.0.0,<4.0.0 +djangorestframework-simplejwt>=5.0.0,<6.0.0 +pyjwt>=2.6.0,<3.0.0 +requests>=2.28.0,<3.0.0 diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..56a06be --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,16 @@ +bandit +black +coverage +djangorestframework-simplejwt +django-extensions +flake8 +isort +mkdocs +mypy +pycodestyle +pydocstyle +pyflakes +pytest +pytest-cov +pytest-django +types-requests diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..ef38af8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,112 @@ +######################## +##### Tox settings ##### +######################## +[tox:tox] +distshare = {homedir}/.tox/distshare + envlist = py{310,311,312,313}-django{42,51,52} +skipsdist = true +skip_missing_interpreters = true +indexserver = + pypi = https://pypi.python.org/simple + +[testenv] +setenv = + PYTHONPATH = {toxinidir} + DJANGO_SETTINGS_MODULE = tests.test_project.test_project_conf.settings +deps = + -rrequirements.txt + -rrequirements_dev.txt + django42: django>=4.2.0,<4.3.0 + django51: django>=5.1.0,<5.2.0 + django52: django>=5.2.0,<6.0.0 +commands = + python3 -m pytest --create-db -vvv -ss --color=yes --durations=0 --durations-min=1.0 --cov=. --cov-report term + coverage xml + + +######################## +### Checks settings #### +######################## +[pycodestyle] +count = True +max-line-length = 120 +max-doc-length = 100 +exclude = venv, .tox +ignore = W503, W504, W605, E121, E123, E126, E203, E501 +# W503: Line break occurred before a binary operator +# W504: Line break occurred after a binary operator +# W605: Invalid escape sequence +# E121: Continuation line under-indented for hanging indent +# E123: Closing bracket does not match indentation of opening bracket's line +# E126: Continuation line over-indented for hanging indent +# E203: Whitespace before ':' +# E501: Line too long + +[mypy] +ignore_missing_imports = True +no_implicit_optional = False +disable_error_code = attr-defined,index,valid-type,union-attr +# valid-type: To remove when dropping python 3.9 support + +[pydocstyle] +convention = numpy +match-dir = (?!tests|migrations|\.).* +match = (?!test_|conftest|manage\.py).*\.py +add_ignore = D100, D104, D105, D106 +# D100: Missing docstring in public module +# D104: Missing docstring in public package +# D105: Missing docstring in magic method +# D106: Missing docstring in public nested class + +[tool:isort] +profile = black +line_length = 120 +src_paths = simple_oauth2,tests + + +[bandit] +targets = simple_oauth2, tests +exclude = venv, .tox +recursive = True +quiet = True +format = custom +msg-template = {abspath}:{line} - {test_id} - {severity} - {msg} + + +############################# +##### Pytest settings ####### +############################# +[tool:pytest] +DJANGO_SETTINGS_MODULE = tests.test_project.test_project_conf.settings +filterwarnings = + error +pythonpath = tests/, tests/test_project/ + + + +############################# +##### Coverage settings ##### +############################# +[coverage:report] +exclude_lines = + pragma: no cover + def __repr__ + def __str__ + TYPE_CHECKING + raise NotImplementedError + @abstractmethod + if verbosity + if verbose +include = + simple_oauth2/* +omit = + venv/* + site-packages/* + +[coverage:run] +branch = True +source = + simple_oauth2 + +[coverage:html] +title = DRF Simple OAuth2's Coverage diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..9154456 --- /dev/null +++ b/setup.py @@ -0,0 +1,47 @@ +"""Setuptools entry point.""" +import codecs +import os + +from setuptools import setup + +DIRNAME = os.path.dirname(__file__) +CLASSIFIERS = [ + "Development Status :: 5 - Production/Stable", + "Framework :: Django :: 4.2", + "Framework :: Django :: 5.1", + "Framework :: Django :: 5.2", + "Intended Audience :: Developers", + "Natural Language :: English", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +LONG_DESCRIPTION = ( + codecs.open(os.path.join(DIRNAME, "README.md"), encoding="utf-8").read() + + "\n" + + codecs.open(os.path.join(DIRNAME, "docs/CHANGELOG.md"), encoding="utf-8").read() +) +REQUIREMENTS = [ + "django>=4.2.0,<6.0.0", + "djangorestframework>=3.0.0,<4.0.0", + "djangorestframework-simplejwt>=5.0.0,<6.0.0", + "pyjwt>=2.6.0,<3.0.0", + "requests>=2.32.0,<3.0.0", +] + +setup( + name="simple_oauth2", + version="1.0.0", + description=""" Simple OAuth2 client package allowing to define OAuth2 / OpenID providers through settings. """, + long_description=LONG_DESCRIPTION, + long_description_content_type="text/markdown", + author="Quentin Coumes (Codoc)", + author_email="quentin@codoc.co", + url="https://github.com/Codoc-os/drf-simple-oauth2", + packages=["simple_oauth2"], + include_package_data=True, + install_requires=REQUIREMENTS, + keywords="django simple_oauth2 oauth2 oauth openid authentication", + classifiers=CLASSIFIERS, +) diff --git a/simple_oauth2/__init__.py b/simple_oauth2/__init__.py new file mode 100644 index 0000000..5becc17 --- /dev/null +++ b/simple_oauth2/__init__.py @@ -0,0 +1 @@ +__version__ = "1.0.0" diff --git a/simple_oauth2/apps.py b/simple_oauth2/apps.py new file mode 100644 index 0000000..a608679 --- /dev/null +++ b/simple_oauth2/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig + + +class SimpleOauth2Config(AppConfig): + """Configuration for the simple_oauth2 app.""" + + default_auto_field = "django.db.models.BigAutoField" + name = "simple_oauth2" diff --git a/simple_oauth2/enums.py b/simple_oauth2/enums.py new file mode 100644 index 0000000..72ce485 --- /dev/null +++ b/simple_oauth2/enums.py @@ -0,0 +1,11 @@ +from django.db import models + + +class Status(models.TextChoices): + """Valid status values for a Session.""" + + PENDING = "pending" + TOKEN_FAILED = "token_failed" + USERINFO_FAILED = "userinfo_failed" + EXPIRED = "expired" + COMPLETED = "completed" diff --git a/simple_oauth2/exceptions.py b/simple_oauth2/exceptions.py new file mode 100644 index 0000000..c267a3a --- /dev/null +++ b/simple_oauth2/exceptions.py @@ -0,0 +1,10 @@ +class SimpleOAuth2Error(Exception): + """Base exception for simple_oauth2's exceptions.""" + + +class UnknownProvider(SimpleOAuth2Error): + """Raised when an unknown OAuth2 provider is referenced.""" + + def __init__(self, provider: str): + self.provider = provider + super().__init__(f"Unknown OAuth2 provider '{provider}'") diff --git a/simple_oauth2/migrations/0001_initial.py b/simple_oauth2/migrations/0001_initial.py new file mode 100644 index 0000000..8d03da4 --- /dev/null +++ b/simple_oauth2/migrations/0001_initial.py @@ -0,0 +1,112 @@ +# Generated by Django 4.1.13 on 2025-09-08 04:51 + +import django.core.validators +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + +import simple_oauth2.utils + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="Session", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("_provider", models.CharField(db_column="provider", max_length=255)), + ( + "nonce", + models.CharField(default=simple_oauth2.utils.generate_nonce, max_length=255), + ), + ( + "state", + models.CharField(default=simple_oauth2.utils.generate_state, max_length=255), + ), + ( + "code_verifier", + models.CharField( + default=simple_oauth2.utils.generate_code_verifier, + max_length=128, + validators=[ + django.core.validators.RegexValidator( + message="code_verifier must be 43–128 chars of unreserved URL characters (RFC 7636).", + regex="^[A-Za-z0-9\\-._~]{43,128}$", + ) + ], + ), + ), + ( + "status", + models.CharField( + choices=[ + ("PENDING", "Pending"), + ("TOKEN_FAILED", "Token Failed"), + ("USERINFO_FAILED", "Userinfo Failed"), + ("EXPIRED", "Expired"), + ("COMPLETED", "Completed"), + ], + default="PENDING", + max_length=20, + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("completed_at", models.DateTimeField(null=True)), + ], + ), + migrations.CreateModel( + name="Sub", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("_provider", models.CharField(db_column="provider", max_length=255)), + ("sub", models.CharField(max_length=255)), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("last_login", models.DateTimeField(auto_now=True)), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + ), + migrations.AddIndex( + model_name="session", + index=models.Index(fields=["created_at"], name="simple_oaut_created_7c172d_idx"), + ), + migrations.AddConstraint( + model_name="session", + constraint=models.UniqueConstraint(fields=("_provider", "state"), name="session_provider_state_key"), + ), + migrations.AddIndex( + model_name="sub", + index=models.Index(fields=["sub"], name="simple_oaut_sub_7eb942_idx"), + ), + migrations.AddConstraint( + model_name="sub", + constraint=models.UniqueConstraint(fields=("_provider", "sub"), name="sub_provider_sub_key"), + ), + ] diff --git a/simple_oauth2/migrations/__init__.py b/simple_oauth2/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/simple_oauth2/models.py b/simple_oauth2/models.py new file mode 100644 index 0000000..c9b6ff7 --- /dev/null +++ b/simple_oauth2/models.py @@ -0,0 +1,192 @@ +import base64 +import datetime +import hashlib +import urllib.parse +from typing import Any + +import requests +from django.contrib.auth import get_user_model +from django.core.exceptions import NON_FIELD_ERRORS +from django.core.validators import RegexValidator +from django.db import IntegrityError, models, transaction +from django.utils import timezone +from rest_framework.exceptions import ValidationError + +from simple_oauth2 import utils +from simple_oauth2.enums import Status +from simple_oauth2.exceptions import SimpleOAuth2Error, UnknownProvider +from simple_oauth2.settings import OAuth2ProviderSettings, oauth2_settings + +PKCE_VALIDATOR = RegexValidator( + regex=r"^[A-Za-z0-9\-._~]{43,128}$", + message="code_verifier must be 43–128 chars of unreserved URL characters (RFC 7636).", +) + + +class Session(models.Model): + """Represent an OAuth2 authorization session.""" + + _provider = models.CharField(max_length=255, db_column="provider") + + nonce = models.CharField(max_length=255, default=utils.generate_nonce) + state = models.CharField(max_length=255, default=utils.generate_state) + code_verifier = models.CharField(max_length=128, default=utils.generate_code_verifier, validators=(PKCE_VALIDATOR,)) + + status = models.CharField(max_length=20, choices=Status.choices, default=Status.PENDING) + created_at = models.DateTimeField(auto_now_add=True) + completed_at = models.DateTimeField(null=True) + + class Meta: + constraints = (models.UniqueConstraint(fields=("_provider", "state"), name="session_provider_state_key"),) + indexes = (models.Index(fields=("created_at",)),) + + @classmethod + def start(cls, provider: str) -> "Session": + """Try creating a unique Session while avoiding race condition.""" + if provider not in oauth2_settings: + raise UnknownProvider(provider) + for _ in range(10): + try: + with transaction.atomic(): + return Session.objects.create(_provider=provider) + except IntegrityError: + continue + raise SimpleOAuth2Error("Could not create a unique authorization session after 10 attempts.") + + @property + def provider(self) -> OAuth2ProviderSettings: + """Return the provider configuration associated with this Session.""" + if self._provider in oauth2_settings: + return oauth2_settings[self._provider] + raise UnknownProvider(self._provider) + + @property + def use_pkce(self) -> bool: + """Return whether or not the Session uses PKCE.""" + return self.provider.USE_PKCE + + @property + def code_challenge_method(self) -> str: + """Return the code challenge method associated with this Session.""" + return self.provider.CODE_CHALLENGE_METHOD + + @property + def code_challenge(self) -> str: + """Return the code challenge from code_verifier and code_challenge_method.""" + if self.code_challenge_method.lower() == "plain": + return self.code_verifier + + elif self.code_challenge_method.lower() == "s256": + digest = hashlib.sha256(self.code_verifier.encode("ascii")).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + + raise SimpleOAuth2Error(f"Unknown challenge method: '{self.code_challenge_method}'") + + def has_expired(self) -> bool: + """Return whether the authorization session as expired.""" + now = timezone.now().astimezone(datetime.timezone.utc) + if timezone.is_naive(self.created_at): + ts = timezone.make_aware(self.created_at, datetime.timezone.utc) + else: + ts = self.created_at.astimezone(datetime.timezone.utc) + if (now - ts).total_seconds() > self.provider.AUTHORIZATION_SESSION_LIFETIME: + return True + return False + + def authentication_url(self) -> str: + """Generate the OAuth2 authentication URL for the given alias.""" + params = { + "response_type": "code", + "client_id": self.provider.CLIENT_ID, + "scope": " ".join(self.provider.SCOPES), + "nonce": self.nonce, + "state": self.state, + "redirect_uri": self.provider.REDIRECT_URI, + } + if self.use_pkce: + params |= { + "code_challenge": self.code_challenge, + "code_challenge_method": self.provider.CODE_CHALLENGE_METHOD, + } + for key, value in self.provider.AUTHORIZATION_EXTRA_PARAMETERS.items(): + params[key] = value + return f"{self.provider.authorization_uri()}?{urllib.parse.urlencode(params)}" + + def get_tokens(self, code: str) -> dict[str, str]: + """Exchange the authorization code for tokens.""" + if self.has_expired(): + self.status = Status.EXPIRED + self.save(update_fields=["status"]) + raise ValidationError({NON_FIELD_ERRORS: "Authorization session has expired."}) + + data = { + "grant_type": "authorization_code", + "client_id": self.provider.CLIENT_ID, + "client_secret": self.provider.CLIENT_SECRET, + "code": code, + "redirect_uri": self.provider.REDIRECT_URI, + } + if self.use_pkce: + data["code_verifier"] = self.code_verifier + + try: + response = requests.post( + self.provider.token_uri(), + data=data, + timeout=self.provider.TIMEOUT, + verify=self.provider.VERIFY_SSL, + allow_redirects=self.provider.ALLOW_REDIRECTS, + ) + response.raise_for_status() + except requests.RequestException as e: + self.status = Status.TOKEN_FAILED + self.save(update_fields=("status",)) + raise ValidationError( + {"__all__": f"Failed to retrieve tokens from provider: {e.response.content.decode()}"} + ) + + return response.json() + + def get_user(self, access_token: str, **kwargs: Any) -> models.Model: + """Fetch the user infos from the provider and feet it to the handler.""" + try: + response = requests.get( + self.provider.userinfo_uri(), + headers={"Authorization": f"Bearer {access_token}"}, + timeout=self.provider.TIMEOUT, + verify=self.provider.VERIFY_SSL, + allow_redirects=self.provider.ALLOW_REDIRECTS, + ) + response.raise_for_status() + except requests.RequestException as e: + self.status = Status.USERINFO_FAILED + self.save(update_fields=("status",)) + raise ValidationError( + {"__all__": f"Failed to fetch user info from the provider: {e.response.content.decode()}."} + ) + + user = self.provider.TOKEN_USERINFO_HANDLER(self.provider, response.json(), **kwargs) + + return user + + def get_payload(self, oauth2_tokens: dict[str, str], user: models.Model) -> dict: + """Fetch the payload from the provider and feed it to the handler.""" + payload = self.provider.TOKEN_PAYLOAD_HANDLER(self.provider, oauth2_tokens, user) + self.status = Status.COMPLETED + self.completed_at = timezone.now() + self.save(update_fields=("status", "completed_at")) + return payload + + +class Sub(models.Model): + """Link a user to a given OAuth2 provider and its sub.""" + + _provider = models.CharField(max_length=255, db_column="provider") + sub = models.CharField(max_length=255) + user = models.ForeignKey(get_user_model(), on_delete=models.CASCADE) + created_at = models.DateTimeField(auto_now_add=True) + last_login = models.DateTimeField(auto_now=True) + + class Meta: + constraints = (models.UniqueConstraint(fields=("_provider", "sub"), name="sub_provider_sub_key"),) + indexes = (models.Index(fields=("sub",)),) diff --git a/simple_oauth2/serializers.py b/simple_oauth2/serializers.py new file mode 100644 index 0000000..6a9b96c --- /dev/null +++ b/simple_oauth2/serializers.py @@ -0,0 +1,9 @@ +from rest_framework import serializers + + +class TokenSerializer(serializers.Serializer): + """Serializer for the token endpoint input data.""" + + provider = serializers.CharField() + code = serializers.CharField() + state = serializers.CharField() diff --git a/simple_oauth2/settings.py b/simple_oauth2/settings.py new file mode 100644 index 0000000..b9deab1 --- /dev/null +++ b/simple_oauth2/settings.py @@ -0,0 +1,148 @@ +import logging +import urllib.parse +from typing import Any, Callable + +import requests +from django.conf import settings +from django.core.exceptions import ImproperlyConfigured +from django.utils.module_loading import import_string + +logger = logging.getLogger(__name__) + +# Default settings for OAuth2 providers +DEFAULTS = { + "OPENID_CONFIGURATION_PATH": "/.well-known/openid-configuration", + "AUTHORIZATION_SESSION_LIFETIME": 300, # 5 minutes + "AUTHORIZATION_EXTRA_PARAMETERS": {}, + "TOKEN_USERINFO_HANDLER": "simple_oauth2.utils.get_user", + "TOKEN_PAYLOAD_HANDLER": "simple_oauth2.utils.simple_jwt_authenticate", + "CODE_CHALLENGE_METHOD": "S256", + "SCOPES": ["openid", "profile", "email"], + "USE_PKCE": True, + "VERIFY_SSL": True, + "TIMEOUT": 5, + "ALLOW_REDIRECTS": True, +} + +# Mapping of settings to OpenID Connect configuration keys +CONFIGURATION_KEY = { + "AUTHORIZATION_PATH": "authorization_endpoint", + "TOKEN_PATH": "token_endpoint", + "USERINFO_PATH": "userinfo_endpoint", + "LOGOUT_PATH": "end_session_endpoint", + "JWKS_PATH": "jwks_uri", + "SIGNING_ALGORITHMS": "id_token_signing_alg_values_supported", +} + +# Settings that may be imported from strings +IMPORT_STRINGS = {"TOKEN_USERINFO_HANDLER", "TOKEN_PAYLOAD_HANDLER"} + +# Mandatory settings that must be either loaded from OpenID configuration, +# or provided by the user +MANDATORY = { + "CLIENT_ID", + "CLIENT_SECRET", + "REDIRECT_URI", + "POST_LOGOUT_REDIRECT_URI", + "BASE_URL", + "AUTHORIZATION_PATH", + "TOKEN_PATH", + "USERINFO_PATH", + "JWKS_PATH", + "LOGOUT_PATH", + "SIGNING_ALGORITHMS", +} + + +def import_from_string(v: str, provider: str, setting_name: str) -> type: + """Attempt to import a class from a string representation.""" + try: + return import_string(v) + except ImportError as e: # pragma: no cover + raise ImportError( + f"Could not import {v} for SIMPLE_OAUTH2 setting '{provider}[{setting_name}]' {e.__class__.__name__}: {e}." + ) + + +class OAuth2ProviderSettings: + """ + A settings object, that allows OAuth2 Provider settings to be accessed as properties. + + Any setting with string import paths will be automatically resolved + and return the class, rather than the string literal. + """ + + def __init__(self, alias: str, user_settings: dict): + self._alias = alias + + # Try retrieving some settings from the provider's OpenID configuration + user_settings = DEFAULTS | user_settings + url = urllib.parse.urljoin(user_settings["BASE_URL"], user_settings["OPENID_CONFIGURATION_PATH"]) + provider_settings = self._load_settings_from_provider( + url, user_settings["TIMEOUT"], user_settings["VERIFY_SSL"], user_settings["ALLOW_REDIRECTS"] + ) + + settings = provider_settings | user_settings + if missing := MANDATORY - {k for k, v in settings.items() if v}: # pragma: no cover + raise ImproperlyConfigured( + f"OAuth2 provider '{self.alias}' is missing mandatory settings: {', '.join(missing)}" + ) + + self._user_settings = settings + + def __getattr__(self, attr: str) -> Any: + """Return the setting value or raise an AttributeError.""" + if attr not in self._user_settings: # pragma: no cover + raise AttributeError(f"Invalid SIMPLE_OAUTH2 setting: '{self.alias}[{attr}]'") + if attr in IMPORT_STRINGS: + return self._perform_import(self._user_settings[attr], attr) + return self._user_settings[attr] + + def _load_settings_from_provider( + self, url: str, timeout: int, verify: bool, allow_redirects: bool + ) -> dict[str, Any]: + """Load settings from the provider's OpenID configuration endpoint.""" + try: + response = requests.get(url, timeout=timeout, verify=verify, allow_redirects=allow_redirects) + response.raise_for_status() + except requests.RequestException as e: # pragma: no cover + logger.warning("Could not fetch '%s' OpenID configuration from '%s': %s", self.alias, url, e) + return {} + configuration = response.json() + return {setting: configuration[path] for setting, path in CONFIGURATION_KEY.items()} + + def _perform_import(self, value: str, setting_name: str) -> Callable: # pragma: no cover + """Import a class from a string representation.""" + if isinstance(value, str): + return import_from_string(value, self.alias, setting_name) + elif isinstance(value, Callable): + return value + raise ImproperlyConfigured( + f"SIMPLE_OAUTH2 setting '{self.alias}[{setting_name}]' must be a string or callable." + ) + + @property + def alias(self) -> str: + """Return the provider alias.""" + return self._alias + + def jwks_uri(self) -> str: + """Return the full JWKS URL.""" + return urllib.parse.urljoin(self.BASE_URL, self.JWKS_PATH) + + def authorization_uri(self) -> str: + """Return the full authorization URL.""" + return urllib.parse.urljoin(self.BASE_URL, self.AUTHORIZATION_PATH) + + def token_uri(self) -> str: + """Return the full token URL.""" + return urllib.parse.urljoin(self.BASE_URL, self.TOKEN_PATH) + + def userinfo_uri(self) -> str: + """Return the full token URL.""" + return urllib.parse.urljoin(self.BASE_URL, self.USERINFO_PATH) + + +oauth2_settings = { + alias: OAuth2ProviderSettings(alias, settings) for alias, settings in getattr(settings, "SIMPLE_OAUTH2", {}).items() +} diff --git a/simple_oauth2/urls.py b/simple_oauth2/urls.py new file mode 100644 index 0000000..cd7eb67 --- /dev/null +++ b/simple_oauth2/urls.py @@ -0,0 +1,13 @@ +from django.urls import include, path +from rest_framework import routers + +from simple_oauth2 import views + +app_name = "simple_oauth2" + +router = routers.SimpleRouter() +router.register(r"oauth2", views.OAuth2ViewSet, basename="oauth2") + +urlpatterns = [ + path("", include(router.urls)), +] diff --git a/simple_oauth2/utils.py b/simple_oauth2/utils.py new file mode 100644 index 0000000..74edb1f --- /dev/null +++ b/simple_oauth2/utils.py @@ -0,0 +1,86 @@ +import secrets +import ssl +import string +import urllib.parse +from typing import Any + +import jwt +from django.contrib.auth import get_user_model +from django.db import models +from rest_framework.exceptions import ValidationError + +from simple_oauth2.settings import OAuth2ProviderSettings + + +def generate_nonce(size: int = 128) -> str: + """Generate a random nonce as an hexadecimal string.""" + return secrets.token_hex(size // 2) + + +def generate_state(size: int = 128) -> str: + """Generate a random state as an hexadecimal string.""" + return secrets.token_hex(size // 2) + + +def generate_code_verifier(size: int = 128) -> str: + """Generate a random code verifier.""" + if not (43 <= size <= 128): + raise ValueError("code_verifier must be 43..128 chars (RFC 7636).") + return "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(size)) + + +def decode_jwt(provider: OAuth2ProviderSettings, token: str) -> dict: + """Decode a JWT token using the provider's JKWS URI.""" + try: + ssl_context = ssl._create_unverified_context() if not provider.VERIFY_SSL else None # nosec: B323 + jwks_client = jwt.PyJWKClient(provider.jwks_uri(), ssl_context=ssl_context, timeout=provider.TIMEOUT) + key = jwks_client.get_signing_key_from_jwt(token).key + except jwt.PyJWKClientConnectionError as e: # pragma: no cover + raise ValidationError({"__all__": f"Failed to fetch JKWS from the provider: {e}"}) + return jwt.decode(token, key, algorithms=provider.SIGNING_ALGORITHMS, audience=provider.CLIENT_ID) + + +def get_user(provider: OAuth2ProviderSettings, userinfo: dict, **kwargs: Any) -> models.Model: + """Create or update a user from the id_token and the userinfo dictionary.""" + from simple_oauth2.models import Sub + + User = get_user_model() + + claims = decode_jwt(provider=provider, token=kwargs["id_token"]) if "id_token" in kwargs else {} + sub_value = claims.get("sub") or userinfo.get("sub") + username = (claims.get("preferred_username") or userinfo.get("preferred_username")) or ( + claims.get("email") or userinfo.get("email") + ) + try: + sub = Sub.objects.get(_provider=provider.alias, sub=sub_value) + sub.save(update_fields=["last_login"]) # Update the 'last_login' timestamp + user = sub.user + except Sub.DoesNotExist: + user = User.objects.create_user(username=username) + Sub.objects.create(user=user, _provider=provider.alias, sub=sub_value) + + user.first_name = claims.get("given_name", "") or userinfo.get("given_name", "") + user.last_name = claims.get("family_name", "") or userinfo.get("family_name", "") + user.email = claims.get("email", "") or userinfo.get("email", "") + user.save() + + return user + + +def simple_jwt_authenticate( + provider: OAuth2ProviderSettings, oauth2_tokens: dict[str, str], user: models.Model +) -> dict: + """Create a simple JWT token payload.""" + from rest_framework_simplejwt.tokens import RefreshToken + + params = { + "client_id": provider.CLIENT_ID, + "id_token_hint": oauth2_tokens["id_token"], + "post_logout_redirect_uri": provider.POST_LOGOUT_REDIRECT_URI, + } + url = urllib.parse.urljoin(provider.BASE_URL, provider.LOGOUT_PATH) + refresh = RefreshToken.for_user(user) + return { + "api": {"refresh": str(refresh), "access": str(refresh.access_token)}, + "provider": oauth2_tokens | {"logout_url": f"{url}?{urllib.parse.urlencode(params)}"}, + } diff --git a/simple_oauth2/views.py b/simple_oauth2/views.py new file mode 100644 index 0000000..efdaaea --- /dev/null +++ b/simple_oauth2/views.py @@ -0,0 +1,56 @@ +from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError +from rest_framework.request import Request +from rest_framework.response import Response + +from simple_oauth2.enums import Status +from simple_oauth2.exceptions import UnknownProvider +from simple_oauth2.models import Session +from simple_oauth2.serializers import TokenSerializer +from simple_oauth2.settings import oauth2_settings + + +class OAuth2ViewSet(viewsets.GenericViewSet): + """Allows authentication through FranceConnect.""" + + authentication_classes: list[type] = [] + permission_classes: list[type] = [] + serializer_class = TokenSerializer + + @action(detail=False, methods=("get",)) + def url(self, request: Request) -> Response: + """Return the OAuth2 URL the login button must use.""" + if (provider := request.query_params.get("provider")) is None: + raise ValidationError( + {"provider": f"OAuth2 provider required, allowed values are: '{', '.join(oauth2_settings)}'"} + ) + try: + session = Session.start(provider) + except UnknownProvider: + raise ValidationError( + { + "provider": f"Unknown OAuth2 provider '{provider}', allowed values are: '{', '.join(oauth2_settings)}'" + } + ) + return Response({"url": session.authentication_url()}) + + @action(detail=False, methods=("post",), serializer_class=TokenSerializer) + def token(self, request: Request) -> Response: + """Return tokens using the code obtained from your provider.""" + serializer = self.get_serializer(data=self.request.data) + serializer.is_valid(raise_exception=True) + + try: + session = Session.objects.get( + _provider=serializer.validated_data["provider"], + state=serializer.validated_data["state"], + status=Status.PENDING, + ) + except Session.DoesNotExist: + raise ValidationError({"state": "No ongoing session matches the provided state."}) + + oauth2_tokens = session.get_tokens(serializer.validated_data["code"]) + user = session.get_user(**oauth2_tokens) + payload = session.get_payload(oauth2_tokens, user) + return Response(payload) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..af8f71f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,140 @@ +import functools +from types import SimpleNamespace +from typing import Any +from unittest import mock + +import jwt +import requests + +from simple_oauth2.settings import OAuth2ProviderSettings + +CONFIGURATION_RESPONSE = { + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + "userinfo_endpoint": "https://example.com/userinfo", + "jwks_uri": "https://example.com/.well-known/jwks.json", + "end_session_endpoint": "https://example.com/logout", + "login_endpoint": "https://example.com/login", + "id_token_signing_alg_values_supported": ["plain", "HS256"], +} + +SIMPLE_OAUTH2_SETTINGS = { + "pkce-plain": { + "CLIENT_ID": "pkce-plain", + "CLIENT_SECRET": "pkce-plain", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "CODE_CHALLENGE_METHOD": "plain", + }, + "pkce-s256": { + "CLIENT_ID": "pkce-s256", + "CLIENT_SECRET": "pkce-s256", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "CODE_CHALLENGE_METHOD": "s256", + }, + "pkce-unknown-alg": { + "CLIENT_ID": "pkce-s256", + "CLIENT_SECRET": "pkce-s256", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "CODE_CHALLENGE_METHOD": "unknown", + }, + "no-pkce": { + "CLIENT_ID": "no-pkce", + "CLIENT_SECRET": "no-pkce", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "USE_PKCE": False, + }, + "extra-params": { + "CLIENT_ID": "extra-params", + "CLIENT_SECRET": "extra-params", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "USE_PKCE": False, + "AUTHORIZATION_EXTRA_PARAMETERS": {"foo": "bar", "baz": "qux"}, + }, + "token-fails": { + "CLIENT_ID": "token-fails", + "CLIENT_SECRET": "token-fails", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "USE_PKCE": False, + "TOKEN_PATH": "/unknown", + }, + "userinfo-fails": { + "CLIENT_ID": "userinfo-fails", + "CLIENT_SECRET": "userinfo-fails", + "BASE_URL": "https://example.com", + "REDIRECT_URI": "https://example.com/callback", + "POST_LOGOUT_REDIRECT_URI": "https://example.com/logout", + "USE_PKCE": False, + "USERINFO_PATH": "/unknown", + }, +} + + +def raise_request_exception(): + """Helper to raise a requests.RequestException from a lambda.""" + raise requests.RequestException(response=SimpleNamespace(status_code="404", content='{"detail": "error"}'.encode())) + + +def mocked_requests(method, url, data=None, **kwargs): + """Mocked requests.request return values.""" + match (method.upper(), url): + case "GET", "https://example.com/.well-known/openid-configuration": + return SimpleNamespace( + status_code="200", json=lambda: CONFIGURATION_RESPONSE, raise_for_status=lambda: None + ) + case "POST", "https://example.com/token": + return SimpleNamespace( + status_code="200", + json=lambda: { + "access_token": "abcdefghijklmnopqrstuwxz", + "id_token": jwt.encode( + { + "sub": "1234567890", + "email": "test@test.com", + "preferred_username": "test", + "aud": data["code"][:-5], + }, + "key", + algorithm="HS256", + ), + }, + raise_for_status=lambda: None, + ) + case "GET", "https://example.com/userinfo": + return SimpleNamespace( + status_code="200", + json=lambda: {"sub": "1234567890", "email": "test@test.com", "preferred_username": "test"}, + raise_for_status=lambda: None, + ) + case _: + return SimpleNamespace( + status_code="404", json=lambda: {"detail": "Not found."}, raise_for_status=raise_request_exception() + ) + + +class override_oauth2_settings: + """Patch 'settings.oauth2_settings' with the provided SIMPLE_OAUTH2 dict.""" + + def __init__(self, simple_oauth2: dict): + self.simple_oauth2 = simple_oauth2 + + def __call__(self, func): + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any): + with mock.patch("requests.api.request", side_effect=mocked_requests): + new = {alias: OAuth2ProviderSettings(alias, settings) for alias, settings in self.simple_oauth2.items()} + with mock.patch.dict("simple_oauth2.settings.oauth2_settings", new, clear=True): + return func() + + return wrapper diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..aed0f6e --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,87 @@ +from contextlib import contextmanager +from datetime import timedelta +from typing import Any, Callable, TypeVar + +import pytest +from django.db import models +from django.utils import timezone + +from simple_oauth2.exceptions import SimpleOAuth2Error, UnknownProvider +from simple_oauth2.models import Session +from simple_oauth2.utils import generate_state +from tests.conftest import SIMPLE_OAUTH2_SETTINGS, override_oauth2_settings + +STATE1 = generate_state() +STATE2 = generate_state() + +T = TypeVar("T") + + +@contextmanager +def override_field_default(model: type[models.Model], field_name: str, default: Callable[[], Any]): + field = model._meta.get_field(field_name) + old = field.default + try: + field._get_default = default + field.default = default + yield + finally: + field.default = old + field._get_default = old + + +class ValuesDefault: + """Callable class to return values from a list one after the other.""" + + def __init__(self, values: list[T]): + self.values = values + self.index = 0 + + def __call__(self) -> T: + value = self.values[self.index] + self.index = (self.index + 1) % len(self.values) + return value + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_start_multiple_try(): + with override_field_default(Session, "state", ValuesDefault([STATE1] * 5 + [STATE2])): + session = Session.start("no-pkce") + assert session.state == STATE1 + session2 = Session.start("no-pkce") + assert session2.state == STATE2 + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_start_integrity_error(): + with override_field_default(Session, "state", ValuesDefault([STATE1] * 11)): + Session.start("no-pkce") + with pytest.raises( + SimpleOAuth2Error, match="Could not create a unique authorization session after 10 attempts." + ): + Session.start("no-pkce") + + +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_provider_unknown(): + session = Session(_provider="unknown") + with pytest.raises(UnknownProvider, match="Unknown OAuth2 provider 'unknown'"): + _ = session.provider + + +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_has_expired(): + session = Session(_provider="no-pkce", created_at=timezone.now()) + assert not session.has_expired() + session.created_at = timezone.now() - timedelta(days=1) + assert session.has_expired() + + +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_has_expired_naive(): + session = Session(_provider="no-pkce", created_at=timezone.now().replace(tzinfo=None)) + assert not session.has_expired() + session.created_at = (timezone.now() - timedelta(days=1)).replace(tzinfo=None) + assert session.has_expired() diff --git a/tests/test_project/__init__.py b/tests/test_project/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_project/test_project_conf/__init__.py b/tests/test_project/test_project_conf/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_project/test_project_conf/asgi.py b/tests/test_project/test_project_conf/asgi.py new file mode 100644 index 0000000..2ed72e0 --- /dev/null +++ b/tests/test_project/test_project_conf/asgi.py @@ -0,0 +1,16 @@ +""" +ASGI config for test_project project. + +It exposes the ASGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/4.1/howto/deployment/asgi/ +""" + +import os + +from django.core.asgi import get_asgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "test_project_conf.settings") + +application = get_asgi_application() diff --git a/tests/test_project/test_project_conf/settings.py b/tests/test_project/test_project_conf/settings.py new file mode 100644 index 0000000..7116dd4 --- /dev/null +++ b/tests/test_project/test_project_conf/settings.py @@ -0,0 +1,121 @@ +""" +Django settings for test_project project. + +Generated by 'django-admin startproject' using Django 4.1.3. + +For more information on this file, see +https://docs.djangoproject.com/en/4.1/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/4.1/ref/settings/ +""" + +import os +import sys +from pathlib import Path + +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent + +sys.path.append(str(BASE_DIR)) + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/4.1/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = "django-insecure-o2p02q(1rym0^aqt4=3$w+mokgz^#je^*l&9v-v_ma71276je#" + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + +# Application definition +INSTALLED_APPS = [ + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", + "simple_oauth2", +] + +MIDDLEWARE = [ + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", +] + +ROOT_URLCONF = "test_project_conf.urls" + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ], + }, + }, +] + +WSGI_APPLICATION = "test_project_conf.wsgi.application" + +# Database +# https://docs.djangoproject.com/en/4.1/ref/settings/#databases + +DATABASES = { + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": os.path.join(BASE_DIR, "db.sqlite3"), + } +} + +# Password validation +# https://docs.djangoproject.com/en/4.1/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", + }, + { + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", + }, +] + +# Internationalization +# https://docs.djangoproject.com/en/4.1/topics/i18n/ + +LANGUAGE_CODE = "en-us" + +TIME_ZONE = "UTC" + +USE_I18N = True + +USE_TZ = True + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.1/howto/static-files/ + +STATIC_URL = "static/" + +# Default primary key field type +# https://docs.djangoproject.com/en/4.1/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/tests/test_project/test_project_conf/urls.py b/tests/test_project/test_project_conf/urls.py new file mode 100644 index 0000000..b40e7f2 --- /dev/null +++ b/tests/test_project/test_project_conf/urls.py @@ -0,0 +1,21 @@ +"""test_project URL Configuration + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/4.1/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" + +from django.urls import include, path + +urlpatterns = [ + path("oauth2/", include("simple_oauth2.urls", namespace="simple_oauth2")), +] diff --git a/tests/test_project/test_project_conf/wsgi.py b/tests/test_project/test_project_conf/wsgi.py new file mode 100644 index 0000000..6179ef3 --- /dev/null +++ b/tests/test_project/test_project_conf/wsgi.py @@ -0,0 +1,16 @@ +""" +WSGI config for test_project project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/4.1/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "test_project_conf.settings") + +application = get_wsgi_application() diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..9a5c73c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,47 @@ +import pytest +from django.contrib.auth import get_user_model + +from simple_oauth2.models import Sub +from simple_oauth2.settings import oauth2_settings +from simple_oauth2.utils import generate_code_verifier, get_user +from tests.conftest import SIMPLE_OAUTH2_SETTINGS, override_oauth2_settings + + +def test_generate_code_verifier(): + assert generate_code_verifier() + with pytest.raises(ValueError): + generate_code_verifier(10) + with pytest.raises(ValueError): + generate_code_verifier(200) + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_get_user(): + user = get_user( + provider=oauth2_settings["no-pkce"], + userinfo={ + "sub": "123456", + "preferred_username": "user", + "email": "test@test.com", + }, + ) + assert Sub.objects.filter(sub="123456").exists() + assert user.username == "user" + assert user.email == "test@test.com" + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_get_user_through_sub(): + created = get_user_model().objects.create(username="user1", email="test@test.com") + Sub.objects.create(user=created, _provider="no-pkce", sub="123456") + got = get_user( + provider=oauth2_settings["no-pkce"], + userinfo={ + "sub": "123456", + "preferred_username": "user2", # Use another username to ensure sub is used + "email": "test@test.com", + }, + ) + assert created.pk == got.pk diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py new file mode 100644 index 0000000..fc1960d --- /dev/null +++ b/tests/test_viewsets.py @@ -0,0 +1,238 @@ +from datetime import timedelta +from types import SimpleNamespace +from unittest.mock import patch +from urllib.parse import parse_qs, urlparse + +import pytest +from django.urls import reverse +from django.utils import timezone +from rest_framework.test import APIClient + +from simple_oauth2.exceptions import SimpleOAuth2Error +from simple_oauth2.models import Session +from simple_oauth2.settings import oauth2_settings +from tests.conftest import SIMPLE_OAUTH2_SETTINGS, override_oauth2_settings + + +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_url_no_provider(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url")) + assert response.status_code == 400 + assert response.json() == { + "provider": f"OAuth2 provider required, allowed values are: '{', '.join(oauth2_settings.keys())}'" + } + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_url_unknown_provider(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "unknown"}) + assert response.status_code == 400 + assert response.json() == { + "provider": f"Unknown OAuth2 provider 'unknown', allowed values are: '{', '.join(oauth2_settings.keys())}'" + } + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_url_pkce_unknown_alg(): + client = APIClient() + with pytest.raises(SimpleOAuth2Error, match="Unknown challenge method: 'unknown'"): + client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "pkce-unknown-alg"}) + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_extra_params_url(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "extra-params"}) + assert response.status_code == 200, response.json() + assert "url" in response.json() + url = response.json()["url"] + query_params = parse_qs(urlparse(url).query) + assert query_params["client_id"] == ["extra-params"] + assert query_params["response_type"] == ["code"] + assert query_params["redirect_uri"] == ["https://example.com/callback"] + assert query_params["scope"] == ["openid profile email"] + assert "state" in query_params + assert "nonce" in query_params + assert "foo" in query_params + assert query_params["foo"] == ["bar"] + assert "baz" in query_params + assert query_params["baz"] == ["qux"] + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_token_unknown_session(): + client = APIClient() + response = client.post( + reverse("simple_oauth2:oauth2-token"), {"provider": "pkce", "state": "unknown", "code": "code"} + ) + assert response.status_code == 400 + assert response.json() == {"state": "No ongoing session matches the provided state."} + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_token_expired(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "no-pkce"}) + state = parse_qs(urlparse(response.json()["url"]).query)["state"][0] + Session.objects.filter(state=state).update(created_at=timezone.now() - timedelta(days=1)) + response = client.post( + reverse("simple_oauth2:oauth2-token"), {"provider": "no-pkce", "state": state, "code": "no-pkce-code"} + ) + assert response.status_code == 400 + assert response.json() == {"__all__": "Authorization session has expired."} + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_token_fails(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "token-fails"}) + state = parse_qs(urlparse(response.json()["url"]).query)["state"][0] + response = client.post( + reverse("simple_oauth2:oauth2-token"), {"provider": "token-fails", "state": state, "code": "token-fails-code"} + ) + assert response.status_code == 400 + assert response.json() == {"__all__": 'Failed to retrieve tokens from provider: {"detail": "error"}'} + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_userinfo_fails(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "userinfo-fails"}) + state = parse_qs(urlparse(response.json()["url"]).query)["state"][0] + response = client.post( + reverse("simple_oauth2:oauth2-token"), + {"provider": "userinfo-fails", "state": state, "code": "userinfo-fails-code"}, + ) + assert response.status_code == 400 + assert response.json() == {"__all__": 'Failed to fetch user info from the provider: {"detail": "error"}.'} + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_without_pkce_url(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "no-pkce"}) + assert response.status_code == 200 + assert "url" in response.json() + url = response.json()["url"] + query_params = parse_qs(urlparse(url).query) + assert query_params["client_id"] == ["no-pkce"] + assert query_params["response_type"] == ["code"] + assert query_params["redirect_uri"] == ["https://example.com/callback"] + assert query_params["scope"] == ["openid profile email"] + assert "state" in query_params + assert "nonce" in query_params + assert "code_challenge" not in query_params + assert "code_challenge_method" not in query_params + + +@pytest.mark.django_db +@patch("jwt.jwks_client.PyJWKClient.get_signing_key", side_effect=lambda token: SimpleNamespace(key="key")) +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_without_pkce_token(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "no-pkce"}) + state = parse_qs(urlparse(response.json()["url"]).query)["state"][0] + response = client.post( + reverse("simple_oauth2:oauth2-token"), {"provider": "no-pkce", "state": state, "code": "no-pkce-code"} + ) + assert response.status_code == 200, response.json() + data = response.json() + assert "api" in data + assert "access" in data["api"] + assert "refresh" in data["api"] + assert "provider" in data + assert "logout_url" in data["provider"] + assert "id_token" in data["provider"] + assert "access_token" in data["provider"] + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_with_pkce_plain_url(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "pkce-plain"}) + assert response.status_code == 200 + assert "url" in response.json() + url = response.json()["url"] + query_params = parse_qs(urlparse(url).query) + assert query_params["client_id"] == ["pkce-plain"] + assert query_params["response_type"] == ["code"] + assert query_params["redirect_uri"] == ["https://example.com/callback"] + assert query_params["scope"] == ["openid profile email"] + assert "state" in query_params + assert "nonce" in query_params + assert "code_challenge" in query_params + assert "code_challenge_method" in query_params + assert query_params["code_challenge_method"] == ["plain"] + + +@pytest.mark.django_db +@patch("jwt.jwks_client.PyJWKClient.get_signing_key", side_effect=lambda token: SimpleNamespace(key="key")) +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_with_pkce_plain_token(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "pkce-plain"}) + state = parse_qs(urlparse(response.json()["url"]).query)["state"][0] + response = client.post( + reverse("simple_oauth2:oauth2-token"), {"provider": "pkce-plain", "state": state, "code": "pkce-plain-code"} + ) + assert response.status_code == 200, response.json() + data = response.json() + assert "api" in data + assert "access" in data["api"] + assert "refresh" in data["api"] + assert "provider" in data + assert "logout_url" in data["provider"] + assert "id_token" in data["provider"] + assert "access_token" in data["provider"] + + +@pytest.mark.django_db +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_with_pkce_s256_url(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "pkce-s256"}) + assert response.status_code == 200 + assert "url" in response.json() + url = response.json()["url"] + query_params = parse_qs(urlparse(url).query) + assert query_params["client_id"] == ["pkce-s256"] + assert query_params["response_type"] == ["code"] + assert query_params["redirect_uri"] == ["https://example.com/callback"] + assert query_params["scope"] == ["openid profile email"] + assert "state" in query_params + assert "nonce" in query_params + assert "code_challenge" in query_params + assert "code_challenge_method" in query_params + assert query_params["code_challenge_method"] == ["s256"] + + +@pytest.mark.django_db +@patch("jwt.jwks_client.PyJWKClient.get_signing_key", side_effect=lambda token: SimpleNamespace(key="key")) +@override_oauth2_settings(SIMPLE_OAUTH2_SETTINGS) +def test_oauth2_with_pkce_s256_token(): + client = APIClient() + response = client.get(reverse("simple_oauth2:oauth2-url"), {"provider": "pkce-s256"}) + state = parse_qs(urlparse(response.json()["url"]).query)["state"][0] + response = client.post( + reverse("simple_oauth2:oauth2-token"), {"provider": "pkce-s256", "state": state, "code": "pkce-s256-code"} + ) + assert response.status_code == 200, response.json() + data = response.json() + assert "api" in data + assert "access" in data["api"] + assert "refresh" in data["api"] + assert "provider" in data + assert "logout_url" in data["provider"] + assert "id_token" in data["provider"] + assert "access_token" in data["provider"]