diff --git a/fastapi_opa/auth/auth_saml.py b/fastapi_opa/auth/auth_saml.py index 18fb0c9..b18a0d8 100644 --- a/fastapi_opa/auth/auth_saml.py +++ b/fastapi_opa/auth/auth_saml.py @@ -1,3 +1,4 @@ +import logging from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -6,28 +7,20 @@ from onelogin.saml2.auth import OneLogin_Saml2_Auth from onelogin.saml2.utils import OneLogin_Saml2_Utils -from pydantic.main import BaseModel from starlette.requests import Request from starlette.responses import RedirectResponse from fastapi_opa.auth.auth_interface import AuthInterface from fastapi_opa.auth.exceptions import SAMLException +logger = logging.getLogger(__name__) + @dataclass class SAMLConfig: settings_directory: str -class Userdata(BaseModel): - samlUserdata: Dict - samlNameId: str - samlNameIdFormat: str - samlNameIdNameQualifier: str - samlNameIdSPNameQualifier: str - samlSessionIndex: str - - class SAMLAuthentication(AuthInterface): def __init__(self, config: SAMLConfig): self.config = config @@ -40,11 +33,11 @@ async def authenticate( auth = await self.init_saml_auth(request_args) if "acs" in request.query_params: - print(datetime.utcnow(), '--acs--') + logger.debug(datetime.utcnow(), '--acs--') return await self.assertion_consumer_service(auth, request_args) # potentially extend with logout here elif 'sso' in request.query_params: - print(datetime.utcnow(), '--sso--') + logger.debug(datetime.utcnow(), '--sso--') return await self.single_sign_on(auth) # TODO: check below code # If AuthNRequest ID need to be stored in order to later validate it, do instead @@ -52,15 +45,15 @@ async def authenticate( # request.session['AuthNRequestID'] = auth.get_last_request_id() # return redirect(sso_built_url) elif 'sso2' in request.query_params: - print(datetime.utcnow(), '--sso2--') + logger.debug(datetime.utcnow(), '--sso2--') return_to = '%sattrs/' % request.base_url return RedirectResponse(auth.login(return_to)) elif 'slo' in request.query_params: - print(datetime.utcnow(), '--slo--') + logger.debug(datetime.utcnow(), '--slo--') return await self.single_log_out(auth) # TODO: handle sls # elif 'sls' in request.query_params: - # print(datetime.utcnow(), '--sls--') + # logger.debug(datetime.utcnow(), '--sls--') # request_id = None # if 'LogoutRequestID' in request.query_params['post_data']: # request_id = req_args['post_data']['LogoutRequestID'] @@ -84,17 +77,12 @@ async def init_saml_auth(self, request_args: Dict) -> OneLogin_Saml2_Auth: @staticmethod async def single_log_out(auth: OneLogin_Saml2_Auth) -> RedirectResponse: - name_id = session_index = name_id_format = name_id_nq = name_id_spnq = None - if auth.get_nameid(): - name_id = auth.get_nameid() - if auth.get_session_index(): - session_index = auth.get_session_index() - if auth.get_nameid_format(): - name_id_format = auth.get_nameid_format() - if auth.get_nameid_spnq(): - name_id_spnq = auth.get_nameid_spnq() - if auth.get_nameid_nq(): - name_id_nq = auth.get_nameid_nq() + + name_id = auth.get_nameid() or None + session_index = auth.get_session_index() or None + name_id_format = auth.get_nameid_format() or None + name_id_spnq = auth.get_nameid_spnq() or None + name_id_nq = auth.get_nameid_nq() or None return RedirectResponse( auth.logout(name_id=name_id, session_index=session_index, nq=name_id_nq, name_id_format=name_id_format, spnq=name_id_spnq)) @@ -107,7 +95,7 @@ async def single_sign_on(auth: OneLogin_Saml2_Auth) -> RedirectResponse: @staticmethod async def assertion_consumer_service( auth: OneLogin_Saml2_Auth, request_args: Dict - ) -> Union[RedirectResponse, Userdata]: + ) -> Union[RedirectResponse, Dict]: auth.process_response() errors = auth.get_errors() if not len(errors) == 0: diff --git a/tests/test_oidc_auth.py b/tests/test_oidc_auth.py index 8fde35c..1165628 100644 --- a/tests/test_oidc_auth.py +++ b/tests/test_oidc_auth.py @@ -141,7 +141,7 @@ def construct_jwt( headers: Optional[Dict] = None, ): iat_timestamp = datetime.datetime.utcnow().timestamp() - delta_day = 1000000 # Unit: day + delta_days = 1000000 # This or patch jwt.decode if not msg: msg = { @@ -150,7 +150,7 @@ def construct_jwt( "jti": "68f7cf57-110d-4cbf-9f29-0f5ad4c90328", "sub": "test-sub", "iat": int(iat_timestamp), - "exp": int(iat_timestamp + 3600 * 24 * delta_day), + "exp": int(iat_timestamp + 3600 * 24 * delta_days), } if algorithm == "HS256": return jwt.encode(msg, "secret", algorithm=algorithm), msg