From f119d972e8b4cc2889a0c65ee438ec8dda9cb72c Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Fri, 9 May 2025 18:53:35 +0000 Subject: [PATCH] xref_context: Add test xref in project context Signed-off-by: Arthur Chan --- data_prep/introspector.py | 30 ++++++++++++++++++- .../project_context/context_introspector.py | 15 ++++++++++ llm_toolkit/prompt_builder.py | 1 + prompts/template_xml/context.txt | 5 ++++ 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/data_prep/introspector.py b/data_prep/introspector.py index 331c71f37e..a85165cc8b 100755 --- a/data_prep/introspector.py +++ b/data_prep/introspector.py @@ -60,6 +60,7 @@ INTROSPECTOR_ORACLE_ALL_PUBLIC_CANDIDATES = '' INTROSPECTOR_ORACLE_OPTIMAL = '' INTROSPECTOR_ORACLE_ALL_TESTS = '' +INTROSPECTOR_ORACLE_ALL_TESTS_XREF = '' INTROSPECTOR_FUNCTION_SOURCE = '' INTROSPECTOR_PROJECT_SOURCE = '' INTROSPECTOR_XREF = '' @@ -108,7 +109,8 @@ def set_introspector_endpoints(endpoint): INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE, \ INTROSPECTOR_ORACLE_ALL_TESTS, INTROSPECTOR_JVM_PROPERTIES, \ INTROSPECTOR_TEST_SOURCE, INTROSPECTOR_HARNESS_SOURCE_AND_EXEC, \ - INTROSPECTOR_JVM_PUBLIC_CLASSES, INTROSPECTOR_LANGUAGE_STATS + INTROSPECTOR_JVM_PUBLIC_CLASSES, INTROSPECTOR_LANGUAGE_STATS, \ + INTROSPECTOR_ORACLE_ALL_TESTS_XREF INTROSPECTOR_ENDPOINT = endpoint @@ -141,6 +143,8 @@ def set_introspector_endpoints(endpoint): INTROSPECTOR_FUNCTION_WITH_MATCHING_RETURN_TYPE = ( f'{INTROSPECTOR_ENDPOINT}/function-with-matching-return-type') INTROSPECTOR_ORACLE_ALL_TESTS = f'{INTROSPECTOR_ENDPOINT}/project-tests' + INTROSPECTOR_ORACLE_ALL_TESTS_XREF = ( + f'{INTROSPECTOR_ENDPOINT}/project-tests-for-functions') INTROSPECTOR_JVM_PROPERTIES = f'{INTROSPECTOR_ENDPOINT}/jvm-method-properties' INTROSPECTOR_HARNESS_SOURCE_AND_EXEC = ( f'{INTROSPECTOR_ENDPOINT}/harness-source-and-executable') @@ -232,6 +236,30 @@ def query_introspector_for_tests(project: str) -> list[str]: return _get_data(resp, 'test-file-list', []) +def query_introspector_for_tests_xref( + project: str, functions: Optional[list[str]]) -> list[str]: + """Gets the list of functions and xref test files in the target project.""" + data = {'project': project} + if functions: + data['functions'] = ','.join(functions) + + resp = _query_introspector(INTROSPECTOR_ORACLE_ALL_TESTS_XREF, data) + + test_files = _get_data(resp, 'test-files', {}) + + handled = set() + result_list = [] + for test_paths in test_files.values(): + for test_path in test_paths: + if test_path in handled: + continue + + handled.add(test_path) + result_list.append(query_introspector_test_source(project, test_path)) + + return result_list + + def query_introspector_for_harness_intrinsics( project: str) -> list[dict[str, str]]: """Gets the list of test files in the target project.""" diff --git a/data_prep/project_context/context_introspector.py b/data_prep/project_context/context_introspector.py index 2f24f71b3e..b8ee498238 100644 --- a/data_prep/project_context/context_introspector.py +++ b/data_prep/project_context/context_introspector.py @@ -174,9 +174,23 @@ def _get_xrefs_to_function(self) -> list[str]: 'function_signature: %s', project, func_sig) return xrefs + def _get_test_xrefs_to_function(self) -> list[str]: + """Queries FI for test source calling the function being fuzzed.""" + project = self._benchmark.project + func_name = self._benchmark.function_name + xrefs = introspector.query_introspector_for_tests_xref(project, [func_name]) + + if not xrefs: + logging.warning( + 'Could not retrieve tests xrefs for project: %s ' + 'function_signature: %s', project, func_name) + + return xrefs + def get_context_info(self) -> dict: """Retrieves contextual information and stores them in a dictionary.""" xrefs = self._get_xrefs_to_function() + test_xrefs = self._get_test_xrefs_to_function() func_source = self._get_function_implementation() files = self._get_files_to_include() decl = self._get_embeddable_declaration() @@ -188,6 +202,7 @@ def get_context_info(self) -> dict: 'files': files, 'decl': decl, 'header': header, + 'test_xrefs': test_xrefs, } logging.info('Context: %s', context_info) diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index 930392bb84..52898a0153 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -204,6 +204,7 @@ def format_context(self, context_info: dict) -> str: func_source=context_info['func_source'], xrefs='\n'.join(context_info['xrefs']), include_statement=context_info['header'], + tests_xrefs='\n'.join(context_info['tests_xrefs']), ) def _select_examples(self, examples: list[list], diff --git a/prompts/template_xml/context.txt b/prompts/template_xml/context.txt index 0a578bb317..ff6981f2f7 100644 --- a/prompts/template_xml/context.txt +++ b/prompts/template_xml/context.txt @@ -24,4 +24,9 @@ Here is the source code for functions which reference the function being tested: {{ xrefs }} + +Here is the source code for the tests/examples that reference the function being tested: + +{{ tests_xrefs }} + {% endif %}