diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 46f9d911..fae4a611 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -21,6 +21,9 @@ jobs: pip install -r requirements/requirements.txt pip install -r requirements/dev-requirements.txt pip install -r requirements/test-requirements.txt + pip install -r requirements/extras/authenticator.txt + pip install -r requirements/extras/pynacl.txt + - name: Lint run: ./scripts/lint.sh @@ -59,6 +62,8 @@ jobs: python -m pip install --upgrade pip pip install -r requirements/requirements.txt pip install -r requirements/test-requirements.txt + pip install -r requirements/extras/authenticator.txt + pip install -r requirements/extras/pynacl.txt - name: Test with pytest, Postgres run: ./scripts/test-postgres.sh env: @@ -86,6 +91,8 @@ jobs: python -m pip install --upgrade pip pip install -r requirements/requirements.txt pip install -r requirements/test-requirements.txt + pip install -r requirements/extras/authenticator.txt + pip install -r requirements/extras/pynacl.txt - name: Test with pytest, SQLite run: ./scripts/test-sqlite.sh - name: Upload coverage diff --git a/.gitignore b/.gitignore index 21e7046b..1e200d95 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ docs/source/_build/ example_projects/token_auth/ .env/ .venv/ + +# Playwright +videos/ diff --git a/docs/source/encryption/index.rst b/docs/source/encryption/index.rst new file mode 100644 index 00000000..6d60ed27 --- /dev/null +++ b/docs/source/encryption/index.rst @@ -0,0 +1,8 @@ +Encryption +========== + +.. toctree:: + :maxdepth: 1 + + ./introduction + ./providers diff --git a/docs/source/encryption/introduction.rst b/docs/source/encryption/introduction.rst new file mode 100644 index 00000000..4f4a5a06 --- /dev/null +++ b/docs/source/encryption/introduction.rst @@ -0,0 +1,6 @@ +Introduction +============ + +Piccolo API provides some wrappers around popular encryption libraries. + +These are current used by :ref:`Multifactor Authentication `. diff --git a/docs/source/encryption/providers.rst b/docs/source/encryption/providers.rst new file mode 100644 index 00000000..3ddfbd9c --- /dev/null +++ b/docs/source/encryption/providers.rst @@ -0,0 +1,69 @@ +Providers +========= + +.. currentmodule:: piccolo_api.encryption.providers + +``EncryptionProvider`` +---------------------- + +.. autoclass:: EncryptionProvider + +``FernetProvider`` +------------------ + +.. autoclass:: FernetProvider + +``PlainTextProvider`` +--------------------- + +.. autoclass:: PlainTextProvider + +``XChaCha20Provider`` +--------------------- + +.. autoclass:: XChaCha20Provider + +------------------------------------------------------------------------------- + +Dependencies +------------ + +When first using some of the providers, you will be prompted to install the +underlying encryption library. + +For example, with ``XChaCha20Provider``, you need to install ``pynacl`` as +follows: + +.. code-block:: bash + + pip install piccolo_api[pynacl] + +------------------------------------------------------------------------------- + +Example usage +------------- + +All of the providers work the same (except their parameters may be different). + +Here's an example using ``XChaCha20Provider``: + +.. code-block:: python + + >>> from piccolo_api.encryption.providers import XChaCha20Provider + + >>> encryption_key = XChaCha20Provider.get_new_key() + >>> provider = XChaCha20Provider(encryption_key=encryption_key) + + >>> encrypted = provider.encrypt("hello world") + >>> print(provider.decrypt(encrypted)) + "hello world" + +------------------------------------------------------------------------------- + +Which provider to use? +---------------------- + +``XChaCha20Provider`` is the most secure. + +You may decide to use ``FernetProvider`` if you already have the Python +``cryptography`` library as a dependency in your project. diff --git a/docs/source/index.rst b/docs/source/index.rst index bd4538f1..a054638e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,6 +26,7 @@ ASGI app, covering authentication, security, and more. ./csp/index ./csrf/index + ./encryption/index ./rate_limiting/index .. toctree:: @@ -35,6 +36,7 @@ ASGI app, covering authentication, security, and more. ./which_authentication/index ./jwt/index ./session_auth/index + ./mfa/index ./token_auth/index ./register/index ./change_password/index diff --git a/docs/source/mfa/endpoints.rst b/docs/source/mfa/endpoints.rst new file mode 100644 index 00000000..c8ba4cf7 --- /dev/null +++ b/docs/source/mfa/endpoints.rst @@ -0,0 +1,21 @@ +Endpoints +========= + +You must mount these ASGI endpoints in your app. + +.. currentmodule:: piccolo_api.mfa.endpoints + +``mfa_setup`` +------------------------- + +.. autofunction:: mfa_setup + +.. image:: images/mfa_register_endpoint.jpg + + +``session_login`` +----------------- + +Make sure you pass the ``mfa_providers`` argument to +:func:`session_login `, +so it knows to look for an MFA token. diff --git a/docs/source/mfa/example.rst b/docs/source/mfa/example.rst new file mode 100644 index 00000000..ae528e2e --- /dev/null +++ b/docs/source/mfa/example.rst @@ -0,0 +1,13 @@ +Full Example +============ + +Let's look at what an entire app looks like, which uses session auth, along +with MFA (using the Authenticator provider). + +------------------------------------------------------------------------------- + +Starlette +--------- + +.. include:: ../../../example_projects/mfa_demo/app.py + :code: python diff --git a/docs/source/mfa/images/mfa_register_endpoint.jpg b/docs/source/mfa/images/mfa_register_endpoint.jpg new file mode 100644 index 00000000..894082b6 Binary files /dev/null and b/docs/source/mfa/images/mfa_register_endpoint.jpg differ diff --git a/docs/source/mfa/index.rst b/docs/source/mfa/index.rst new file mode 100644 index 00000000..8ae65be7 --- /dev/null +++ b/docs/source/mfa/index.rst @@ -0,0 +1,13 @@ +.. _MFA: + +Multi-Factor Authentication +=========================== + +.. toctree:: + :maxdepth: 1 + + ./introduction + ./endpoints + ./providers + ./tables + ./example diff --git a/docs/source/mfa/introduction.rst b/docs/source/mfa/introduction.rst new file mode 100644 index 00000000..ce09290a --- /dev/null +++ b/docs/source/mfa/introduction.rst @@ -0,0 +1,13 @@ +Introduction +============ + +What is Multi-Factor Authentication (MFA)? +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MFA provides additional security to :ref:`SessionAuth`. + +As well as needing a username and password to login, the user must provide an +additional piece of information. + +One of the most popular ways of doing this is by providing a code generated by +an authenticator app on the user's phone. diff --git a/docs/source/mfa/providers.rst b/docs/source/mfa/providers.rst new file mode 100644 index 00000000..0c3e4791 --- /dev/null +++ b/docs/source/mfa/providers.rst @@ -0,0 +1,23 @@ +Providers +========= + +Most of the MFA code is fairly generic, but ``Providers`` implement the logic +which is specific to its particular authentication mechanism. + +For example, ``AuthenticatorProvider`` knows how to authenticate tokens which +come from an authenticator app on a user's phone, and knows how to generate new +secrets which allow users to enable MFA. + +.. currentmodule:: piccolo_api.mfa.provider + +``MFAProvider`` +--------------- + +.. autoclass:: MFAProvider + +.. currentmodule:: piccolo_api.mfa.authenticator.provider + +``AuthenticatorProvider`` +------------------------- + +.. autoclass:: AuthenticatorProvider diff --git a/docs/source/mfa/tables.rst b/docs/source/mfa/tables.rst new file mode 100644 index 00000000..c497dfb5 --- /dev/null +++ b/docs/source/mfa/tables.rst @@ -0,0 +1,35 @@ +Tables +====== + +``AuthenticatorSecret`` +----------------------- + +This is required by :class:`AuthenticatorProvider `. + +To create this table, you can using Piccolo's migrations. + +Add ``piccolo_api.mfa.authenticator.piccolo_app`` to ``APP_REGISTRY`` in +``piccolo_conf.py``: + +.. code-block:: python + + APP_REGISTRY = AppRegistry( + apps=[ + "piccolo_api.mfa.authenticator.piccolo_app", + ... + ] + ) + +Then run the migrations: + +.. code-block:: bash + + piccolo migrations forwards mfa_authenticator + +Alternatively, if not using Piccolo migrations, you can create the table +manually: + +.. code-block:: pycon + + >>> from piccolo_api.mfa.authenticator.table import AuthenticatorProvider + >>> AuthenticatorProvider.create_table().run_sync() diff --git a/e2e/__init__.py b/e2e/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/e2e/conftest.py b/e2e/conftest.py new file mode 100644 index 00000000..84def3cb --- /dev/null +++ b/e2e/conftest.py @@ -0,0 +1,62 @@ +import os +import time +from http.client import HTTPConnection +from subprocess import Popen + +import pytest + +HOST = "localhost" +PORT = 8000 +BASE_URL = f"http://{HOST}:{PORT}" + + +@pytest.fixture +def browser_context_args(): + return {"record_video_dir": "videos/"} + + +@pytest.fixture +def context(context): + # We don't need a really long timeout. + # The timeout determines how long Playwright waits for a HTML element to + # become available. + # By default it's 30 seconds, which is way too long when testing an app + # locally. + context.set_default_timeout(10000) + yield context + + +@pytest.fixture +def mfa_app(): + """ + Running dev server and Playwright test in parallel. + More info https://til.simonwillison.net/pytest/playwright-pytest + """ + path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "example_projects", + "mfa_demo", + ) + + process = Popen( + ["python", "-m", "main", "--reset-db"], + cwd=path, + ) + retries = 5 + while retries > 0: + conn = HTTPConnection(f"{HOST}:{PORT}") + try: + conn.request("HEAD", "/") + response = conn.getresponse() + if response is not None: + yield process + break + except ConnectionRefusedError: + time.sleep(1) + retries -= 1 + + if not retries: + raise RuntimeError("Failed to start http server") + else: + process.terminate() + process.wait() diff --git a/e2e/pages.py b/e2e/pages.py new file mode 100644 index 00000000..236a466c --- /dev/null +++ b/e2e/pages.py @@ -0,0 +1,66 @@ +""" +By using pages we can make out test more scalable. + +https://playwright.dev/docs/pom +""" + +from playwright.sync_api import Page + +USERNAME = "piccolo" +PASSWORD = "piccolo123" + + +class LoginPage: + url = "http://localhost:8000/login/" + + def __init__(self, page: Page): + self.page = page + self.username_input = page.locator('input[name="username"]') + self.password_input = page.locator('input[name="password"]') + self.button = page.locator("button") + + def reset(self): + self.page.goto(self.url) + + def login(self, username: str = USERNAME, password: str = PASSWORD): + self.username_input.fill(username) + self.password_input.fill(password) + self.button.click() + + +class RegisterPage: + url = "http://localhost:8000/register/" + + def __init__(self, page: Page): + self.page = page + self.username_input = page.locator("[name=username]") + self.email_input = page.locator("[name=email]") + self.password_input = page.locator("[name=password]") + self.confirm_password_input = page.locator("[name=confirm_password]") + self.button = page.locator("button") + + def reset(self): + self.page.goto(self.url) + + def login(self, username: str = USERNAME, password: str = PASSWORD): + self.username_input.fill(username) + self.email_input.fill("test@piccolo-orm.com") + self.password_input.fill(password) + self.confirm_password_input.fill(password) + self.button.click() + + +class MFASetupPage: + url = "http://localhost:8000/private/mfa-setup/" + + def __init__(self, page: Page): + self.page = page + self.password_input = page.locator("[name=password]") + self.button = page.locator("button") + + def reset(self): + self.page.goto(self.url) + + def register(self, password: str = PASSWORD): + self.password_input.fill(password) + self.button.click() diff --git a/e2e/test_mfa.py b/e2e/test_mfa.py new file mode 100644 index 00000000..ba8c2c6b --- /dev/null +++ b/e2e/test_mfa.py @@ -0,0 +1,29 @@ +from playwright.async_api import Page + +from .pages import LoginPage, MFASetupPage, RegisterPage + + +def test_mfa_signup(page: Page, mfa_app): + """ + Make sure we create an account and sign up for MFA. + """ + register_page = RegisterPage(page=page) + register_page.reset() + register_page.login() + + login_page = LoginPage(page=page) + login_page.reset() + login_page.login() + + mfa_setup_page = MFASetupPage(page=page) + mfa_setup_page.reset() + + # Test an incorrect password + # TODO - assert response code is correct + mfa_setup_page.register(password="fake_password_123") + + # Test the correct password + # TODO - make sure it navigated to the right page + mfa_setup_page.register() + + mfa_setup_page.reset() diff --git a/example_projects/__init__.py b/example_projects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/example_projects/mfa_demo/README.md b/example_projects/mfa_demo/README.md new file mode 100644 index 00000000..571779a8 --- /dev/null +++ b/example_projects/mfa_demo/README.md @@ -0,0 +1,28 @@ +# MFA demo + +This project demos how to use the MFA with the `session_login` endpoint. + +## Setup + +### Install requirements + +```bash +pip install -r requirements.txt +``` + +### Create database + +Make sure a Postgres database exists, called 'piccolo_api_mfa'. See +`piccolo_conf.py` for the full details. + +### Run migrations + +``` +piccolo migrations forwards all +``` + +## Run the app + +```bash +python main.py +``` diff --git a/example_projects/mfa_demo/app.py b/example_projects/mfa_demo/app.py new file mode 100644 index 00000000..257f861e --- /dev/null +++ b/example_projects/mfa_demo/app.py @@ -0,0 +1,102 @@ +import os + +from jinja2 import Environment, FileSystemLoader +from starlette.applications import Starlette +from starlette.endpoints import HTTPEndpoint +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import Request +from starlette.responses import HTMLResponse, RedirectResponse +from starlette.routing import Mount, Route + +from piccolo_api.csrf.middleware import CSRFMiddleware +from piccolo_api.encryption.providers import XChaCha20Provider +from piccolo_api.mfa.authenticator.provider import AuthenticatorProvider +from piccolo_api.mfa.endpoints import mfa_setup +from piccolo_api.register.endpoints import register +from piccolo_api.session_auth.endpoints import session_login, session_logout +from piccolo_api.session_auth.middleware import SessionsAuthBackend + +EXAMPLE_DB_ENCRYPTION_KEY = b"W\x8b&E[\x8elr\xba\xb7\x19g\n\xd5`g\xea!Q#\x97\xcf\xed\xdd+\xc7\x0e\xf7P\x82\xdf\x86" # noqa: E501 + + +environment = Environment( + loader=FileSystemLoader( + os.path.join(os.path.dirname(__file__), "templates"), + ), + autoescape=True, +) + + +class HomeEndpoint(HTTPEndpoint): + async def get(self, request): + home_template = environment.get_template("home.html") + + return HTMLResponse(content=home_template.render()) + + +class PrivateEndpoint(HTTPEndpoint): + async def get(self, request): + return HTMLResponse( + content=( + "" + "

