diff --git a/src/scheduler/analysis/scheduler.py b/src/scheduler/analysis/scheduler.py index 918250790..211861f45 100644 --- a/src/scheduler/analysis/scheduler.py +++ b/src/scheduler/analysis/scheduler.py @@ -105,6 +105,7 @@ def __init__( post_analysis: Optional[Callable[[str, str, dict], None]] = None, db_interface=None, unpacking_locks: UnpackingLockManager | None = None, + status: AnalysisStatus | None = None, ): self.analysis_plugins = {} self._plugin_runners = {} @@ -115,7 +116,7 @@ def __init__( self.unpacking_locks = unpacking_locks self.scheduling_lock = Lock() - self.status = AnalysisStatus() + self.status = status or AnalysisStatus() self.task_scheduler = AnalysisTaskScheduler(self.analysis_plugins) self.schedule_processes = [] self.result_collector_processes = [] diff --git a/src/scheduler/analysis_status.py b/src/scheduler/analysis_status.py index 6a4f56a98..f350b9ec0 100644 --- a/src/scheduler/analysis_status.py +++ b/src/scheduler/analysis_status.py @@ -12,7 +12,6 @@ import config from helperFunctions.process import stop_process -from objects.firmware import Firmware from storage.redis_status_interface import RedisStatusInterface if TYPE_CHECKING: @@ -45,38 +44,39 @@ def shutdown(self): self._worker.shutdown() self._manager.shutdown() - def add_update(self, fw_object: Firmware | FileObject, included_files: list[str] | set[str]): - self.add_object(fw_object) + def add_update(self, fw_object: FileObject, included_files: list[str] | set[str]): + # normally, status is initialized during unpacking, but since unpacking is skipped for updates, we need to + # init it here first before adding the object + self.init_firmware(fw_object) self._worker.queue.put((_UpdateType.add_update, fw_object.uid, included_files)) - def add_object(self, fw_object: Firmware | FileObject): - if isinstance(fw_object, Firmware): - self._worker.queue.put( - ( - _UpdateType.add_firmware, - fw_object.uid, - fw_object.files_included, - fw_object.get_hid(), - fw_object.scheduled_analysis, - ) + def init_firmware(self, fw_object: FileObject): + self._worker.queue.put( + ( + _UpdateType.add_firmware, + fw_object.uid, + fw_object.get_hid(), + fw_object.scheduled_analysis, ) - else: - self._worker.queue.put( - ( - _UpdateType.add_file, - fw_object.uid, - fw_object.root_uid, - fw_object.files_included, - ) + ) + + def add_object(self, fw_object: FileObject): + self._worker.queue.put( + ( + _UpdateType.add_file, + fw_object.uid, + fw_object.root_uid, + fw_object.files_included, ) + ) def add_analysis(self, fw_object: FileObject, plugin: str): self._worker.queue.put((_UpdateType.add_analysis, fw_object.root_uid, plugin)) - def remove_object(self, fw_object: Firmware | FileObject): + def remove_object(self, fw_object: FileObject): self._worker.queue.put((_UpdateType.remove_file, fw_object.uid, fw_object.root_uid)) - def fw_analysis_is_in_progress(self, fw_object: Firmware | FileObject) -> bool: + def fw_analysis_is_in_progress(self, fw_object: FileObject) -> bool: return fw_object.root_uid in self._currently_analyzed or fw_object.uid in self._currently_analyzed def cancel_analysis(self, root_uid: str): @@ -93,7 +93,7 @@ class FwAnalysisStatus: start_time: float = field(default_factory=time) completed_files: set[str] = field(default_factory=set) total_files_with_duplicates: int = 1 - unpacked_files_count: int = 1 + unpacked_files_count: int = 0 analyzed_files_count: int = 0 @@ -109,14 +109,16 @@ def __init__(self, currently_analyzed_fw: dict): self.redis = RedisStatusInterface() def start(self): - self._running.value = 1 - self._worker_process = Process(target=self._worker_loop) - self._worker_process.start() + if self._running.value == 0: + self._running.value = 1 + self._worker_process = Process(target=self._worker_loop) + self._worker_process.start() def shutdown(self): - self._running.value = 0 - if self._worker_process is not None: - stop_process(self._worker_process, timeout=10) + if self._running.value == 1: + self._running.value = 0 + if self._worker_process is not None: + stop_process(self._worker_process, timeout=10) def _worker_loop(self): logging.debug(f'starting analysis status worker (pid: {os.getpid()})') @@ -157,11 +159,11 @@ def _add_update(self, fw_uid: str, included_files: set[str]): status.total_files_with_duplicates = file_count status.files_to_analyze = {fw_uid, *included_files} - def _add_firmware(self, uid: str, included_files: set[str], hid: str, scheduled_analyses: list[str] | None): + def _add_firmware(self, uid: str, hid: str, scheduled_analyses: list[str] | None): self.currently_running[uid] = FwAnalysisStatus( - files_to_unpack=set(included_files), + files_to_unpack={uid}, files_to_analyze={uid}, - total_files_count=1 + len(included_files), + total_files_count=1, hid=hid, analysis_plugins={p: 0 for p in scheduled_analyses or []}, ) diff --git a/src/scheduler/unpacking_scheduler.py b/src/scheduler/unpacking_scheduler.py index a42e20ddb..e94a47914 100644 --- a/src/scheduler/unpacking_scheduler.py +++ b/src/scheduler/unpacking_scheduler.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from objects.file import FileObject + from scheduler.analysis_status import AnalysisStatus class NoFreeWorker(RuntimeError): # noqa: N818 @@ -51,6 +52,7 @@ def __init__( fs_organizer=None, unpacking_locks=None, db_interface=BackendDbInterface, + status: AnalysisStatus | None = None, ): self.stop_condition = Value('i', 0) self.throttle_condition = Value('i', 0) @@ -62,6 +64,7 @@ def __init__( self.post_unpack = post_unpack self.unpacking_locks = unpacking_locks self.unpacker = Unpacker(fs_organizer=fs_organizer, unpacking_locks=unpacking_locks) + self.status = status self.manager = None self.workers = None @@ -125,6 +128,8 @@ def add_task(self, fw: Firmware): schedule a firmware_object for unpacking """ fw.root_uid = fw.uid # make sure the root_uid is set correctly for unpacking and analysis scheduling + if self.status is not None: + self.status.init_firmware(fw) # initialize unpacking and analysis progress tracking self.in_queue.put(fw) def get_scheduled_workload(self): diff --git a/src/start_fact_backend.py b/src/start_fact_backend.py index 77f2324f4..36f247760 100755 --- a/src/start_fact_backend.py +++ b/src/start_fact_backend.py @@ -58,6 +58,7 @@ def __init__(self): post_unpack=self.analysis_service.start_analysis_of_object, analysis_workload=self.analysis_service.get_combined_analysis_workload, unpacking_locks=self.unpacking_lock_manager, + status=self.analysis_service.status, ) self.compare_service = ComparisonScheduler() self.intercom = InterComBackEndBinding( diff --git a/src/test/conftest.py b/src/test/conftest.py index 596cf82b5..3e1ef2738 100644 --- a/src/test/conftest.py +++ b/src/test/conftest.py @@ -9,6 +9,7 @@ import config from scheduler.analysis import AnalysisScheduler +from scheduler.analysis_status import AnalysisStatus from scheduler.comparison_scheduler import ComparisonScheduler from scheduler.unpacking_scheduler import UnpackingScheduler from storage.db_connection import ReadOnlyConnection, ReadWriteConnection @@ -204,6 +205,14 @@ def database_interfaces(_database_interfaces) -> DatabaseInterfaces: _database_interfaces.admin.intercom.deleted_files.clear() +@pytest.fixture +def analysis_status(): + status = AnalysisStatus() + status.start() + yield status + status.shutdown() + + @pytest.fixture def common_db(database_interfaces) -> DbInterfaceCommon: """Convenience fixture. Equivalent to ``database_interfaces.common``.""" @@ -313,6 +322,7 @@ def analysis_scheduler( # noqa: PLR0913 analysis_finished_counter, unpacking_lock_manager, test_config, + analysis_status, monkeypatch, ) -> AnalysisScheduler: """Returns an instance of :py:class:`~scheduler.analysis.AnalysisScheduler`. @@ -325,6 +335,7 @@ def analysis_scheduler( # noqa: PLR0913 _analysis_scheduler = AnalysisScheduler( post_analysis=lambda *_: None, unpacking_locks=unpacking_lock_manager, + status=analysis_status, ) fs_organizer = test_config.fs_organizer_class() @@ -396,6 +407,7 @@ def unpacking_scheduler( test_config, unpacking_finished_event, unpacking_finished_counter, + analysis_status, ) -> UnpackingScheduler: """Returns an instance of :py:class:`~scheduler.unpacking_scheduler.UnpackingScheduler`. The scheduler has some extra testing features. See :py:class:`SchedulerTestConfig` for the features. @@ -418,6 +430,7 @@ def _post_unpack_hook(fw): fs_organizer=fs_organizer, unpacking_locks=unpacking_lock_manager, db_interface=test_config.backend_db_class, + status=analysis_status, ) add_task = _unpacking_scheduler.add_task diff --git a/src/test/unit/scheduler/test_analysis_status.py b/src/test/unit/scheduler/test_analysis_status.py index ad54b2c22..42b50638a 100644 --- a/src/test/unit/scheduler/test_analysis_status.py +++ b/src/test/unit/scheduler/test_analysis_status.py @@ -14,9 +14,22 @@ class TestAnalysisStatus: def setup_method(self): self.status = AnalysisStatus() - def test_add_firmware_to_current_analyses(self): + def test_init_firmware(self): fw = Firmware(binary=b'foo') fw.files_included = ['foo', 'bar'] + self.status.init_firmware(fw) + self.status._worker._update_status() + + assert fw.uid in self.status._worker.currently_running + result = self.status._worker.currently_running[fw.uid] + assert result.files_to_unpack == {fw.uid} + assert result.files_to_analyze == {fw.uid} + assert result.completed_files == set() + assert result.unpacked_files_count == 0 + assert result.analyzed_files_count == 0 + assert result.total_files_count == 1 + + # after unpacking, the file is added again with add_object to add the included files self.status.add_object(fw) self.status._worker._update_status() @@ -38,6 +51,7 @@ def test_add_file_to_current_analyses(self): hid='', total_files_count=2, total_files_with_duplicates=2, + unpacked_files_count=1, ) } fo = FileObject(binary=b'foo')