4
4
import logging
5
5
import os
6
6
from difflib import SequenceMatcher
7
- from typing import Any
7
+ from typing import Any , Optional
8
8
9
9
from data_prep import introspector
10
10
from experiment import benchmark as benchmarklib
@@ -165,12 +165,14 @@ def get_context_info(self) -> dict:
165
165
func_source = self ._get_function_implementation ()
166
166
files = self ._get_files_to_include ()
167
167
decl = self ._get_embeddable_declaration ()
168
+ header = self .get_prefixed_header_file ()
168
169
169
170
context_info = {
170
171
'xrefs' : xrefs ,
171
172
'func_source' : func_source ,
172
173
'files' : files ,
173
- 'decl' : decl
174
+ 'decl' : decl ,
175
+ 'header' : header ,
174
176
}
175
177
176
178
logging .debug ('Context: %s' , context_info )
@@ -229,7 +231,7 @@ def get_same_header_file_paths(self, wrong_file: str) -> list[str]:
229
231
for header in header_list :
230
232
correct_file_name = os .path .splitext (os .path .basename (header ))
231
233
if wrong_file_name == correct_file_name :
232
- candidate_headers .append (header )
234
+ candidate_headers .append (os . path . normpath ( header ) )
233
235
234
236
return candidate_headers [:5 ]
235
237
@@ -245,11 +247,23 @@ def get_similar_header_file_paths(self, wrong_file: str) -> list[str]:
245
247
candidate_headers = sorted (candidate_header_scores ,
246
248
key = lambda x : candidate_header_scores [x ],
247
249
reverse = True )
248
- return candidate_headers [:5 ]
250
+ return [os .path .normpath (header ) for header in candidate_headers [:5 ]]
251
+
252
+ def _get_header_files_to_include (self , func_sig : str ) -> Optional [str ]:
253
+ """Retrieves the header file of the function signature."""
254
+ header_file = introspector .query_introspector_header_files_to_include (
255
+ self ._benchmark .project , func_sig )
256
+ return header_file [0 ] if header_file else None
249
257
250
- def get_target_function_file_path (self ) -> str :
258
+ def _get_target_function_file_path (self ) -> str :
251
259
"""Retrieves the header/source file of the function under test."""
252
- # Step 1: Find a header file that shares the same name as the source file.
260
+ # Step 1: Find a header file from the default API.
261
+ header_file = self ._get_header_files_to_include (
262
+ self ._benchmark .function_signature )
263
+ if header_file :
264
+ return header_file
265
+
266
+ # Step 2: Find a header file that shares the same name as the source file.
253
267
# TODO: Make this more robust, e.g., when header file and base file do not
254
268
# share the same basename.
255
269
source_file = introspector .query_introspector_source_file_path (
@@ -264,5 +278,41 @@ def get_target_function_file_path(self) -> str:
264
278
if candidate_headers :
265
279
return candidate_headers [0 ]
266
280
267
- # Step 2 : Use the source file If it does not have a same-name-header.
281
+ # Step 3 : Use the source file If it does not have a same-name-header.
268
282
return source_file
283
+
284
+ def get_prefixed_header_file (self , func_sig : str = '' ) -> Optional [str ]:
285
+ """Retrieves the header_file with `extern "C"` if needed."""
286
+ if func_sig :
287
+ header_file = self ._get_header_files_to_include (func_sig )
288
+ else :
289
+ header_file = self ._get_target_function_file_path ()
290
+
291
+ if not header_file :
292
+ return None
293
+ include_statement = f'#include "{ os .path .normpath (header_file )} "'
294
+ return (f'extern "C" {{\n { include_statement } \n }}'
295
+ if self ._benchmark .needs_extern else include_statement )
296
+
297
+ def get_prefixed_header_file_by_name (self , func_name : str ) -> Optional [str ]:
298
+ """Retrieves the header file based on function name with `extern "C"` if
299
+ needed."""
300
+ func_sig = introspector .query_introspector_function_signature (
301
+ self ._benchmark .project , func_name )
302
+ return self .get_prefixed_header_file (func_sig )
303
+
304
+ def get_prefixed_source_file (self ,
305
+ function_signature : str = '' ) -> Optional [str ]:
306
+ """Retrieves the source file with `extern "C"` if needed."""
307
+ if function_signature :
308
+ source_file = introspector .query_introspector_source_file_path (
309
+ self ._benchmark .project , function_signature )
310
+ else :
311
+ source_file = introspector .query_introspector_source_file_path (
312
+ self ._benchmark .project , self ._benchmark .function_signature )
313
+ if not source_file :
314
+ return None
315
+
316
+ include_statement = f'#include "{ os .path .normpath (source_file )} "'
317
+ return (f'extern "C" {{\n { include_statement } \n }}'
318
+ if self ._benchmark .needs_extern else include_statement )
0 commit comments