diff --git a/src/tap/utils.py b/src/tap/utils.py index 37e6e16..9f22c34 100644 --- a/src/tap/utils.py +++ b/src/tap/utils.py @@ -1,4 +1,5 @@ from argparse import ArgumentParser, ArgumentTypeError +import ast from base64 import b64encode, b64decode import copy from functools import wraps @@ -10,6 +11,7 @@ import re import subprocess import sys +import textwrap import tokenize from typing import ( Any, @@ -20,10 +22,12 @@ List, Literal, Optional, + Set, Tuple, Union, ) from typing_inspect import get_args as typing_inspect_get_args, get_origin as typing_inspect_get_origin +import warnings if sys.version_info >= (3, 10): from types import UnionType @@ -184,7 +188,6 @@ def tokenize_source(obj: object) -> Generator: """Returns a generator for the tokens of the object's source code.""" source = inspect.getsource(obj) token_generator = tokenize.generate_tokens(StringIO(source).readline) - return token_generator @@ -204,21 +207,65 @@ def source_line_to_tokens(obj: object) -> Dict[int, List[Dict[str, Union[str, in """Gets a dictionary mapping from line number to a dictionary of tokens on that line for an object's source code.""" line_to_tokens = {} for token_type, token, (start_line, start_column), (end_line, end_column), line in tokenize_source(obj): - line_to_tokens.setdefault(start_line, []).append( - { - "token_type": token_type, - "token": token, - "start_line": start_line, - "start_column": start_column, - "end_line": end_line, - "end_column": end_column, - "line": line, - } - ) + line_to_tokens.setdefault(start_line, []).append({ + 'token_type': token_type, + 'token': token, + 'start_line': start_line, + 'start_column': start_column, + 'end_line': end_line, + 'end_column': end_column, + 'line': line + }) return line_to_tokens +def get_subsequent_assign_lines(cls: type) -> Set[int]: + """For all multiline assign statements, get the line numbers after the first line of the assignment.""" + # Get source code of class + source = inspect.getsource(cls) + + # Parse source code using ast (with an if statement to avoid indentation errors) + source = f"if True:\n{textwrap.indent(source, ' ')}" + body = ast.parse(source).body[0] + + # Set up warning message + parse_warning = ( + "Could not parse class source code to extract comments. " + "Comments in the help string may be incorrect." + ) + + # Check for correct parsing + if not isinstance(body, ast.If): + warnings.warn(parse_warning) + return set() + + # Extract if body + if_body = body.body + + # Check for a single body + if len(if_body) != 1: + warnings.warn(parse_warning) + return set() + + # Extract class body + cls_body = if_body[0] + + # Check for a single class definition + if not isinstance(cls_body, ast.ClassDef): + warnings.warn(parse_warning) + return set() + + # Get line numbers of assign statements + assign_lines = set() + for node in cls_body.body: + if isinstance(node, (ast.Assign, ast.AnnAssign)): + # Get line number of assign statement excluding the first line (and minus 1 for the if statement) + assign_lines |= set(range(node.lineno, node.end_lineno)) + + return assign_lines + + def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]: """Returns a dictionary mapping class variables to their additional information (currently just comments).""" # Get mapping from line number to tokens @@ -227,12 +274,19 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]: # Get class variable column number class_variable_column = get_class_column(cls) + # For all multiline assign statements, get the line numbers after the first line of the assignment + # This is used to avoid identifying comments in multiline assign statements + subsequent_assign_lines = get_subsequent_assign_lines(cls) + # Extract class variables class_variable = None variable_to_comment = {} - for tokens in line_to_tokens.values(): - for i, token in enumerate(tokens): + for line, tokens in line_to_tokens.items(): + # Skip assign lines after the first line of multiline assign statements + if line in subsequent_assign_lines: + continue + for i, token in enumerate(tokens): # Skip whitespace if token["token"].strip() == "": continue @@ -244,8 +298,21 @@ def get_class_variables(cls: type) -> Dict[str, Dict[str, str]]: and token["token"][:1] in {'"', "'"} ): sep = " " if variable_to_comment[class_variable]["comment"] else "" + + # Identify the quote character (single or double) quote_char = token["token"][:1] - variable_to_comment[class_variable]["comment"] += sep + token["token"].strip(quote_char).strip() + + # Identify the number of quote characters at the start of the string + num_quote_chars = len(token["token"]) - len(token["token"].lstrip(quote_char)) + + # Remove the number of quote characters at the start of the string and the end of the string + token["token"] = token["token"][num_quote_chars:-num_quote_chars] + + # Remove the unicode escape sequences (e.g. "\"") + token["token"] = bytes(token["token"], encoding='ascii').decode('unicode-escape') + + # Add the token to the comment, stripping whitespace + variable_to_comment[class_variable]["comment"] += sep + token["token"].strip() # Match class variable class_variable = None diff --git a/tests/test_utils.py b/tests/test_utils.py index a3fabcf..33729ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -300,6 +300,30 @@ class TripleQuoteMultiline: class_variables = {"bar": {"comment": "biz baz"}, "hi": {"comment": "Hello there"}} self.assertEqual(get_class_variables(TripleQuoteMultiline), class_variables) + def test_comments_with_quotes(self): + class MultiquoteMultiline: + bar: int = 0 + '\'\'biz baz\'' + + hi: str + "\"Hello there\"\"" + + class_variables = {} + class_variables['bar'] = {'comment': "''biz baz'"} + class_variables['hi'] = {'comment': '"Hello there""'} + self.assertEqual(get_class_variables(MultiquoteMultiline), class_variables) + + def test_multiline_argument(self): + class MultilineArgument: + bar: str = ( + "This is a multiline argument" + " that should not be included in the docstring" + ) + """biz baz""" + + class_variables = {"bar": {"comment": "biz baz"}} + self.assertEqual(get_class_variables(MultilineArgument), class_variables) + def test_single_quote_multiline(self): class SingleQuoteMultiline: bar: int = 0