Private page

" + ) + ) + + +def on_auth_error(request: Request, exc: Exception): + return RedirectResponse("/login/") + + +private_app = Starlette( + routes=[ + Route("/", PrivateEndpoint), + Route("/logout/", session_logout(redirect_to="/")), + Route( + "/mfa-setup/", + mfa_setup( + provider=AuthenticatorProvider( + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ) + ) + ), + ), + ], + middleware=[ + Middleware( + AuthenticationMiddleware, + on_error=on_auth_error, + backend=SessionsAuthBackend(admin_only=False), + ), + ], + debug=True, +) + + +app = Starlette( + routes=[ + Route("/", HomeEndpoint), + Route( + "/login/", + session_login( + mfa_providers=[ + AuthenticatorProvider( + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ) + ) + ] + ), + ), + Route( + "/register/", + register(redirect_to="/login/", user_defaults={"active": True}), + ), + Mount("/private/", private_app), + ], + middleware=[ + Middleware(CSRFMiddleware, allow_form_param=True), + ], +) diff --git a/example_projects/mfa_demo/main.py b/example_projects/mfa_demo/main.py new file mode 100644 index 00000000..f28f4e1a --- /dev/null +++ b/example_projects/mfa_demo/main.py @@ -0,0 +1,28 @@ +import os +import sys + +# Modify the path, so piccolo_api is available +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + +def reset_db(): + print("Resetting DB ...") + + from piccolo.apps.user.tables import BaseUser + + from piccolo_api.mfa.authenticator.tables import AuthenticatorSecret + from piccolo_api.session_auth.tables import SessionsBase + + BaseUser.delete(force=True).run_sync() + AuthenticatorSecret.delete(force=True).run_sync() + SessionsBase.delete(force=True).run_sync() + + +if __name__ == "__main__": + + if "--reset-db" in sys.argv: + reset_db() + + import uvicorn + + uvicorn.run("app:app", reload=True) diff --git a/example_projects/mfa_demo/piccolo_conf.py b/example_projects/mfa_demo/piccolo_conf.py new file mode 100644 index 00000000..8825471d --- /dev/null +++ b/example_projects/mfa_demo/piccolo_conf.py @@ -0,0 +1,26 @@ +import os +import sys + +from piccolo.conf.apps import AppRegistry +from piccolo.engine.postgres import PostgresEngine + +# Modify the path, so piccolo_api is available +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + +DB = PostgresEngine( + config={ + "database": "piccolo_api_mfa", + "user": "postgres", + "password": "", + "host": "localhost", + "port": 5432, + } +) + +APP_REGISTRY = AppRegistry( + apps=[ + "piccolo.apps.user.piccolo_app", + "piccolo_api.session_auth.piccolo_app", + "piccolo_api.mfa.authenticator.piccolo_app", + ] +) diff --git a/example_projects/mfa_demo/requirements.txt b/example_projects/mfa_demo/requirements.txt new file mode 100644 index 00000000..1ca5e39e --- /dev/null +++ b/example_projects/mfa_demo/requirements.txt @@ -0,0 +1,5 @@ +starlette +uvicorn[all] +piccolo[postgres] +httpx +python-multipart diff --git a/example_projects/mfa_demo/templates/home.html b/example_projects/mfa_demo/templates/home.html new file mode 100644 index 00000000..25d5c5ed --- /dev/null +++ b/example_projects/mfa_demo/templates/home.html @@ -0,0 +1,24 @@ + + + + + + Home + + + + + +

