diff --git a/data_prep/project_context/context_introspector.py b/data_prep/project_context/context_introspector.py index 2f24f71b3e..1d5f05206b 100644 --- a/data_prep/project_context/context_introspector.py +++ b/data_prep/project_context/context_introspector.py @@ -162,6 +162,31 @@ def _get_function_implementation(self) -> str: return function_source + def _get_macro_block(self) -> list[str]: + """Queries FI to check macro block around this function.""" + project = self._benchmark.project + if not self._benchmark.function_dict: + return [] + + source = self._benchmark.function_dict.get('function_filename', '') + start = self._benchmark.function_dict.get('source_line_begin', 0) + end = self._benchmark.function_dict.get('source_end_begin', 99999) + macro_block = introspector.query_introspector_macro_block( + project, source, start, end) + + results = set() + for macro in macro_block: + conditions = macro.get('conditions', []) + for condition in conditions: + cond_type = condition.get('type') + cond_detail = condition.get('condition') + if not cond_type or not cond_detail: + continue + + results.add(f'#{cond_type} {cond_detail}') + + return list(results) + def _get_xrefs_to_function(self) -> list[str]: """Queries FI for function being fuzzed.""" project = self._benchmark.project @@ -177,6 +202,7 @@ def _get_xrefs_to_function(self) -> list[str]: def get_context_info(self) -> dict: """Retrieves contextual information and stores them in a dictionary.""" xrefs = self._get_xrefs_to_function() + macro = self._get_macro_block() func_source = self._get_function_implementation() files = self._get_files_to_include() decl = self._get_embeddable_declaration() @@ -184,6 +210,7 @@ def get_context_info(self) -> dict: context_info = { 'xrefs': xrefs, + 'macro_block': macro, 'func_source': func_source, 'files': files, 'decl': decl, diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index f18c633307..48efecf356 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -203,6 +203,7 @@ def format_context(self, context_info: dict) -> str: must_insert=context_info['decl'], func_source=context_info['func_source'], xrefs='\n'.join(context_info['xrefs']), + macro='\n'.join(context_info['macro_block']), include_statement=context_info['header'], ) diff --git a/prompts/template_xml/context.txt b/prompts/template_xml/context.txt index 0a578bb317..f34eddbe88 100644 --- a/prompts/template_xml/context.txt +++ b/prompts/template_xml/context.txt @@ -25,3 +25,12 @@ Here is the source code for functions which reference the function being tested: {{ xrefs }} {% endif %} +{% if macro %} + +Here are all the macro conditions wrapped around or present in the target functions: + +{{ macro }} + +{% endif %} + +