Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/unit/base/test_apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from base.apps import BaseConfig
from django.test import TestCase


class BaseConfigTest(TestCase):
def test_base_config(self):
self.assertEqual(BaseConfig.name, "base")
253 changes: 251 additions & 2 deletions tests/unit/base/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
import os
import unittest
from datetime import timedelta
from unittest import TestCase
from unittest.mock import MagicMock, patch

import botocore
import requests
import responses
from allauth.account.models import EmailAddress
from base.utils import RandomFileName, is_user_a_staff, send_slack_notification
from base.utils import (
RandomFileName,
decode_data,
encode_data,
get_boto3_client,
get_or_create_sqs_queue,
get_url_from_hostname,
is_user_a_staff,
mock_if_non_prod_aws,
send_email,
send_slack_notification,
)
from challenges.models import Challenge, ChallengePhase
from django.conf import settings
from django.contrib.auth.models import User
Expand All @@ -17,6 +32,7 @@
from rest_framework.test import APIClient, APITestCase

from scripts import seed
from settings.common import SQS_RETENTION_PERIOD


class BaseAPITestClass(APITestCase):
Expand Down Expand Up @@ -134,7 +150,6 @@ def test_if_slack_notification_works(self):


class TestUserIsStaff(BaseAPITestClass):

def test_if_user_is_staff(self):
self.user = User.objects.create(
username="someuser1",
Expand All @@ -154,3 +169,237 @@ def test_if_user_is_not_staff(self):
self.user.is_staff = False
self.user.save()
self.assertFalse(is_user_a_staff(self.user))


class TestGetUrlFromHostname(TestCase):
def test_debug_mode(self):
settings.DEBUG = True
hostname = "example.com"
expected_url = "http://example.com"
self.assertEqual(get_url_from_hostname(hostname), expected_url)

def test_test_mode(self):
settings.TEST = True
hostname = "example.com"
expected_url = "http://example.com"
self.assertEqual(get_url_from_hostname(hostname), expected_url)

def test_production_mode(self):
settings.DEBUG = False
settings.TEST = False
hostname = "example.com"
expected_url = "https://example.com"
self.assertEqual(get_url_from_hostname(hostname), expected_url)


class TestAwsReturnFunc(BaseAPITestClass):
def test_mock_if_non_prod_aws_returns_original_func_when_not_debug_or_test(
self,
):
with patch("django.conf.settings.DEBUG", False):
with patch("django.conf.settings.TEST", False):
aws_mocker = MagicMock()
func = MagicMock(return_value="mocked_value")
decorated_func = mock_if_non_prod_aws(aws_mocker)(func)
result = decorated_func()
self.assertEqual(result, "mocked_value")
aws_mocker.assert_not_called()


class TestGetOrCreateSqsQueue(BaseAPITestClass):
@patch("base.utils.boto3.resource")
@patch("base.utils.settings.DEBUG", True)
@patch("base.utils.settings.TEST", False)
def test_debug_mode_queue_name(self, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"

get_or_create_sqs_queue(queue_name)

mock_boto3.assert_called_with(
"sqs",
endpoint_url="http://sqs:9324",
region_name="us-east-1",
aws_secret_access_key="x",
aws_access_key_id="x",
)
mock_sqs.get_queue_by_name.assert_called_with(
QueueName="evalai_submission_queue"
)

@patch("base.utils.boto3.resource")
@patch("base.utils.settings.DEBUG", False)
@patch("base.utils.settings.TEST", False)
def test_non_debug_non_test_mode_with_challenge(self, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"
challenge = MagicMock()
challenge.use_host_sqs = True
challenge.queue_aws_region = "us-west-2"
challenge.aws_secret_access_key = "challenge_secret"
challenge.aws_access_key_id = "challenge_key"

get_or_create_sqs_queue(queue_name, challenge)

mock_boto3.assert_called_with(
"sqs",
region_name="us-west-2",
aws_secret_access_key="challenge_secret",
aws_access_key_id="challenge_key",
)
mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name)

@patch("base.utils.boto3.resource")
@patch("base.utils.settings.DEBUG", False)
@patch("base.utils.settings.TEST", False)
def test_non_debug_non_test_mode_without_challenge(self, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"

with patch.dict(
"os.environ",
{
"AWS_DEFAULT_REGION": "us-east-1",
"AWS_SECRET_ACCESS_KEY": "env_secret",
"AWS_ACCESS_KEY_ID": "env_key",
},
):
get_or_create_sqs_queue(queue_name)

mock_boto3.assert_called_with(
"sqs",
region_name="us-east-1",
aws_secret_access_key="env_secret",
aws_access_key_id="env_key",
)
mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name)

@patch("base.utils.boto3.resource")
@patch("base.utils.logger")
@patch("base.utils.settings")
def test_get_or_create_sqs_queue_exception_logging(
self, mock_settings, mock_logger, mock_boto3_resource
):
mock_settings.DEBUG = False
mock_settings.TEST = False
mock_sqs = MagicMock()
mock_boto3_resource.return_value = mock_sqs

error_response = {
"Error": {"Code": "SomeOtherError", "Message": "An error occurred"}
}
client_error = botocore.exceptions.ClientError(
error_response, "GetQueueUrl"
)

mock_sqs.get_queue_by_name.side_effect = client_error

queue_name = "test_queue"
get_or_create_sqs_queue(queue_name)

mock_logger.exception.assert_called_once_with(
"Cannot get queue: {}".format(queue_name)
)

@patch("base.utils.boto3.resource")
@patch("base.utils.settings.DEBUG", False)
@patch("base.utils.settings.TEST", False)
@patch("base.utils.logger")
def test_queue_creation_on_non_existent_queue(
self, mock_logger, mock_boto3
):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"
challenge = None
mock_sqs.get_queue_by_name.side_effect = (
botocore.exceptions.ClientError(
{"Error": {"Code": "AWS.SimpleQueueService.NonExistentQueue"}},
"GetQueueUrl",
)
)

get_or_create_sqs_queue(queue_name, challenge)

mock_sqs.create_queue.assert_called_with(
QueueName=queue_name,
Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD},
)
mock_logger.exception.assert_not_called()


class TestSendEmail(unittest.TestCase):
@patch("base.utils.sendgrid.SendGridAPIClient")
@patch("base.utils.os.environ.get")
@patch("base.utils.logger")
def test_send_email_success(
self, mock_logger, mock_get_env, mock_sendgrid_client
):
mock_get_env.return_value = "fake_api_key"
mock_sg_instance = MagicMock()
mock_sendgrid_client.return_value = mock_sg_instance

send_email(
sender="[email protected]",
recipient="[email protected]",
template_id="template_id",
template_data={"key": "value"},
)

mock_sendgrid_client.assert_called_once_with(api_key="fake_api_key")
mock_sg_instance.client.mail.send.post.assert_called_once()
mock_logger.warning.assert_not_called()

@patch("base.utils.sendgrid.SendGridAPIClient")
@patch("base.utils.os.environ.get")
@patch("base.utils.logger")
def test_send_email_exception(
self, mock_logger, mock_get_env, mock_sendgrid_client
):
# Mock environment variable
mock_get_env.return_value = "fake_api_key"
mock_sendgrid_client.side_effect = Exception("SendGrid error")

send_email(
sender="[email protected]",
recipient="[email protected]",
template_id="template_id",
template_data={"key": "value"},
)

mock_logger.warning.assert_called_once_with(
"Cannot make sendgrid call. "
"Please check if SENDGRID_API_KEY is present."
)


class TestGetBoto3Client(unittest.TestCase):
@patch("base.utils.boto3.client")
@patch("base.utils.logger")
def test_get_boto3_client_exception(self, mock_logger, mock_boto3_client):
mock_boto3_client.side_effect = Exception("Boto3 error")

aws_keys = {
"AWS_REGION": "us-west-2",
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key",
}

get_boto3_client("s3", aws_keys)
mock_logger.exception.assert_called_once()
args, kwargs = mock_logger.exception.call_args
self.assertIsInstance(args[0], Exception)
self.assertEqual(str(args[0]), "Boto3 error")


class TestDataEncoding(unittest.TestCase):
def test_encode_data_empty_list(self):
data = []
self.assertEqual(encode_data(data), [])

def test_decode_data_empty_list(self):
data = []
self.assertEqual(decode_data(data), [])