Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add base app test #4403

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/unit/base/test_apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.test import TestCase
from base.apps import BaseConfig


class BaseConfigTest(TestCase):
def test_base_config(self):
self.assertEqual(BaseConfig.name, "base")
13 changes: 13 additions & 0 deletions tests/unit/base/test_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import unittest
from unittest.mock import MagicMock
from base.management.commands.seed import Command


class TestSeedCommand(unittest.TestCase):
def test_add_arguments_adds_nc_argument(self):
parser = MagicMock()
command = Command()
command.add_arguments(parser)
parser.add_argument.assert_called_once_with(
'-nc', nargs='?', default=20, type=int, help='Number of challenges.'
)
231 changes: 227 additions & 4 deletions tests/unit/base/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from unittest import TestCase
import unittest
import requests
import responses

Expand All @@ -8,18 +10,30 @@
from django.contrib.auth.models import User
from django.core.files.uploadedfile import SimpleUploadedFile
from django.utils import timezone

import botocore
from allauth.account.models import EmailAddress
from rest_framework import status
from rest_framework.test import APITestCase, APIClient

from base.utils import RandomFileName, send_slack_notification, is_user_a_staff
from base.utils import (
RandomFileName,
decode_data,
encode_data,
send_slack_notification,
is_user_a_staff,
get_url_from_hostname,
send_email,
mock_if_non_prod_aws,
get_or_create_sqs_queue,
get_boto3_client,
)
from challenges.models import Challenge, ChallengePhase
from hosts.models import ChallengeHostTeam
from jobs.models import Submission
from participants.models import Participant, ParticipantTeam

from unittest.mock import MagicMock, patch
from scripts import seed
from settings.common import SQS_RETENTION_PERIOD


class BaseAPITestClass(APITestCase):
Expand Down Expand Up @@ -137,7 +151,6 @@ def test_if_slack_notification_works(self):


class TestUserIsStaff(BaseAPITestClass):

