diff --git a/src/helperFunctions/yara_binary_search.py b/src/helperFunctions/yara_binary_search.py index d0cf52b8d..1d27451bb 100644 --- a/src/helperFunctions/yara_binary_search.py +++ b/src/helperFunctions/yara_binary_search.py @@ -42,13 +42,12 @@ def _execute_yara_search(self, rule_file_path: str | Path, target_path: str | Pa yara_process = subprocess.run(command, shell=True, stdout=PIPE, stderr=STDOUT, text=True, check=False) return yara_process.stdout - def _execute_yara_search_for_single_firmware(self, rule_file_path: str, firmware_uid: str) -> str: - file_paths = self._get_file_paths_of_files_included_in_fw(firmware_uid) - result = (self._execute_yara_search(rule_file_path, path) for path in file_paths) - return '\n'.join(result) - - def _get_file_paths_of_files_included_in_fw(self, fw_uid: str) -> list[str]: - return [self.fs_organizer.generate_path_from_uid(uid) for uid in self.db.get_all_files_in_fw(fw_uid)] + def _do_yara_search_on_incl_files(self, rule_file_path: str, uid: str) -> str: + file_paths = [ + self.fs_organizer.generate_path_from_uid(included_uid) + for included_uid in self.db.get_list_of_all_included_files(uid) + ] + return '\n'.join(self._execute_yara_search(rule_file_path, path) for path in file_paths) @staticmethod def _parse_raw_result(raw_result: str) -> dict[str, dict[str, list[dict]]]: @@ -108,21 +107,21 @@ def get_binary_search_result(self, task: tuple[bytes, str | None]) -> dict[str, :return: dict of matching rules with lists of (unique) matched UIDs as values or an error message. """ with NamedTemporaryFile() as temp_rule_file: - yara_rules, firmware_uid = task + yara_rules, uid = task try: self._prepare_temp_rule_file(temp_rule_file, yara_rules) - raw_result = self._get_raw_result(firmware_uid, temp_rule_file) + raw_result = self._get_raw_result(uid, temp_rule_file) return self._parse_raw_result(raw_result) except yara.SyntaxError as yara_error: return f'There seems to be an error in the rule file:\n{yara_error}' except CalledProcessError as process_error: return f'Error when calling YARA:\n{process_error.output.decode()}' - def _get_raw_result(self, firmware_uid: str | None, temp_rule_file: NamedTemporaryFile) -> str: - if firmware_uid is None: + def _get_raw_result(self, uid: str | None, temp_rule_file: NamedTemporaryFile) -> str: + if uid is None: raw_result = self._execute_yara_search(temp_rule_file.name) else: - raw_result = self._execute_yara_search_for_single_firmware(temp_rule_file.name, firmware_uid) + raw_result = self._do_yara_search_on_incl_files(temp_rule_file.name, uid) return raw_result @staticmethod diff --git a/src/storage/db_interface_common.py b/src/storage/db_interface_common.py index 2f6228223..4741d9c5e 100644 --- a/src/storage/db_interface_common.py +++ b/src/storage/db_interface_common.py @@ -283,10 +283,14 @@ def _generate_file_tree_path(self, uid: str, child_to_parents: dict[str, set[str # ===== included files. ===== - def get_list_of_all_included_files(self, fo: FileObject) -> set[str]: - if isinstance(fo, Firmware): - return self.get_all_files_in_fw(fo.uid) - return self.get_all_files_in_fo(fo) + def get_list_of_all_included_files(self, fo_or_uid: FileObject | str) -> set[str]: + if isinstance(fo_or_uid, str): + uid = fo_or_uid + is_fw = self.is_firmware(uid) + else: + uid = fo_or_uid.uid + is_fw = isinstance(fo_or_uid, Firmware) + return self.get_all_files_in_fw(uid) if is_fw else self.get_all_files_in_fo(uid) def get_all_files_in_fw(self, fw_uid: str) -> set[str]: """Get a set of UIDs of all files (recursively) contained in a firmware""" @@ -294,19 +298,20 @@ def get_all_files_in_fw(self, fw_uid: str) -> set[str]: query = select(fw_files_table.c.file_uid).where(fw_files_table.c.root_uid == fw_uid) return set(session.execute(query).scalars()) - def get_all_files_in_fo(self, fo: FileObject) -> set[str]: + def get_all_files_in_fo(self, uid: str) -> set[str]: """Get a set of UIDs of all files (recursively) contained in a file""" with self.get_read_only_session() as session: - return self._get_files_in_files(session, fo.files_included).union({fo.uid, *fo.files_included}) - - def _get_files_in_files(self, session, uid_set: set[str], recursive: bool = True) -> set[str]: - if not uid_set: - return set() - query = select(FileObjectEntry).filter(FileObjectEntry.uid.in_(uid_set)) - included_files = {child.uid for fo in session.execute(query).scalars() for child in fo.included_files} - if recursive and included_files: - included_files.update(self._get_files_in_files(session, included_files)) - return included_files + # recursive query for included files + bottom_query = ( + select(included_files_table).where(included_files_table.c.parent_uid == uid).cte(recursive=True) + ) + parent_table = aliased(included_files_table) + recursive_parent_child_query = select(parent_table).join( + bottom_query, parent_table.c.parent_uid == bottom_query.c.child_uid + ) + final_query = bottom_query.union_all(recursive_parent_child_query) + included_files = set(session.execute(select(final_query.c.child_uid)).scalars()) + return included_files.union({uid}) # ===== summary ===== diff --git a/src/test/common_helper.py b/src/test/common_helper.py index 400f51adb..ef53ceaaf 100644 --- a/src/test/common_helper.py +++ b/src/test/common_helper.py @@ -199,7 +199,7 @@ def get_number_of_total_matches(self, *_, **__): return 10 def exists(self, uid): - return uid in (self.fw_uid, self.fo_uid, self.fw2_uid, 'error') + return uid in (self.fw_uid, self.fo_uid, self.fw2_uid, 'error', 'uid_in_db') def uid_list_exists(self, uid_list): return set() diff --git a/src/test/integration/storage/test_db_interface_common.py b/src/test/integration/storage/test_db_interface_common.py index 2e360e5dd..7d364fe59 100644 --- a/src/test/integration/storage/test_db_interface_common.py +++ b/src/test/integration/storage/test_db_interface_common.py @@ -100,8 +100,18 @@ def test_all_files_in_fw(backend_db, common_db): def test_all_files_in_fo(backend_db, common_db): fw, parent_fo, child_fo = create_fw_with_parent_and_child() backend_db.insert_multiple_objects(fw, parent_fo, child_fo) - assert common_db.get_all_files_in_fo(fw) == {fw.uid, parent_fo.uid, child_fo.uid} - assert common_db.get_all_files_in_fo(parent_fo) == {parent_fo.uid, child_fo.uid} + assert common_db.get_all_files_in_fo(fw.uid) == {fw.uid, parent_fo.uid, child_fo.uid} + assert common_db.get_all_files_in_fo(parent_fo.uid) == {parent_fo.uid, child_fo.uid} + + +def test_get_list_of_all_included_files(backend_db, common_db): + fw, parent_fo, child_fo = create_fw_with_parent_and_child() + backend_db.insert_multiple_objects(fw, parent_fo, child_fo) + expected_result = {parent_fo.uid, child_fo.uid} + assert common_db.get_list_of_all_included_files(fw) == expected_result + assert common_db.get_list_of_all_included_files(fw.uid) == expected_result + assert common_db.get_list_of_all_included_files(parent_fo) == expected_result + assert common_db.get_list_of_all_included_files(parent_fo.uid) == expected_result def test_get_objects_by_uid_list(backend_db, common_db): diff --git a/src/test/unit/helperFunctions/test_yara_binary_search.py b/src/test/unit/helperFunctions/test_yara_binary_search.py index a4c033573..a4f2b5856 100644 --- a/src/test/unit/helperFunctions/test_yara_binary_search.py +++ b/src/test/unit/helperFunctions/test_yara_binary_search.py @@ -1,5 +1,4 @@ import unittest -from pathlib import Path from subprocess import CalledProcessError from unittest import mock from unittest.mock import patch @@ -17,7 +16,7 @@ class MockCommonDbInterface: @staticmethod - def get_all_files_in_fw(uid): + def get_list_of_all_included_files(uid): if uid == 'single_firmware': return [TEST_FILE_2, TEST_FILE_3] return [] @@ -107,9 +106,3 @@ def test_execute_yara_search_for_single_file(self): target_path=get_test_data_dir() / TEST_FILE_1 / TEST_FILE_1, ) assert 'test_rule' in result - - def test_get_file_paths_of_files_included_in_fo(self): - result = self.yara_binary_scanner._get_file_paths_of_files_included_in_fw('single_firmware') - assert len(result) == 2 - assert Path(result[0]).name == TEST_FILE_2 - assert Path(result[1]).name == TEST_FILE_3 diff --git a/src/web_interface/components/database_routes.py b/src/web_interface/components/database_routes.py index bb4d00901..967f3e1ac 100644 --- a/src/web_interface/components/database_routes.py +++ b/src/web_interface/components/database_routes.py @@ -274,13 +274,13 @@ def show_advanced_search(self, error=None): def start_binary_search(self): error = None if request.method == 'POST': - yara_rule_file, firmware_uid, only_firmware = self._get_items_from_binary_search_request(request) - if firmware_uid and not self._firmware_is_in_db(firmware_uid): - error = f'Error: Firmware with UID {firmware_uid!r} not found in database' + yara_rule_file, uid, only_firmware = self._get_items_from_binary_search_request(request) + if uid and not self.db.frontend.exists(uid): + error = f'Error: File with UID {uid!r} not found in database' elif yara_rule_file is not None: yara_error = self.intercom.get_yara_error(yara_rule_file) if yara_error == '': - request_id = self.intercom.add_binary_search_request(yara_rule_file, firmware_uid) + request_id = self.intercom.add_binary_search_request(yara_rule_file, uid) return redirect( url_for('get_binary_search_results', request_id=request_id, only_firmware=only_firmware) ) @@ -300,9 +300,6 @@ def _get_items_from_binary_search_request(req): only_firmware = req.form.get('only_firmware') is not None return yara_rule_file, firmware_uid, only_firmware - def _firmware_is_in_db(self, firmware_uid: str) -> bool: - return self.db.frontend.is_firmware(firmware_uid) - @roles_accepted(*PRIVILEGES['pattern_search']) @AppRoute('/database/binary_search_results', GET) def get_binary_search_results(self): diff --git a/src/web_interface/templates/show_analysis/button_groups.j2 b/src/web_interface/templates/show_analysis/button_groups.j2 index ccc5a9097..929cbd7ce 100644 --- a/src/web_interface/templates/show_analysis/button_groups.j2 +++ b/src/web_interface/templates/show_analysis/button_groups.j2 @@ -38,8 +38,8 @@ {% endif %} {% if firmware.vendor %} {{ button_tooltip('Update analysis', 'update-button', '/update-analysis/', 'redo-alt') }} - {{ button_tooltip('YARA search', 'yara-button', '/database/binary_search?firmware_uid=', 'search') }} {% endif %} + {{ button_tooltip('YARA search', 'yara-button', '/database/binary_search?firmware_uid=', 'search') }} {% if firmware.vendor and user_has_admin_clearance %}