Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
687 changes: 677 additions & 10 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pybetter/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""App's entry point."""

import sys

from pybetter.cli import main
Expand Down
69 changes: 38 additions & 31 deletions pybetter/cli.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
"""Command line interface definition."""

import sys
import time
from typing import List, FrozenSet, Tuple, Type, Iterable
from typing import FrozenSet, Iterable, List, Tuple, Type

import libcst as cst
import click
import libcst as cst
from pyemojify import emojify

from pybetter.improvements import (
FixNotInConditionOrder,
BaseImprovement,
FixBooleanEqualityChecks,
FixEqualsNone,
FixMissingAllAttribute,
FixMutableDefaultArgs,
FixNotInConditionOrder,
FixParenthesesInReturn,
FixMissingAllAttribute,
FixEqualsNone,
FixBooleanEqualityChecks,
FixTrivialFmtStringCreation,
FixTrivialNestedWiths,
FixUnhashableList,
)
from pybetter.utils import resolve_paths, create_diff, prettify_time_interval
from pybetter.utils import create_diff, prettify_time_interval, resolve_paths

ALL_IMPROVEMENTS = (
FixNotInConditionOrder,
Expand All @@ -32,31 +34,41 @@
FixUnhashableList,
)

ALL_CODES = frozenset(improvement.CODE for improvement in ALL_IMPROVEMENTS)


def filter_improvements_by_code(code_list: str) -> FrozenSet[str]:
all_codes = frozenset([improvement.CODE for improvement in ALL_IMPROVEMENTS])
codes = frozenset([code.strip() for code in code_list.split(",")]) - {""}
codes = frozenset(
code.strip() for code in code_list.split(",")
) - {""}

if not codes:
return frozenset()

wrong_codes = codes.difference(all_codes)
wrong_codes = ','.join(codes - ALL_CODES)
if wrong_codes:
print(
emojify(
f":no_entry_sign: Unknown improvements selected: {','.join(wrong_codes)}"
)
)
print(emojify(
f":no_entry_sign: Unknown improvements selected: {wrong_codes}",
))
return frozenset()

return codes


def process_file(
source: str, improvements: Iterable[Type[BaseImprovement]]
source: str, improvements: Iterable[Type[BaseImprovement]],
) -> Tuple[str, List[BaseImprovement]]:
tree: cst.Module = cst.parse_module(source)
modified_tree: cst.Module = tree
"""
Apply some improvements to the source file.

Arguments:
source: Python source
improvements: list of improvements to apply

Returns:
A tuple of processed code and list of applied improvements.
"""
modified_tree: cst.Module = cst.parse_module(source)
improvements_applied = []

