diff --git a/refactor/actions.py b/refactor/actions.py index 173a0e4..8e4135d 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -4,6 +4,7 @@ import warnings from contextlib import suppress from dataclasses import dataclass, field, replace +from pprint import pprint from typing import Generic, TypeVar, cast from refactor.ast import split_lines @@ -91,11 +92,11 @@ def apply(self, context: Context, source: str) -> str: view = slice(lineno - 1, end_lineno) source_lines = lines[view] - indentation, start_prefix = find_indent(source_lines[0][:col_offset]) - end_suffix = source_lines[-1][end_col_offset:] replacement = split_lines(self._resynthesize(context)) - replacement.apply_indentation( - indentation, start_prefix=start_prefix, end_suffix=end_suffix + # Applies the block indentation only if the replacement lines are different from source + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), ) lines[view] = replacement @@ -170,16 +171,16 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]): def apply(self, context: Context, source: str) -> str: lines = split_lines(source, encoding=context.file_info.get_encoding()) - indentation, start_prefix = find_indent( - lines[self.node.lineno - 1][: self.node.col_offset] - ) replacement = split_lines(context.unparse(self.build())) - replacement.apply_indentation(indentation, start_prefix=start_prefix) + replacement.apply_source_formatting( + source_lines=lines, + markers=(self.node.lineno - 1, self.node.col_offset, None), + ) original_node_end = cast(int, self.node.end_lineno) - 1 if lines[original_node_end].endswith(lines._newline_type): - replacement[-1] += lines._newline_type + replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else "" else: # If the original anchor's last line doesn't end with a newline, # then we need to also prevent our new source from ending with @@ -212,13 +213,13 @@ class LazyInsertBefore(_LazyActionMixin[ast.stmt, ast.stmt]): def apply(self, context: Context, source: str) -> str: lines = split_lines(source, encoding=context.file_info.get_encoding()) - indentation, start_prefix = find_indent( - lines[self.node.lineno - 1][: self.node.col_offset] - ) replacement = split_lines(context.unparse(self.build())) - replacement.apply_indentation(indentation, start_prefix=start_prefix) - replacement[-1] += lines._newline_type + replacement.apply_source_formatting( + source_lines=lines, + markers=(self.node.lineno - 1, self.node.col_offset, None), + ) + replacement[-1] += lines._newline_type if not replacement[-1].endswith(lines._newline_type) else "" original_node_start = cast(int, self.node.lineno) for line in reversed(replacement): diff --git a/refactor/ast.py b/refactor/ast.py index 3d23262..345a972 100644 --- a/refactor/ast.py +++ b/refactor/ast.py @@ -4,15 +4,18 @@ import io import operator import os +import re import tokenize from collections import UserList, UserString from collections.abc import Generator, Iterator from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import cached_property -from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast +from pprint import pprint +from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple, Set, Dict, List, Callable from refactor import common +from refactor.common import find_indent, extract_str_difference, find_indent_comments DEFAULT_ENCODING = "utf-8" @@ -27,29 +30,89 @@ class Lines(UserList[StringType]): def __post_init__(self) -> None: super().__init__(self.lines) self.lines = self.data + self.best_matches: list[StringType] = [] def join(self) -> str: """Return the combined source code.""" return "".join(map(str, self.lines)) - def apply_indentation( - self, - indentation: StringType, - *, - start_prefix: AnyStringType = "", - end_suffix: AnyStringType = "", + @staticmethod + def find_best_matching_source_line(line: str, source_lines: Lines, percentile: float = 10) -> Tuple[ + str | None, str, str]: + """Finds the best matching line in a list of lines + Returns the source line indentation and comments""" + _line: str | None = None + changes: List[Dict[str, Tuple[str, str, str] | Dict[str, str | float | Set[str]]]] = [] + for _l in source_lines.lines: + _line: str = str(_l) + # Estimate the changes between the two lines + changes.append(extract_str_difference(_line, line, without_comments=True)) + indentation, _, comments = find_indent_comments(_line) + changes[-1]['output'] = _line, indentation, "" if "#" in line else comments + + sort_key = lambda x: x['a']['percent'] + x['b']['percent'] + sorted_changes = sorted(changes, key=sort_key) + for i, change in enumerate(sorted_changes): + # There should be minimal changes to the original line - how to estimate that threshold? + if change['a']['percent'] < percentile: + return change['output'] + return None, "", "" + + def apply_source_formatting( + self, + source_lines: Lines, + *, + markers: Tuple[int, int, int | None] = None, + comments_separator: str = " " ) -> None: - """Apply the given indentation, optionally with start and end prefixes - to the bound source lines.""" + """Apply the indentation from source_lines when the first several characters match + + :param source_lines: Original lines in source code + :param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None) + :param comments_separator: Separator for comments + """ + + def reconstruct_comments(line_to_update, new, original) -> str: + if new and not new.isspace(): + return line_to_update + comments_separator + new + if original and not original.isspace(): + original = re.sub(self._newline_type, '', original) + if line_to_update and line_to_update[-1] == self._newline_type: + return line_to_update[:-1] + comments_separator + original + line_to_update[-1] + return line_to_update + comments_separator + original + return line_to_update + + block_indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]]) for index, line in enumerate(self.data): - if index == 0: - self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore + # Let's see if we find a matching original line using statistical change to original lines (default < 10%) + (original_line, + indentation, + comments) = self.find_best_matching_source_line(line, source_lines[markers[0]:]) + + if original_line is None: + # If there is no good match as original line, implement the block_separator + self.data[index] = block_indentation + str(line) else: - self.data[index] = indentation + line # type: ignore + # Remove the line indentation, collect comments + _, line, new_comments = find_indent_comments(line) + line: str = reconstruct_comments(line, new_comments, comments) + self.data[index] = indentation + str(line) - if len(self.data) >= 1: - self.data[-1] += str(end_suffix) # type: ignore + # Edge cases: First line, override with the block indentation, keeping the comments + if index == 0: + self.data[index] = block_indentation + str(start_prefix) + str(line) + + # For the last line: + # if no good match, append the end_suffix + # otherwise, add the newline if present in the original + # Note that this seems more appropriate here than in the actions.py + if index == len(self.data) - 1: + if original_line is None: + end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:] + self.data[index] = self.data[index] + str(end_suffix) + elif original_line[-1] == self._newline_type: + self.data[index] = self.data[index] + self._newline_type @cached_property def _newline_type(self) -> str: @@ -77,7 +140,7 @@ def __getitem__(self, index: SupportsIndex | slice) -> SourceSegment: # re-implements the direct indexing as slicing (e.g. a[1] is a[1:2], with # error handling). direct_index = operator.index(index) - view = raw_line[direct_index : direct_index + 1].decode( + view = raw_line[direct_index: direct_index + 1].decode( encoding=self.encoding ) if not view: @@ -200,9 +263,9 @@ def maybe_retrieve(self, node: ast.AST) -> bool: @contextmanager def _collect_stmt_comments(self, node: ast.AST) -> Iterator[None]: def _write_if_unseen_comment( - line_no: int, - line: str, - comment_begin: int, + line_no: int, + line: str, + comment_begin: int, ) -> None: if line_no in self._visited_comment_lines: # We have already written this comment as the diff --git a/refactor/common.py b/refactor/common.py index 68dbfaf..5a2d633 100644 --- a/refactor/common.py +++ b/refactor/common.py @@ -2,12 +2,14 @@ import ast import copy +import difflib +import re from collections import deque from collections.abc import Iterable, Iterator from dataclasses import dataclass from functools import cache, singledispatch, wraps from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast, Set, List, Tuple, Dict, AnyStr if TYPE_CHECKING: from refactor.context import Context @@ -76,7 +78,7 @@ def is_truthy(op: ast.cmpop) -> bool | None: def _type_checker( - *types: type, binders: Iterable[Callable[[type], bool]] = () + *types: type, binders: Iterable[Callable[[type], bool]] = () ) -> Callable[[Any], bool]: binders = [getattr(binder, "fast_checker", binder) for binder in binders] @@ -134,7 +136,7 @@ def _get_known_location_from_source(source: str, location: PositionType) -> str if start_line == end_line: return lines[start_line][start_col:end_col] - start, *middle, end = lines[start_line : end_line + 1] + start, *middle, end = lines[start_line: end_line + 1] new_lines = (start[start_col:], *middle, end[:end_col]) return "\n".join(new_lines) @@ -176,6 +178,31 @@ def find_indent(source: str) -> tuple[str, str]: return source[:index], source[index:] +def find_comments(source: str) -> tuple[str, str]: + """Split the given line into the current indentation + and the remaining characters.""" + index = 0 + for index, char in enumerate(source, 1): + if char == "#": + break + return source[:index], source[index:] + + +def find_indent_comments(source: str) -> tuple[str, str, str]: + """Split the given line into the current indentation + and the remaining characters.""" + indent, comment = -1, -1 + for index, char in enumerate(source, 1): + if not char.isspace() and indent == -1: + indent = index - 1 + if char == "#": + comment = index - 1 + break + if comment != -1: + return source[:indent], source[indent:comment].strip(), source[comment:] + return source[:indent], source[indent:], "" + + def find_closest(node: ast.AST, *targets: ast.AST) -> ast.AST: """Find the closest node to the given ``node`` from the given sequence of ``targets`` (uses absolute distance from starting points).""" @@ -201,6 +228,106 @@ def extract_from_text(text: str) -> ast.AST: return ast.parse(text).body[0] +def split_python_wise(x: str, seps: List[str] = " ()[]{}\"'"): + default_sep = seps[0] + for s in seps[1:]: + x = x.replace(s, default_sep + s + default_sep) + return [i.strip() for i in x.split(default_sep)] + + +def split_on_separators(string: str, separators: List[str] = "()[]{}'" + '"') -> List[str]: + pattern = "|".join([f"{re.escape(sep)}(?!{re.escape(sep)})" for sep in separators] + [" "]) + result = [s + s if s in separators else s for s in re.split(pattern, string)] + separators_found = [s for s in separators if s in string] + return result + separators_found + + +def extract_str_difference(a: str, + b: str, + without_comments: bool = True, + ignore_leading_spaces: bool = True + ) -> Dict[str, Dict[str, str | float | Set[str]]]: + """Returns a set of "words" that are different between 2 strings""" + # Remove comments if requested + a = re.match(r'^([^#]*)', a).group(1) if without_comments else a + b = re.match(r'^([^#]*)', b).group(1) if without_comments else b + + # Remove leading white spaces if requested + a = m.group(1) if ignore_leading_spaces and (m := re.match(r'^\s*?([\S\n].*)', a)) else a + b = m.group(1) if ignore_leading_spaces and (m := re.match(r'^\s*?([\S\n].*)', b)) else b + + differences: Dict[str, Dict[str, str | float | Set[str]]] = { + "a": {"changes": set(), "percent": 0.0}, + "b": {"changes": set(), "percent": 0.0}} + + raw_diff: Set[str] = set(split_on_separators(a)).symmetric_difference(set(split_on_separators(b))) + for item in raw_diff: + if item in a and item not in b: + differences['a']['changes'].add(item) + differences['a']['percent'] = differences['a']['percent'] + len(item) + else: + differences['b']['changes'].add(item) + differences['b']['percent'] = differences['b']['percent'] + len(item) + + differences['a']['percent'] = differences['a']['percent'] / (len(a.split()) + 1) * 100 + differences['b']['percent'] = differences['b']['percent'] / (len(b.split()) + 1) * 100 + return differences + + +def extract_string_differences(a: str, + b: str, + without_comments: bool = True, + ignore_leading_spaces: bool = True, + ignore_spaces: bool = False) -> Dict[str, Dict[str, str | float]]: + """Calculate the difference between two strings. Optionally removes that comments. + + :param str a: The first string to compare. + :param str b: The second string to compare. + :param bool with_comments: Optional removal of comments from the extraction. + :param bool ignore_spaces: Optional inclusion of space counting. + :returns: The differences and percentiles between the two strings in a dictionary. + :rtype: Dict + """ + # Remove comments if requested + a = re.match(r'^([^#]*)', a).group(1) if without_comments else a + b = re.match(r'^([^#]*)', b).group(1) if without_comments else b + + # Remove leading white spaces if requested + a = re.match(r'^\s*?([\S].*)$', a).group(1) if ignore_leading_spaces else a + b = re.match(r'^\s*?([\S].*)$', b).group(1) if ignore_leading_spaces else b + + # Remove white spaces if requested + a = "".join(a.split() if ignore_spaces else list(a)) + b = "".join(b.split() if ignore_spaces else list(b)) + + # Initialize the SequenceMatcher + matcher = difflib.SequenceMatcher(a=a, b=b) + + # Store the differences between the two strings + differences = { + "a": {"changes": "", "percent": 0.0}, + "b": {"changes": "", "percent": 0.0}, + "common": {"changes": "", "percent": 0.0}, + } + + # Iterate over the opcodes and update the difference counter + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "replace": + differences["a"]["changes"] += a[i1:i2] + differences["b"]["changes"] += b[j1:j2] + elif tag == "delete": + differences["a"]["changes"] += a[i1:i2] + elif tag == "insert": + differences["b"]["changes"] += b[j1:j2] + elif tag == "equal": + differences["common"]["changes"] += a[i1:i2] + + for key, value in differences.items(): + value["percent"] = len(value["changes"]) / len(a) * 100 + + return differences + + _POSITIONAL_ATTRIBUTES = ( "lineno", "col_offset", diff --git a/tests/test_ast.py b/tests/test_ast.py index 1e5dc1d..ddde822 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -6,8 +6,9 @@ import pytest -from refactor import common +from refactor import common, Context from refactor.ast import BaseUnparser, PreciseUnparser, split_lines +from refactor.common import position_for, clone def test_split_lines(): @@ -169,6 +170,7 @@ def test_precise_unparser_indented_literals(): """\ def func(): if something: + # On change, comments are removed print( "bleh" "zoom" @@ -240,3 +242,374 @@ def foo(): base = PreciseUnparser(source=source) assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_no_changes(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + tree = ast.parse(source) + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_precise_unparser_custom_indent_del(): + source = """def func(): + if something: + # Arguments have custom indentation + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + tree = ast.parse(source) + del tree.body[0].body[0].body[0].value.args[2] + + base = PreciseUnparser(source=source) + assert base.unparse(tree) + "\n" == expected_src + + +def test_apply_source_formatting_maintains_with_await_0(): + source = """def func(): + if something: + # Comments are retrieved + print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print( + call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_await_1(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), # Comments are retrieved. separation can be updated to x spaces (1 default) + maybe+something_else_that_is_very_very_very_long, + maybe / other,# Comments are retrieved. separation can be updated to x spaces (1 default) + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print(call(.1), # Comments are retrieved. separation can be updated to x spaces (1 default) + maybe+something_else_that_is_very_very_very_long, + maybe / other, # Comments are retrieved. separation can be updated to x spaces (1 default) + thing . a + ) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + call_instead(print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + )) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_call_on_closing_parens(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), # Comments are retrieved, spacing of x spaces (1 default) + maybe+something_else_that_is_very_very_very_long, + maybe / other, # Comments are retrieved, spacing of x spaces (1 default) + thing . a + ) # Non-standard indent is conserved and comments, spacing of x spaces (1 default) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + call_instead(print(call(.1), # Comments are retrieved, spacing of x spaces (1 default) + maybe+something_else_that_is_very_very_very_long, + maybe / other, # Comments are retrieved, spacing of x spaces (1 default) + thing . a + )) # Non-standard indent is conserved and comments, spacing of x spaces (1 default) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="call_instead"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async_complex(): + source = """def func(): + if something: + # Comments are retrieved + with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) as p: + do_something() +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_async(): + source = """def func(): + if something: + # Comments are retrieved + with something: # comment2, spacing of x spaces (1 default) + a = 1 # Non-standard indent is conserved, spacing of x spaces (1 default) + b = 2 # Non-standard indent is conserved, spacing of x spaces (1 default) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + async with something: # comment2, spacing of x spaces (1 default) + a = 1 # Non-standard indent is conserved, spacing of x spaces (1 default) + b = 2 # Non-standard indent is conserved, spacing of x spaces (1 default) +""" + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0] + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = clone(node_to_replace) + new_node.__class__ = ast.AsyncWith + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_maintains_with_fstring(): + source = '''def f(): + return """ +a +""" +''' + + expected_src = '''def f(): + return F(""" +a +""") +''' + + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + new_node = ast.Call(func=ast.Name(id="F"), args=[node_to_replace], keywords=[]) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src + + +def test_apply_source_formatting_does_not_with_change(): + source = """def func(): + if something: + # Comments are retrieved + print(call(.1), + maybe+something_else_that_is_very_very_very_long, + maybe / other, + thing . a + ) +""" + + expected_src = """def func(): + if something: + # Comments are retrieved + await print(call(.1), maybe+something_else_that_is_very_very_very_long, thing . a) +""" + + source_tree = ast.parse(source) + context = Context(source, source_tree) + + node_to_replace = source_tree.body[0].body[0].body[0].value + + (lineno, col_offset, end_lineno, end_col_offset) = position_for(node_to_replace) + view = slice(lineno - 1, end_lineno) + + lines = split_lines(source) + source_lines = lines[view] + + del node_to_replace.args[2] + new_node = ast.Await(node_to_replace) + replacement = split_lines(context.unparse(new_node)) + replacement.apply_source_formatting( + source_lines=source_lines, + markers=(0, col_offset, end_col_offset), + ) + lines[view] = replacement + assert lines.join() == expected_src diff --git a/tests/test_core.py b/tests/test_core.py index e28f540..fcb61cf 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -50,14 +50,14 @@ def test_apply_simple(source, expected, target_func, replacement): [ ( """ - import x # comments + import x # comments will be copied since identical python statement print(x.y) # comments here def something(x, y): return x + y # comments """, """ - import x # comments - import x + import x # comments will be copied since identical python statement + import x # comments will be copied since identical python statement print(x.y) # comments here def something(x, y): return x + y # comments