diff --git a/tests/unit/base/test_apps.py b/tests/unit/base/test_apps.py new file mode 100644 index 0000000000..3747976efb --- /dev/null +++ b/tests/unit/base/test_apps.py @@ -0,0 +1,7 @@ +from base.apps import BaseConfig +from django.test import TestCase + + +class BaseConfigTest(TestCase): + def test_base_config(self): + self.assertEqual(BaseConfig.name, "base") diff --git a/tests/unit/base/test_utils.py b/tests/unit/base/test_utils.py index 1dec3aeea2..3701157fc0 100644 --- a/tests/unit/base/test_utils.py +++ b/tests/unit/base/test_utils.py @@ -1,10 +1,25 @@ import os +import unittest from datetime import timedelta +from unittest import TestCase +from unittest.mock import MagicMock, patch +import botocore import requests import responses from allauth.account.models import EmailAddress -from base.utils import RandomFileName, is_user_a_staff, send_slack_notification +from base.utils import ( + RandomFileName, + decode_data, + encode_data, + get_boto3_client, + get_or_create_sqs_queue, + get_url_from_hostname, + is_user_a_staff, + mock_if_non_prod_aws, + send_email, + send_slack_notification, +) from challenges.models import Challenge, ChallengePhase from django.conf import settings from django.contrib.auth.models import User @@ -17,6 +32,7 @@ from rest_framework.test import APIClient, APITestCase from scripts import seed +from settings.common import SQS_RETENTION_PERIOD class BaseAPITestClass(APITestCase): @@ -134,7 +150,6 @@ def test_if_slack_notification_works(self): class TestUserIsStaff(BaseAPITestClass): - def test_if_user_is_staff(self): self.user = User.objects.create( username="someuser1", @@ -154,3 +169,237 @@ def test_if_user_is_not_staff(self): self.user.is_staff = False self.user.save() self.assertFalse(is_user_a_staff(self.user)) + + +class TestGetUrlFromHostname(TestCase): + def test_debug_mode(self): + settings.DEBUG = True + hostname = "example.com" + expected_url = "http://example.com" + self.assertEqual(get_url_from_hostname(hostname), expected_url) + + def test_test_mode(self): + settings.TEST = True + hostname = "example.com" + expected_url = "http://example.com" + self.assertEqual(get_url_from_hostname(hostname), expected_url) + + def test_production_mode(self): + settings.DEBUG = False + settings.TEST = False + hostname = "example.com" + expected_url = "https://example.com" + self.assertEqual(get_url_from_hostname(hostname), expected_url) + + +class TestAwsReturnFunc(BaseAPITestClass): + def test_mock_if_non_prod_aws_returns_original_func_when_not_debug_or_test( + self, + ): + with patch("django.conf.settings.DEBUG", False): + with patch("django.conf.settings.TEST", False): + aws_mocker = MagicMock() + func = MagicMock(return_value="mocked_value") + decorated_func = mock_if_non_prod_aws(aws_mocker)(func) + result = decorated_func() + self.assertEqual(result, "mocked_value") + aws_mocker.assert_not_called() + + +class TestGetOrCreateSqsQueue(BaseAPITestClass): + @patch("base.utils.boto3.resource") + @patch("base.utils.settings.DEBUG", True) + @patch("base.utils.settings.TEST", False) + def test_debug_mode_queue_name(self, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + + get_or_create_sqs_queue(queue_name) + + mock_boto3.assert_called_with( + "sqs", + endpoint_url="http://sqs:9324", + region_name="us-east-1", + aws_secret_access_key="x", + aws_access_key_id="x", + ) + mock_sqs.get_queue_by_name.assert_called_with( + QueueName="evalai_submission_queue" + ) + + @patch("base.utils.boto3.resource") + @patch("base.utils.settings.DEBUG", False) + @patch("base.utils.settings.TEST", False) + def test_non_debug_non_test_mode_with_challenge(self, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + challenge = MagicMock() + challenge.use_host_sqs = True + challenge.queue_aws_region = "us-west-2" + challenge.aws_secret_access_key = "challenge_secret" + challenge.aws_access_key_id = "challenge_key" + + get_or_create_sqs_queue(queue_name, challenge) + + mock_boto3.assert_called_with( + "sqs", + region_name="us-west-2", + aws_secret_access_key="challenge_secret", + aws_access_key_id="challenge_key", + ) + mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name) + + @patch("base.utils.boto3.resource") + @patch("base.utils.settings.DEBUG", False) + @patch("base.utils.settings.TEST", False) + def test_non_debug_non_test_mode_without_challenge(self, mock_boto3): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + + with patch.dict( + "os.environ", + { + "AWS_DEFAULT_REGION": "us-east-1", + "AWS_SECRET_ACCESS_KEY": "env_secret", + "AWS_ACCESS_KEY_ID": "env_key", + }, + ): + get_or_create_sqs_queue(queue_name) + + mock_boto3.assert_called_with( + "sqs", + region_name="us-east-1", + aws_secret_access_key="env_secret", + aws_access_key_id="env_key", + ) + mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name) + + @patch("base.utils.boto3.resource") + @patch("base.utils.logger") + @patch("base.utils.settings") + def test_get_or_create_sqs_queue_exception_logging( + self, mock_settings, mock_logger, mock_boto3_resource + ): + mock_settings.DEBUG = False + mock_settings.TEST = False + mock_sqs = MagicMock() + mock_boto3_resource.return_value = mock_sqs + + error_response = { + "Error": {"Code": "SomeOtherError", "Message": "An error occurred"} + } + client_error = botocore.exceptions.ClientError( + error_response, "GetQueueUrl" + ) + + mock_sqs.get_queue_by_name.side_effect = client_error + + queue_name = "test_queue" + get_or_create_sqs_queue(queue_name) + + mock_logger.exception.assert_called_once_with( + "Cannot get queue: {}".format(queue_name) + ) + + @patch("base.utils.boto3.resource") + @patch("base.utils.settings.DEBUG", False) + @patch("base.utils.settings.TEST", False) + @patch("base.utils.logger") + def test_queue_creation_on_non_existent_queue( + self, mock_logger, mock_boto3 + ): + mock_sqs = MagicMock() + mock_boto3.return_value = mock_sqs + queue_name = "test_queue" + challenge = None + mock_sqs.get_queue_by_name.side_effect = ( + botocore.exceptions.ClientError( + {"Error": {"Code": "AWS.SimpleQueueService.NonExistentQueue"}}, + "GetQueueUrl", + ) + ) + + get_or_create_sqs_queue(queue_name, challenge) + + mock_sqs.create_queue.assert_called_with( + QueueName=queue_name, + Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD}, + ) + mock_logger.exception.assert_not_called() + + +class TestSendEmail(unittest.TestCase): + @patch("base.utils.sendgrid.SendGridAPIClient") + @patch("base.utils.os.environ.get") + @patch("base.utils.logger") + def test_send_email_success( + self, mock_logger, mock_get_env, mock_sendgrid_client + ): + mock_get_env.return_value = "fake_api_key" + mock_sg_instance = MagicMock() + mock_sendgrid_client.return_value = mock_sg_instance + + send_email( + sender="sender@example.com", + recipient="recipient@example.com", + template_id="template_id", + template_data={"key": "value"}, + ) + + mock_sendgrid_client.assert_called_once_with(api_key="fake_api_key") + mock_sg_instance.client.mail.send.post.assert_called_once() + mock_logger.warning.assert_not_called() + + @patch("base.utils.sendgrid.SendGridAPIClient") + @patch("base.utils.os.environ.get") + @patch("base.utils.logger") + def test_send_email_exception( + self, mock_logger, mock_get_env, mock_sendgrid_client + ): + # Mock environment variable + mock_get_env.return_value = "fake_api_key" + mock_sendgrid_client.side_effect = Exception("SendGrid error") + + send_email( + sender="sender@example.com", + recipient="recipient@example.com", + template_id="template_id", + template_data={"key": "value"}, + ) + + mock_logger.warning.assert_called_once_with( + "Cannot make sendgrid call. " + "Please check if SENDGRID_API_KEY is present." + ) + + +class TestGetBoto3Client(unittest.TestCase): + @patch("base.utils.boto3.client") + @patch("base.utils.logger") + def test_get_boto3_client_exception(self, mock_logger, mock_boto3_client): + mock_boto3_client.side_effect = Exception("Boto3 error") + + aws_keys = { + "AWS_REGION": "us-west-2", + "AWS_ACCESS_KEY_ID": "fake_access_key_id", + "AWS_SECRET_ACCESS_KEY": "fake_secret_access_key", + } + + get_boto3_client("s3", aws_keys) + mock_logger.exception.assert_called_once() + args, kwargs = mock_logger.exception.call_args + self.assertIsInstance(args[0], Exception) + self.assertEqual(str(args[0]), "Boto3 error") + + +class TestDataEncoding(unittest.TestCase): + def test_encode_data_empty_list(self): + data = [] + self.assertEqual(encode_data(data), []) + + def test_decode_data_empty_list(self): + data = [] + self.assertEqual(decode_data(data), [])