diff --git a/fastapi_opa/auth/auth_saml.py b/fastapi_opa/auth/auth_saml.py index 3ff4ce8..c819dcd 100644 --- a/fastapi_opa/auth/auth_saml.py +++ b/fastapi_opa/auth/auth_saml.py @@ -1,16 +1,21 @@ +import json +import logging from dataclasses import dataclass from pathlib import Path from typing import Dict from typing import Union from onelogin.saml2.auth import OneLogin_Saml2_Auth +from onelogin.saml2.settings import OneLogin_Saml2_Settings from onelogin.saml2.utils import OneLogin_Saml2_Utils from starlette.requests import Request -from starlette.responses import RedirectResponse +from starlette.responses import RedirectResponse, Response from fastapi_opa.auth.auth_interface import AuthInterface from fastapi_opa.auth.exceptions import SAMLException +logger = logging.getLogger(__name__) + @dataclass class SAMLConfig: @@ -28,9 +33,27 @@ async def authenticate( request_args = await self.prepare_request(request) auth = await self.init_saml_auth(request_args) - if "acs" in request.query_params: - return await self.assertion_consumer_service(auth, request_args) - # potentially extend with logout here + if 'sso' in request.query_params: + logger.debug('--sso--') + return await self.single_sign_on(auth) + + elif 'sso2' in request.query_params: + logger.debug('--sso2--') + return_to = '%sattrs/' % request.base_url + return await self.single_sign_on(auth, return_to) + + elif "acs" in request.query_params: + logger.debug('--acs--') + return await self.assertion_consumer_service(auth, request_args, request) + + elif 'slo' in request.query_params: + logger.debug('--slo--') + return await self.single_log_out(auth) + + elif 'sls' in request.query_params: + logger.debug('--sls--') + return await self.single_log_out_from_IdP(request) + return await self.single_sign_on(auth) async def init_saml_auth(self, request_args: Dict) -> OneLogin_Saml2_Auth: @@ -39,18 +62,47 @@ async def init_saml_auth(self, request_args: Dict) -> OneLogin_Saml2_Auth: ) @staticmethod - async def single_sign_on(auth: OneLogin_Saml2_Auth) -> RedirectResponse: - redirect_url = auth.login() + async def single_log_out_from_IdP(request: Request) -> \ + Union[RedirectResponse, Dict]: + req_args = await SAMLAuthentication.prepare_request(request) + if not req_args['get_data'].get('SAMLResponse') and request.query_params.get('SAMLResponse'): + req_args['get_data'] = {'SAMLResponse': request.query_params.get('SAMLResponse')} + auth = await SAMLAuthentication.init_saml_auth(req_args) + dscb = lambda: request.session.clear() + url = auth.process_slo(delete_session_cb=dscb) + errors = auth.get_errors() + if len(errors) == 0: + if url is not None: + return RedirectResponse(url) + else: + return {'success_slo': True} + else: + return {'error': auth.get_last_error_reason()} + + @staticmethod + async def single_log_out(auth: OneLogin_Saml2_Auth) -> RedirectResponse: + name_id = auth.get_nameid() + session_index = auth.get_session_index() + name_id_format = auth.get_nameid_format() + name_id_spnq = auth.get_nameid_spnq() + name_id_nq = auth.get_nameid_nq() + return RedirectResponse( + auth.logout(name_id=name_id, session_index=session_index, nq=name_id_nq, name_id_format=name_id_format, + spnq=name_id_spnq)) + + @staticmethod + async def single_sign_on(auth: OneLogin_Saml2_Auth, url: str = None) -> RedirectResponse: + redirect_url = auth.login(url) return RedirectResponse(redirect_url) @staticmethod async def assertion_consumer_service( - auth: OneLogin_Saml2_Auth, request_args: Dict + auth: OneLogin_Saml2_Auth, request_args: Dict, request: Request ) -> Union[RedirectResponse, Dict]: auth.process_response() errors = auth.get_errors() if not len(errors) == 0: - raise SAMLException() + raise SAMLException(auth.get_last_error_reason()) userdata = { "samlUserdata": auth.get_attributes(), "samlNameId": auth.get_nameid(), @@ -59,6 +111,7 @@ async def assertion_consumer_service( "samlNameIdSPNameQualifier": auth.get_nameid_spnq(), "samlSessionIndex": auth.get_session_index(), } + request.session['saml_session'] = json.dumps(userdata) self_url = OneLogin_Saml2_Utils.get_self_url(request_args) if "RelayState" in request_args.get("post_data") and self_url.rstrip( @@ -69,17 +122,30 @@ async def assertion_consumer_service( request_args.get("post_data", {}).get("RelayState") ) ) - else: - return userdata + + return userdata @staticmethod async def prepare_request(request: Request): + form_data = await request.form() return { "https": "on" if request.url.scheme == "https" else "off", "http_host": request.url.hostname, "server_port": request.url.port, "script_name": request.url.path, - "post_data": await request.form() + "post_data": form_data, # Uncomment if using ADFS - # "lowercase_urlencoding": True + # "lowercase_urlencoding": True, + 'get_data': form_data } + + async def get_metadata(self, request: Request): + saml_settings = OneLogin_Saml2_Settings(custom_base_path=self.custom_folder, + sp_validation_only=True) + metadata = saml_settings.get_sp_metadata() + errors = saml_settings.validate_metadata(metadata) + status_code = 200 + if len(errors) != 0: + metadata = ', '.join(errors) + status_code = 500 + return Response(content=metadata, media_type="application/xml", status_code=status_code) diff --git a/tests/test_oidc_auth.py b/tests/test_oidc_auth.py index 229f280..1165628 100644 --- a/tests/test_oidc_auth.py +++ b/tests/test_oidc_auth.py @@ -140,18 +140,17 @@ def construct_jwt( msg: Dict[str, Any] = None, headers: Optional[Dict] = None, ): - iat = datetime.datetime.utcnow() - exp = datetime.datetime.utcnow() + datetime.timedelta( - days=1000000 - ) # This or patch jwt.decode + iat_timestamp = datetime.datetime.utcnow().timestamp() + delta_days = 1000000 + # This or patch jwt.decode if not msg: msg = { "name": "John Doe", "aud": "example-client", "jti": "68f7cf57-110d-4cbf-9f29-0f5ad4c90328", "sub": "test-sub", - "iat": int(iat.timestamp()), - "exp": int(exp.timestamp()), + "iat": int(iat_timestamp), + "exp": int(iat_timestamp + 3600 * 24 * delta_days), } if algorithm == "HS256": return jwt.encode(msg, "secret", algorithm=algorithm), msg diff --git a/tests/test_saml_auth.py b/tests/test_saml_auth.py index 1b42ecf..61a0294 100644 --- a/tests/test_saml_auth.py +++ b/tests/test_saml_auth.py @@ -15,18 +15,40 @@ async def test_single_sign_on(): saml_auth_mock = Mock() saml_auth_mock.login.return_value = "http://idp.com/cryptic-stuff" - response = await saml_auth.single_sign_on(saml_auth_mock) + url = "http://idp.com/cryptic-stuff/attrs" + response = await saml_auth.single_sign_on(saml_auth_mock, url) assert isinstance(response, RedirectResponse) assert response.headers.get("location") == "http://idp.com/cryptic-stuff" +@pytest.mark.asyncio +async def test_single_sign_on_with_parameter(): + saml_conf = SAMLConfig(settings_directory="./tests/test_data/saml") + saml_auth = SAMLAuthentication(saml_conf) + + def side_effect(url): + return url + + saml_auth_mock = Mock() + saml_auth_mock.login = Mock(side_effect=side_effect) + attr_url = "http://idp.com/cryptic-stuff/attrs" + response = await saml_auth.single_sign_on(saml_auth_mock, attr_url) + + assert isinstance(response, RedirectResponse) + assert response.headers.get("location") == attr_url + + @pytest.mark.asyncio @patch("fastapi_opa.auth.auth_saml.OneLogin_Saml2_Utils") async def test_assertion_consumer_service(saml_util_mock): saml_util_mock.get_self_url.return_value = "http://sp.com" saml_conf = SAMLConfig(settings_directory="./tests/test_data/saml") saml_auth = SAMLAuthentication(saml_conf) + + request_mock = Mock() + request_mock.session.__setitem__ = Mock() + saml_auth_mock = Mock() saml_auth_mock.get_errors.return_value = [] saml_auth_mock.get_attributes.return_value = { @@ -48,7 +70,7 @@ async def test_assertion_consumer_service(saml_util_mock): saml_auth_mock.get_session_index.return_value = "8167416b-6a10-4a4c-889c-7574074e3fc5::f1eaf88b-2bb9-4d2e-8d3d-39587ba1ef37" # noqa response = await saml_auth.assertion_consumer_service( - saml_auth_mock, {"post_data": []} + saml_auth_mock, {"post_data": []}, request_mock ) expected = { "samlUserdata": { @@ -67,4 +89,48 @@ async def test_assertion_consumer_service(saml_util_mock): "samlNameIdSPNameQualifier": None, "samlSessionIndex": "8167416b-6a10-4a4c-889c-7574074e3fc5::f1eaf88b-2bb9-4d2e-8d3d-39587ba1ef37", # noqa } + + request_mock.session.__setitem__.assert_called_once() assert expected == response + + +@pytest.mark.asyncio +async def test_single_log_out(): + saml_conf = SAMLConfig(settings_directory="./tests/test_data/saml") + saml_auth = SAMLAuthentication(saml_conf) + + saml_auth_mock = Mock() + saml_auth_mock.get_slo_url.return_value = "http://idp.com" + saml_auth_mock.get_self_url_no_query.return_value = "http://idp.com" + saml_auth_mock.get_nameid.return_value = "alice" + saml_auth_mock.get_nameid_format.return_value = ( + "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" + ) + saml_auth_mock.get_nameid_nq.return_value = None + saml_auth_mock.get_nameid_spnq.return_value = None + saml_auth_mock.get_session_index.return_value = "8167416b-6a10-4a4c-889c-7574074e3fc5::f1eaf88b-2bb9-4d2e-8d3d-39587ba1ef37" # noqa + + response = await saml_auth.single_log_out(saml_auth_mock) + assert isinstance(response, RedirectResponse) + assert response.status_code == 307 + + +async def async_return(result): + import asyncio + f = asyncio.Future() + f.set_result(result) + return f + + +@pytest.mark.asyncio +@patch("fastapi_opa.auth.auth_saml.OneLogin_Saml2_Settings") +async def test_get_index_metadata(saml_settings_mock): + saml_conf = SAMLConfig(settings_directory="./tests/test_data/saml") + saml_auth = SAMLAuthentication(saml_conf) + saml_settings_mock.get_sp_metadata.return_value = '' + saml_settings_mock.validate_metadata.return_value = [] + + request_mock = Mock() + response = await saml_auth.get_metadata(request_mock) + assert response.status_code == 200 + assert response.headers['content-type'] == 'application/xml'