for case_cls in improvements:
Expand All @@ -70,12 +82,7 @@ def process_file(
return modified_tree.code, improvements_applied


@click.group()
def cli():
pass


@cli.command()
@click.command()
@click.option(
"--noop",
is_flag=True,
Expand Down Expand Up @@ -105,18 +112,18 @@ def cli():
)
@click.argument("paths", type=click.Path(), nargs=-1)
def main(paths, noop: bool, show_diff: bool, selected: str, excluded: str):
"""Make your code better."""
if not paths:
print(emojify("Nothing to do. :sleeping:"))
return

selected_improvements = list(ALL_IMPROVEMENTS)

if selected and excluded:
print(
emojify(
":no_entry_sign: '--select' and '--exclude' options are mutually exclusive!"
)
)
print(emojify(
":no_entry_sign: '--select' and '--exclude' options"
" are mutually exclusive!",
))
return

if selected:
Expand Down Expand Up @@ -149,7 +156,7 @@ def main(paths, noop: bool, show_diff: bool, selected: str, excluded: str):

start_ts = time.process_time()
processed_source, applied = process_file(
original_source, selected_improvements
original_source, selected_improvements,
)
end_ts = time.process_time()

Expand Down Expand Up @@ -189,4 +196,4 @@ def main(paths, noop: bool, show_diff: bool, selected: str, excluded: str):
print(emojify(f":sparkles: All done! :sparkles: :clock2: {time_taken}"))


__all__ = ["main", "process_file"]
__all__ = ("main", "process_file")
31 changes: 25 additions & 6 deletions pybetter/improvements.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,49 @@
from typing_extensions import Type

from pybetter.transformers.all_attribute import AllAttributeTransformer
from pybetter.transformers.base import NoqaDetectionVisitor, NoqaAwareTransformer
from pybetter.transformers.boolean_equality import BooleanLiteralEqualityTransformer
from pybetter.transformers.base import (
NoqaAwareTransformer,
NoqaDetectionVisitor,
)
from pybetter.transformers.boolean_equality import (
BooleanLiteralEqualityTransformer,
)
from pybetter.transformers.empty_fstring import TrivialFmtStringTransformer
from pybetter.transformers.equals_none import EqualsNoneIsNoneTransformer
from pybetter.transformers.mutable_args import ArgEmptyInitTransformer
from pybetter.transformers.nested_withs import NestedWithTransformer
from pybetter.transformers.not_in import NotInConditionTransformer
from pybetter.transformers.parenthesized_return import RemoveParenthesesFromReturn
from pybetter.transformers.parenthesized_return import (
RemoveParenthesesFromReturn,
)
from pybetter.transformers.unhashable_list import UnhashableListTransformer


class BaseImprovement(ABC):
"""Base class for improvements."""

CODE: str
NAME: str
DESCRIPTION: str
TRANSFORMER: Type[NoqaAwareTransformer]

def improve(self, tree: cst.Module):
"""Apply improvement to the syntax tree.

Arguments:
tree: syntax tree

Returns:
None
"""
noqa_detector = NoqaDetectionVisitor()
wrapper = MetadataWrapper(tree)

with noqa_detector.resolve(wrapper):
wrapper.visit(noqa_detector)
transformer = self.TRANSFORMER(self.CODE, noqa_detector.get_noqa_lines())
transformer = self.TRANSFORMER(
self.CODE, noqa_detector.get_noqa_lines(),
)
return wrapper.visit(transformer)


Expand Down Expand Up @@ -95,7 +114,7 @@ class FixUnhashableList(BaseImprovement):
TRANSFORMER = UnhashableListTransformer


__all__ = [
__all__ = (
"BaseImprovement",
"FixBooleanEqualityChecks",
"FixEqualsNone",
Expand All @@ -106,4 +125,4 @@ class FixUnhashableList(BaseImprovement):
"FixTrivialFmtStringCreation",
"FixTrivialNestedWiths",
"FixUnhashableList",
]
)
21 changes: 11 additions & 10 deletions pybetter/transformers/all_attribute.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Union

import libcst as cst
from libcst.metadata import ScopeProvider, GlobalScope
import libcst.matchers as m
import libcst.matchers as m # noqa: WPS111,WPS301,WPS347
from libcst.metadata import GlobalScope, ScopeProvider

from pybetter.transformers.base import NoqaAwareTransformer

Expand All @@ -17,7 +17,9 @@ def __init__(self, *args, **kwargs):

def process_node(
self,
node: Union[cst.FunctionDef, cst.ClassDef, cst.BaseAssignTargetExpression],
node: Union[
cst.FunctionDef, cst.ClassDef, cst.BaseAssignTargetExpression,
],
name: str,
) -> None:
scope = self.get_metadata(ScopeProvider, node)
Expand All @@ -31,27 +33,26 @@ def visit_AssignTarget(self, node: cst.AssignTarget) -> Optional[bool]:
self.already_exists = True
else:
self.process_node(node.target, target.value)
return None

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.process_node(node, node.name.value)
return None

def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.process_node(node, node.name.value)
return None

def leave_Module(
self, original_node: cst.Module, updated_node: cst.Module
self, original_node: cst.Module, updated_node: cst.Module,
) -> cst.Module:
if not self.names or self.already_exists:
return original_node

modified_body = list(original_node.body)
config = original_node.config_for_parsing

list_of_names = f",{config.default_newline}{config.default_indent}".join(
[repr(name) for name in sorted(self.names)]
list_of_names = (
f",{config.default_newline}{config.default_indent}".join(
repr(name) for name in sorted(self.names)
)
)

all_names = cst.parse_statement(
Expand All @@ -68,4 +69,4 @@ def leave_Module(
return updated_node.with_changes(body=modified_body)


__all__ = ["AllAttributeTransformer"]
__all__ = ("AllAttributeTransformer",)
40 changes: 25 additions & 15 deletions pybetter/transformers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import re
from abc import ABCMeta
from typing import Dict, Optional, FrozenSet
from typing import Dict, FrozenSet, Optional

import libcst as cst
from libcst.matchers import MatcherDecoratableTransformer
from libcst.metadata import PositionProvider

NOQA_MARKUP_REGEX = re.compile(r"noqa(?:: ((?:B[0-9]{3},)+(?:B[0-9]{3})|(B[0-9]{3})))?")
NOQA_MARKUP_REGEX = re.compile(
r"noqa(?:: ((?:B[0-9]{3},)+(?:B[0-9]{3})|(B[0-9]{3})))?",
)
NOQA_CATCHALL: str = "B999"

NoqaLineMapping = Dict[int, FrozenSet[str]]
Expand All @@ -21,14 +23,15 @@ def __init__(self):
super().__init__()

def visit_Comment(self, node: cst.Comment) -> Optional[bool]:
m = re.search(NOQA_MARKUP_REGEX, node.value)
if m:
codes = m.group(1)
position: cst.metadata.CodeRange = self.get_metadata(PositionProvider, node)
if codes:
self._line_to_code[position.start.line] = frozenset(codes.split(","))
else:
self._line_to_code[position.start.line] = frozenset({NOQA_CATCHALL})
match = re.search(NOQA_MARKUP_REGEX, node.value)
if match:
codes = match.group(1)
position: cst.metadata.CodeRange = self.get_metadata(
PositionProvider, node,
)
self._line_to_code[position.start.line] = frozenset(
codes.split(",") if codes else {NOQA_CATCHALL},
)

return True

Expand All @@ -46,7 +49,8 @@ def __new__(cls, name, bases, attrs):


class NoqaAwareTransformer(
MatcherDecoratableTransformer, metaclass=PositionProviderEnsuranceMetaclass
MatcherDecoratableTransformer,
metaclass=PositionProviderEnsuranceMetaclass,
):
METADATA_DEPENDENCIES = (PositionProvider,) # type: ignore

Expand All @@ -56,15 +60,21 @@ def __init__(self, code: str, noqa_lines: NoqaLineMapping):
super().__init__()

def on_visit(self, node: cst.CSTNode):
position: cst.metadata.CodeRange = self.get_metadata(PositionProvider, node)
position: cst.metadata.CodeRange = self.get_metadata(
PositionProvider, node,
)
applicable_noqa: FrozenSet[str] = self.noqa_lines.get(
position.start.line, frozenset()
position.start.line, frozenset(),
)

if self.check_code in applicable_noqa or NOQA_CATCHALL in applicable_noqa:
if ( # noqa: WPS337
self.check_code in applicable_noqa
) or (
NOQA_CATCHALL in applicable_noqa
):
return False

return super().on_visit(node)


__all__ = ["NoqaAwareTransformer", "NoqaDetectionVisitor", "NoqaLineMapping"]
__all__ = ("NoqaAwareTransformer", "NoqaDetectionVisitor", "NoqaLineMapping")
Loading