Skip to content

Commit 5832115

Browse files
error query param in reautentication
1 parent 12d2bec commit 5832115

File tree

4 files changed

+35
-16
lines changed

4 files changed

+35
-16
lines changed

tests/unit/accounts/test_views.py

+2
Original file line numberDiff line numberDiff line change
@@ -3423,12 +3423,14 @@ def test_reauth(self, monkeypatch, pyramid_request, pyramid_services, next_route
34233423
pyramid_request.matched_route = pretend.stub(name=pretend.stub())
34243424
pyramid_request.matchdict = {"foo": "bar"}
34253425
pyramid_request.GET = pretend.stub(mixed=lambda: {"baz": "bar"})
3426+
pyramid_request.params = {}
34263427

34273428
form_obj = pretend.stub(
34283429
next_route=pretend.stub(data=next_route),
34293430
next_route_matchdict=pretend.stub(data="{}"),
34303431
next_route_query=pretend.stub(data="{}"),
34313432
validate=lambda: True,
3433+
password=pretend.stub(errors=[]),
34323434
)
34333435
form_class = pretend.call_recorder(lambda d, **kw: form_obj)
34343436

tests/unit/manage/test_init.py

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_reauth(self, monkeypatch, require_reauth, needs_reauth_calls):
6666
session=pretend.stub(
6767
needs_reauthentication=pretend.call_recorder(lambda *args: True)
6868
),
69+
params={},
6970
user=pretend.stub(username=pretend.stub()),
7071
matched_route=pretend.stub(name=pretend.stub()),
7172
matchdict={"foo": "bar"},

warehouse/accounts/views.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,6 @@ def profile_public_email(user, request):
15151515

15161516
@view_config(
15171517
route_name="accounts.reauthenticate",
1518-
renderer="re-auth.html",
15191518
uses_session=True,
15201519
require_csrf=True,
15211520
require_methods=False,
@@ -1542,27 +1541,34 @@ def reauthenticate(request, _form_class=ReAuthenticateForm):
15421541
],
15431542
)
15441543

1545-
if form.next_route.data and form.next_route_matchdict.data:
1546-
redirect_to = request.route_path(
1547-
form.next_route.data,
1548-
**json.loads(form.next_route_matchdict.data)
1549-
| dict(_query=json.loads(form.next_route_query.data)),
1550-
)
1551-
else:
1552-
redirect_to = request.route_path("manage.projects")
1544+
next_route = form.next_route.data or "manage.projects"
1545+
next_route_matchdict = json.loads(form.next_route_matchdict.data or "{}")
1546+
next_route_query = json.loads(form.next_route_query.data or "{}")
15531547

1554-
resp = HTTPSeeOther(redirect_to)
1548+
is_valid = form.validate()
15551549

1556-
if request.method == "POST" and form.validate():
1550+
# Ensure errors don't persist across successful validations
1551+
next_route_query.pop("errors", None)
1552+
1553+
if request.method == "POST" and is_valid:
15571554
request.session.record_auth_timestamp()
15581555
request.session.record_password_timestamp(
15591556
user_service.get_password_timestamp(request.user.id)
15601557
)
1561-
return resp
1558+
else:
1559+
# Inject password errors into query if validation failed
1560+
if form.password.errors:
1561+
next_route_query["errors"] = json.dumps({
1562+
"password": [str(e) for e in form.password.errors]
1563+
})
1564+
1565+
redirect_to = request.route_path(
1566+
next_route,
1567+
**next_route_matchdict,
1568+
_query=next_route_query,
1569+
)
15621570

1563-
return {
1564-
"form": form,
1565-
}
1571+
return HTTPSeeOther(redirect_to)
15661572

15671573

15681574
@view_defaults(

warehouse/manage/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def reauth_view(view, info):
3535
def wrapped(context, request):
3636
if request.session.needs_reauthentication(time_to_reauth):
3737
user_service = request.find_service(IUserService, context=None)
38-
3938
form = ReAuthenticateForm(
4039
request.POST,
4140
request=request,
@@ -45,6 +44,17 @@ def wrapped(context, request):
4544
next_route_query=json.dumps(request.GET.mixed()),
4645
user_service=user_service,
4746
)
47+
errors_param = request.params.get("errors")
48+
if errors_param:
49+
try:
50+
parsed_errors = json.loads(errors_param)
51+
for field_name, messages in parsed_errors.items():
52+
field = getattr(form, field_name, None)
53+
if field is not None and hasattr(field, "errors"):
54+
field.errors = list(messages)
55+
except (ValueError, TypeError):
56+
# log or ignore bad JSON
57+
pass
4858

4959
return render_to_response(
5060
"re-auth.html",

0 commit comments

Comments
 (0)