Skip to content

Commit

Permalink
Fix AST safety check false negative (psf#4270)
Browse files Browse the repository at this point in the history
Fixes psf#4268

Previously we would allow whitespace changes in all strings, now
only in docstrings.

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
  • Loading branch information
2 people authored and sumezulike committed Mar 10, 2024
1 parent 647dd2c commit bf9469f
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Expand Up @@ -11,6 +11,10 @@
<!-- Changes that affect Black's stable style -->

- Don't move comments along with delimiters, which could cause crashes (#4248)
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
of Black would incorrectly format the contents of certain unusual f-strings containing
nested strings with the same quote type. Now, Black will crash on such strings until
support for the new f-string syntax is implemented. (#4270)

### Preview style

Expand Down
15 changes: 10 additions & 5 deletions src/black/__init__.py
Expand Up @@ -77,8 +77,13 @@
syms,
)
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
from black.parsing import ( # noqa F401
ASTSafetyError,
InvalidInput,
lib2to3_parse,
parse_ast,
stringify_ast,
)
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
from black.report import Changed, NothingChanged, Report
from black.trans import iter_fexpr_spans
Expand Down Expand Up @@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
try:
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
raise ASTSafetyError(
"cannot use --safe with this file; failed to parse source file AST: "
f"{exc}\n"
"This could be caused by running Black with an older Python version "
Expand All @@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast = parse_ast(dst)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
raise ASTSafetyError(
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
"Please report a bug on https://github.com/psf/black/issues. "
f"This invalid output might be helpful: {log}"
Expand All @@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast_str = "\n".join(stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
raise ASTSafetyError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
" source. Please report a bug on "
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
Expand Down
42 changes: 34 additions & 8 deletions src/black/parsing.py
Expand Up @@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
return code


class ASTSafetyError(Exception):
"""Raised when Black's generated code is not equivalent to the old AST."""


def _parse_single_version(
src: str, version: Tuple[int, int], *, type_comments: bool
) -> ast.AST:
Expand Down Expand Up @@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
return normalized.strip()


def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
def stringify_ast(node: ast.AST) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
return _stringify_ast(node, [])


def _stringify_ast_with_new_parent(
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
) -> Iterator[str]:
parent_stack.append(new_parent)
yield from _stringify_ast(node, parent_stack)
parent_stack.pop()


def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
if (
isinstance(node, ast.Constant)
and isinstance(node.value, str)
Expand All @@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
# over the kind
node.kind = None

yield f"{' ' * depth}{node.__class__.__name__}("
yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("

for field in sorted(node._fields): # noqa: F402
# TypeIgnore has only one field 'lineno' which breaks this comparison
Expand All @@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
except AttributeError:
continue

yield f"{' ' * (depth + 1)}{field}="
yield f"{' ' * (len(parent_stack) + 1)}{field}="

if isinstance(value, list):
for item in value:
Expand All @@ -191,20 +206,28 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
and isinstance(item, ast.Tuple)
):
for elt in item.elts:
yield from stringify_ast(elt, depth + 2)
yield from _stringify_ast_with_new_parent(
elt, parent_stack, node
)

elif isinstance(item, ast.AST):
yield from stringify_ast(item, depth + 2)
yield from _stringify_ast_with_new_parent(item, parent_stack, node)

elif isinstance(value, ast.AST):
yield from stringify_ast(value, depth + 2)
yield from _stringify_ast_with_new_parent(value, parent_stack, node)

else:
normalized: object
if (
isinstance(node, ast.Constant)
and field == "value"
and isinstance(value, str)
and len(parent_stack) >= 2
and isinstance(parent_stack[-1], ast.Expr)
and isinstance(
parent_stack[-2],
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
)
):
# Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing. Similarly,
Expand All @@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
normalized = value.rstrip()
else:
normalized = value
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}"
yield (
f"{' ' * (len(parent_stack) + 1)}{normalized!r}, #"
f" {value.__class__.__name__}"
)

yield f"{' ' * depth}) # /{node.__class__.__name__}"
yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"
122 changes: 108 additions & 14 deletions tests/test_black.py
Expand Up @@ -46,6 +46,7 @@
from black.debug import DebugVisitor
from black.mode import Mode, Preview
from black.output import color_diff, diff
from black.parsing import ASTSafetyError
from black.report import Report

# Import other test classes
Expand Down Expand Up @@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
ff(test_file, write_back=black.WriteBack.YES)
self.assertEqual(test_file.read_bytes(), expected)

def test_assert_equivalent_different_asts(self) -> None:
with self.assertRaises(AssertionError):
black.assert_equivalent("{}", "None")

def test_root_logger_not_used_directly(self) -> None:
def fail(*args: Any, **kwargs: Any) -> None:
self.fail("Record created with root logger")
Expand Down Expand Up @@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:

exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")

def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(AssertionError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")

err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")

def test_line_ranges_with_code_option(self) -> None:
code = textwrap.dedent("""\
if a == b:
Expand Down Expand Up @@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())


class TestASTSafety(BlackBaseTestCase):
def check_ast_equivalence(
self, source: str, dest: str, *, should_fail: bool = False
) -> None:
# If we get a failure, make sure it's not because the code itself
# is invalid, since that will also cause assert_equivalent() to throw
# ASTSafetyError.
source = textwrap.dedent(source)
dest = textwrap.dedent(dest)
black.parse_ast(source)
black.parse_ast(dest)
if should_fail:
with self.assertRaises(ASTSafetyError):
black.assert_equivalent(source, dest)
else:
black.assert_equivalent(source, dest)

def test_assert_equivalent_basic(self) -> None:
self.check_ast_equivalence("{}", "None", should_fail=True)
self.check_ast_equivalence("1+2", "1 + 2")
self.check_ast_equivalence("hi # comment", "hi")

def test_assert_equivalent_del(self) -> None:
self.check_ast_equivalence("del (a, b)", "del a, b")

def test_assert_equivalent_strings(self) -> None:
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""docstring"""
''',
)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""ddocstring"""
''',
should_fail=True,
)
self.check_ast_equivalence(
'''
class A:
"""
docstring
"""
''',
'''
class A:
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
def f():
" docstring "
""",
'''
def f():
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
async def f():
" docstring "
""",
'''
async def f():
"""docstring"""
''',
)

def test_assert_equivalent_fstring(self) -> None:
major, minor = sys.version_info[:2]
if major < 3 or (major == 3 and minor < 12):
pytest.skip("relies on 3.12+ syntax")
# https://github.com/psf/black/issues/4268
self.check_ast_equivalence(
"""print(f"{"|".join([a,b,c])}")""",
"""print(f"{" | ".join([a,b,c])}")""",
should_fail=True,
)
self.check_ast_equivalence(
"""print(f"{"|".join(['a','b','c'])}")""",
"""print(f"{" | ".join(['a','b','c'])}")""",
should_fail=True,
)

def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(ASTSafetyError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")

err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")


try:
with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines()
Expand Down

0 comments on commit bf9469f

Please sign in to comment.