From f969b438bc642bc639e000d120039bd65b8d19ab Mon Sep 17 00:00:00 2001 From: hj Date: Tue, 6 Aug 2024 04:09:43 +0530 Subject: [PATCH 1/3] Add base test --- tests/unit/base/test_apps.py | 6 + tests/unit/base/test_seed.py | 16 +++ tests/unit/base/test_utils.py | 231 +++++++++++++++++++++++++++++++++- 3 files changed, 249 insertions(+), 4 deletions(-) create mode 100644 tests/unit/base/test_apps.py create mode 100644 tests/unit/base/test_seed.py diff --git a/tests/unit/base/test_apps.py b/tests/unit/base/test_apps.py new file mode 100644 index 0000000000..c447e5920c --- /dev/null +++ b/tests/unit/base/test_apps.py @@ -0,0 +1,6 @@ +from django.test import TestCase +from base.apps import BaseConfig + +class BaseConfigTest(TestCase): + def test_base_config(self): + self.assertEqual(BaseConfig.name, "base") diff --git a/tests/unit/base/test_seed.py b/tests/unit/base/test_seed.py new file mode 100644 index 0000000000..4b7ed1c262 --- /dev/null +++ b/tests/unit/base/test_seed.py @@ -0,0 +1,16 @@ +import unittest +from unittest.mock import MagicMock +from base.management.commands.seed import Command + +class TestSeedCommand(unittest.TestCase): + def test_add_arguments_adds_nc_argument(self): + parser = MagicMock() + command = Command() + command.add_arguments(parser) + parser.add_argument.assert_called_once_with( + '-nc', nargs='?', default=20, type=int, help='Number of challenges.') + + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/unit/base/test_utils.py b/tests/unit/base/test_utils.py index 8dd6758e77..171dce08ae 100644 --- a/tests/unit/base/test_utils.py +++ b/tests/unit/base/test_utils.py @@ -1,4 +1,6 @@ import os +from unittest import TestCase +import unittest import requests import responses @@ -8,18 +10,30 @@ from django.contrib.auth.models import User from django.core.files.uploadedfile import SimpleUploadedFile from django.utils import timezone - +import botocore from allauth.account.models import EmailAddress from rest_framework import status from rest_framework.test import APITestCase, APIClient -from base.utils import RandomFileName, send_slack_notification, is_user_a_staff +from base.utils import ( + RandomFileName, + decode_data, + encode_data, + send_slack_notification, + is_user_a_staff, get_url_from_hostname, + send_email, + mock_if_non_prod_aws, + get_or_create_sqs_queue, + get_boto3_client + ) from challenges.models import Challenge, ChallengePhase from hosts.models import ChallengeHostTeam from jobs.models import Submission from participants.models import Participant, ParticipantTeam - +from unittest.mock import MagicMock, patch +from django.test import override_settings from scripts import seed +from settings.common import SQS_RETENTION_PERIOD class BaseAPITestClass(APITestCase): @@ -137,7 +151,6 @@ def test_if_slack_notification_works(self): class TestUserIsStaff(BaseAPITestClass): - def test_if_user_is_staff(self): self.user = User.objects.create( username="someuser1", @@ -157,3 +170,213 @@ def test_if_user_is_not_staff(self): self.user.is_staff = False self.user.save() self.assertFalse(is_user_a_staff(self.user)) + + +class TestGetUrlFromHostname(TestCase): + def test_debug_mode(self): + settings.DEBUG = True + hostname = 'example.com' + expected_url = 'http://example.com' + self.assertEqual(get_url_from_hostname(hostname), expected_url) + + def test_test_mode(self): + settings.TEST = True + hostname = 'example.com' + expected_url = 'http://example.com' + self.assertEqual(get_url_from_hostname(hostname), expected_url) + + def test_production_mode(self): + settings.DEBUG = False + settings.TEST = False + hostname = 'example.com' + expected_url = 'https://example.com' + self.assertEqual(get_url_from_hostname(hostname), expected_url) + + +class TestAwsReturnFunc(BaseAPITestClass): + def test_mock_if_non_prod_aws_returns_original_func_when_not_debug_or_test(self): + with patch('django.conf.settings.DEBUG', False): + with patch('django.conf.settings.TEST', False): + aws_mocker = MagicMock() + func = MagicMock(return_value="mocked_value") + decorated_func = mock_if_non_prod_aws(aws_mocker)(func) + result = decorated_func() + self.assertEqual(result, "mocked_value") + aws_mocker.assert_not_called() + + +class TestGetOrCreateSqsQueue(BaseAPITestClass): + @patch('base.utils.boto3.resource') + @patch('base.utils.settings.DEBUG', True) + @patch('base.utils.settings.TEST', False) + def test_debug_mode_queue_name(self, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + + get_or_create_sqs_queue(queue_name) + + mock_boto3.assert_called_with( + "sqs", + endpoint_url="http://sqs:9324", + region_name="us-east-1", + aws_secret_access_key="x", + aws_access_key_id="x" + ) + mock_sqs.get_queue_by_name.assert_called_with(QueueName="evalai_submission_queue") + + @patch('base.utils.boto3.resource') + @patch('base.utils.settings.DEBUG', False) + @patch('base.utils.settings.TEST', False) + def test_non_debug_non_test_mode_with_challenge(self, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + challenge = MagicMock() + challenge.use_host_sqs = True + challenge.queue_aws_region = "us-west-2" + challenge.aws_secret_access_key = "challenge_secret" + challenge.aws_access_key_id = "challenge_key" + + get_or_create_sqs_queue(queue_name, challenge) + + mock_boto3.assert_called_with( + "sqs", + region_name="us-west-2", + aws_secret_access_key="challenge_secret", + aws_access_key_id="challenge_key" + ) + mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name) + + @patch('base.utils.boto3.resource') + @patch('base.utils.settings.DEBUG', False) + @patch('base.utils.settings.TEST', False) + def test_non_debug_non_test_mode_without_challenge(self, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + + with patch.dict('os.environ', { + 'AWS_DEFAULT_REGION': 'us-east-1', + 'AWS_SECRET_ACCESS_KEY': 'env_secret', + 'AWS_ACCESS_KEY_ID': 'env_key' + }): + get_or_create_sqs_queue(queue_name) + + mock_boto3.assert_called_with( + "sqs", + region_name="us-east-1", + aws_secret_access_key="env_secret", + aws_access_key_id="env_key" + ) + mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name) + + @patch('base.utils.boto3.resource') + @patch('base.utils.logger') + @patch('base.utils.settings') + def test_get_or_create_sqs_queue_exception_logging(self, mock_settings, mock_logger, mock_boto3_resource): + mock_settings.DEBUG = False + mock_settings.TEST = False + mock_sqs = MagicMock() + mock_boto3_resource.return_value = mock_sqs + + error_response = {'Error': {'Code': 'SomeOtherError', 'Message': 'An error occurred'}} + client_error = botocore.exceptions.ClientError(error_response, 'GetQueueUrl') + + mock_sqs.get_queue_by_name.side_effect = client_error + + queue_name = "test_queue" + get_or_create_sqs_queue(queue_name) + + mock_logger.exception.assert_called_once_with("Cannot get queue: {}".format(queue_name)) + + @patch('base.utils.boto3.resource') + @patch('base.utils.settings.DEBUG', False) + @patch('base.utils.settings.TEST', False) + @patch('base.utils.logger') + def test_queue_creation_on_non_existent_queue(self, mock_logger, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + challenge = None + mock_sqs.get_queue_by_name.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AWS.SimpleQueueService.NonExistentQueue"}}, "GetQueueUrl" + ) + + get_or_create_sqs_queue(queue_name, challenge) + + mock_sqs.create_queue.assert_called_with( + QueueName=queue_name, + Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD} + ) + mock_logger.exception.assert_not_called() + + +class TestSendEmail(unittest.TestCase): + + @patch('base.utils.sendgrid.SendGridAPIClient') + @patch('base.utils.os.environ.get') + @patch('base.utils.logger') + def test_send_email_success(self, mock_logger, mock_get_env, mock_sendgrid_client): + mock_get_env.return_value = 'fake_api_key' + mock_sg_instance = MagicMock() + mock_sendgrid_client.return_value = mock_sg_instance + + send_email( + sender='sender@example.com', + recipient='recipient@example.com', + template_id='template_id', + template_data={'key': 'value'} + ) + + mock_sendgrid_client.assert_called_once_with(api_key='fake_api_key') + mock_sg_instance.client.mail.send.post.assert_called_once() + mock_logger.warning.assert_not_called() + + @patch('base.utils.sendgrid.SendGridAPIClient') + @patch('base.utils.os.environ.get') + @patch('base.utils.logger') + def test_send_email_exception(self, mock_logger, mock_get_env, mock_sendgrid_client): + # Mock environment variable + mock_get_env.return_value = 'fake_api_key' + mock_sendgrid_client.side_effect = Exception('SendGrid error') + + send_email( + sender='sender@example.com', + recipient='recipient@example.com', + template_id='template_id', + template_data={'key': 'value'} + ) + + mock_logger.warning.assert_called_once_with( + "Cannot make sendgrid call. Please check if SENDGRID_API_KEY is present." + ) + + +class TestGetBoto3Client(unittest.TestCase): + @patch('base.utils.boto3.client') + @patch('base.utils.logger') + def test_get_boto3_client_exception(self, mock_logger, mock_boto3_client): + mock_boto3_client.side_effect = Exception('Boto3 error') + + aws_keys = { + "AWS_REGION": "us-west-2", + "AWS_ACCESS_KEY_ID": "fake_access_key_id", + "AWS_SECRET_ACCESS_KEY": "fake_secret_access_key" + } + + get_boto3_client('s3', aws_keys) + mock_logger.exception.assert_called_once() + args, kwargs = mock_logger.exception.call_args + self.assertIsInstance(args[0], Exception) + self.assertEqual(str(args[0]), 'Boto3 error') + + +class TestDataEncoding(unittest.TestCase): + def test_encode_data_empty_list(self): + data = [] + self.assertEqual(encode_data(data), []) + + def test_decode_data_empty_list(self): + data = [] + self.assertEqual(decode_data(data), []) From e149da8c1e0559ce84252fc299285042e382cc60 Mon Sep 17 00:00:00 2001 From: hj Date: Tue, 6 Aug 2024 05:06:35 +0530 Subject: [PATCH 2/3] Build Fix --- tests/unit/base/test_apps.py | 1 + tests/unit/base/test_seed.py | 9 +++----- tests/unit/base/test_utils.py | 40 +++++++++++++++++------------------ 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/tests/unit/base/test_apps.py b/tests/unit/base/test_apps.py index c447e5920c..80d2bcd929 100644 --- a/tests/unit/base/test_apps.py +++ b/tests/unit/base/test_apps.py @@ -1,6 +1,7 @@ from django.test import TestCase from base.apps import BaseConfig + class BaseConfigTest(TestCase): def test_base_config(self): self.assertEqual(BaseConfig.name, "base") diff --git a/tests/unit/base/test_seed.py b/tests/unit/base/test_seed.py index 4b7ed1c262..d63832f716 100644 --- a/tests/unit/base/test_seed.py +++ b/tests/unit/base/test_seed.py @@ -2,15 +2,12 @@ from unittest.mock import MagicMock from base.management.commands.seed import Command + class TestSeedCommand(unittest.TestCase): def test_add_arguments_adds_nc_argument(self): parser = MagicMock() command = Command() command.add_arguments(parser) parser.add_argument.assert_called_once_with( - '-nc', nargs='?', default=20, type=int, help='Number of challenges.') - - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + '-nc', nargs='?', default=20, type=int, help='Number of challenges.' + ) \ No newline at end of file diff --git a/tests/unit/base/test_utils.py b/tests/unit/base/test_utils.py index 171dce08ae..4581e956cb 100644 --- a/tests/unit/base/test_utils.py +++ b/tests/unit/base/test_utils.py @@ -18,20 +18,20 @@ from base.utils import ( RandomFileName, decode_data, - encode_data, - send_slack_notification, - is_user_a_staff, get_url_from_hostname, - send_email, - mock_if_non_prod_aws, + encode_data, + send_slack_notification, + is_user_a_staff, + get_url_from_hostname, + send_email, + mock_if_non_prod_aws, get_or_create_sqs_queue, - get_boto3_client - ) + get_boto3_client, +) from challenges.models import Challenge, ChallengePhase from hosts.models import ChallengeHostTeam from jobs.models import Submission from participants.models import Participant, ParticipantTeam from unittest.mock import MagicMock, patch -from django.test import override_settings from scripts import seed from settings.common import SQS_RETENTION_PERIOD @@ -213,9 +213,9 @@ def test_debug_mode_queue_name(self, mock_boto3): mock_sqs = MagicMock() mock_boto3.return_value = mock_sqs queue_name = "test_queue" - + get_or_create_sqs_queue(queue_name) - + mock_boto3.assert_called_with( "sqs", endpoint_url="http://sqs:9324", @@ -237,9 +237,9 @@ def test_non_debug_non_test_mode_with_challenge(self, mock_boto3): challenge.queue_aws_region = "us-west-2" challenge.aws_secret_access_key = "challenge_secret" challenge.aws_access_key_id = "challenge_key" - + get_or_create_sqs_queue(queue_name, challenge) - + mock_boto3.assert_called_with( "sqs", region_name="us-west-2", @@ -255,14 +255,14 @@ def test_non_debug_non_test_mode_without_challenge(self, mock_boto3): mock_sqs = MagicMock() mock_boto3.return_value = mock_sqs queue_name = "test_queue" - + with patch.dict('os.environ', { 'AWS_DEFAULT_REGION': 'us-east-1', 'AWS_SECRET_ACCESS_KEY': 'env_secret', 'AWS_ACCESS_KEY_ID': 'env_key' }): get_or_create_sqs_queue(queue_name) - + mock_boto3.assert_called_with( "sqs", region_name="us-east-1", @@ -279,15 +279,15 @@ def test_get_or_create_sqs_queue_exception_logging(self, mock_settings, mock_log mock_settings.TEST = False mock_sqs = MagicMock() mock_boto3_resource.return_value = mock_sqs - + error_response = {'Error': {'Code': 'SomeOtherError', 'Message': 'An error occurred'}} client_error = botocore.exceptions.ClientError(error_response, 'GetQueueUrl') - + mock_sqs.get_queue_by_name.side_effect = client_error - + queue_name = "test_queue" get_or_create_sqs_queue(queue_name) - + mock_logger.exception.assert_called_once_with("Cannot get queue: {}".format(queue_name)) @patch('base.utils.boto3.resource') @@ -302,9 +302,9 @@ def test_queue_creation_on_non_existent_queue(self, mock_logger, mock_boto3): mock_sqs.get_queue_by_name.side_effect = botocore.exceptions.ClientError( {"Error": {"Code": "AWS.SimpleQueueService.NonExistentQueue"}}, "GetQueueUrl" ) - + get_or_create_sqs_queue(queue_name, challenge) - + mock_sqs.create_queue.assert_called_with( QueueName=queue_name, Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD} From 90fdb029ab452ece67493f772ebd455323c87d79 Mon Sep 17 00:00:00 2001 From: hj Date: Tue, 6 Aug 2024 05:33:59 +0530 Subject: [PATCH 3/3] buld fix --- tests/unit/base/test_seed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/base/test_seed.py b/tests/unit/base/test_seed.py index d63832f716..76cb16edc9 100644 --- a/tests/unit/base/test_seed.py +++ b/tests/unit/base/test_seed.py @@ -10,4 +10,4 @@ def test_add_arguments_adds_nc_argument(self): command.add_arguments(parser) parser.add_argument.assert_called_once_with( '-nc', nargs='?', default=20, type=int, help='Number of challenges.' - ) \ No newline at end of file + )