def test_if_user_is_staff(self):
self.user = User.objects.create(
username="someuser1",
Expand All @@ -157,3 +170,213 @@ def test_if_user_is_not_staff(self):
self.user.is_staff = False
self.user.save()
self.assertFalse(is_user_a_staff(self.user))


class TestGetUrlFromHostname(TestCase):
def test_debug_mode(self):
settings.DEBUG = True
hostname = 'example.com'
expected_url = 'http://example.com'
self.assertEqual(get_url_from_hostname(hostname), expected_url)

def test_test_mode(self):
settings.TEST = True
hostname = 'example.com'
expected_url = 'http://example.com'
self.assertEqual(get_url_from_hostname(hostname), expected_url)

def test_production_mode(self):
settings.DEBUG = False
settings.TEST = False
hostname = 'example.com'
expected_url = 'https://example.com'
self.assertEqual(get_url_from_hostname(hostname), expected_url)


class TestAwsReturnFunc(BaseAPITestClass):
def test_mock_if_non_prod_aws_returns_original_func_when_not_debug_or_test(self):
with patch('django.conf.settings.DEBUG', False):
with patch('django.conf.settings.TEST', False):
aws_mocker = MagicMock()
func = MagicMock(return_value="mocked_value")
decorated_func = mock_if_non_prod_aws(aws_mocker)(func)
result = decorated_func()
self.assertEqual(result, "mocked_value")
aws_mocker.assert_not_called()


class TestGetOrCreateSqsQueue(BaseAPITestClass):
@patch('base.utils.boto3.resource')
@patch('base.utils.settings.DEBUG', True)
@patch('base.utils.settings.TEST', False)
def test_debug_mode_queue_name(self, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"

get_or_create_sqs_queue(queue_name)

mock_boto3.assert_called_with(
"sqs",
endpoint_url="http://sqs:9324",
region_name="us-east-1",
aws_secret_access_key="x",
aws_access_key_id="x"
)
mock_sqs.get_queue_by_name.assert_called_with(QueueName="evalai_submission_queue")

@patch('base.utils.boto3.resource')
@patch('base.utils.settings.DEBUG', False)
@patch('base.utils.settings.TEST', False)
def test_non_debug_non_test_mode_with_challenge(self, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"
challenge = MagicMock()
challenge.use_host_sqs = True
challenge.queue_aws_region = "us-west-2"
challenge.aws_secret_access_key = "challenge_secret"
challenge.aws_access_key_id = "challenge_key"

get_or_create_sqs_queue(queue_name, challenge)

mock_boto3.assert_called_with(
"sqs",
region_name="us-west-2",
aws_secret_access_key="challenge_secret",
aws_access_key_id="challenge_key"
)
mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name)

@patch('base.utils.boto3.resource')
@patch('base.utils.settings.DEBUG', False)
@patch('base.utils.settings.TEST', False)
def test_non_debug_non_test_mode_without_challenge(self, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"

with patch.dict('os.environ', {
'AWS_DEFAULT_REGION': 'us-east-1',
'AWS_SECRET_ACCESS_KEY': 'env_secret',
'AWS_ACCESS_KEY_ID': 'env_key'
}):
get_or_create_sqs_queue(queue_name)

mock_boto3.assert_called_with(
"sqs",
region_name="us-east-1",
aws_secret_access_key="env_secret",
aws_access_key_id="env_key"
)
mock_sqs.get_queue_by_name.assert_called_with(QueueName=queue_name)

@patch('base.utils.boto3.resource')
@patch('base.utils.logger')
@patch('base.utils.settings')
def test_get_or_create_sqs_queue_exception_logging(self, mock_settings, mock_logger, mock_boto3_resource):
mock_settings.DEBUG = False
mock_settings.TEST = False
mock_sqs = MagicMock()
mock_boto3_resource.return_value = mock_sqs

error_response = {'Error': {'Code': 'SomeOtherError', 'Message': 'An error occurred'}}
client_error = botocore.exceptions.ClientError(error_response, 'GetQueueUrl')

mock_sqs.get_queue_by_name.side_effect = client_error

queue_name = "test_queue"
get_or_create_sqs_queue(queue_name)

mock_logger.exception.assert_called_once_with("Cannot get queue: {}".format(queue_name))

@patch('base.utils.boto3.resource')
@patch('base.utils.settings.DEBUG', False)
@patch('base.utils.settings.TEST', False)
@patch('base.utils.logger')
def test_queue_creation_on_non_existent_queue(self, mock_logger, mock_boto3):
mock_sqs = MagicMock()
mock_boto3.return_value = mock_sqs
queue_name = "test_queue"
challenge = None
mock_sqs.get_queue_by_name.side_effect = botocore.exceptions.ClientError(
{"Error": {"Code": "AWS.SimpleQueueService.NonExistentQueue"}}, "GetQueueUrl"
)

get_or_create_sqs_queue(queue_name, challenge)

mock_sqs.create_queue.assert_called_with(
QueueName=queue_name,
Attributes={"MessageRetentionPeriod": SQS_RETENTION_PERIOD}
)
mock_logger.exception.assert_not_called()


class TestSendEmail(unittest.TestCase):

@patch('base.utils.sendgrid.SendGridAPIClient')
@patch('base.utils.os.environ.get')
@patch('base.utils.logger')
def test_send_email_success(self, mock_logger, mock_get_env, mock_sendgrid_client):
mock_get_env.return_value = 'fake_api_key'
mock_sg_instance = MagicMock()
mock_sendgrid_client.return_value = mock_sg_instance

send_email(
sender='[email protected]',
recipient='[email protected]',
template_id='template_id',
template_data={'key': 'value'}
)

mock_sendgrid_client.assert_called_once_with(api_key='fake_api_key')
mock_sg_instance.client.mail.send.post.assert_called_once()
mock_logger.warning.assert_not_called()

@patch('base.utils.sendgrid.SendGridAPIClient')
@patch('base.utils.os.environ.get')
@patch('base.utils.logger')
def test_send_email_exception(self, mock_logger, mock_get_env, mock_sendgrid_client):
# Mock environment variable
mock_get_env.return_value = 'fake_api_key'
mock_sendgrid_client.side_effect = Exception('SendGrid error')

send_email(
sender='[email protected]',
recipient='[email protected]',
template_id='template_id',
template_data={'key': 'value'}
)

mock_logger.warning.assert_called_once_with(
"Cannot make sendgrid call. Please check if SENDGRID_API_KEY is present."
)


class TestGetBoto3Client(unittest.TestCase):
@patch('base.utils.boto3.client')
@patch('base.utils.logger')
def test_get_boto3_client_exception(self, mock_logger, mock_boto3_client):
mock_boto3_client.side_effect = Exception('Boto3 error')

aws_keys = {
"AWS_REGION": "us-west-2",
"AWS_ACCESS_KEY_ID": "fake_access_key_id",
"AWS_SECRET_ACCESS_KEY": "fake_secret_access_key"
}

get_boto3_client('s3', aws_keys)
mock_logger.exception.assert_called_once()
args, kwargs = mock_logger.exception.call_args
self.assertIsInstance(args[0], Exception)
self.assertEqual(str(args[0]), 'Boto3 error')


class TestDataEncoding(unittest.TestCase):
def test_encode_data_empty_list(self):
data = []
self.assertEqual(encode_data(data), [])

def test_decode_data_empty_list(self):
data = []
self.assertEqual(decode_data(data), [])