@@ -234,8 +234,9 @@ def get_target_edit_files(
234234 local_repo : git .Repo ,
235235 src_dir : str ,
236236 test_dir : str ,
237- latest_commit : str ,
237+ branch : str ,
238238 reference_commit : str ,
239+ use_topo_sort_dependencies : bool = True ,
239240) -> tuple [list [str ], dict ]:
240241 """Find the files with functions with the pass statement."""
241242 target_dir = str (local_repo .working_dir )
@@ -269,7 +270,7 @@ def get_target_edit_files(
269270 ), "all files should be included"
270271
271272 # change to latest commit
272- local_repo .git .checkout (latest_commit )
273+ local_repo .git .checkout (branch )
273274
274275 # Remove the base_dir prefix
275276 topological_sort_files = [
@@ -282,35 +283,88 @@ def get_target_edit_files(
282283 key_without_prefix = key .replace (target_dir , "" ).lstrip ("/" )
283284 value_without_prefix = [v .replace (target_dir , "" ).lstrip ("/" ) for v in value ]
284285 import_dependencies_without_prefix [key_without_prefix ] = value_without_prefix
286+ if use_topo_sort_dependencies :
287+ return topological_sort_files , import_dependencies_without_prefix
288+ else :
289+ filtered_files = [
290+ file .replace (target_dir , "" ).lstrip ("/" ) for file in filtered_files
291+ ]
292+ return filtered_files , import_dependencies_without_prefix
293+
294+
295+ def get_target_edit_files_from_patch (
296+ local_repo : git .Repo , patch : str , use_topo_sort_dependencies : bool = True
297+ ) -> tuple [list [str ], dict ]:
298+ """Get the target files from the patch."""
299+ working_dir = str (local_repo .working_dir )
300+ target_files = set ()
301+ for line in patch .split ("\n " ):
302+ if line .startswith ("+++" ) or line .startswith ("---" ):
303+ file_path = line .split ()[1 ]
304+ if file_path .startswith ("a/" ):
305+ file_path = file_path [2 :]
306+ if file_path .startswith ("b/" ):
307+ file_path = file_path [2 :]
308+ target_files .add (file_path )
309+
310+ target_files_list = list (target_files )
311+ target_files_list = [
312+ os .path .join (working_dir , file_path ) for file_path in target_files_list
313+ ]
285314
286- return topological_sort_files , import_dependencies_without_prefix
315+ if use_topo_sort_dependencies :
316+ topological_sort_files , import_dependencies = (
317+ topological_sort_based_on_dependencies (target_files_list )
318+ )
319+ if len (topological_sort_files ) != len (target_files_list ):
320+ if len (topological_sort_files ) < len (target_files_list ):
321+ missing_files = set (target_files_list ) - set (topological_sort_files )
322+ topological_sort_files = topological_sort_files + list (missing_files )
323+ else :
324+ raise ValueError (
325+ "topological_sort_files should not be longer than target_files_list"
326+ )
327+ assert len (topological_sort_files ) == len (
328+ target_files_list
329+ ), "all files should be included"
330+
331+ topological_sort_files = [
332+ file .replace (working_dir , "" ).lstrip ("/" ) for file in topological_sort_files
333+ ]
334+ for key , value in import_dependencies .items ():
335+ import_dependencies [key ] = [
336+ v .replace (working_dir , "" ).lstrip ("/" ) for v in value
337+ ]
338+ return topological_sort_files , import_dependencies
339+ else :
340+ target_files_list = [
341+ file .replace (working_dir , "" ).lstrip ("/" ) for file in target_files_list
342+ ]
343+ return target_files_list , {}
287344
288345
289346def get_message (
290347 agent_config : AgentConfig ,
291348 repo_path : str ,
292- test_dir : str | None = None ,
293- test_file : str | None = None ,
349+ test_files : list [str ] | None = None ,
294350) -> str :
295351 """Get the message to Aider."""
296352 prompt = f"{ PROMPT_HEADER } " + agent_config .user_prompt
297353
298- if agent_config .use_unit_tests_info and test_dir :
299- unit_tests_info = (
300- f"\n { UNIT_TESTS_INFO_HEADER } "
301- + get_dir_info (
302- dir_path = Path (os .path .join (repo_path , test_dir )),
303- prefix = "" ,
304- include_stubs = True ,
305- )[: agent_config .max_unit_tests_info_length ]
306- )
307- elif agent_config .use_unit_tests_info and test_file :
308- unit_tests_info = (
309- f"\n { UNIT_TESTS_INFO_HEADER } "
310- + get_file_info (
354+ # if agent_config.use_unit_tests_info and test_file:
355+ # unit_tests_info = (
356+ # f"\n{UNIT_TESTS_INFO_HEADER} "
357+ # + get_file_info(
358+ # file_path=Path(os.path.join(repo_path, test_file)), prefix=""
359+ # )[: agent_config.max_unit_tests_info_length]
360+ # )
361+ if agent_config .use_unit_tests_info and test_files :
362+ unit_tests_info = f"\n { UNIT_TESTS_INFO_HEADER } "
363+ for test_file in test_files :
364+ unit_tests_info += get_file_info (
311365 file_path = Path (os .path .join (repo_path , test_file )), prefix = ""
312- )[: agent_config . max_unit_tests_info_length ]
313- )
366+ )
367+ unit_tests_info = unit_tests_info [: agent_config . max_unit_tests_info_length ]
314368 else :
315369 unit_tests_info = ""
316370
@@ -405,6 +459,33 @@ def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None:
405459 raise RuntimeError (f"Failed to create or switch to branch '{ branch } ': { e } " )
406460
407461
462+ def get_changed_files_from_commits (
463+ repo : git .Repo , commit1 : str , commit2 : str
464+ ) -> list [str ]:
465+ """Get the changed files from two commits."""
466+ try :
467+ # Get the commit objects
468+ commit1_obj = repo .commit (commit1 )
469+ commit2_obj = repo .commit (commit2 )
470+
471+ # Get the diff between the two commits
472+ diff = commit1_obj .diff (commit2_obj )
473+
474+ # Extract the changed file paths
475+ changed_files = [item .a_path for item in diff ]
476+
477+ # Check if each changed file is a Python file
478+ python_files = [file for file in changed_files if file .endswith (".py" )]
479+
480+ # Update the changed_files list to only include Python files
481+ changed_files = python_files
482+
483+ return changed_files
484+ except Exception as e :
485+ print (f"An error occurred: { e } " )
486+ return []
487+
488+
408489def args2string (agent_config : AgentConfig ) -> str :
409490 """Converts specific fields from an `AgentConfig` object into a formatted string.
410491
@@ -453,13 +534,14 @@ def get_changed_files(repo: git.Repo) -> list[str]:
453534 return files_changed
454535
455536
456- def get_lint_cmd (repo_name : str , use_lint_info : bool ) -> str :
537+ def get_lint_cmd (repo_name : str , use_lint_info : bool , commit0_config_file : str ) -> str :
457538 """Generate a linting command based on whether to include files.
458539
459540 Args:
460541 ----
461542 repo_name (str): The name of the repository.
462543 use_lint_info (bool): A flag indicating whether to include changed files in the lint command.
544+ commit0_config_file (str): The path to the commit0 dot file.
463545
464546 Returns:
465547 -------
@@ -469,7 +551,9 @@ def get_lint_cmd(repo_name: str, use_lint_info: bool) -> str:
469551 """
470552 lint_cmd = "python -m commit0 lint "
471553 if use_lint_info :
472- lint_cmd += repo_name + " --files "
554+ lint_cmd += (
555+ repo_name + " --commit0-config-file " + commit0_config_file + " --files "
556+ )
473557 else :
474558 lint_cmd = ""
475559 return lint_cmd
0 commit comments