diff --git a/judge/models/contest.py b/judge/models/contest.py index 6f6ac1693..02b69a10b 100644 --- a/judge/models/contest.py +++ b/judge/models/contest.py @@ -267,7 +267,11 @@ def can_see_full_scoreboard(self, user): return True if user.profile.id in self.editor_ids: return True - if self.view_contest_scoreboard.filter(id=user.profile.id).exists(): + cache = getattr(self, '_prefetched_objects_cache', {}) + if 'view_contest_scoreboard' in cache: + if any(p.id == user.profile.id for p in self.view_contest_scoreboard.all()): + return True + elif self.view_contest_scoreboard.filter(id=user.profile.id).exists(): return True if self.scoreboard_visibility == self.SCOREBOARD_AFTER_PARTICIPATION and self.has_completed_contest(user): return True @@ -377,12 +381,27 @@ def ended(self): @cached_property def author_ids(self): + cache = getattr(self, '_prefetched_objects_cache', {}) + if 'authors' in cache: + return frozenset(p.id for p in self.authors.all()) return Contest.authors.through.objects.filter(contest=self).values_list('profile_id', flat=True) @cached_property def editor_ids(self): - return self.author_ids.union( - Contest.curators.through.objects.filter(contest=self).values_list('profile_id', flat=True)) + cache = getattr(self, '_prefetched_objects_cache', {}) + if 'authors' in cache: + author_ids = frozenset(p.id for p in self.authors.all()) + else: + author_ids = frozenset( + Contest.authors.through.objects.filter(contest=self).values_list('profile_id', flat=True), + ) + + if 'curators' in cache: + return author_ids | frozenset(p.id for p in self.curators.all()) + + return author_ids | frozenset( + Contest.curators.through.objects.filter(contest=self).values_list('profile_id', flat=True), + ) @cached_property def tester_ids(self): diff --git a/judge/tests/test_contest_list_perf.py b/judge/tests/test_contest_list_perf.py new file mode 100644 index 000000000..cb78f91d1 --- /dev/null +++ b/judge/tests/test_contest_list_perf.py @@ -0,0 +1,51 @@ +from datetime import timedelta + +from django.contrib.auth import get_user_model +from django.db import connection, reset_queries +from django.test import Client, TestCase, override_settings +from django.utils import timezone + +from judge.models import Contest, Profile + +User = get_user_model() + +NUM_CONTESTS = 10 +QUERY_BUDGET = 35 + + +class ContestListPerfTest(TestCase): + """Ensure ContestList does not generate N+1 queries for hidden-scoreboard contests.""" + + @classmethod + def setUpTestData(cls): + cls.user = User.objects.create_user(username='perf_user', password='pass') + Profile.objects.create(user=cls.user) + + now = timezone.now() + for i in range(NUM_CONTESTS): + Contest.objects.create( + key=f'hidden_contest_{i}', + name=f'Hidden Contest {i}', + start_time=now - timedelta(hours=2), + end_time=now - timedelta(hours=1), + is_visible=True, + scoreboard_visibility=Contest.SCOREBOARD_HIDDEN, + ) + + def setUp(self): + self.client = Client() + self.client.login(username='perf_user', password='pass') + + def test_query_count_bounded(self): + """Query count must stay within budget regardless of number of hidden-scoreboard contests.""" + with override_settings(DEBUG=True): + reset_queries() + response = self.client.get('/contests/') + query_count = len(connection.queries) + + self.assertEqual(response.status_code, 200) + print(f'\n [{NUM_CONTESTS} hidden-scoreboard contests] query count: {query_count} (budget: {QUERY_BUDGET})') + self.assertLessEqual( + query_count, QUERY_BUDGET, + msg=f'Expected <= {QUERY_BUDGET} queries for {NUM_CONTESTS} contests, got {query_count}.', + ) diff --git a/judge/views/contests.py b/judge/views/contests.py index 1b92563d2..1673bba57 100644 --- a/judge/views/contests.py +++ b/judge/views/contests.py @@ -104,7 +104,9 @@ def _now(self): return timezone.now() def _get_queryset(self): - return super().get_queryset().prefetch_related('tags', 'organization', 'authors', 'curators', 'testers') + return super().get_queryset().prefetch_related( + 'tags', 'organization', 'authors', 'curators', 'testers', 'view_contest_scoreboard', + ) def get_queryset(self): self.search_query = None @@ -133,7 +135,8 @@ def get_context_data(self, **kwargs): for participation in ContestParticipation.objects.filter(virtual=0, user=self.request.profile, contest_id__in=present) \ .select_related('contest') \ - .prefetch_related('contest__authors', 'contest__curators', 'contest__testers') \ + .prefetch_related('contest__authors', 'contest__curators', 'contest__testers', + 'contest__view_contest_scoreboard') \ .annotate(key=F('contest__key')): if participation.ended: finished.add(participation.contest.key)