Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
from unittest.mock import MagicMock


class AsyncMock(MagicMock):
async def __call__(self, *args, **kwargs):
return super(AsyncMock, self).__call__(*args, **kwargs)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datetime import datetime, timezone
from base64 import b64encode
from unittest import mock
from unittest.mock import AsyncMock
from aiohttp_apispec import validation_middleware
from aiohttp import web
import aiohttp_jinja2
Expand Down Expand Up @@ -64,7 +65,6 @@
from app.api.rest_api import RestApi

from app import version
from tests import AsyncMock

DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_DIR = os.path.join(DIR, '..', 'conf')
Expand Down
19 changes: 11 additions & 8 deletions tests/planners/test_atomic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

from tests import AsyncMock

import pytest
from unittest.mock import AsyncMock

from app.planners.atomic import LogicalPlanner

Expand All @@ -15,7 +14,10 @@ def __init__(self):
self.adversary = AdversaryStub()
self.agents = ['agent_1']
self.wait_for_links_completion = AsyncMock()
self.apply = AsyncMock()
self.apply = AsyncMock(side_effect=self._apply_side_effect)

def _apply_side_effect(self, value):
return value.id


class PlanningSvcStub():
Expand All @@ -34,9 +36,10 @@ def __init__(self, ability_id):
class LinkStub():
def __init__(self, ability_id):
self.ability = AbilityStub(ability_id)
self.id = 'link_' + ability_id

def __eq__(self, other):
return self.ability.ability_id == other.ability.ability_id
return self.ability.ability_id == other.ability.ability_id and self.id == other.id


@pytest.fixture
Expand Down Expand Up @@ -64,8 +67,8 @@ def test_atomic_with_links_in_order(self, event_loop, atomic_planner):

assert atomic_planner.operation.apply.call_count == 1
assert atomic_planner.operation.wait_for_links_completion.call_count == 1
atomic_planner.operation.apply.assert_called_with(LinkStub('ability_b'))
atomic_planner.operation.wait_for_links_completion.assert_called_with([LinkStub('ability_b')])
atomic_planner.operation.apply.assert_awaited_with(LinkStub('ability_b'))
atomic_planner.operation.wait_for_links_completion.assert_awaited_with(['link_ability_b'])

def test_atomic_with_links_out_of_order(self, event_loop, atomic_planner):

Expand All @@ -80,8 +83,8 @@ def test_atomic_with_links_out_of_order(self, event_loop, atomic_planner):

assert atomic_planner.operation.apply.call_count == 1
assert atomic_planner.operation.wait_for_links_completion.call_count == 1
atomic_planner.operation.apply.assert_called_with(LinkStub('ability_b'))
atomic_planner.operation.wait_for_links_completion.assert_called_with([LinkStub('ability_b')])
atomic_planner.operation.apply.assert_awaited_with(LinkStub('ability_b'))
atomic_planner.operation.wait_for_links_completion.assert_awaited_with(['link_ability_b'])

def test_atomic_no_links(self, event_loop, atomic_planner):

Expand Down
2 changes: 1 addition & 1 deletion tests/services/test_file_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import yaml

from base64 import b64encode
from tests import AsyncMock
from asyncio import Future
from unittest.mock import AsyncMock

from app.data_encoders.base64_basic import Base64Encoder
from app.data_encoders.plain_text import PlainTextEncoder
Expand Down
3 changes: 1 addition & 2 deletions tests/services/test_planning_svc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import asyncio
import base64
from unittest.mock import MagicMock
from unittest.mock import MagicMock, AsyncMock

from app.objects.c_adversary import Adversary
from app.objects.c_obfuscator import Obfuscator
Expand All @@ -11,7 +11,6 @@
from app.objects.secondclass.c_fact import Fact
from app.objects.secondclass.c_requirement import Requirement
from app.utility.base_world import BaseWorld
from tests import AsyncMock


stop_bucket_exhaustion_params = [
Expand Down