Skip to content

Retrieve user identifier from form data #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions flask_multipass/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class AuthProvider(metaclass=SupportsMeta):
#: form in your application, specify a :class:`~flask_wtf.Form`
#: here (usually containing a username/email and a password field).
login_form = None
#: The field name in the login form that contains the identifier.
#: Useful to reliably retrieve identifier data in applications that use
#: multiple auth providers.
identifier_field_name = None

def __init__(self, multipass, name, settings):
self.multipass = multipass
Expand Down Expand Up @@ -113,5 +117,11 @@ def process_logout(self, return_url):
"""
return None

def get_identifier(self, data):
"""Get the user identifier from form data."""
if self.identifier_field_name is None:
raise NotImplementedError('No identifier field name set')
return data.get(self.identifier_field_name)

def __repr__(self):
return f'<{type(self).__name__}({self.name})>'
2 changes: 2 additions & 0 deletions flask_multipass/providers/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class SQLAlchemyAuthProviderBase(AuthProvider):
#: i.e. the username. This needs to be a SQLAlchemy column object,
#: e.g. ``Identity.identifier``
identifier_column = None
#: The field name in the ``login_form`` that contains the identifier
identifier_field_name = 'identifier'

def check_password(self, identity, password):
"""Checks the entered password.
Expand Down
1 change: 1 addition & 0 deletions flask_multipass/providers/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class StaticAuthProvider(AuthProvider):
"""

login_form = StaticLoginForm
identifier_field_name = 'username'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
9 changes: 6 additions & 3 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,10 @@ class NoProvider:
class InvalidProvider(AuthProvider, IdentityProvider):
pass

pytest.raises(TypeError, get_provider_base, NoProvider)
pytest.raises(TypeError, get_provider_base, InvalidProvider)
with pytest.raises(TypeError):
get_provider_base(NoProvider)
with pytest.raises(TypeError):
get_provider_base(InvalidProvider)


def test_login_view(mocker):
Expand Down Expand Up @@ -267,7 +269,8 @@ def test_validate_provider_map(valid, auth_providers, identity_providers, provid
if valid:
validate_provider_map(state)
else:
pytest.raises(ValueError, validate_provider_map, state)
with pytest.raises(ValueError):
validate_provider_map(state)


def test_classproperty():
Expand Down