From 0b9e8b5db60cab74a8095967dc18c4915d8fc2bb Mon Sep 17 00:00:00 2001 From: memento Date: Thu, 29 Dec 2022 18:10:47 -0600 Subject: [PATCH 1/6] Clunky attempt at preserving comments --- refactor/actions.py | 16 +- refactor/ast.py | 51 ++++-- refactor/common.py | 21 ++- tests/test_ast.py | 375 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 444 insertions(+), 19 deletions(-) diff --git a/refactor/actions.py b/refactor/actions.py index 13e8d76..4905a7a 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -89,11 +89,11 @@ def apply(self, context: Context, source: str) -> str: view = slice(lineno - 1, end_lineno) source_lines = lines[view] - indentation, start_prefix = find_indent(source_lines[0][:col_offset]) - end_suffix = source_lines[-1][end_col_offset:] replacement = split_lines(self._resynthesize(context)) - replacement.apply_indentation( - indentation, start_prefix=start_prefix, end_suffix=end_suffix + # Applies the block indentation only if the replacement lines are different from source + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), ) lines[view] = replacement @@ -168,12 +168,12 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]): def apply(self, context: Context, source: str) -> str: lines = split_lines(source, encoding=context.file_info.get_encoding()) - indentation, start_prefix = find_indent( - lines[self.node.lineno - 1][: self.node.col_offset] - ) replacement = split_lines(context.unparse(self.build())) - replacement.apply_indentation(indentation, start_prefix=start_prefix) + replacement.apply_source_formatting( + source_lines=lines, + markers=(self.node.lineno - 1, self.node.col_offset, None), + ) original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): diff --git a/refactor/ast.py b/refactor/ast.py index 3d23262..294a614 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -4,15 +4,17 @@ import io import operator import os +import re import tokenize from collections import UserList, UserString from collections.abc import Generator, Iterator from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple from refactor import common +from refactor.common import find_indent, extract_str_difference DEFAULT_ENCODING = "utf-8" @@ -32,21 +34,52 @@ def join(self) -> str: """Return the combined source code.""" return "".join(map(str, self.lines)) - def apply_indentation( + def apply_source_formatting( self, - indentation: StringType, + source_lines: Lines, *, - start_prefix: AnyStringType = "", - end_suffix: AnyStringType = "", + markers: Tuple[int, int, int | None] = None, ) -> None: - """Apply the given indentation, optionally with start and end prefixes - to the bound source lines.""" + """Apply the indentation from source_lines when the first several characters match + :param source_lines: Original lines in source code + :param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None) + """ + + indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) + end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] + + original_line: str | None for index, line in enumerate(self.data): + comments: str = "" + p_common = 0 + if index < len(source_lines): + original_line = str(source_lines[index]) + difference, p_common = extract_str_difference(original_line, line, with_comments=True) + print(difference) + if "#" in difference: + m = re.search(r"(#.+)\n", original_line) + comments = " " + m.group(1) if m and m.group(1) != "" else "" + difference, p_common = extract_str_difference(original_line, line) + else: + original_line = None + + line_w_comments: str = line[:-1] + comments + line[-1] if len(line) > 0 and line[-1] == "\n" else line + comments + line_w_comments = line_w_comments if p_common < 0.25 else line + if index == 0: - self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore + self.data[index] = indentation + str(start_prefix) + str(line_w_comments) # type: ignore + elif index == len(self.data) - 1: + if original_line is not None and original_line.startswith(line[:-1]): + self.data[index] = line # type: ignore + else: + self.data[index] = indentation + line # type: ignore + + # The updated line can have an extra wrapping in brackets + elif original_line is not None and original_line.startswith(line[:-1]): + self.data[index] = line_w_comments # type: ignore else: - self.data[index] = indentation + line # type: ignore + self.data[index] = indentation + line_w_comments # type: ignore if len(self.data) >= 1: self.data[-1] += str(end_suffix) # type: ignore diff --git a/refactor/common.py b/refactor/common.py index 68dbfaf..6de51fb 100644 --- a/refactor/common.py +++ b/refactor/common.py @@ -2,12 +2,13 @@ import ast import copy +import re from collections import deque from collections.abc import Iterable, Iterator from dataclasses import dataclass from functools import cache, singledispatch, wraps from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast, Set, List, Tuple if TYPE_CHECKING: from refactor.context import Context @@ -201,6 +202,24 @@ def extract_from_text(text: str) -> ast.AST: return ast.parse(text).body[0] +def extract_str_difference(s1: str, s2: str, with_comments: bool = False) -> Tuple[Set[str], float]: + """Returns a set of "words" that are different between 2 strings""" + if not with_comments: + s1 = re.match(r'^([^#]*)', s1).group() + s2 = re.match(r'^([^#]*)', s2).group() + difference: Set[str] = set(s1.split()).symmetric_difference(s2.split()) + count_changed_chars: int = 0 + for w in difference: + count_changed_chars = count_changed_chars + len(w) + percentile_change: float = count_changed_chars / (len(s1 + s2)) + return set(s1.split()).symmetric_difference(s2.split()), percentile_change + + +def refactored_matching(s1: str, s2: str) -> Set[str]: + """Returns whether the changes in two strings can be considered minimal""" + return set(s1.split()).symmetric_difference(s2.split()) + + _POSITIONAL_ATTRIBUTES = ( "lineno", "col_offset", diff --git a/tests/test_ast.py b/tests/test_ast.py index 1e5dc1d..e79b53a 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -6,8 +6,9 @@ import pytest -from refactor import common +from refactor import common, Context from refactor.ast import BaseUnparser, PreciseUnparser, split_lines +from refactor.common import position_for, clone def test_split_lines(): @@ -169,6 +170,7 @@ def test_precise_unparser_indented_literals(): """\ def func(): if something: + # On change, comments are removed print( "bleh" "zoom" @@ -240,3 +242,374 @@ def foo(): base = PreciseUnparser(source=source) assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_no_changes(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + tree = ast.parse(source) + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_del(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + tree = ast.parse(source) + del tree.body[0].body[0].body[0].value.args[2] + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_apply_source_formatting_maintains_with_await_0(): + source = """def func(): + if something: + # Comments are retrieved + print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_await_1(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + call_instead(print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + )) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call_on_closing_parens(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), # This comment is unchanged + maybe+something_else_that_is_very_very_very_long, + maybe / other, # This comment is unchanged + thing . a + ) # This is mis-aligned and spacing of comment doesn't change on last line +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + call_instead(print(call(.1), # This comment is unchanged + maybe+something_else_that_is_very_very_very_long, + maybe / other, # This comment is unchanged + thing . a + )) # This is mis-aligned and spacing of comment doesn't change on last line +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async_complex(): + source = """def func(): + if something: + # Comments are retrieved + with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async(): + source = """def func(): + if something: + # Comments are retrieved + with something: # comment2 becomes spacing of 2 + a = 1 # Non-standard indent + b = 2 # Non-standard indent, comment is unchanged due to 'end_suffix' +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with something: # comment2 becomes spacing of 2 + a = 1 # Non-standard indent + b = 2 # Non-standard indent, comment is unchanged due to 'end_suffix' +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_fstring(): + source = '''def f(): + return """ +a +""" +''' + + expected_src = '''def f(): + return F(""" +a +""") +''' + + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="F"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_does_not_with_change(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + del node_to_replace.args[2] + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src From 14e6e8716ffd9e1392ab6521042367a6cf27d23c Mon Sep 17 00:00:00 2001 From: memento Date: Sat, 31 Dec 2022 18:31:07 -0600 Subject: [PATCH 2/6] Better? some edge cases failing --- refactor/ast.py | 88 +++++++++++++++++----------- refactor/common.py | 142 +++++++++++++++++++++++++++++++++++++++------ tests/test_ast.py | 32 +++++----- 3 files changed, 195 insertions(+), 67 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 294a614..569e7a2 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -14,7 +14,7 @@ from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple from refactor import common -from refactor.common import find_indent, extract_str_difference +from refactor.common import find_indent, extract_str_difference, find_indent_comments DEFAULT_ENCODING = "utf-8" @@ -29,60 +29,82 @@ class Lines(UserList[StringType]): def __post_init__(self) -> None: super().__init__(self.lines) self.lines = self.data + self.best_matches: list[StringType] = [] def join(self) -> str: """Return the combined source code.""" return "".join(map(str, self.lines)) + @staticmethod + def find_best_matching_source_line(line: str, source_lines: Lines, percentile: float = 10) -> Tuple[str | None, str, str]: + """Finds the best matching line in a list of lines + Returns the source line indentation and comments""" + _line: str | None = None + for _l in source_lines.lines: + _line: str = str(_l) + indentation, _, comments = find_indent_comments(_line) + + # Estimate the changes between the two lines + changes = extract_str_difference(_line, line, without_comments=True) + + # There should be minimal changes - how to estimate that threshold? + if changes['a']['percent'] < percentile: + if "#" in line: + return _line, indentation, "" + return _line, indentation, comments + return None, "", "" + def apply_source_formatting( self, source_lines: Lines, *, markers: Tuple[int, int, int | None] = None, + comments_separator: str = " " ) -> None: """Apply the indentation from source_lines when the first several characters match :param source_lines: Original lines in source code :param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None) + :param comments_separator: Separator for comments """ - indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) + block_indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] - original_line: str | None for index, line in enumerate(self.data): - comments: str = "" - p_common = 0 - if index < len(source_lines): - original_line = str(source_lines[index]) - difference, p_common = extract_str_difference(original_line, line, with_comments=True) - print(difference) - if "#" in difference: - m = re.search(r"(#.+)\n", original_line) - comments = " " + m.group(1) if m and m.group(1) != "" else "" - difference, p_common = extract_str_difference(original_line, line) - else: - original_line = None - - line_w_comments: str = line[:-1] + comments + line[-1] if len(line) > 0 and line[-1] == "\n" else line + comments - line_w_comments = line_w_comments if p_common < 0.25 else line + print(f">{line}<>{end_suffix}<") + # Let's see if we find a matching original line using statistical change to original lines (default < 10%) + original_line, indentation, comments = self.find_best_matching_source_line(line, source_lines[markers[0]:]) + + if original_line is not None: + # Remove the line indentation, collect comments + _, line, new_comments = find_indent_comments(line) + + # Update for comments either on the 'line' or on the original line + if new_comments and not new_comments.isspace(): + # 'line' include comments, keep and implement 2 spaces separation + line = line + comments_separator + new_comments + + elif comments and not comments.isspace(): + # Comments from original line may have end-of-line, using the 'line' terminator + comments = re.sub(self._newline_type, '', comments) + # If line has a return, insert the comments just before it + # Use 2 space separator as recommended by PyCharm (from PEP?) + if line and line[-1] == self._newline_type: + line = line[:-1] + comments_separator + comments + line[-1] + else: + line = line + comments_separator + comments + + self.data[index] = indentation + str(line) if index == 0: - self.data[index] = indentation + str(start_prefix) + str(line_w_comments) # type: ignore - elif index == len(self.data) - 1: - if original_line is not None and original_line.startswith(line[:-1]): - self.data[index] = line # type: ignore - else: - self.data[index] = indentation + line # type: ignore - - # The updated line can have an extra wrapping in brackets - elif original_line is not None and original_line.startswith(line[:-1]): - self.data[index] = line_w_comments # type: ignore - else: - self.data[index] = indentation + line_w_comments # type: ignore - - if len(self.data) >= 1: - self.data[-1] += str(end_suffix) # type: ignore + self.data[index] = block_indentation + str(start_prefix) + str(line) + + if index == len(self.data) - 1: + if original_line is None: + self.data[index] = self.data[index] + str(end_suffix) + elif original_line[-1] == self._newline_type: + self.data[index] = self.data[index] + self._newline_type @cached_property def _newline_type(self) -> str: diff --git a/refactor/common.py b/refactor/common.py index 6de51fb..cc6a8c0 100644 --- a/refactor/common.py +++ b/refactor/common.py @@ -2,13 +2,14 @@ import ast import copy +import difflib import re from collections import deque from collections.abc import Iterable, Iterator from dataclasses import dataclass from functools import cache, singledispatch, wraps from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast, Set, List, Tuple +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast, Set, List, Tuple, Dict, AnyStr if TYPE_CHECKING: from refactor.context import Context @@ -77,7 +78,7 @@ def is_truthy(op: ast.cmpop) -> bool | None: def _type_checker( - *types: type, binders: Iterable[Callable[[type], bool]] = () + *types: type, binders: Iterable[Callable[[type], bool]] = () ) -> Callable[[Any], bool]: binders = [getattr(binder, "fast_checker", binder) for binder in binders] @@ -135,7 +136,7 @@ def _get_known_location_from_source(source: str, location: PositionType) -> str if start_line == end_line: return lines[start_line][start_col:end_col] - start, *middle, end = lines[start_line : end_line + 1] + start, *middle, end = lines[start_line: end_line + 1] new_lines = (start[start_col:], *middle, end[:end_col]) return "\n".join(new_lines) @@ -177,6 +178,31 @@ def find_indent(source: str) -> tuple[str, str]: return source[:index], source[index:] +def find_comments(source: str) -> tuple[str, str]: + """Split the given line into the current indentation + and the remaining characters.""" + index = 0 + for index, char in enumerate(source, 1): + if char == "#": + break + return source[:index], source[index:] + + +def find_indent_comments(source: str) -> tuple[str, str, str]: + """Split the given line into the current indentation + and the remaining characters.""" + indent, comment = -1, -1 + for index, char in enumerate(source, 1): + if not char.isspace() and indent == -1: + indent = index - 1 + if char == "#": + comment = index - 1 + break + if comment != -1: + return source[:indent], source[indent:comment].strip(), source[comment:] + return source[:indent], source[indent:], "" + + def find_closest(node: ast.AST, *targets: ast.AST) -> ast.AST: """Find the closest node to the given ``node`` from the given sequence of ``targets`` (uses absolute distance from starting points).""" @@ -202,22 +228,102 @@ def extract_from_text(text: str) -> ast.AST: return ast.parse(text).body[0] -def extract_str_difference(s1: str, s2: str, with_comments: bool = False) -> Tuple[Set[str], float]: +def split_python_wise(x: str, seps: List[str] = " ()[]{}\"'"): + default_sep = seps[0] + for s in seps[1:]: + x = x.replace(s, default_sep + s + default_sep) + return [i.strip() for i in x.split(default_sep)] + + +def split_on_separators(string: str, separators: List[str] = "()[]{}'" + '"') -> List[str]: + pattern = "|".join([f"{re.escape(sep)}(?!{re.escape(sep)})" for sep in separators] + [" "]) + return [s + s if s in separators else s for s in re.split(pattern, string)] + + +def extract_str_difference(a: str, + b: str, + without_comments: bool = True, + ignore_leading_spaces: bool = True + ) -> Dict[str, Dict[str, str | float | Set[str]]]: """Returns a set of "words" that are different between 2 strings""" - if not with_comments: - s1 = re.match(r'^([^#]*)', s1).group() - s2 = re.match(r'^([^#]*)', s2).group() - difference: Set[str] = set(s1.split()).symmetric_difference(s2.split()) - count_changed_chars: int = 0 - for w in difference: - count_changed_chars = count_changed_chars + len(w) - percentile_change: float = count_changed_chars / (len(s1 + s2)) - return set(s1.split()).symmetric_difference(s2.split()), percentile_change - - -def refactored_matching(s1: str, s2: str) -> Set[str]: - """Returns whether the changes in two strings can be considered minimal""" - return set(s1.split()).symmetric_difference(s2.split()) + # Remove comments if requested + a = re.match(r'^([^#]*)', a).group(1) if without_comments else a + b = re.match(r'^([^#]*)', b).group(1) if without_comments else b + + # Remove leading white spaces if requested + a = re.match(r'^\s*?([\S].*)$', a).group(1) if ignore_leading_spaces else a + b = re.match(r'^\s*?([\S].*)$', b).group(1) if ignore_leading_spaces else b + + differences: Dict[str, Dict[str, str | float | Set[str]]] = { + "a": {"changes": set(), "percent": 0.0}, + "b": {"changes": set(), "percent": 0.0}} + + raw_diff: Set[str] = set(split_on_separators(a)).symmetric_difference(set(split_on_separators(b))) + for item in raw_diff: + if item in a and item not in b: + differences['a']['changes'].add(item) + differences['a']['percent'] = differences['a']['percent'] + len(item) + else: + differences['b']['changes'].add(item) + differences['b']['percent'] = differences['b']['percent'] + len(item) + + differences['a']['percent'] = differences['a']['percent'] / len(a.split()) * 100 + differences['b']['percent'] = differences['b']['percent'] / len(b.split()) * 100 + return differences + + +def extract_string_differences(a: str, + b: str, + without_comments: bool = True, + ignore_leading_spaces: bool = True, + ignore_spaces: bool = False) -> Dict[str, Dict[str, str | float]]: + """Calculate the difference between two strings. Optionally removes that comments. + + :param str a: The first string to compare. + :param str b: The second string to compare. + :param bool with_comments: Optional removal of comments from the extraction. + :param bool ignore_spaces: Optional inclusion of space counting. + :returns: The differences and percentiles between the two strings in a dictionary. + :rtype: Dict + """ + # Remove comments if requested + a = re.match(r'^([^#]*)', a).group(1) if without_comments else a + b = re.match(r'^([^#]*)', b).group(1) if without_comments else b + + # Remove leading white spaces if requested + a = re.match(r'^\s*?([\S].*)$', a).group(1) if ignore_leading_spaces else a + b = re.match(r'^\s*?([\S].*)$', b).group(1) if ignore_leading_spaces else b + + # Remove white spaces if requested + a = "".join(a.split() if ignore_spaces else list(a)) + b = "".join(b.split() if ignore_spaces else list(b)) + + # Initialize the SequenceMatcher + matcher = difflib.SequenceMatcher(a=a, b=b) + + # Store the differences between the two strings + differences = { + "a": {"changes": "", "percent": 0.0}, + "b": {"changes": "", "percent": 0.0}, + "common": {"changes": "", "percent": 0.0}, + } + + # Iterate over the opcodes and update the difference counter + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "replace": + differences["a"]["changes"] += a[i1:i2] + differences["b"]["changes"] += b[j1:j2] + elif tag == "delete": + differences["a"]["changes"] += a[i1:i2] + elif tag == "insert": + differences["b"]["changes"] += b[j1:j2] + elif tag == "equal": + differences["common"]["changes"] += a[i1:i2] + + for key, value in differences.items(): + value["percent"] = len(value["changes"]) / len(a) * 100 + + return differences _POSITIONAL_ATTRIBUTES = ( diff --git a/tests/test_ast.py b/tests/test_ast.py index e79b53a..ddde822 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -341,9 +341,9 @@ def test_apply_source_formatting_maintains_with_await_1(): source = """def func(): if something: # Comments are retrieved - print(call(.1), + print(call(.1), # Comments are retrieved. separation can be updated to x spaces (1 default) maybe+something_else_that_is_very_very_very_long, - maybe / other, + maybe / other,# Comments are retrieved. separation can be updated to x spaces (1 default) thing . a ) """ @@ -351,9 +351,9 @@ def test_apply_source_formatting_maintains_with_await_1(): expected_src = """def func(): if something: # Comments are retrieved - await print(call(.1), + await print(call(.1), # Comments are retrieved. separation can be updated to x spaces (1 default) maybe+something_else_that_is_very_very_very_long, - maybe / other, + maybe / other, # Comments are retrieved. separation can be updated to x spaces (1 default) thing . a ) """ @@ -423,21 +423,21 @@ def test_apply_source_formatting_maintains_with_call_on_closing_parens(): source = """def func(): if something: # Comments are retrieved - print(call(.1), # This comment is unchanged + print(call(.1), # Comments are retrieved, spacing of x spaces (1 default) maybe+something_else_that_is_very_very_very_long, - maybe / other, # This comment is unchanged + maybe / other, # Comments are retrieved, spacing of x spaces (1 default) thing . a - ) # This is mis-aligned and spacing of comment doesn't change on last line + ) # Non-standard indent is conserved and comments, spacing of x spaces (1 default) """ expected_src = """def func(): if something: # Comments are retrieved - call_instead(print(call(.1), # This comment is unchanged + call_instead(print(call(.1), # Comments are retrieved, spacing of x spaces (1 default) maybe+something_else_that_is_very_very_very_long, - maybe / other, # This comment is unchanged + maybe / other, # Comments are retrieved, spacing of x spaces (1 default) thing . a - )) # This is mis-aligned and spacing of comment doesn't change on last line + )) # Non-standard indent is conserved and comments, spacing of x spaces (1 default) """ source_tree = ast.parse(source) context = Context(source, source_tree) @@ -508,17 +508,17 @@ def test_apply_source_formatting_maintains_with_async(): source = """def func(): if something: # Comments are retrieved - with something: # comment2 becomes spacing of 2 - a = 1 # Non-standard indent - b = 2 # Non-standard indent, comment is unchanged due to 'end_suffix' + with something: # comment2, spacing of x spaces (1 default) + a = 1 # Non-standard indent is conserved, spacing of x spaces (1 default) + b = 2 # Non-standard indent is conserved, spacing of x spaces (1 default) """ expected_src = """def func(): if something: # Comments are retrieved - async with something: # comment2 becomes spacing of 2 - a = 1 # Non-standard indent - b = 2 # Non-standard indent, comment is unchanged due to 'end_suffix' + async with something: # comment2, spacing of x spaces (1 default) + a = 1 # Non-standard indent is conserved, spacing of x spaces (1 default) + b = 2 # Non-standard indent is conserved, spacing of x spaces (1 default) """ source_tree = ast.parse(source) context = Context(source, source_tree) From 6ad219cd432f62ebc1933efdcc57dfe7f170d85b Mon Sep 17 00:00:00 2001 From: memento Date: Sun, 1 Jan 2023 15:35:44 -0600 Subject: [PATCH 3/6] Passing all tests, most original, some modified to represent the new logic --- refactor/actions.py | 4 +++- refactor/ast.py | 54 ++++++++++++++++++++++++++------------------- refactor/common.py | 12 +++++----- tests/test_core.py | 6 ++--- 4 files changed, 44 insertions(+), 32 deletions(-) diff --git a/refactor/actions.py b/refactor/actions.py index 4905a7a..f28229c 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -4,6 +4,7 @@ import warnings from contextlib import suppress from dataclasses import dataclass, field, replace +from pprint import pprint from typing import Generic, TypeVar, cast from refactor.ast import split_lines @@ -177,7 +178,8 @@ def apply(self, context: Context, source: str) -> str: original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): - replacement[-1] += lines._newline_type + pprint(replacement) + replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else "" else: # If the original anchor's last line doesn't end with a newline, # then we need to also prevent our new source from ending with diff --git a/refactor/ast.py b/refactor/ast.py index 569e7a2..3c75adc 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -11,7 +11,8 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple +from pprint import pprint +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple, Set, Dict, List, Callable from refactor import common from refactor.common import find_indent, extract_str_difference, find_indent_comments @@ -36,30 +37,33 @@ def join(self) -> str: return "".join(map(str, self.lines)) @staticmethod - def find_best_matching_source_line(line: str, source_lines: Lines, percentile: float = 10) -> Tuple[str | None, str, str]: + def find_best_matching_source_line(line: str, source_lines: Lines, percentile: float = 10) -> Tuple[ + str | None, str, str]: """Finds the best matching line in a list of lines Returns the source line indentation and comments""" _line: str | None = None + changes: List[Dict[str, Tuple[str, str, str] | Dict[str, str | float | Set[str]]]] = [] for _l in source_lines.lines: _line: str = str(_l) - indentation, _, comments = find_indent_comments(_line) - # Estimate the changes between the two lines - changes = extract_str_difference(_line, line, without_comments=True) - - # There should be minimal changes - how to estimate that threshold? - if changes['a']['percent'] < percentile: - if "#" in line: - return _line, indentation, "" - return _line, indentation, comments + changes.append(extract_str_difference(_line, line, without_comments=True)) + indentation, _, comments = find_indent_comments(_line) + changes[-1]['output'] = _line, indentation, "" if "#" in line else comments + + sort_key = lambda x: x['a']['percent'] + x['b']['percent'] + sorted_changes = sorted(changes, key=sort_key) + for i, change in enumerate(sorted_changes): + # There should be minimal changes to the original line - how to estimate that threshold? + if change['a']['percent'] < percentile: + return change['output'] return None, "", "" def apply_source_formatting( - self, - source_lines: Lines, - *, - markers: Tuple[int, int, int | None] = None, - comments_separator: str = " " + self, + source_lines: Lines, + *, + markers: Tuple[int, int, int | None] = None, + comments_separator: str = " " ) -> None: """Apply the indentation from source_lines when the first several characters match @@ -72,10 +76,12 @@ def apply_source_formatting( end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] for index, line in enumerate(self.data): - print(f">{line}<>{end_suffix}<") # Let's see if we find a matching original line using statistical change to original lines (default < 10%) - original_line, indentation, comments = self.find_best_matching_source_line(line, source_lines[markers[0]:]) + (original_line, + indentation, + comments) = self.find_best_matching_source_line(line, source_lines[markers[0]:]) + print(f">{line[:-1]}<") if original_line is not None: # Remove the line indentation, collect comments _, line, new_comments = find_indent_comments(line) @@ -95,7 +101,9 @@ def apply_source_formatting( else: line = line + comments_separator + comments - self.data[index] = indentation + str(line) + self.data[index] = indentation + str(line) + else: + self.data[index] = block_indentation + str(line) if index == 0: self.data[index] = block_indentation + str(start_prefix) + str(line) @@ -132,7 +140,7 @@ def __getitem__(self, index: SupportsIndex | slice) -> SourceSegment: # re-implements the direct indexing as slicing (e.g. a[1] is a[1:2], with # error handling). direct_index = operator.index(index) - view = raw_line[direct_index : direct_index + 1].decode( + view = raw_line[direct_index: direct_index + 1].decode( encoding=self.encoding ) if not view: @@ -255,9 +263,9 @@ def maybe_retrieve(self, node: ast.AST) -> bool: @contextmanager def _collect_stmt_comments(self, node: ast.AST) -> Iterator[None]: def _write_if_unseen_comment( - line_no: int, - line: str, - comment_begin: int, + line_no: int, + line: str, + comment_begin: int, ) -> None: if line_no in self._visited_comment_lines: # We have already written this comment as the diff --git a/refactor/common.py b/refactor/common.py index cc6a8c0..5a2d633 100644 --- a/refactor/common.py +++ b/refactor/common.py @@ -237,7 +237,9 @@ def split_python_wise(x: str, seps: List[str] = " ()[]{}\"'"): def split_on_separators(string: str, separators: List[str] = "()[]{}'" + '"') -> List[str]: pattern = "|".join([f"{re.escape(sep)}(?!{re.escape(sep)})" for sep in separators] + [" "]) - return [s + s if s in separators else s for s in re.split(pattern, string)] + result = [s + s if s in separators else s for s in re.split(pattern, string)] + separators_found = [s for s in separators if s in string] + return result + separators_found def extract_str_difference(a: str, @@ -251,8 +253,8 @@ def extract_str_difference(a: str, b = re.match(r'^([^#]*)', b).group(1) if without_comments else b # Remove leading white spaces if requested - a = re.match(r'^\s*?([\S].*)$', a).group(1) if ignore_leading_spaces else a - b = re.match(r'^\s*?([\S].*)$', b).group(1) if ignore_leading_spaces else b + a = m.group(1) if ignore_leading_spaces and (m := re.match(r'^\s*?([\S\n].*)', a)) else a + b = m.group(1) if ignore_leading_spaces and (m := re.match(r'^\s*?([\S\n].*)', b)) else b differences: Dict[str, Dict[str, str | float | Set[str]]] = { "a": {"changes": set(), "percent": 0.0}, @@ -267,8 +269,8 @@ def extract_str_difference(a: str, differences['b']['changes'].add(item) differences['b']['percent'] = differences['b']['percent'] + len(item) - differences['a']['percent'] = differences['a']['percent'] / len(a.split()) * 100 - differences['b']['percent'] = differences['b']['percent'] / len(b.split()) * 100 + differences['a']['percent'] = differences['a']['percent'] / (len(a.split()) + 1) * 100 + differences['b']['percent'] = differences['b']['percent'] / (len(b.split()) + 1) * 100 return differences diff --git a/tests/test_core.py b/tests/test_core.py index e28f540..fcb61cf 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -50,14 +50,14 @@ def test_apply_simple(source, expected, target_func, replacement): [ ( """ - import x # comments + import x # comments will be copied since identical python statement print(x.y) # comments here def something(x, y): return x + y # comments """, """ - import x # comments - import x + import x # comments will be copied since identical python statement + import x # comments will be copied since identical python statement print(x.y) # comments here def something(x, y): return x + y # comments From 710e1fcd9a314ea7e6e231819a6b5a856c69973b Mon Sep 17 00:00:00 2001 From: memento Date: Sun, 1 Jan 2023 15:56:08 -0600 Subject: [PATCH 4/6] Correction in merged InsertBefore --- refactor/actions.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/refactor/actions.py b/refactor/actions.py index 72043ab..8e4135d 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -180,7 +180,6 @@ def apply(self, context: Context, source: str) -> str: original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): - pprint(replacement) replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else "" else: # If the original anchor's last line doesn't end with a newline, @@ -214,13 +213,13 @@ class LazyInsertBefore(_LazyActionMixin[ast.stmt, ast.stmt]): def apply(self, context: Context, source: str) -> str: lines = split_lines(source, encoding=context.file_info.get_encoding()) - indentation, start_prefix = find_indent( - lines[self.node.lineno - 1][: self.node.col_offset] - ) replacement = split_lines(context.unparse(self.build())) - replacement.apply_indentation(indentation, start_prefix=start_prefix) - replacement[-1] += lines._newline_type + replacement.apply_source_formatting( + source_lines=lines, + markers=(self.node.lineno - 1, self.node.col_offset, None), + ) + replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else "" original_node_start = cast(int, self.node.lineno) for line in reversed(replacement): From 1b7279ee75f0104bf22feeffb756917d5fd335b2 Mon Sep 17 00:00:00 2001 From: memento Date: Sun, 1 Jan 2023 16:00:42 -0600 Subject: [PATCH 5/6] Lost print --- refactor/ast.py | 1 - 1 file changed, 1 deletion(-) diff --git a/refactor/ast.py b/refactor/ast.py index 3c75adc..042ca0a 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -81,7 +81,6 @@ def apply_source_formatting( indentation, comments) = self.find_best_matching_source_line(line, source_lines[markers[0]:]) - print(f">{line[:-1]}<") if original_line is not None: # Remove the line indentation, collect comments _, line, new_comments = find_indent_comments(line) From fcf2919e86a5519400265b03fc245d38ddf9f3cc Mon Sep 17 00:00:00 2001 From: memento Date: Sun, 1 Jan 2023 19:18:44 -0600 Subject: [PATCH 6/6] A bit cleaner? --- refactor/ast.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/refactor/ast.py b/refactor/ast.py index 042ca0a..345a972 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -72,8 +72,17 @@ def apply_source_formatting( :param comments_separator: Separator for comments """ + def reconstruct_comments(line_to_update, new, original) -> str: + if new and not new.isspace(): + return line_to_update + comments_separator + new + if original and not original.isspace(): + original = re.sub(self._newline_type, '', original) + if line_to_update and line_to_update[-1] == self._newline_type: + return line_to_update[:-1] + comments_separator + original + line_to_update[-1] + return line_to_update + comments_separator + original + return line_to_update + block_indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) - end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] for index, line in enumerate(self.data): # Let's see if we find a matching original line using statistical change to original lines (default < 10%) @@ -81,34 +90,26 @@ def apply_source_formatting( indentation, comments) = self.find_best_matching_source_line(line, source_lines[markers[0]:]) - if original_line is not None: + if original_line is None: + # If there is no good match as original line, implement the block_separator + self.data[index] = block_indentation + str(line) + else: # Remove the line indentation, collect comments _, line, new_comments = find_indent_comments(line) - - # Update for comments either on the 'line' or on the original line - if new_comments and not new_comments.isspace(): - # 'line' include comments, keep and implement 2 spaces separation - line = line + comments_separator + new_comments - - elif comments and not comments.isspace(): - # Comments from original line may have end-of-line, using the 'line' terminator - comments = re.sub(self._newline_type, '', comments) - # If line has a return, insert the comments just before it - # Use 2 space separator as recommended by PyCharm (from PEP?) - if line and line[-1] == self._newline_type: - line = line[:-1] + comments_separator + comments + line[-1] - else: - line = line + comments_separator + comments - + line: str = reconstruct_comments(line, new_comments, comments) self.data[index] = indentation + str(line) - else: - self.data[index] = block_indentation + str(line) + # Edge cases: First line, override with the block indentation, keeping the comments if index == 0: self.data[index] = block_indentation + str(start_prefix) + str(line) + # For the last line: + # if no good match, append the end_suffix + # otherwise, add the newline if present in the original + # Note that this seems more appropriate here than in the actions.py if index == len(self.data) - 1: if original_line is None: + end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] self.data[index] = self.data[index] + str(end_suffix) elif original_line[-1] == self._newline_type: self.data[index] = self.data[index] + self._newline_type