Skip to content

(fix) Proposing a workaround for superfluous indentation on multi-lines #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
16 changes: 8 additions & 8 deletions refactor/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def apply(self, context: Context, source: str) -> str:
view = slice(lineno - 1, end_lineno)
source_lines = lines[view]

indentation, start_prefix = find_indent(source_lines[0][:col_offset])
end_suffix = source_lines[-1][end_col_offset:]
replacement = split_lines(self._resynthesize(context))
replacement.apply_indentation(
indentation, start_prefix=start_prefix, end_suffix=end_suffix
# Applies the block indentation only if the replacement lines are different from source
replacement.apply_source_formatting(
source_lines=source_lines,
markers=(0, col_offset, end_col_offset),
)

lines[view] = replacement
Expand Down Expand Up @@ -168,12 +168,12 @@ class LazyInsertAfter(_LazyActionMixin[ast.stmt, ast.stmt]):

def apply(self, context: Context, source: str) -> str:
lines = split_lines(source, encoding=context.file_info.get_encoding())
indentation, start_prefix = find_indent(
lines[self.node.lineno - 1][: self.node.col_offset]
)

replacement = split_lines(context.unparse(self.build()))
replacement.apply_indentation(indentation, start_prefix=start_prefix)
replacement.apply_source_formatting(
source_lines=lines,
markers=(self.node.lineno - 1, self.node.col_offset, None),
)

original_node_end = cast(int, self.node.end_lineno) - 1
if lines[original_node_end].endswith(lines._newline_type):
Expand Down
31 changes: 23 additions & 8 deletions refactor/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import cached_property
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast
from typing import Any, ContextManager, Protocol, SupportsIndex, TypeVar, Union, cast, Tuple

from refactor import common
from refactor.common import find_indent

DEFAULT_ENCODING = "utf-8"

Expand All @@ -32,19 +33,33 @@ def join(self) -> str:
"""Return the combined source code."""
return "".join(map(str, self.lines))

def apply_indentation(
def apply_source_formatting(
self,
indentation: StringType,
source_lines: Lines,
*,
start_prefix: AnyStringType = "",
end_suffix: AnyStringType = "",
markers: Tuple[int, int, int | None] = None,
) -> None:
"""Apply the given indentation, optionally with start and end prefixes
to the bound source lines."""
"""Apply the indentation from source_lines when the first several characters match

:param source_lines: Original lines in source code
:param markers: Indentation and prefix parameters. Tuple of (start line, col_offset, end_suffix | None)
"""

indentation, start_prefix = find_indent(source_lines[markers[0]][:markers[1]])
end_suffix = "" if markers[2] is None else source_lines[-1][markers[2]:]

original_line: str | None
for index, line in enumerate(self.data):
if index < len(source_lines):
original_line = source_lines[index]
else:
original_line = None

if index == 0:
self.data[index] = indentation + str(start_prefix) + str(line) # type: ignore
# The updated line can have an extra wrapping in brackets
elif original_line is not None and original_line.startswith(line[:-1]):
self.data[index] = line # type: ignore
else:
self.data[index] = indentation + line # type: ignore

Expand Down Expand Up @@ -77,7 +92,7 @@ def __getitem__(self, index: SupportsIndex | slice) -> SourceSegment:
# re-implements the direct indexing as slicing (e.g. a[1] is a[1:2], with
# error handling).
direct_index = operator.index(index)
view = raw_line[direct_index : direct_index + 1].decode(
view = raw_line[direct_index: direct_index + 1].decode(
encoding=self.encoding
)
if not view:
Expand Down
Loading