|
6 | 6 | from pathlib import Path |
7 | 7 | from typing import List |
8 | 8 | import fitz |
| 9 | +from import_deps import ModuleSet |
| 10 | +from graphlib import TopologicalSorter, CycleError |
9 | 11 | import yaml |
10 | 12 |
|
11 | 13 | from agent.class_types import AgentConfig |
|
16 | 18 | UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n" |
17 | 19 | LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n" |
18 | 20 | SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n" |
| 21 | +IMPORT_DEPENDENCIES_HEADER = "\n\n>>> Here are the Import Dependencies:\n" |
19 | 22 | # prefix components: |
20 | 23 | space = " " |
21 | 24 | branch = "│ " |
@@ -190,25 +193,97 @@ def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str] |
190 | 193 | return files |
191 | 194 |
|
192 | 195 |
|
193 | | -def get_target_edit_files(target_dir: str, src_dir: str, test_dir: str) -> list[str]: |
| 196 | +def ignore_cycles(graph: dict) -> list[str]: |
| 197 | + """Ignore the cycles in the graph.""" |
| 198 | + ts = TopologicalSorter(graph) |
| 199 | + try: |
| 200 | + return list(ts.static_order()) |
| 201 | + except CycleError as e: |
| 202 | + # print(f"Cycle detected: {e.args[1]}") |
| 203 | + # You can either break the cycle by modifying the graph or handle it as needed. |
| 204 | + # For now, let's just remove the first node in the cycle and try again. |
| 205 | + cycle_nodes = e.args[1] |
| 206 | + node_to_remove = cycle_nodes[0] |
| 207 | + # print(f"Removing node {node_to_remove} to resolve cycle.") |
| 208 | + graph.pop(node_to_remove, None) |
| 209 | + return ignore_cycles(graph) |
| 210 | + |
| 211 | + |
| 212 | +def topological_sort_based_on_dependencies( |
| 213 | + pkg_paths: list[str], |
| 214 | +) -> tuple[list[str], dict]: |
| 215 | + """Topological sort based on dependencies.""" |
| 216 | + module_set = ModuleSet([str(p) for p in pkg_paths]) |
| 217 | + |
| 218 | + import_dependencies = {} |
| 219 | + for path in sorted(module_set.by_path.keys()): |
| 220 | + module_name = ".".join(module_set.by_path[path].fqn) |
| 221 | + mod = module_set.by_name[module_name] |
| 222 | + try: |
| 223 | + imports = module_set.get_imports(mod) |
| 224 | + import_dependencies[path] = set([str(x) for x in imports]) |
| 225 | + except Exception: |
| 226 | + import_dependencies[path] = set() |
| 227 | + |
| 228 | + import_dependencies_files = ignore_cycles(import_dependencies) |
| 229 | + |
| 230 | + return import_dependencies_files, import_dependencies |
| 231 | + |
| 232 | + |
| 233 | +def get_target_edit_files( |
| 234 | + local_repo: git.Repo, |
| 235 | + src_dir: str, |
| 236 | + test_dir: str, |
| 237 | + latest_commit: str, |
| 238 | + reference_commit: str, |
| 239 | +) -> tuple[list[str], dict]: |
194 | 240 | """Find the files with functions with the pass statement.""" |
| 241 | + target_dir = str(local_repo.working_dir) |
195 | 242 | files = _find_files_to_edit(target_dir, src_dir, test_dir) |
196 | 243 | filtered_files = [] |
197 | 244 | for file_path in files: |
198 | | - with open(file_path, "r", encoding="utf-8", errors="ignore") as file: |
| 245 | + with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file: |
199 | 246 | content = file.read() |
200 | 247 | if len(content.splitlines()) > 1500: |
201 | 248 | continue |
202 | 249 | if " pass" in content: |
203 | 250 | filtered_files.append(file_path) |
| 251 | + # Change to reference commit to get the correct dependencies |
| 252 | + local_repo.git.checkout(reference_commit) |
| 253 | + |
| 254 | + topological_sort_files, import_dependencies = ( |
| 255 | + topological_sort_based_on_dependencies(filtered_files) |
| 256 | + ) |
| 257 | + if len(topological_sort_files) != len(filtered_files): |
| 258 | + if len(topological_sort_files) < len(filtered_files): |
| 259 | + # Find the missing elements |
| 260 | + missing_files = set(filtered_files) - set(topological_sort_files) |
| 261 | + # Add the missing files to the end of the list |
| 262 | + topological_sort_files = topological_sort_files + list(missing_files) |
| 263 | + else: |
| 264 | + raise ValueError( |
| 265 | + "topological_sort_files should not be longer than filtered_files" |
| 266 | + ) |
| 267 | + assert len(topological_sort_files) == len( |
| 268 | + filtered_files |
| 269 | + ), "all files should be included" |
| 270 | + |
| 271 | + # change to latest commit |
| 272 | + local_repo.git.checkout(latest_commit) |
204 | 273 |
|
205 | 274 | # Remove the base_dir prefix |
206 | | - filtered_files = [ |
207 | | - file.replace(target_dir, "").lstrip("/") for file in filtered_files |
| 275 | + topological_sort_files = [ |
| 276 | + file.replace(target_dir, "").lstrip("/") for file in topological_sort_files |
208 | 277 | ] |
209 | | - # Only keep python files |
210 | 278 |
|
211 | | - return filtered_files |
| 279 | + # Remove the base_dir prefix from import dependencies |
| 280 | + import_dependencies_without_prefix = {} |
| 281 | + for key, value in import_dependencies.items(): |
| 282 | + key_without_prefix = key.replace(target_dir, "").lstrip("/") |
| 283 | + value_without_prefix = [v.replace(target_dir, "").lstrip("/") for v in value] |
| 284 | + import_dependencies_without_prefix[key_without_prefix] = value_without_prefix |
| 285 | + |
| 286 | + return topological_sort_files, import_dependencies_without_prefix |
212 | 287 |
|
213 | 288 |
|
214 | 289 | def get_message( |
@@ -268,6 +343,20 @@ def get_message( |
268 | 343 | return message_to_agent |
269 | 344 |
|
270 | 345 |
|
| 346 | +def update_message_with_dependencies(message: str, dependencies: list[str]) -> str: |
| 347 | + """Update the message with the dependencies.""" |
| 348 | + if len(dependencies) == 0: |
| 349 | + return message |
| 350 | + import_dependencies_info = f"\n{IMPORT_DEPENDENCIES_HEADER}" |
| 351 | + for dependency in dependencies: |
| 352 | + with open(dependency, "r") as file: |
| 353 | + import_dependencies_info += ( |
| 354 | + f"\nHere is the content of the file {dependency}:\n{file.read()}" |
| 355 | + ) |
| 356 | + message += import_dependencies_info |
| 357 | + return message |
| 358 | + |
| 359 | + |
271 | 360 | def get_specification(specification_pdf_path: Path) -> str: |
272 | 361 | """Get the reference for a given specification PDF path.""" |
273 | 362 | # TODO: after pdf_to_text is available, use it to extract the text from the PDF |
|
0 commit comments