MFA Demo

+

First register

+

Then login

+

Then sign up for MFA

+

Then try the private page

+

And logout

+ + + diff --git a/piccolo_api/encryption/__init__.py b/piccolo_api/encryption/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/piccolo_api/encryption/providers.py b/piccolo_api/encryption/providers.py new file mode 100644 index 00000000..fe1eaf58 --- /dev/null +++ b/piccolo_api/encryption/providers.py @@ -0,0 +1,221 @@ +from __future__ import annotations + +import logging +import typing as t +from abc import ABCMeta, abstractmethod + +if t.TYPE_CHECKING: + import cryptography + import nacl + + +logger = logging.getLogger(__name__) + + +def get_cryptography() -> cryptography: # type: ignore + try: + import cryptography + except ImportError as e: + print( + "Install pip install piccolo_api[cryptography] to use this " + "feature." + ) + raise e + + return cryptography + + +class EncryptionProvider(metaclass=ABCMeta): + """ + Base class for encryption providers. Don't use it directly, it must be + subclassed. + """ + + def __init__(self, prefix: str): + self.prefix = prefix + + @abstractmethod + def encrypt(self, value: str, add_prefix: bool = True) -> str: + """ + :param value: + The value to encrypt. + :param add_prefix: + For example, with ``FernetProvider``, it will return a value like: + ``'fernet-abc123'`` if ``add_prefix=True``. It can be useful to + have some idea of how the value was encrypted if stored in a + database. + + """ + raise NotImplementedError() + + @abstractmethod + def decrypt(self, encrypted_value: str, has_prefix: bool = True) -> str: + """ + :param encrypted_value: + The value to decrypt. + :param has_prefix: + If the value has a prefix or not, indicating the algorithm used, + i.e. ``'fernet-abc123'`` or just ``'abc123'``. + + """ + raise NotImplementedError() + + def remove_prefix(self, encrypted_value: str) -> str: + if encrypted_value.startswith(self.prefix): + return encrypted_value.replace(f"{self.prefix}-", "", 1) + else: + raise ValueError( + "Unable to identify which encryption was used - if moving " + "to a new encryption provider, use " + "`migrate_encrypted_value`." + ) + + def add_prefix(self, encrypted_value: str) -> str: + return f"{self.prefix}-{encrypted_value}" + + +class PlainTextProvider(EncryptionProvider): + """ + The values aren't encrypted - can be useful for testing. + """ + + def __init__(self): + super().__init__(prefix="plain") + + def encrypt(self, value: str, add_prefix: bool = True) -> str: + return self.add_prefix(value) if add_prefix else value + + def decrypt(self, encrypted_value: str, has_prefix: bool = True) -> str: + return ( + self.remove_prefix(encrypted_value) + if has_prefix + else encrypted_value + ) + + +class FernetProvider(EncryptionProvider): + + def __init__(self, encryption_key: bytes): + """ + Uses the Fernet algorithm for encryption. + + :param encryption_key: + This can be generated using ``FernetEncryption.get_new_key()``. + + """ + self.encryption_key = encryption_key + super().__init__(prefix="fernet") + + @staticmethod + def get_new_key() -> bytes: + cryptography = get_cryptography() + return cryptography.fernet.Fernet.generate_key() # type: ignore + + def encrypt(self, value: str, add_prefix: bool = True) -> str: + cryptography = get_cryptography() + fernet = cryptography.fernet.Fernet( # type: ignore + self.encryption_key + ) + encrypted_value = fernet.encrypt(value.encode("utf-8")).decode("utf-8") + return ( + self.add_prefix(encrypted_value=encrypted_value) + if add_prefix + else encrypted_value + ) + + def decrypt(self, encrypted_value: str, has_prefix: bool = True) -> str: + if has_prefix: + encrypted_value = self.remove_prefix(encrypted_value) + + cryptography = get_cryptography() + + fernet = cryptography.fernet.Fernet( # type: ignore + self.encryption_key + ) + return fernet.decrypt(encrypted_value.encode("utf-8")).decode("utf-8") + + +def get_nacl_encoding() -> nacl.encoding: # type: ignore + try: + import nacl.encoding + except ImportError as e: + print("Install pip install piccolo_api[pynacl] to use this feature.") + raise e + + return nacl.encoding + + +def get_nacl_utils() -> nacl.utils: # type: ignore + try: + import nacl.utils + except ImportError as e: + print("Install pip install piccolo_api[pynacl] to use this feature.") + raise e + + return nacl.utils + + +def get_nacl_secret() -> nacl.secret: # type: ignore + try: + import nacl.secret + except ImportError as e: + print("Install pip install piccolo_api[pynacl] to use this feature.") + raise e + + return nacl.secret + + +class XChaCha20Provider(EncryptionProvider): + + def __init__(self, encryption_key: bytes): + """ + Uses the XChaCha20-Poly1305 algorithm for encryption. + + This is more secure than ``FernetProvider``. + + :param encryption_key: + This can be generated using ``XChaCha20Provider.get_new_key()``. + + """ + self.encryption_key = encryption_key + super().__init__(prefix="xchacha20") + + @staticmethod + def get_new_key() -> bytes: + nacl_utils = get_nacl_utils() + return nacl_utils.random(nacl.secret.Aead.KEY_SIZE) # type: ignore + + def _get_nacl_box(self) -> nacl.secret.Aead: + nacl_secret = get_nacl_secret() + return nacl_secret.Aead(self.encryption_key) # type: ignore + + def encrypt(self, value: str, add_prefix: bool = True) -> str: + box = self._get_nacl_box() + + encrypted_value = box.encrypt(value.encode()).hex() + + return ( + self.add_prefix(encrypted_value=encrypted_value) + if add_prefix + else encrypted_value + ) + + def decrypt(self, encrypted_value: str, has_prefix: bool = True) -> str: + if has_prefix: + encrypted_value = self.remove_prefix(encrypted_value) + + box = self._get_nacl_box() + + return box.decrypt(bytes.fromhex(encrypted_value)).decode("utf-8") + + +def migrate_encrypted_value( + old_provider: EncryptionProvider, + new_provider: EncryptionProvider, + encrypted_value: str, +): + """ + If you're migrating from one form of encryption to another, you can use + this utility. + """ + return new_provider.encrypt(old_provider.decrypt(encrypted_value)) diff --git a/piccolo_api/mfa/README.md b/piccolo_api/mfa/README.md new file mode 100644 index 00000000..78b017f0 --- /dev/null +++ b/piccolo_api/mfa/README.md @@ -0,0 +1,4 @@ +# MFA + +Multi Factor Authentication - currently using an authenticator app on a mobile +device. diff --git a/piccolo_api/mfa/__init__.py b/piccolo_api/mfa/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/piccolo_api/mfa/authenticator/__init__.py b/piccolo_api/mfa/authenticator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/piccolo_api/mfa/authenticator/piccolo_app.py b/piccolo_api/mfa/authenticator/piccolo_app.py new file mode 100644 index 00000000..2b57d90a --- /dev/null +++ b/piccolo_api/mfa/authenticator/piccolo_app.py @@ -0,0 +1,23 @@ +""" +Import all of the Tables subclasses in your app here, and register them with +the APP_CONFIG. +""" + +import os + +from piccolo.conf.apps import AppConfig + +from .tables import AuthenticatorSecret + +CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) + + +APP_CONFIG = AppConfig( + app_name="mfa_authenticator", + migrations_folder_path=os.path.join( + CURRENT_DIRECTORY, "piccolo_migrations" + ), + table_classes=[AuthenticatorSecret], + migration_dependencies=[], + commands=[], +) diff --git a/piccolo_api/mfa/authenticator/piccolo_migrations/__init__.py b/piccolo_api/mfa/authenticator/piccolo_migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/piccolo_api/mfa/authenticator/piccolo_migrations/mfa_authenticator_2024_08_08t21_41_46_837552.py b/piccolo_api/mfa/authenticator/piccolo_migrations/mfa_authenticator_2024_08_08t21_41_46_837552.py new file mode 100644 index 00000000..e51ea4c0 --- /dev/null +++ b/piccolo_api/mfa/authenticator/piccolo_migrations/mfa_authenticator_2024_08_08t21_41_46_837552.py @@ -0,0 +1,213 @@ +from piccolo.apps.migrations.auto.migration_manager import MigrationManager +from piccolo.columns.column_types import Array, Integer, Text, Timestamptz +from piccolo.columns.defaults.timestamptz import TimestamptzNow +from piccolo.columns.indexes import IndexMethod + +ID = "2024-08-08T21:41:46:837552" +VERSION = "1.16.0" +DESCRIPTION = "Add AuthenticatorSecret table" + + +async def forwards(): + manager = MigrationManager( + migration_id=ID, app_name="mfa_authenticator", description=DESCRIPTION + ) + + manager.add_table( + class_name="AuthenticatorSecret", + tablename="authenticator_secret", + schema=None, + columns=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="user_id", + db_column_name="user_id", + column_class_name="Integer", + column_class=Integer, + params={ + "default": 0, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="secret", + db_column_name="secret", + column_class_name="Text", + column_class=Text, + params={ + "default": "", + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": True, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="recovery_codes", + db_column_name="recovery_codes", + column_class_name="Array", + column_class=Array, + params={ + "base_column": Text( + default="", + null=False, + primary_key=False, + unique=False, + index=False, + index_method=IndexMethod.btree, + choices=None, + db_column_name=None, + secret=False, + ), + "default": list, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": True, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="recovery_codes_used_at", + db_column_name="recovery_codes_used_at", + column_class_name="Array", + column_class=Array, + params={ + "base_column": Timestamptz( + default=TimestamptzNow(), + null=False, + primary_key=False, + unique=False, + index=False, + index_method=IndexMethod.btree, + choices=None, + db_column_name=None, + secret=False, + ), + "default": list, + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="created_at", + db_column_name="created_at", + column_class_name="Timestamptz", + column_class=Timestamptz, + params={ + "default": TimestamptzNow(), + "null": False, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="revoked_at", + db_column_name="revoked_at", + column_class_name="Timestamptz", + column_class=Timestamptz, + params={ + "default": None, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="last_used_at", + db_column_name="last_used_at", + column_class_name="Timestamptz", + column_class=Timestamptz, + params={ + "default": None, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + manager.add_column( + table_class_name="AuthenticatorSecret", + tablename="authenticator_secret", + column_name="last_used_code", + db_column_name="last_used_code", + column_class_name="Text", + column_class=Text, + params={ + "default": None, + "null": True, + "primary_key": False, + "unique": False, + "index": False, + "index_method": IndexMethod.btree, + "choices": None, + "db_column_name": None, + "secret": False, + }, + schema=None, + ) + + return manager diff --git a/piccolo_api/mfa/authenticator/provider.py b/piccolo_api/mfa/authenticator/provider.py new file mode 100644 index 00000000..f87e5677 --- /dev/null +++ b/piccolo_api/mfa/authenticator/provider.py @@ -0,0 +1,154 @@ +import os +import typing as t + +from jinja2 import Environment, FileSystemLoader +from piccolo.apps.user.tables import BaseUser + +from piccolo_api.encryption.providers import EncryptionProvider +from piccolo_api.mfa.authenticator.tables import AuthenticatorSecret +from piccolo_api.mfa.authenticator.utils import get_b64encoded_qr_image +from piccolo_api.mfa.provider import MFAProvider +from piccolo_api.shared.auth.styles import Styles + +MFA_SETUP_TEMPLATE_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))), + "templates", + "mfa_authenticator_setup.html", +) + + +class AuthenticatorProvider(MFAProvider): + + def __init__( + self, + encryption_provider: EncryptionProvider, + recovery_code_count: int = 8, + secret_table: t.Type[AuthenticatorSecret] = AuthenticatorSecret, + issuer_name: str = "Piccolo-MFA", + register_template_path: t.Optional[str] = None, + styles: t.Optional[Styles] = None, + valid_window: int = 0, + ): + """ + Allows authentication using an authenticator app on the user's phone, + like Google Authenticator. + + :param encryption_provider: + The shared secrets can be encrypted in the database. We recommend + using :class:`XChaCha20Provider `. + Use :class:`PlainTextProvider ` + to store the secrets as plain text. + :param recovery_code_count: + How many recovery codes should be generated. + :param secret_table: + This is the table used to store secrets. You shouldn't have to + override this, unless you subclassed the default + ``AuthenticatorSecret`` table for some reason. + :param issuer_name: + This is how it will be identified in the user's authenticator app. + :param register_template_path: + You can override the HTML template if you want. Try using the + ``styles`` param instead though if possible if you just want basic + visual changes. + :param styles: + Modify the appearance of the HTML template using CSS. + :param valid_window: + Extends the validity to this many counter ticks before and after + the current one. Increasing it is more convenient for users, but + is less secure. + + """ # noqa: E501 + super().__init__( + name="Authenticator App", + ) + + self.encryption_provider = encryption_provider + self.recovery_code_count = recovery_code_count + self.secret_table = secret_table + self.issuer_name = issuer_name + self.styles = styles or Styles() + self.valid_window = valid_window + + # Load the Jinja Template + register_template_path = ( + register_template_path or MFA_SETUP_TEMPLATE_PATH + ) + directory, filename = os.path.split(register_template_path) + environment = Environment( + loader=FileSystemLoader(directory), autoescape=True + ) + self.register_template = environment.get_template(filename) + + async def authenticate_user(self, user: BaseUser, code: str) -> bool: + """ + The code could be a TOTP code, or a recovery code. + """ + return await self.secret_table.authenticate( + user_id=user.id, + code=code, + encryption_provider=self.encryption_provider, + valid_window=self.valid_window, + ) + + async def is_user_enrolled(self, user: BaseUser) -> bool: + return await self.secret_table.is_user_enrolled(user_id=user.id) + + async def send_code(self, *args, **kwargs) -> bool: + """ + Deliberately blank - the user already has the code on their phone. + """ + return False + + ########################################################################### + # Registration + + async def _generate_qrcode_image( + self, secret: AuthenticatorSecret, email: str + ): + uri = secret.get_authentication_setup_uri( + email=email, + encryption_provider=self.encryption_provider, + issuer_name=self.issuer_name, + ) + + return get_b64encoded_qr_image(data=uri) + + async def get_registration_html(self, user: BaseUser) -> str: + """ + When a user wants to register for MFA, this HTML is shown containing + instructions. + """ + secret, recovery_codes = await self.secret_table.create_new( + user_id=user.id, + encryption_provider=self.encryption_provider, + recovery_code_count=self.recovery_code_count, + ) + + qrcode_image = await self._generate_qrcode_image( + secret=secret, email=user.email + ) + + return self.register_template.render( + qrcode_image=qrcode_image, + recovery_codes=recovery_codes, + recovery_codes_str="\n".join(recovery_codes), + styles=self.styles, + ) + + async def get_registration_json(self, user: BaseUser) -> dict: + """ + When a user wants to register for MFA, the client can request a JSON + response, rather than HTML, if they want to render the UI themselves. + """ + secret, recovery_codes = await self.secret_table.create_new( + user_id=user.id, encryption_provider=self.encryption_provider + ) + + qrcode_image = await self._generate_qrcode_image( + secret=secret, email=user.email + ) + + return {"qrcode_image": qrcode_image, "recovery_codes": recovery_codes} + + async def delete_registration(self, user: BaseUser): + await self.secret_table.revoke(user_id=user.id) diff --git a/piccolo_api/mfa/authenticator/tables.py b/piccolo_api/mfa/authenticator/tables.py new file mode 100644 index 00000000..466ac2a6 --- /dev/null +++ b/piccolo_api/mfa/authenticator/tables.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import datetime +import logging +import typing as t + +from piccolo.apps.user.tables import BaseUser +from piccolo.columns import Array, Integer, Serial, Text, Timestamptz +from piccolo.table import Table + +from piccolo_api.encryption.providers import EncryptionProvider +from piccolo_api.mfa.recovery_codes import generate_recovery_code + +if t.TYPE_CHECKING: # pragma: no cover + import pyotp + + +logger = logging.getLogger(__name__) + + +def get_pyotp() -> pyotp: # type: ignore # pragma: no cover + try: + import pyotp + except ImportError as e: + print( + "Install pip install piccolo_api[authenticator] to use this " + "feature." + ) + raise e + + return pyotp + + +class AuthenticatorSecret(Table): + id: Serial + user_id = Integer(null=False) + secret = Text(secret=True) + recovery_codes = Array( + Text(), + help_text="Used to gain temporary access, if they lose their phone.", + secret=True, + ) + recovery_codes_used_at = Array( + Timestamptz(), + help_text="Whenever a recovery code is used, store a timestamp here.", + ) + created_at = Timestamptz() + revoked_at = Timestamptz( + null=True, + default=None, + help_text=( + "If set, this instance should be considered unusable for " + "authentication purposes." + ), + ) + last_used_at = Timestamptz(null=True, default=None) + last_used_code = Text( + null=True, + default=None, + help_text=( + "We store the last used code, to guard against replay attacks." + ), + ) + + @classmethod + def generate_secret(cls) -> str: + pyotp = get_pyotp() + return pyotp.random_base32() # type: ignore + + @classmethod + async def create_new( + cls, + user_id: int, + encryption_provider: EncryptionProvider, + recovery_code_count: int = 8, + ) -> t.Tuple[AuthenticatorSecret, t.List[str]]: + """ + Returns the new ``AuthenticatorSecret`` and the unhashed recovery + codes. This is the only time the unhashed recovery codes will be + accessible. + + :param user_id: + The user to create the secret for. + :param encryption_provider: + Determines how the secret is stored in the database. + :param recovery_code_count: + How many recovery codes to generate for the user - this allows + them to still gain access if their phone is lost. + + """ + # Generate recovery codes + + recovery_codes = [ + generate_recovery_code() for _ in range(recovery_code_count) + ] + + ####################################################################### + # Hash the recovery codes + + # Use the hashing logic from BaseUser. + # We want to use the same salt for all of the user's recovery codes, + # otherwise logging in using a recovery code will take a long time. + salt = BaseUser.get_salt() + + hashed_recovery_codes = [ + BaseUser.hash_password(password=recovery_code, salt=salt) + for recovery_code in recovery_codes + ] + + ####################################################################### + # Generate a shared secret + + secret = cls.generate_secret() + + # We'll encrypt the secret for storing in the database. + encrypted_secret = encryption_provider.encrypt(value=secret) + + ####################################################################### + + instance = cls( + { + cls.user_id: user_id, + cls.secret: encrypted_secret, + cls.recovery_codes: hashed_recovery_codes, + } + ) + await instance.save() + + return (instance, recovery_codes) + + @classmethod + async def revoke(cls, user_id: int): + now = datetime.datetime.now(tz=datetime.timezone.utc) + await cls.update({cls.revoked_at: now}).where( + cls.user_id == user_id, + cls.revoked_at.is_null(), + ) + + @classmethod + async def authenticate( + cls, + user_id: int, + code: str, + encryption_provider: EncryptionProvider, + valid_window: int = 0, + ) -> bool: + """ + :param valid_window: + Extends the validity to this many counter ticks before and after + the current one. + + """ + secret = ( + await cls.objects() + .where( + cls.user_id == user_id, + cls.revoked_at.is_null(), + ) + .order_by(cls.created_at, ascending=False) + .first() + ) + + if secret is None: + return False + + pyotp = get_pyotp() + + if secret.last_used_code == code: + logger.warning( + f"User {user_id} reused a token - potential replay attack." + ) + return False + + shared_secret = encryption_provider.decrypt( + encrypted_value=secret.secret + ) + totp = pyotp.TOTP(shared_secret) # type: ignore + + if totp.verify(code, valid_window=valid_window): + secret.last_used_at = datetime.datetime.now( + tz=datetime.timezone.utc + ) + secret.last_used_code = code + await secret.save(columns=[cls.last_used_at, cls.last_used_code]) + + return True + + ####################################################################### + # Check recovery code + + # Do a sanity check that it's roughly long enough. + if len(code) > 10 and (recovery_codes := secret.recovery_codes): + first_recovery_code = recovery_codes[0] + + # Get the algorithm, salt etc - they should be the same for each + # of the user's recovery codes, to save overhead. + _, iterations_, salt, _ = BaseUser.split_stored_password( + password=first_recovery_code + ) + + hashed_code = BaseUser.hash_password( + password=code, + salt=salt, + iterations=int(iterations_), + ) + + for recovery_code in recovery_codes: + if recovery_code == hashed_code: + # Remove the recovery code, and record when it was used. + secret.recovery_codes = [ + i for i in recovery_codes if i != recovery_code + ] + secret.recovery_codes_used_at.append( + datetime.datetime.now(tz=datetime.timezone.utc) + ) + await secret.save( + columns=[ + cls.recovery_codes, + cls.recovery_codes_used_at, + ] + ) + + return True + + return False + + @classmethod + async def is_user_enrolled(cls, user_id: int) -> bool: + return await cls.exists().where( + cls.user_id == user_id, cls.revoked_at.is_null() + ) + + def get_authentication_setup_uri( + self, + email: str, + encryption_provider: EncryptionProvider, + issuer_name: str = "Piccolo-MFA", + ) -> str: + pyotp = get_pyotp() + + shared_secret = encryption_provider.decrypt( + encrypted_value=self.secret + ) + + return pyotp.totp.TOTP(shared_secret).provisioning_uri( # type: ignore + name=email, issuer_name=issuer_name + ) diff --git a/piccolo_api/mfa/authenticator/utils.py b/piccolo_api/mfa/authenticator/utils.py new file mode 100644 index 00000000..b6c67521 --- /dev/null +++ b/piccolo_api/mfa/authenticator/utils.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import typing as t +from base64 import b64encode +from io import BytesIO + +if t.TYPE_CHECKING: # pragma: no cover + import qrcode + + +def get_qrcode() -> qrcode: # pragma: no cover + try: + import qrcode + except ImportError as e: + print( + "Install pip install piccolo_api[authenticator] to use this " + "feature." + ) + raise e + + return qrcode + + +def get_b64encoded_qr_image(data: str) -> str: + """ + Creates a QR code from ``data``, and returns a base64 PNG image, which can + be used in a HTML document as follows: + + .. code-block:: html + + + + """ + qrcode = get_qrcode() + + qr = qrcode.QRCode(version=1, box_size=4, border=5) + qr.add_data(data) + qr.make(fit=True) + img = qr.make_image(fill_color="black", back_color="white") + buffered = BytesIO() + img.save(buffered) + return b64encode(buffered.getvalue()).decode("utf-8") diff --git a/piccolo_api/mfa/endpoints.py b/piccolo_api/mfa/endpoints.py new file mode 100644 index 00000000..b4f7187b --- /dev/null +++ b/piccolo_api/mfa/endpoints.py @@ -0,0 +1,177 @@ +import os +import typing as t +from abc import ABCMeta, abstractmethod +from json import JSONDecodeError + +from jinja2 import Environment, FileSystemLoader +from piccolo.apps.user.tables import BaseUser +from starlette.endpoints import HTTPEndpoint +from starlette.requests import Request +from starlette.responses import HTMLResponse, JSONResponse + +from piccolo_api.mfa.provider import MFAProvider +from piccolo_api.shared.auth.styles import Styles + +TEMPLATE_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "templates", +) + + +environment = Environment( + loader=FileSystemLoader(TEMPLATE_PATH), autoescape=True +) + + +class MFASetupEndpoint(HTTPEndpoint, metaclass=ABCMeta): + + @property + @abstractmethod + def _provider(self) -> MFAProvider: + raise NotImplementedError + + @property + @abstractmethod + def _auth_table(self) -> t.Type[BaseUser]: + raise NotImplementedError + + @property + @abstractmethod + def _styles(self) -> Styles: + raise NotImplementedError + + def _render_register_template( + self, + request: Request, + extra_context: t.Optional[t.Dict] = None, + status_code: int = 200, + ): + template = environment.get_template("mfa_setup.html") + + return HTMLResponse( + status_code=status_code, + content=template.render( + styles=self._styles, + csrftoken=request.scope.get("csrftoken"), + **(extra_context or {}), + ), + ) + + def _render_cancel_template( + self, + request: Request, + ): + template = environment.get_template("mfa_cancel.html") + + return HTMLResponse( + status_code=400, + content=template.render( + styles=self._styles, + csrftoken=request.scope.get("csrftoken"), + ), + ) + + async def get(self, request: Request): + piccolo_user: BaseUser = request.user.user + + if await self._provider.is_user_enrolled(user=piccolo_user): + return self._render_cancel_template(request=request) + else: + return self._render_register_template(request=request) + + async def post(self, request: Request): + piccolo_user: BaseUser = request.user.user + + # Some middleware (for example CSRF) has already awaited the request + # body, and adds it to the request. + body: t.Any = request.scope.get("form") + + if not body: + try: + body = await request.json() + except JSONDecodeError: + body = await request.form() + + if action := body.get("action"): + if action == "register": + + ############################################################### + # If the user is already enrolled, don't proceed. + if await self._provider.is_user_enrolled(user=piccolo_user): + return self._render_cancel_template(request=request) + + ############################################################### + # Make sure the password is correct. + + password = body.get("password") + + if not password or not await self._auth_table.login( + username=piccolo_user.username, password=password + ): + return self._render_register_template( + request=request, + status_code=403, + extra_context={"error": "Incorrect password"}, + ) + + ############################################################### + # Return the content + + if body.get("format") == "json": + json_content = await self._provider.get_registration_json( + user=piccolo_user + ) + return JSONResponse(content=json_content) + else: + html_content = await self._provider.get_registration_html( + user=piccolo_user + ) + return HTMLResponse(content=html_content) + elif action == "revoke": + if password := body.get("password"): + if await self._auth_table.login( + username=piccolo_user.username, password=password + ): + await self._provider.delete_registration( + user=piccolo_user + ) + + template = environment.get_template( + "mfa_disabled.html", + ) + + return HTMLResponse( + content=template.render( + styles=self._styles, + ) + ) + + return HTMLResponse(content="

