Skip to content

Commit

Permalink
correct some bugs after unit testing with @FBrosset
Browse files Browse the repository at this point in the history
  • Loading branch information
Badatos committed Feb 4, 2025
1 parent 79bd664 commit 10fa81b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
25 changes: 13 additions & 12 deletions pod/authentication/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ def update_owner_params(user, params) -> None:

class OIDCBackend(OIDCAuthenticationBackend):
"""OIDC backend authentication."""

def create_user(self, claims):
"""Create user connectd by OIDC."""
user = self._initialize_user(claims)
self._assign_affiliations(user, claims)
user = super(OIDCBackend, self).create_user(claims)
self._initialize_user(user, claims)
user.is_staff = is_staff_affiliation(user.owner.affiliation)
user.owner.save()
user.save()
Expand All @@ -111,19 +112,17 @@ def update_user(self, user, claims):

return user

def _initialize_user(self, claims):
def _initialize_user(self, user, claims):
"""Initialize user object from OIDC claims."""
email = claims.get("email")
username = claims.get(OIDC_CLAIM_PREFERRED_USERNAME, "")
user = User.objects.create_user(username, email=email)
user.first_name = claims.get(OIDC_CLAIM_GIVEN_NAME, "")
user.last_name = claims.get(OIDC_CLAIM_FAMILY_NAME, "")
user.username = username # Evite de récupérer deux fois
user.username = claims.get(OIDC_CLAIM_PREFERRED_USERNAME, "")
self._assign_affiliations(user, claims)
return user

def _assign_affiliations(self, user, claims):
def _assign_affiliations(self, user, claims) -> None:
"""Assign affiliations and access groups to user."""
affiliations = claims.get(OIDC_CLAIM_AFFILIATION, []) or [OIDC_DEFAULT_AFFILIATION]
affiliations = claims.get(OIDC_CLAIM_AFFILIATION, [OIDC_DEFAULT_AFFILIATION])

for affiliation in affiliations:
self._add_access_group(user, affiliation)
Expand All @@ -132,9 +131,11 @@ def _assign_affiliations(self, user, claims):
for code_name in OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES:
self._safe_add_access_group(user, code_name)

user.owner.affiliation = affiliations[0] if affiliations else OIDC_DEFAULT_AFFILIATION
user.owner.affiliation = (
affiliations[0] if affiliations else OIDC_DEFAULT_AFFILIATION
)

def _add_access_group(self, user, affiliation):
def _add_access_group(self, user, affiliation) -> None:
"""Create or retrieve access group and assign to user."""
accessgroup, created = AccessGroup.objects.get_or_create(code_name=affiliation)
if created:
Expand All @@ -144,7 +145,7 @@ def _add_access_group(self, user, affiliation):
accessgroup.save()
user.owner.accessgroup_set.add(accessgroup)

def _safe_add_access_group(self, user, code_name):
def _safe_add_access_group(self, user, code_name) -> None:
"""Safely add an access group if it exists."""
try:
user.owner.accessgroup_set.add(AccessGroup.objects.get(code_name=code_name))
Expand Down
27 changes: 18 additions & 9 deletions pod/authentication/tests/test_populated.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def test_make_profile(self) -> None:

print(" ---> test_make_profile of PopulatedShibTestCase: OK!")


@override_settings(
# ggignore-start
OIDC_RP_CLIENT_ID="MWViNTY2NzJjNGY4YTQ1MTAwMTNiYjk3",
Expand All @@ -549,6 +550,7 @@ def test_make_profile(self) -> None:
OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES=["specific", "unique"],
)
class PopulatedOIDCTestCase(TestCase):
@override_settings(OIDC_DEFAULT_AFFILIATION=UNSTAFFABLE_AFFILIATION)
def test_OIDC_commoner_with_default_unstaffable_affiliation(self) -> None:
backends.OIDC_DEFAULT_AFFILIATION = UNSTAFFABLE_AFFILIATION
user = OIDCBackend().create_user(
Expand Down Expand Up @@ -626,22 +628,29 @@ def test_OIDC_user_with_multiple_default_access_groups(self) -> None:
" ---> test_OIDC_user_with_multiple_default_access_groups"
" of PopulatedOIDCTestCase: OK!"
)
def test_OIDC_user_with_multiple_claim_access_groups(self):
backends.OIDC_CLAIM_AFFILIATION = ["accessgroup1", "accessgroup2"]

@override_settings(OIDC_CLAIM_AFFILIATION="affiliations")
def test_OIDC_user_with_multiple_claim_access_groups(self) -> None:
"""Test if user is added to access groups from OIDC affiliations."""

user = OIDCBackend().create_user(
claims={OIDC_CLAIM_GIVEN_NAME: "Jean", OIDC_CLAIM_FAMILY_NAME: "Fit"}
claims={
OIDC_CLAIM_GIVEN_NAME: "Jean",
OIDC_CLAIM_FAMILY_NAME: "Fit",
"affiliations": ["accessgroup1", "accessgroup2"],
}
)

for code_name in settings.OIDC_CLAIM_AFFILIATION:
accessgroup, group_created = AccessGroup.objects.get_or_create(code_name=code_name)
user.owner.accessgroup_set.add(accessgroup)

self.assertEqual(user.first_name, "Jean")
self.assertEqual(user.last_name, "Fit")
self.assertEqual(AccessGroup.objects.all().count(), 2)
self.assertEqual(user.owner.accessgroup_set.all().count(), 2)
self.assertTrue(user.owner.accessgroup_set.filter(code_name="accessgroup1").exists())
self.assertTrue(user.owner.accessgroup_set.filter(code_name="accessgroup2").exists())
self.assertTrue(
user.owner.accessgroup_set.filter(code_name="accessgroup1").exists()
)
self.assertTrue(
user.owner.accessgroup_set.filter(code_name="accessgroup2").exists()
)
print(
" ---> test_OIDC_user_with_multiple_claim_access_groups"
" of PopulatedOIDCTestCase: OK!"
Expand Down

0 comments on commit 10fa81b

Please sign in to comment.