Skip to content

(Experimental) Preserve comments on mostly similar lines #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from contextlib import suppress
from dataclasses import dataclass, field, replace
from pprint import pprint
from typing import Generic, TypeVar, cast

from refactor.ast import split_lines
Expand Down Expand Up @@ -91,11 +92,11 @@ def apply(self, context: Context, source: str) -> str:
view = slice(lineno - 1, end_lineno)
source_lines = lines[view]

indentation, start_prefix = find_indent(source_lines[0][:col_offset])
end_suffix = source_lines[-1][end_col_offset:]
replacement = split_lines(self._resynthesize(context))
replacement.apply_indentation(
indentation, start_prefix=start_prefix, end_suffix=end_suffix
# Applies the block indentation only if the replacement lines are different from source
replacement.apply_source_formatting(
source_lines=source_lines,
markers=(0, col_offset, end_col_offset),
)

lines[view] = replacement
Expand Down Expand Up @@ -170,16 +171,16 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]):

def apply(self, context: Context, source: str) -> str:
lines = split_lines(source, encoding=context.file_info.get_encoding())
indentation, start_prefix = find_indent(
lines[self.node.lineno - 1][: self.node.col_offset]
)

replacement = split_lines(context.unparse(self.build()))
replacement.apply_indentation(indentation, start_prefix=start_prefix)
replacement.apply_source_formatting(
source_lines=lines,
markers=(self.node.lineno - 1, self.node.col_offset, None),
)

original_node_end = cast(int, self.node.end_lineno) - 1
if lines[original_node_end].endswith(lines._newline_type):
replacement[-1] += lines._newline_type
replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else ""
else:
# If the original anchor's last line doesn't end with a newline,
# then we need to also prevent our new source from ending with
Expand Down Expand Up @@ -212,13 +213,13 @@ class LazyInsertBefore(_LazyActionMixin[ast.stmt, ast.stmt]):

def apply(self, context: Context, source: str) -> str:
lines = split_lines(source, encoding=context.file_info.get_encoding())
indentation, start_prefix = find_indent(
lines[self.node.lineno - 1][: self.node.col_offset]
)

replacement = split_lines(context.unparse(self.build()))
replacement.apply_indentation(indentation, start_prefix=start_prefix)
replacement[-1] += lines._newline_type
replacement.apply_source_formatting(
source_lines=lines,
markers=(self.node.lineno - 1, self.node.col_offset, None),
)
replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else ""

original_node_start = cast(int, self.node.lineno)
for line in reversed(replacement):
Expand Down
99 changes: 81 additions & 18 deletions refactor/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
import io
import operator
import os
import re
import tokenize
from collections import UserList, UserString
from collections.abc import Generator, Iterator
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import cached_property
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast
from pprint import pprint
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple, Set, Dict, List, Callable

from refactor import common
from refactor.common import find_indent, extract_str_difference, find_indent_comments

DEFAULT_ENCODING = "utf-8"

Expand All @@ -27,29 +30,89 @@ class Lines(UserList[StringType]):
def __post_init__(self) -> None:
super().__init__(self.lines)
self.lines = self.data
self.best_matches: list[StringType] = []

def join(self) -> str:
"""Return the combined source code."""
return "".join(map(str, self.lines))

def apply_indentation(
self,
indentation: StringType,
*,
start_prefix: AnyStringType = "",
end_suffix: AnyStringType = "",
@staticmethod
def find_best_matching_source_line(line: str, source_lines: Lines, percentile: float = 10) -> Tuple[
str | None, str, str]:
"""Finds the best matching line in a list of lines
Returns the source line indentation and comments"""
_line: str | None = None
changes: List[Dict[str, Tuple[str, str, str] | Dict[str, str | float | Set[str]]]] = []
for _l in source_lines.lines:
_line: str = str(_l)
# Estimate the changes between the two lines
changes.append(extract_str_difference(_line, line, without_comments=True))
indentation, _, comments = find_indent_comments(_line)
changes[-1]['output'] = _line, indentation, "" if "#" in line else comments

sort_key = lambda x: x['a']['percent'] + x['b']['percent']
sorted_changes = sorted(changes, key=sort_key)
for i, change in enumerate(sorted_changes):
# There should be minimal changes to the original line - how to estimate that threshold?
if change['a']['percent'] < percentile:
return change['output']
return None, "", ""

def apply_source_formatting(
self,
source_lines: Lines,
*,
markers: Tuple[int, int, int | None] = None,
comments_separator: str = " "
) -> None:
"""Apply the given indentation, optionally with start and end prefixes
to the bound source lines."""
"""Apply the indentation from source_lines when the first several characters match

:param source_lines: Original lines in source code
:param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None)
:param comments_separator: Separator for comments
"""

def reconstruct_comments(line_to_update, new, original) -> str:
if new and not new.isspace():
return line_to_update + comments_separator + new
if original and not original.isspace():
original = re.sub(self._newline_type, '', original)
if line_to_update and line_to_update[-1] == self._newline_type:
return line_to_update[:-1] + comments_separator + original + line_to_update[-1]
return line_to_update + comments_separator + original
return line_to_update

block_indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]])

