Skip to content

Commit

Permalink
Correct test_populated to test presence of group from OIDC_DEFAULT_AF…
Browse files Browse the repository at this point in the history
…FILIATION
  • Loading branch information
Badatos committed Feb 5, 2025
1 parent 10fa81b commit e62bb86
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
16 changes: 9 additions & 7 deletions pod/authentication/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def update_owner_params(user, params) -> None:
OIDC_DEFAULT_AFFILIATION = getattr(
settings, "OIDC_DEFAULT_AFFILIATION", DEFAULT_AFFILIATION
)
OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES = getattr(
settings, "OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES", []
)

OIDC_CLAIM_AFFILIATION = getattr(settings, "OIDC_CLAIM_AFFILIATION", "affiliations")


Expand Down Expand Up @@ -127,9 +125,13 @@ def _assign_affiliations(self, user, claims) -> None:
for affiliation in affiliations:
self._add_access_group(user, affiliation)

if not affiliations:
for code_name in OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES:
self._safe_add_access_group(user, code_name)
OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES = getattr(
settings, "OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES", []
)

# Add default access groups
for code_name in OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES:
self._safe_assign_access_group(user, code_name)

user.owner.affiliation = (
affiliations[0] if affiliations else OIDC_DEFAULT_AFFILIATION
Expand All @@ -145,7 +147,7 @@ def _add_access_group(self, user, affiliation) -> None:
accessgroup.save()
user.owner.accessgroup_set.add(accessgroup)

def _safe_add_access_group(self, user, code_name) -> None:
def _safe_assign_access_group(self, user, code_name) -> None:
"""Safely add an access group if it exists."""
try:
user.owner.accessgroup_set.add(AccessGroup.objects.get(code_name=code_name))
Expand Down
42 changes: 29 additions & 13 deletions pod/authentication/tests/test_populated.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ def test_make_profile(self) -> None:
# ggignore-end
OIDC_OP_TOKEN_ENDPOINT="https://auth.server.com/oauth/token",
OIDC_OP_USER_ENDPOINT="https://auth.server.com/oauth/userinfo",
OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES=["specific", "unique"],
)
class PopulatedOIDCTestCase(TestCase):
@override_settings(OIDC_DEFAULT_AFFILIATION=UNSTAFFABLE_AFFILIATION)
Expand Down Expand Up @@ -583,9 +582,15 @@ def test_OIDC_django_admin_user_with_default_staff_affiliation(self) -> None:
" of PopulatedOIDCTestCase: OK!"
)

@override_settings(
OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES=["specific"],
)
def test_OIDC_user_with_default_access_group(self) -> None:
backends.OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES = ["specific"]
for code_name in settings.OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES + [
ACCESS_GROUP_CODE_NAMES = getattr(
settings, "OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES", []
)
self.assertEqual(AccessGroup.objects.all().count(), 0)
for code_name in ACCESS_GROUP_CODE_NAMES + [
"useless",
"dull",
]:
Expand All @@ -597,18 +602,26 @@ def test_OIDC_user_with_default_access_group(self) -> None:
)
self.assertEqual(user.first_name, "Jean")
self.assertEqual(user.last_name, "Fit")
self.assertEqual(AccessGroup.objects.all().count(), 3)
self.assertEqual(user.owner.accessgroup_set.all().count(), 1)

self.assertEqual(AccessGroup.objects.all().count(), 4)
# Assert that user has DEFAULT_ACCESS_GROUP_CODE_NAMES + DEFAULT_AFFILIATION as accessgroup_set
self.assertEqual(user.owner.accessgroup_set.all().count(), 2)
# OIDC new user should have the specified access group from settings
self.assertTrue(user.owner.accessgroup_set.filter(code_name="specific").exists())
self.assertTrue(
user.owner.accessgroup_set.filter(code_name="specific").exists()
) # OIDC new user should have the specified access group from in settings
user.owner.accessgroup_set.filter(
code_name=backends.OIDC_DEFAULT_AFFILIATION
).exists()
)
print(
" ---> test_OIDC_user_with_default_access_group"
" of PopulatedOIDCTestCase: OK!"
)

@override_settings(
OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES=["specific", "unique"],
)
def test_OIDC_user_with_multiple_default_access_groups(self) -> None:
backends.OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES = ["specific", "unique"]
for code_name in settings.OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES + ["dull"]:
AccessGroup.objects.create(
code_name=code_name, display_name=f"Access group {code_name}"
Expand All @@ -618,21 +631,24 @@ def test_OIDC_user_with_multiple_default_access_groups(self) -> None:
)
self.assertEqual(user.first_name, "Jean")
self.assertEqual(user.last_name, "Fit")
self.assertEqual(AccessGroup.objects.all().count(), 3)
self.assertEqual(AccessGroup.objects.all().count(), 4)
self.assertEqual(
user.owner.accessgroup_set.all().count(), 2
) # OIDC new user should have 2 access groups
user.owner.accessgroup_set.all().count(), 3
) # OIDC new user should have 3 access groups
self.assertTrue(user.owner.accessgroup_set.filter(code_name="specific").exists())
self.assertTrue(user.owner.accessgroup_set.filter(code_name="unique").exists())
self.assertTrue(
user.owner.accessgroup_set.filter(
code_name=backends.OIDC_DEFAULT_AFFILIATION
).exists()
)
print(
" ---> test_OIDC_user_with_multiple_default_access_groups"
" of PopulatedOIDCTestCase: OK!"
)

@override_settings(OIDC_CLAIM_AFFILIATION="affiliations")
def test_OIDC_user_with_multiple_claim_access_groups(self) -> None:
"""Test if user is added to access groups from OIDC affiliations."""

user = OIDCBackend().create_user(
claims={
OIDC_CLAIM_GIVEN_NAME: "Jean",
Expand Down
4 changes: 2 additions & 2 deletions pod/main/configuration.json
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@
]
},
"pod_version_end": "",
"pod_version_init": "3.1"
"pod_version_init": "4.0"
},
"OIDC_DEFAULT_ACCESS_GROUP_CODE_NAMES": {
"default_value": "[]",
Expand Down Expand Up @@ -1191,7 +1191,7 @@
"pod_version_end": "",
"pod_version_init": ""
}

},
"title": {
"en": "Seaker application configuration",
Expand Down

0 comments on commit e62bb86

Please sign in to comment.