Skip to content

Commit

Permalink
Changing print to logger and some format
Browse files Browse the repository at this point in the history
  • Loading branch information
Tracy.Wu committed Jul 4, 2021
1 parent 2a803d6 commit 876c481
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 29 deletions.
42 changes: 15 additions & 27 deletions fastapi_opa/auth/auth_saml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
Expand All @@ -6,28 +7,20 @@

from onelogin.saml2.auth import OneLogin_Saml2_Auth
from onelogin.saml2.utils import OneLogin_Saml2_Utils
from pydantic.main import BaseModel
from starlette.requests import Request
from starlette.responses import RedirectResponse

from fastapi_opa.auth.auth_interface import AuthInterface
from fastapi_opa.auth.exceptions import SAMLException

logger = logging.getLogger(__name__)


@dataclass
class SAMLConfig:
settings_directory: str


class Userdata(BaseModel):
samlUserdata: Dict
samlNameId: str
samlNameIdFormat: str
samlNameIdNameQualifier: str
samlNameIdSPNameQualifier: str
samlSessionIndex: str


class SAMLAuthentication(AuthInterface):
def __init__(self, config: SAMLConfig):
self.config = config
Expand All @@ -40,27 +33,27 @@ async def authenticate(
auth = await self.init_saml_auth(request_args)

if "acs" in request.query_params:
print(datetime.utcnow(), '--acs--')
logger.debug(datetime.utcnow(), '--acs--')
return await self.assertion_consumer_service(auth, request_args)
# potentially extend with logout here
elif 'sso' in request.query_params:
print(datetime.utcnow(), '--sso--')
logger.debug(datetime.utcnow(), '--sso--')
return await self.single_sign_on(auth)
# TODO: check below code
# If AuthNRequest ID need to be stored in order to later validate it, do instead
# sso_built_url = auth.login()
# request.session['AuthNRequestID'] = auth.get_last_request_id()
# return redirect(sso_built_url)
elif 'sso2' in request.query_params:
print(datetime.utcnow(), '--sso2--')
logger.debug(datetime.utcnow(), '--sso2--')
return_to = '%sattrs/' % request.base_url
return RedirectResponse(auth.login(return_to))
elif 'slo' in request.query_params:
print(datetime.utcnow(), '--slo--')
logger.debug(datetime.utcnow(), '--slo--')
return await self.single_log_out(auth)
# TODO: handle sls
# elif 'sls' in request.query_params:
# print(datetime.utcnow(), '--sls--')
# logger.debug(datetime.utcnow(), '--sls--')
# request_id = None
# if 'LogoutRequestID' in request.query_params['post_data']:
# request_id = req_args['post_data']['LogoutRequestID']
Expand All @@ -84,17 +77,12 @@ async def init_saml_auth(self, request_args: Dict) -> OneLogin_Saml2_Auth:

@staticmethod
async def single_log_out(auth: OneLogin_Saml2_Auth) -> RedirectResponse:
name_id = session_index = name_id_format = name_id_nq = name_id_spnq = None
if auth.get_nameid():
name_id = auth.get_nameid()
if auth.get_session_index():
session_index = auth.get_session_index()
if auth.get_nameid_format():
name_id_format = auth.get_nameid_format()
if auth.get_nameid_spnq():
name_id_spnq = auth.get_nameid_spnq()
if auth.get_nameid_nq():
name_id_nq = auth.get_nameid_nq()

name_id = auth.get_nameid() or None
session_index = auth.get_session_index() or None
name_id_format = auth.get_nameid_format() or None
name_id_spnq = auth.get_nameid_spnq() or None
name_id_nq = auth.get_nameid_nq() or None
return RedirectResponse(
auth.logout(name_id=name_id, session_index=session_index, nq=name_id_nq, name_id_format=name_id_format,
spnq=name_id_spnq))
Expand All @@ -107,7 +95,7 @@ async def single_sign_on(auth: OneLogin_Saml2_Auth) -> RedirectResponse:
@staticmethod
async def assertion_consumer_service(
auth: OneLogin_Saml2_Auth, request_args: Dict
) -> Union[RedirectResponse, Userdata]:
) -> Union[RedirectResponse, Dict]:
auth.process_response()
errors = auth.get_errors()
if not len(errors) == 0:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_oidc_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def construct_jwt(
headers: Optional[Dict] = None,
):
iat_timestamp = datetime.datetime.utcnow().timestamp()
delta_day = 1000000 # Unit: day
delta_days = 1000000
# This or patch jwt.decode
if not msg:
msg = {
Expand All @@ -150,7 +150,7 @@ def construct_jwt(
"jti": "68f7cf57-110d-4cbf-9f29-0f5ad4c90328",
"sub": "test-sub",
"iat": int(iat_timestamp),
"exp": int(iat_timestamp + 3600 * 24 * delta_day),
"exp": int(iat_timestamp + 3600 * 24 * delta_days),
}
if algorithm == "HS256":
return jwt.encode(msg, "secret", algorithm=algorithm), msg
Expand Down

0 comments on commit 876c481

Please sign in to comment.