diff --git a/refactor/actions.py b/refactor/actions.py index 13e8d76..756577e 100644 --- a/refactor/actions.py +++ b/refactor/actions.py @@ -78,13 +78,17 @@ def _replace_input(self, node: ast.AST) -> _LazyActionMixin[K, T]: class _ReplaceCodeSegmentAction(BaseAction): def apply(self, context: Context, source: str) -> str: + # The decorators are removed in the 'lines' but present in the 'context` + # This lead to the 'replacement' containing the decorators and the returned + # 'lines' to duplicate them. Proposed workaround is to add the decorators in + # the 'view', in case the '_resynthesize()' adds/modifies them lines = split_lines(source, encoding=context.file_info.get_encoding()) ( lineno, col_offset, end_lineno, end_col_offset, - ) = self._get_segment_span(context) + ) = self._get_decorated_segment_span(context) view = slice(lineno - 1, end_lineno) source_lines = lines[view] @@ -102,6 +106,9 @@ def apply(self, context: Context, source: str) -> str: def _get_segment_span(self, context: Context) -> PositionType: raise NotImplementedError + def _get_decorated_segment_span(self, context: Context) -> PositionType: + raise NotImplementedError + def _resynthesize(self, context: Context) -> str: raise NotImplementedError @@ -121,6 +128,13 @@ class LazyReplace(_ReplaceCodeSegmentAction, _LazyActionMixin[ast.AST, ast.AST]) def _get_segment_span(self, context: Context) -> PositionType: return position_for(self.node) + def _get_decorated_segment_span(self, context: Context) -> PositionType: + lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) + # Add the decorators to the segment span to resolve an issue with def -> async def + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) + return lineno, col_offset, end_lineno, end_col_offset + def _resynthesize(self, context: Context) -> str: return context.unparse(self.build()) @@ -228,6 +242,9 @@ class _Rename(Replace): def _get_segment_span(self, context: Context) -> PositionType: return self.identifier_span + def _get_decorated_segment_span(self, context: Context) -> PositionType: + return self.identifier_span + def _resynthesize(self, context: Context) -> str: return self.target.name @@ -260,6 +277,13 @@ def is_critical_node(self, context: Context) -> bool: def _get_segment_span(self, context: Context) -> PositionType: return position_for(self.node) + def _get_decorated_segment_span(self, context: Context) -> PositionType: + lineno, col_offset, end_lineno, end_col_offset = position_for(self.node) + # Add the decorators to the segment span to resolve an issue with def -> async def + if hasattr(self.node, "decorator_list") and len(getattr(self.node, "decorator_list")) > 0: + lineno, _, _, _ = position_for(getattr(self.node, "decorator_list")[0]) + return lineno, col_offset, end_lineno, end_col_offset + def _resynthesize(self, context: Context) -> str: if self.is_critical_node(context): raise InvalidActionError( diff --git a/tests/test_common.py b/tests/test_common.py index 5ff4b9d..49b5085 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -137,6 +137,24 @@ def func(): assert position_for(right_node) == (3, 23, 3, 25) +def test_get_positions_with_decorator(): + source = textwrap.dedent( + """\ + @deco0 + @deco1(arg0, + arg1) + def func(): + if a > 5: + return 5 + 3 + 25 + elif b > 10: + return 1 + 3 + 5 + 7 + """ + ) + tree = ast.parse(source) + right_node = tree.body[0].body[0].body[0].value.right + assert position_for(right_node) == (6, 23, 6, 25) + + def test_singleton(): from dataclasses import dataclass diff --git a/tests/test_complete_rules.py b/tests/test_complete_rules.py index b46c14e..c8a9296 100644 --- a/tests/test_complete_rules.py +++ b/tests/test_complete_rules.py @@ -296,6 +296,107 @@ def match(self, node): return AsyncifierAction(node) +class MakeFunctionAsyncWithDecorators(Rule): + INPUT_SOURCE = """ + @deco0 + @deco1(arg0, + arg1) + def something(): + a += .1 + '''you know + this is custom + literal + ''' + print(we, + preserve, + everything + ) + return ( + right + "?") + """ + + EXPECTED_SOURCE = """ + @deco0 + @deco1(arg0, + arg1) + async def something(): + a += .1 + '''you know + this is custom + literal + ''' + print(we, + preserve, + everything + ) + return ( + right + "?") + """ + + def match(self, node): + assert isinstance(node, ast.FunctionDef) + return AsyncifierAction(node) + + +class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): + context_providers = (context.Scope,) + + INPUT_SOURCE = """ + class Klass: + def method(self, *, a): + print() + + lambda self, *, a: print + + """ + + EXPECTED_SOURCE = """ + class Klass: + def method(self, *, a=None): + print() + + lambda self, *, a=None: print + + """ + + def match(self, node: ast.AST) -> BaseAction | None: + assert isinstance(node, (ast.FunctionDef, ast.Lambda)) + assert any(kw_default is None for kw_default in node.args.kw_defaults) + + if isinstance(node, ast.Lambda) and not ( + isinstance(node.body, ast.Name) and isinstance(node.body.ctx, ast.Load) + ): + scope = self.context["scope"].resolve(node.body) + scope.definitions.get(node.body.id, []) + + elif isinstance(node, ast.FunctionDef): + for stmt in node.body: + for identifier in ast.walk(stmt): + if not ( + isinstance(identifier, ast.Name) + and isinstance(identifier.ctx, ast.Load) + ): + continue + + scope = self.context["scope"].resolve(identifier) + while not scope.definitions.get(identifier.id, []): + scope = scope.parent + if scope is None: + break + + kw_defaults = [] + for kw_default in node.args.kw_defaults: + if kw_default is None: + kw_defaults.append(ast.Constant(value=None)) + else: + kw_defaults.append(kw_default) + + target = deepcopy(node) + target.args.kw_defaults = kw_defaults + + return Replace(node, target) + + class OnlyKeywordArgumentDefaultNotSetCheckRule(Rule): context_providers = (context.Scope,) @@ -944,6 +1045,7 @@ def match(self, node: ast.AST) -> Iterator[Replace | InsertAfter]: @pytest.mark.parametrize( "rule", [ + MakeFunctionAsyncWithDecorators, ReplaceNexts, ReplacePlaceholders, PropagateConstants,