Error

") + + +def mfa_setup( + provider: MFAProvider, + auth_table: t.Type[BaseUser] = BaseUser, + styles: t.Optional[Styles] = None, +) -> t.Type[HTTPEndpoint]: + """ + This endpoint needs to be protected ``SessionAuthMiddleware``, ensuring + that only logged in users can access it. + + We also recommend protecting it with ``RateLimitingMiddleware``, because: + + * Some of the forms accept a password, and we want to protect against brute + forcing. + * Generating secrets and refresh tokens is somewhat expensive, so we want + to protect against abuse. + + Users can setup and manage their MFA setup using this endpoint. + + """ + + class _MFARegisterEndpoint(MFASetupEndpoint): + _auth_table = auth_table + _provider = provider + _styles = styles or Styles() + + return _MFARegisterEndpoint diff --git a/piccolo_api/mfa/provider.py b/piccolo_api/mfa/provider.py new file mode 100644 index 00000000..8551789e --- /dev/null +++ b/piccolo_api/mfa/provider.py @@ -0,0 +1,71 @@ +from abc import ABCMeta, abstractmethod + +from piccolo.apps.user.tables import BaseUser + + +class MFAProvider(metaclass=ABCMeta): + + def __init__(self, name: str = "MFA Code"): + """ + This is the base class which all providers must inherit from. Use it + to build your own custom providers. If you use it directly, it won't + do anything. See :class:`AuthenticatorProvider ` + for a concrete implementation. + + :param token_name: + Each provider should specify a unique ``token_name``, so + when a token is passed to the login endpoint, we know which + ``MFAProvider`` it belongs to. + + """ # noqa: E501 + self.name = name + + @abstractmethod + async def authenticate_user(self, user: BaseUser, code: str) -> bool: + """ + Should return ``True`` if the code is correct for the user. + + The code could be a TOTP code, or a recovery code. + + """ + + @abstractmethod + async def is_user_enrolled(self, user: BaseUser) -> bool: + """ + Should return ``True`` if the user is enrolled in this MFA, and hence + should submit a code. + """ + + @abstractmethod + async def send_code(self, user: BaseUser) -> bool: + """ + If the provider needs to send a code (e.g. if using email or SMS), then + implement it here. + + Return ``True`` if a code was sent, and ``False`` if not (e.g. an app + based TOTP codes). + + """ + + ########################################################################### + # Registration + + @abstractmethod + async def get_registration_html(self, user: BaseUser) -> str: + """ + When a user wants to register for MFA, this HTML is shown containing + instructions. + """ + + @abstractmethod + async def get_registration_json(self, user: BaseUser) -> dict: + """ + When a user wants to register for MFA, the client can request a JSON + response, rather than HTML, if they want to render the UI themselves. + """ + + @abstractmethod + async def delete_registration(self, user: BaseUser): + """ + Used to remove the MFA. + """ diff --git a/piccolo_api/mfa/recovery_codes.py b/piccolo_api/mfa/recovery_codes.py new file mode 100644 index 00000000..2d456b57 --- /dev/null +++ b/piccolo_api/mfa/recovery_codes.py @@ -0,0 +1,52 @@ +import math +import secrets +import string +import typing as t + +DEFAULT_CHARACTERS = string.ascii_lowercase + string.digits + + +def _get_random_string(length: int, characters: t.Sequence[str]) -> str: + """ + :param length: + How long to make the string. + :param characters: + Which characters to randomly pick from. + + """ + return "".join(secrets.choice(characters) for _ in range(length)) + + +def generate_recovery_code( + length: int = 12, + characters: t.Sequence[str] = DEFAULT_CHARACTERS, + separator: str = "-", +): + """ + :param length: + How long the recovery code should be, excluding the separator. Must + be at least 10 (it's unusual for a recovery code to be shorter than + this). + :param characters: + Which characters to randomly pick from. Recovery codes tend to be + case insensitive, and just use a-z and 0-9 (presumably to make them + less error prone for users). + :param separator: + The recovery code will have this character in the middle, making it + easier for users to read (e.g. ``abc123-xyz789``). Specify an empty + string if you want to disable this behaviour. + + """ + if length < 10: + raise ValueError("The length must be at least 10.") + + random_string = _get_random_string(length=length, characters=characters) + + if separator: + split_at = math.ceil(length / 2) + + return separator.join( + [random_string[:split_at], random_string[split_at:]] + ) + + return random_string diff --git a/piccolo_api/session_auth/endpoints.py b/piccolo_api/session_auth/endpoints.py index 4b52f381..a1580231 100644 --- a/piccolo_api/session_auth/endpoints.py +++ b/piccolo_api/session_auth/endpoints.py @@ -19,6 +19,7 @@ ) from starlette.status import HTTP_303_SEE_OTHER +from piccolo_api.mfa.provider import MFAProvider from piccolo_api.session_auth.tables import SessionsBase from piccolo_api.shared.auth.hooks import LoginHooks from piccolo_api.shared.auth.styles import Styles @@ -168,6 +169,11 @@ def _captcha(self) -> t.Optional[Captcha]: def _styles(self) -> t.Optional[Styles]: raise NotImplementedError + @property + @abstractmethod + def _mfa_providers(self) -> t.Optional[t.Sequence[MFAProvider]]: + raise NotImplementedError + def _render_template( self, request: Request, @@ -219,8 +225,8 @@ async def post(self, request: Request) -> Response: except JSONDecodeError: body = await request.form() - username = body.get("username", None) - password = body.get("password", None) + username = body.get("username") + password = body.get("password") return_html = body.get("format") == "html" if (not username) or (not password): @@ -264,6 +270,108 @@ async def post(self, request: Request) -> Response: ) if user_id: + # Apply MFA + if mfa_providers := self._mfa_providers: + user = ( + await self._auth_table.objects() + .where(self._auth_table.id == user_id) + .first() + ) + + assert user is not None + + if enrolled_mfa_providers := [ + mfa_provider + for mfa_provider in mfa_providers + if await mfa_provider.is_user_enrolled(user=user) + ]: + mfa_code = body.get("mfa_code") + + if mfa_code is None: + has_sent_code: t.List[bool] = [] + for mfa_provider in enrolled_mfa_providers: + # Send the code (only used with things like email + # and SMS MFA). + has_sent_code.append( + await mfa_provider.send_code(user=user) + ) + + message = "MFA code required" + if any(has_sent_code): + message += " (we sent you a code)" + + if return_html: + return self._render_template( + request, + template_context={ + "error": message, + "show_mfa_input": True, + "mfa_provider_names": [ + mfa_provider.name + for mfa_provider in enrolled_mfa_providers # noqa: E501 + ], + }, + ) + else: + raise HTTPException( + status_code=401, detail=message + ) + + # Work out which MFA provider to use: + if len(enrolled_mfa_providers) == 1: + active_mfa_provider = enrolled_mfa_providers[0] + else: + mfa_provider_name = body.get("mfa_provider_name") + + if mfa_provider_name is None: + raise HTTPException( + status_code=401, + detail="MFA provider must be specified", + ) + + filtered_mfa_providers = [ + i + for i in enrolled_mfa_providers + if i.name == mfa_provider_name + ] + + if len(filtered_mfa_providers) == 0: + raise HTTPException( + status_code=401, + detail="MFA provider not recognised.", + ) + + if len(filtered_mfa_providers) > 1: + raise HTTPException( + status_code=401, + detail=( + "Multiple matching MFA providers found." + ), + ) + + active_mfa_provider = filtered_mfa_providers[0] + + if not await active_mfa_provider.authenticate_user( + user=user, code=mfa_code + ): + if return_html: + return self._render_template( + request, + template_context={ + "error": "MFA failed", + "show_mfa_input": True, + "mfa_provider_names": { + mfa_provider.name + for mfa_provider in enrolled_mfa_providers # noqa: E501 + }, + }, + ) + else: + raise HTTPException( + status_code=401, + detail="MFA failed", + ) + # Run login_success hooks if self._hooks and self._hooks.login_success: hooks_response = await self._hooks.run_login_success( @@ -349,6 +457,7 @@ def session_login( hooks: t.Optional[LoginHooks] = None, captcha: t.Optional[Captcha] = None, styles: t.Optional[Styles] = None, + mfa_providers: t.Optional[t.Sequence[MFAProvider]] = None, ) -> t.Type[SessionLoginEndpoint]: """ An endpoint for creating a user session. @@ -388,6 +497,9 @@ def session_login( See :class:`Captcha `. :param styles: Modify the appearance of the HTML template using CSS. + :param mfa_providers: + Add additional security to the login process using Multi-Factor + Authentication. """ # noqa: E501 template_path = ( @@ -412,6 +524,7 @@ class _SessionLoginEndpoint(SessionLoginEndpoint): _hooks = hooks _captcha = captcha _styles = styles or Styles() + _mfa_providers = mfa_providers return _SessionLoginEndpoint diff --git a/piccolo_api/shared/auth/styles.py b/piccolo_api/shared/auth/styles.py index 6a6b5be5..b5d082b1 100644 --- a/piccolo_api/shared/auth/styles.py +++ b/piccolo_api/shared/auth/styles.py @@ -19,4 +19,5 @@ class Styles: error_text_color: str = "red" button_color: str = "#419EF8" button_text_color: str = "white" + link_color: str = "#419EF8" border_color: str = "rgba(0, 0, 0, 0.2)" diff --git a/piccolo_api/templates/base.html b/piccolo_api/templates/base.html index b9c34209..3a164926 100644 --- a/piccolo_api/templates/base.html +++ b/piccolo_api/templates/base.html @@ -18,6 +18,7 @@ --error_text_color: {{ styles.error_text_color }}; --button_color: {{ styles.button_color }}; --button_text_color: {{ styles.button_text_color }}; + --link_color: {{ styles.link_color }}; --border_color: {{ styles.border_color }}; } @@ -51,6 +52,11 @@ font-size: 1.8rem; } + a { + color: var(--link_color); + text-decoration: none; + } + p.error { font-size: 0.9rem; color: var(--error_text_color); @@ -65,12 +71,13 @@ label { font-size: 0.85rem; + margin: 0.5rem 0; } input, - label { + label, + select { display: block; - margin: 0.5rem 0; width: 100%; } @@ -79,12 +86,17 @@ border-radius: 0.2rem; } - input { + input, select { border: 1px solid var(--border_color); padding: 0.5rem; margin: 0.5rem 0 0.8rem; } + textarea { + width: 100%; + max-width: 100%; + } + button { background-color: var(--button_color); border: none; @@ -101,6 +113,10 @@ div.captcha { margin-bottom: 0.5rem; } + + div.qr_code { + text-align: center; + } diff --git a/piccolo_api/templates/mfa_authenticator_setup.html b/piccolo_api/templates/mfa_authenticator_setup.html new file mode 100644 index 00000000..714859f7 --- /dev/null +++ b/piccolo_api/templates/mfa_authenticator_setup.html @@ -0,0 +1,21 @@ +{% extends "base.html" %} + +{% block title %}MFA Authenticator Setup{% endblock %} + +{% block content %} +

Authenticator Setup

+ +

Use an authenticator app like Google Authenticator, available on iOS and Android, to scan this QR code:

+ +
+ +
+ +

Copy these recovery codes and keep them safe:

+ +
    + {% for recovery_code in recovery_codes %} +
  • {{ recovery_code }}
  • + {% endfor %} +
+{% endblock %} diff --git a/piccolo_api/templates/mfa_cancel.html b/piccolo_api/templates/mfa_cancel.html new file mode 100644 index 00000000..41abc648 --- /dev/null +++ b/piccolo_api/templates/mfa_cancel.html @@ -0,0 +1,18 @@ +{% extends "base.html" %} + +{% block title %}MFA Authenticator Cancel{% endblock %} + +{% block content %} +

MFA Setup

+ +

You are already enrolled.

+ +
+ + + + + + +
+{% endblock %} diff --git a/piccolo_api/templates/mfa_disabled.html b/piccolo_api/templates/mfa_disabled.html new file mode 100644 index 00000000..954c5f68 --- /dev/null +++ b/piccolo_api/templates/mfa_disabled.html @@ -0,0 +1,11 @@ +{% extends "base.html" %} + +{% block title %}MFA Disabled{% endblock %} + +{% block content %} +

MFA Disabled

+ +

You no longer require MFA to login - consider re-enabling it when you can.

+ +

Re-enable

+{% endblock %} diff --git a/piccolo_api/templates/mfa_setup.html b/piccolo_api/templates/mfa_setup.html new file mode 100644 index 00000000..436ba40b --- /dev/null +++ b/piccolo_api/templates/mfa_setup.html @@ -0,0 +1,22 @@ +{% extends "base.html" %} + +{% block title %}MFA Setup{% endblock %} + +{% block content %} +

MFA Setup

+ + {% if error %} +

{{ error }}

+ {% endif %} + +
+ + + +

Please enter your password to enable MFA:

+ + + + +
+{% endblock %} diff --git a/piccolo_api/templates/session_login.html b/piccolo_api/templates/session_login.html index 524d247d..f808479e 100644 --- a/piccolo_api/templates/session_login.html +++ b/piccolo_api/templates/session_login.html @@ -15,6 +15,20 @@

Login

+ {% if show_mfa_input %} + + + {% if mfa_provider_names|length > 1 %} + + {% endif %} + + + {% endif %} + {% if csrftoken and csrf_cookie_name %} {% endif %} diff --git a/pyproject.toml b/pyproject.toml index d4812e3b..4b33f556 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ module = [ "moto", "botocore", "botocore.config", - "httpx" + "httpx", + "qrcode" ] ignore_missing_imports = true diff --git a/requirements/e2e-requirements.txt b/requirements/e2e-requirements.txt new file mode 100644 index 00000000..ed853472 --- /dev/null +++ b/requirements/e2e-requirements.txt @@ -0,0 +1,3 @@ +pytest==8.0.1 +playwright==1.41.2 +pytest-playwright==0.4.4 diff --git a/requirements/extras/authenticator.txt b/requirements/extras/authenticator.txt new file mode 100644 index 00000000..55bd1192 --- /dev/null +++ b/requirements/extras/authenticator.txt @@ -0,0 +1,2 @@ +pyotp==2.9.0 +qrcode==7.4.2 diff --git a/requirements/extras/cryptography.txt b/requirements/extras/cryptography.txt new file mode 100644 index 00000000..ce12e287 --- /dev/null +++ b/requirements/extras/cryptography.txt @@ -0,0 +1 @@ +cryptography==43.0.0 diff --git a/requirements/extras/pynacl.txt b/requirements/extras/pynacl.txt new file mode 100644 index 00000000..ae848ece --- /dev/null +++ b/requirements/extras/pynacl.txt @@ -0,0 +1 @@ +PyNaCl==1.5.0 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c18d9c93..af69a186 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,5 @@ Jinja2>=2.11.0 -piccolo[postgres]>=1.5 +piccolo[postgres]>=1.16.0 pydantic[email]>=2.0 python-multipart>=0.0.5 fastapi>=0.100.0 diff --git a/scripts/run-e2e-test.sh b/scripts/run-e2e-test.sh new file mode 100755 index 00000000..fcaa7341 --- /dev/null +++ b/scripts/run-e2e-test.sh @@ -0,0 +1,6 @@ +#!/bin/bash +# Run end-to-end tests + +extraArgs=$@ + +pytest --ignore=tests -s $extraArgs diff --git a/scripts/test-postgres.sh b/scripts/test-postgres.sh index 656be4dd..2cbada1a 100755 --- a/scripts/test-postgres.sh +++ b/scripts/test-postgres.sh @@ -7,4 +7,4 @@ export PYTHONPATH="$PWD:$PYTHONPATH" export PICCOLO_CONF="tests.postgres_conf" -python -m pytest --cov=piccolo_api --cov-report xml --cov-report html --cov-fail-under 85 -s $@ +python -m pytest --ignore=e2e --cov=piccolo_api --cov-report xml --cov-report html --cov-fail-under 85 -s $@ diff --git a/scripts/test-sqlite.sh b/scripts/test-sqlite.sh index 2388d3da..a1bba931 100755 --- a/scripts/test-sqlite.sh +++ b/scripts/test-sqlite.sh @@ -7,4 +7,4 @@ export PYTHONPATH="$PWD:$PYTHONPATH" export PICCOLO_CONF="tests.sqlite_conf" -python -m pytest --cov=piccolo_api --cov-report xml --cov-report html --cov-fail-under 85 -s $@ +python -m pytest --ignore=e2e --cov=piccolo_api --cov-report xml --cov-report html --cov-fail-under 85 -s $@ diff --git a/tests/mfa/authenticator/test_tables.py b/tests/mfa/authenticator/test_tables.py new file mode 100644 index 00000000..657d451a --- /dev/null +++ b/tests/mfa/authenticator/test_tables.py @@ -0,0 +1,200 @@ +import datetime +from unittest import TestCase +from unittest.mock import MagicMock, patch + +import pyotp +from piccolo.apps.user.tables import BaseUser +from piccolo.testing.test_case import AsyncTableTest + +from example_projects.mfa_demo.app import EXAMPLE_DB_ENCRYPTION_KEY +from piccolo_api.encryption.providers import XChaCha20Provider +from piccolo_api.mfa.authenticator.tables import AuthenticatorSecret + + +class TestGenerateSecret(TestCase): + + def test_generate_secret(self): + """ + Make sure secrets are generated correctly. + """ + secret_1 = AuthenticatorSecret.generate_secret() + secret_2 = AuthenticatorSecret.generate_secret() + + self.assertIsInstance(secret_1, str) + self.assertNotEqual(secret_1, secret_2) + self.assertEqual(len(secret_1), 32) + + +class TestAuthenticate(AsyncTableTest): + + tables = [AuthenticatorSecret, BaseUser] + + @patch("piccolo_api.mfa.authenticator.tables.logger") + async def test_replay_attack(self, logger: MagicMock): + """ + If a token which was just used successfully is reused, it should be + rejected, because it might be a replay attack. + """ + user = await BaseUser.create_user( + username="test", password="test123456" + ) + + code = "123456" + + secret, _ = await AuthenticatorSecret.create_new( + user_id=user.id, + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ), + ) + secret.last_used_code = code + await secret.save() + + auth_response = await AuthenticatorSecret.authenticate( + user_id=user.id, + code=code, + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ), + ) + assert auth_response is False + + logger.warning.assert_called_with( + "User 1 reused a token - potential replay attack." + ) + + async def test_code(self): + """ + Make sure a valid code can be used to authenticate. + """ + user = await BaseUser.create_user( + username="test", password="test123456" + ) + + encryption_provider = XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ) + + authenticator_secret, _ = await AuthenticatorSecret.create_new( + user_id=user.id, + encryption_provider=encryption_provider, + ) + + secret = encryption_provider.decrypt(authenticator_secret.secret) + + # Make sure a valid code works + auth_response = await AuthenticatorSecret.authenticate( + user_id=user.id, + code=pyotp.TOTP(s=secret).now(), + encryption_provider=encryption_provider, + ) + assert auth_response is True + + # Make sure an invalid code fails + auth_response = await AuthenticatorSecret.authenticate( + user_id=user.id, + code="ABC123", + encryption_provider=encryption_provider, + ) + assert auth_response is False + + async def test_recovery_code(self): + """ + Make sure a valid recovery code can be used to authenticate. + """ + user = await BaseUser.create_user( + username="test", password="test123456" + ) + + encryption_provider = XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ) + + _, recovery_codes = await AuthenticatorSecret.create_new( + user_id=user.id, + encryption_provider=encryption_provider, + ) + + # Make sure a valid recovery code works + auth_response = await AuthenticatorSecret.authenticate( + user_id=user.id, + code=recovery_codes[0], + encryption_provider=encryption_provider, + ) + assert auth_response is True + + # Make sure an invalid recovery code fails + fake_code = "".join("a" for _ in range(len(recovery_codes[0]))) + auth_response = await AuthenticatorSecret.authenticate( + user_id=user.id, + code=fake_code, + encryption_provider=encryption_provider, + ) + assert auth_response is False + + async def test_unenrolled_user(self): + """ + Make sure a user who isn't enrolled fails authentication. + """ + user = await BaseUser.create_user( + username="test", password="test123456" + ) + + auth_response = await AuthenticatorSecret.authenticate( + user_id=user.id, + code="abc123", + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ), + ) + assert auth_response is False + + +class TestCreateNew(AsyncTableTest): + + tables = [AuthenticatorSecret, BaseUser] + + async def test_create_new(self): + user = await BaseUser.create_user( + username="test", password="test123456" + ) + + secret, _ = await AuthenticatorSecret.create_new( + user_id=user.id, + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ), + ) + + self.assertEqual(secret.id, user.id) + self.assertIsNotNone(secret.secret) + self.assertIsInstance(secret.created_at, datetime.datetime) + self.assertIsNone(secret.last_used_at) + self.assertIsNone(secret.revoked_at) + self.assertIsNone(secret.last_used_code) + + +class TestRevoke(AsyncTableTest): + """ + Make sure we can revoke a user's MFA code. + """ + + tables = [AuthenticatorSecret, BaseUser] + + async def test_revoke(self): + user = await BaseUser.create_user( + username="test", password="test123456" + ) + + secret, _ = await AuthenticatorSecret.create_new( + user_id=user.id, + encryption_provider=XChaCha20Provider( + encryption_key=EXAMPLE_DB_ENCRYPTION_KEY + ), + ) + + await AuthenticatorSecret.revoke(user_id=user.id) + + await secret.refresh() + + assert secret.revoked_at is not None diff --git a/tests/mfa/test_mfa_endpoints.py b/tests/mfa/test_mfa_endpoints.py new file mode 100644 index 00000000..7db234ad --- /dev/null +++ b/tests/mfa/test_mfa_endpoints.py @@ -0,0 +1,67 @@ +from piccolo.apps.user.tables import BaseUser +from piccolo.testing.test_case import AsyncTableTest +from starlette.testclient import TestClient + +from example_projects.mfa_demo.app import app +from piccolo_api.mfa.authenticator.tables import AuthenticatorSecret +from piccolo_api.session_auth.tables import SessionsBase + + +class TestMFARegisterEndpoint(AsyncTableTest): + + tables = [AuthenticatorSecret, BaseUser, SessionsBase] + username = "alice" + password = "test123" + + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + + self.user = await BaseUser.create_user( + username=self.username, password=self.password, active=True + ) + + async def test_register(self): + client = TestClient(app=app) + + # Get a CSRF cookie + response = client.get("/login/") + csrf_token = response.cookies["csrftoken"] + self.assertEqual(response.status_code, 200) + + # Login + response = client.post( + "/login/", + json={"username": self.username, "password": self.password}, + headers={"X-CSRFToken": csrf_token}, + ) + self.assertEqual(response.status_code, 200) + self.assertIn("id", client.cookies) + + # Register for MFA - JSON + response = client.post( + "/private/mfa-setup/", + json={ + "action": "register", + "format": "json", + "password": self.password, + }, + headers={"X-CSRFToken": csrf_token}, + ) + self.assertEqual(response.status_code, 200) + + data = response.json() + self.assertIn("qrcode_image", data) + self.assertIn("recovery_codes", data) + + # Register for MFA - HTML + await AuthenticatorSecret.delete().where( + AuthenticatorSecret.user_id == self.user.id + ) + response = client.post( + "/private/mfa-setup/", + data={"action": "register", "password": self.password}, + headers={"X-CSRFToken": csrf_token}, + ) + self.assertEqual(response.status_code, 200) + html = response.content + self.assertIn(b"Authenticator Setup", html) diff --git a/tests/mfa/test_recovery_codes.py b/tests/mfa/test_recovery_codes.py new file mode 100644 index 00000000..e560955c --- /dev/null +++ b/tests/mfa/test_recovery_codes.py @@ -0,0 +1,25 @@ +from unittest import TestCase + +from piccolo_api.mfa.recovery_codes import generate_recovery_code + + +class TestGenerateRecoveryCode(TestCase): + + def test_randomness(self): + self.assertNotEqual(generate_recovery_code(), generate_recovery_code()) + + def test_response_format(self): + self.assertEqual( + generate_recovery_code(length=10, characters=["a"]), + "aaaaa-aaaaa", + ) + + def test_no_separator(self): + self.assertEqual( + generate_recovery_code(length=10, characters=["a"], separator=""), + "aaaaaaaaaa", + ) + + def test_length(self): + with self.assertRaises(ValueError): + generate_recovery_code(length=6),