diff --git a/tests/__init__.py b/tests/__init__.py index 41cf6b56e..e69de29bb 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +0,0 @@ -from unittest.mock import MagicMock - - -class AsyncMock(MagicMock): - async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index c83aca648..5ea34a25b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from datetime import datetime, timezone from base64 import b64encode from unittest import mock +from unittest.mock import AsyncMock from aiohttp_apispec import validation_middleware from aiohttp import web import aiohttp_jinja2 @@ -64,7 +65,6 @@ from app.api.rest_api import RestApi from app import version -from tests import AsyncMock DIR = os.path.dirname(os.path.abspath(__file__)) CONFIG_DIR = os.path.join(DIR, '..', 'conf') diff --git a/tests/planners/test_atomic.py b/tests/planners/test_atomic.py index bd481fa69..3a9adbdb1 100644 --- a/tests/planners/test_atomic.py +++ b/tests/planners/test_atomic.py @@ -1,7 +1,6 @@ -from tests import AsyncMock - import pytest +from unittest.mock import AsyncMock from app.planners.atomic import LogicalPlanner @@ -15,7 +14,10 @@ def __init__(self): self.adversary = AdversaryStub() self.agents = ['agent_1'] self.wait_for_links_completion = AsyncMock() - self.apply = AsyncMock() + self.apply = AsyncMock(side_effect=self._apply_side_effect) + + def _apply_side_effect(self, value): + return value.id class PlanningSvcStub(): @@ -34,9 +36,10 @@ def __init__(self, ability_id): class LinkStub(): def __init__(self, ability_id): self.ability = AbilityStub(ability_id) + self.id = 'link_' + ability_id def __eq__(self, other): - return self.ability.ability_id == other.ability.ability_id + return self.ability.ability_id == other.ability.ability_id and self.id == other.id @pytest.fixture @@ -64,8 +67,8 @@ def test_atomic_with_links_in_order(self, event_loop, atomic_planner): assert atomic_planner.operation.apply.call_count == 1 assert atomic_planner.operation.wait_for_links_completion.call_count == 1 - atomic_planner.operation.apply.assert_called_with(LinkStub('ability_b')) - atomic_planner.operation.wait_for_links_completion.assert_called_with([LinkStub('ability_b')]) + atomic_planner.operation.apply.assert_awaited_with(LinkStub('ability_b')) + atomic_planner.operation.wait_for_links_completion.assert_awaited_with(['link_ability_b']) def test_atomic_with_links_out_of_order(self, event_loop, atomic_planner): @@ -80,8 +83,8 @@ def test_atomic_with_links_out_of_order(self, event_loop, atomic_planner): assert atomic_planner.operation.apply.call_count == 1 assert atomic_planner.operation.wait_for_links_completion.call_count == 1 - atomic_planner.operation.apply.assert_called_with(LinkStub('ability_b')) - atomic_planner.operation.wait_for_links_completion.assert_called_with([LinkStub('ability_b')]) + atomic_planner.operation.apply.assert_awaited_with(LinkStub('ability_b')) + atomic_planner.operation.wait_for_links_completion.assert_awaited_with(['link_ability_b']) def test_atomic_no_links(self, event_loop, atomic_planner): diff --git a/tests/services/test_file_svc.py b/tests/services/test_file_svc.py index fe7c6c16b..5aecb7c84 100644 --- a/tests/services/test_file_svc.py +++ b/tests/services/test_file_svc.py @@ -5,8 +5,8 @@ import yaml from base64 import b64encode -from tests import AsyncMock from asyncio import Future +from unittest.mock import AsyncMock from app.data_encoders.base64_basic import Base64Encoder from app.data_encoders.plain_text import PlainTextEncoder diff --git a/tests/services/test_planning_svc.py b/tests/services/test_planning_svc.py index 655b3de16..1b2f7af4a 100644 --- a/tests/services/test_planning_svc.py +++ b/tests/services/test_planning_svc.py @@ -1,7 +1,7 @@ import pytest import asyncio import base64 -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock from app.objects.c_adversary import Adversary from app.objects.c_obfuscator import Obfuscator @@ -11,7 +11,6 @@ from app.objects.secondclass.c_fact import Fact from app.objects.secondclass.c_requirement import Requirement from app.utility.base_world import BaseWorld -from tests import AsyncMock stop_bucket_exhaustion_params = [