for index, line in enumerate(self.data):
if index == 0:
self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore
# Let's see if we find a matching original line using statistical change to original lines (default < 10%)
(original_line,
indentation,
comments) = self.find_best_matching_source_line(line, source_lines[markers[0]:])

if original_line is None:
# If there is no good match as original line, implement the block_separator
self.data[index] = block_indentation + str(line)
else:
self.data[index] = indentation + line # type: ignore
# Remove the line indentation, collect comments
_, line, new_comments = find_indent_comments(line)
line: str = reconstruct_comments(line, new_comments, comments)
self.data[index] = indentation + str(line)

if len(self.data) >= 1:
self.data[-1] += str(end_suffix) # type: ignore
# Edge cases: First line, override with the block indentation, keeping the comments
if index == 0:
self.data[index] = block_indentation + str(start_prefix) + str(line)

# For the last line:
# if no good match, append the end_suffix
# otherwise, add the newline if present in the original
# Note that this seems more appropriate here than in the actions.py
if index == len(self.data) - 1:
if original_line is None:
end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:]
self.data[index] = self.data[index] + str(end_suffix)
elif original_line[-1] == self._newline_type:
self.data[index] = self.data[index] + self._newline_type

@cached_property
def _newline_type(self) -> str:
Expand Down Expand Up @@ -77,7 +140,7 @@ def __getitem__(self, index: SupportsIndex | slice) -> SourceSegment:
# re-implements the direct indexing as slicing (e.g. a[1] is a[1:2], with
# error handling).
direct_index = operator.index(index)
view = raw_line[direct_index : direct_index + 1].decode(
view = raw_line[direct_index: direct_index + 1].decode(
encoding=self.encoding
)
if not view:
Expand Down Expand Up @@ -200,9 +263,9 @@ def maybe_retrieve(self, node: ast.AST) -> bool:
@contextmanager
def _collect_stmt_comments(self, node: ast.AST) -> Iterator[None]:
def _write_if_unseen_comment(
line_no: int,
line: str,
comment_begin: int,
line_no: int,
line: str,
comment_begin: int,
) -> None:
if line_no in self._visited_comment_lines:
# We have already written this comment as the
Expand Down
133 changes: 130 additions & 3 deletions refactor/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import ast
import copy
import difflib
import re
from collections import deque
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from functools import cache, singledispatch, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast, Set, List, Tuple, Dict, AnyStr

if TYPE_CHECKING:
from refactor.context import Context
Expand Down Expand Up @@ -76,7 +78,7 @@ def is_truthy(op: ast.cmpop) -> bool | None:


def _type_checker(
*types: type, binders: Iterable[Callable[[type], bool]] = ()
*types: type, binders: Iterable[Callable[[type], bool]] = ()
) -> Callable[[Any], bool]:
binders = [getattr(binder, "fast_checker", binder) for binder in binders]

Expand Down Expand Up @@ -134,7 +136,7 @@ def _get_known_location_from_source(source: str, location: PositionType) -> str
if start_line == end_line:
return lines[start_line][start_col:end_col]

start, *middle, end = lines[start_line : end_line + 1]
start, *middle, end = lines[start_line: end_line + 1]
new_lines = (start[start_col:], *middle, end[:end_col])
return "\n".join(new_lines)

Expand Down Expand Up @@ -176,6 +178,31 @@ def find_indent(source: str) -> tuple[str, str]:
return source[:index], source[index:]


def find_comments(source: str) -> tuple[str, str]:
"""Split the given line into the current indentation
and the remaining characters."""
index = 0
for index, char in enumerate(source, 1):
if char == "#":
break
return source[:index], source[index:]


def find_indent_comments(source: str) -> tuple[str, str, str]:
"""Split the given line into the current indentation
and the remaining characters."""
indent, comment = -1, -1
for index, char in enumerate(source, 1):
if not char.isspace() and indent == -1:
indent = index - 1
if char == "#":
comment = index - 1
break
if comment != -1:
return source[:indent], source[indent:comment].strip(), source[comment:]
return source[:indent], source[indent:], ""


def find_closest(node: ast.AST, *targets: ast.AST) -> ast.AST:
"""Find the closest node to the given ``node`` from the given
sequence of ``targets`` (uses absolute distance from starting points)."""
Expand All @@ -201,6 +228,106 @@ def extract_from_text(text: str) -> ast.AST:
return ast.parse(text).body[0]


def split_python_wise(x: str, seps: List[str] = " ()[]{}\"'"):
default_sep = seps[0]
for s in seps[1:]:
x = x.replace(s, default_sep + s + default_sep)
return [i.strip() for i in x.split(default_sep)]


def split_on_separators(string: str, separators: List[str] = "()[]{}'" + '"') -> List[str]:
pattern = "|".join([f"{re.escape(sep)}(?!{re.escape(sep)})" for sep in separators] + [" "])
result = [s + s if s in separators else s for s in re.split(pattern, string)]
separators_found = [s for s in separators if s in string]
return result + separators_found


def extract_str_difference(a: str,
b: str,
without_comments: bool = True,
ignore_leading_spaces: bool = True
) -> Dict[str, Dict[str, str | float | Set[str]]]:
"""Returns a set of "words" that are different between 2 strings"""
# Remove comments if requested
a = re.match(r'^([^#]*)', a).group(1) if without_comments else a
b = re.match(r'^([^#]*)', b).group(1) if without_comments else b

# Remove leading white spaces if requested
a = m.group(1) if ignore_leading_spaces and (m := re.match(r'^\s*?([\S\n].*)', a)) else a
b = m.group(1) if ignore_leading_spaces and (m := re.match(r'^\s*?([\S\n].*)', b)) else b

differences: Dict[str, Dict[str, str | float | Set[str]]] = {
"a": {"changes": set(), "percent": 0.0},
"b": {"changes": set(), "percent": 0.0}}

raw_diff: Set[str] = set(split_on_separators(a)).symmetric_difference(set(split_on_separators(b)))
for item in raw_diff:
if item in a and item not in b:
differences['a']['changes'].add(item)
differences['a']['percent'] = differences['a']['percent'] + len(item)
else:
differences['b']['changes'].add(item)
differences['b']['percent'] = differences['b']['percent'] + len(item)

differences['a']['percent'] = differences['a']['percent'] / (len(a.split()) + 1) * 100
differences['b']['percent'] = differences['b']['percent'] / (len(b.split()) + 1) * 100
return differences


def extract_string_differences(a: str,
b: str,
without_comments: bool = True,
ignore_leading_spaces: bool = True,
ignore_spaces: bool = False) -> Dict[str, Dict[str, str | float]]:
"""Calculate the difference between two strings. Optionally removes that comments.

:param str a: The first string to compare.
:param str b: The second string to compare.
:param bool with_comments: Optional removal of comments from the extraction.
:param bool ignore_spaces: Optional inclusion of space counting.
:returns: The differences and percentiles between the two strings in a dictionary.
:rtype: Dict
"""
# Remove comments if requested
a = re.match(r'^([^#]*)', a).group(1) if without_comments else a
b = re.match(r'^([^#]*)', b).group(1) if without_comments else b

# Remove leading white spaces if requested
a = re.match(r'^\s*?([\S].*)$', a).group(1) if ignore_leading_spaces else a
b = re.match(r'^\s*?([\S].*)$', b).group(1) if ignore_leading_spaces else b

# Remove white spaces if requested
a = "".join(a.split() if ignore_spaces else list(a))
b = "".join(b.split() if ignore_spaces else list(b))

# Initialize the SequenceMatcher
matcher = difflib.SequenceMatcher(a=a, b=b)

# Store the differences between the two strings
differences = {
"a": {"changes": "", "percent": 0.0},
"b": {"changes": "", "percent": 0.0},
"common": {"changes": "", "percent": 0.0},
}

# Iterate over the opcodes and update the difference counter
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "replace":
differences["a"]["changes"] += a[i1:i2]
differences["b"]["changes"] += b[j1:j2]
elif tag == "delete":
differences["a"]["changes"] += a[i1:i2]
elif tag == "insert":
differences["b"]["changes"] += b[j1:j2]
elif tag == "equal":
differences["common"]["changes"] += a[i1:i2]

for key, value in differences.items():
value["percent"] = len(value["changes"]) / len(a) * 100

return differences


_POSITIONAL_ATTRIBUTES = (
"lineno",
"col_offset",
Expand Down
Loading