Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AST safety check false negative #4270

Merged
merged 8 commits into from Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Expand Up @@ -11,6 +11,10 @@
<!-- Changes that affect Black's stable style -->

- Don't move comments along with delimiters, which could cause crashes (#4248)
- Strengthen AST safety check to catch more unsafe changes to strings. Previous versions
of Black would incorrectly format the contents of certain unusual f-strings containing
nested strings with the same quote type. Now, Black will crash on such strings until
support for the new f-string syntax is implemented. (#4270)

### Preview style

Expand Down
15 changes: 10 additions & 5 deletions src/black/__init__.py
Expand Up @@ -77,8 +77,13 @@
syms,
)
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
from black.parsing import ( # noqa F401
ASTSafetyError,
InvalidInput,
lib2to3_parse,
parse_ast,
stringify_ast,
)
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
from black.report import Changed, NothingChanged, Report
from black.trans import iter_fexpr_spans
Expand Down Expand Up @@ -1511,7 +1516,7 @@ def assert_equivalent(src: str, dst: str) -> None:
try:
src_ast = parse_ast(src)
except Exception as exc:
raise AssertionError(
raise ASTSafetyError(
"cannot use --safe with this file; failed to parse source file AST: "
f"{exc}\n"
"This could be caused by running Black with an older Python version "
Expand All @@ -1522,7 +1527,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast = parse_ast(dst)
except Exception as exc:
log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst)
raise AssertionError(
raise ASTSafetyError(
f"INTERNAL ERROR: Black produced invalid code: {exc}. "
"Please report a bug on https://github.com/psf/black/issues. "
f"This invalid output might be helpful: {log}"
Expand All @@ -1532,7 +1537,7 @@ def assert_equivalent(src: str, dst: str) -> None:
dst_ast_str = "\n".join(stringify_ast(dst_ast))
if src_ast_str != dst_ast_str:
log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst"))
raise AssertionError(
raise ASTSafetyError(
"INTERNAL ERROR: Black produced code that is not equivalent to the"
" source. Please report a bug on "
f"https://github.com/psf/black/issues. This diff might be helpful: {log}"
Expand Down
42 changes: 34 additions & 8 deletions src/black/parsing.py
Expand Up @@ -110,6 +110,10 @@ def lib2to3_unparse(node: Node) -> str:
return code


class ASTSafetyError(Exception):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably inherit from AssertionError since you are replacing raise AssertionError with this new exception.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't provide compatibility guarantees for this function. I don't think inheriting from AssertionError makes semantic sense.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the response. That makes sense considering that many people probably don't use this functionality directly and I maybe agree inheriting from AssertionError is odd from a semantic/purity standpoint, even if it's less practical for compatibility (which as you mention is a non-goal here).

"""Raised when Black's generated code is not equivalent to the old AST."""


def _parse_single_version(
src: str, version: Tuple[int, int], *, type_comments: bool
) -> ast.AST:
Expand Down Expand Up @@ -154,9 +158,20 @@ def _normalize(lineend: str, value: str) -> str:
return normalized.strip()


def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
def stringify_ast(node: ast.AST) -> Iterator[str]:
"""Simple visitor generating strings to compare ASTs by content."""
return _stringify_ast(node, [])


def _stringify_ast_with_new_parent(
node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST
) -> Iterator[str]:
parent_stack.append(new_parent)
yield from _stringify_ast(node, parent_stack)
parent_stack.pop()


def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]:
if (
isinstance(node, ast.Constant)
and isinstance(node.value, str)
Expand All @@ -167,7 +182,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
# over the kind
node.kind = None

yield f"{' ' * depth}{node.__class__.__name__}("
yield f"{' ' * len(parent_stack)}{node.__class__.__name__}("

for field in sorted(node._fields): # noqa: F402
# TypeIgnore has only one field 'lineno' which breaks this comparison
Expand All @@ -179,7 +194,7 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
except AttributeError:
continue

yield f"{' ' * (depth + 1)}{field}="
yield f"{' ' * (len(parent_stack) + 1)}{field}="

if isinstance(value, list):
for item in value:
Expand All @@ -191,20 +206,28 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
and isinstance(item, ast.Tuple)
):
for elt in item.elts:
yield from stringify_ast(elt, depth + 2)
yield from _stringify_ast_with_new_parent(
elt, parent_stack, node
)

elif isinstance(item, ast.AST):
yield from stringify_ast(item, depth + 2)
yield from _stringify_ast_with_new_parent(item, parent_stack, node)

elif isinstance(value, ast.AST):
yield from stringify_ast(value, depth + 2)
yield from _stringify_ast_with_new_parent(value, parent_stack, node)

else:
normalized: object
if (
isinstance(node, ast.Constant)
and field == "value"
and isinstance(value, str)
and len(parent_stack) >= 2
and isinstance(parent_stack[-1], ast.Expr)
and isinstance(
parent_stack[-2],
(ast.FunctionDef, ast.AsyncFunctionDef, ast.Module, ast.ClassDef),
)
):
# Constant strings may be indented across newlines, if they are
# docstrings; fold spaces after newlines when comparing. Similarly,
Expand All @@ -215,6 +238,9 @@ def stringify_ast(node: ast.AST, depth: int = 0) -> Iterator[str]:
normalized = value.rstrip()
else:
normalized = value
yield f"{' ' * (depth + 2)}{normalized!r}, # {value.__class__.__name__}"
yield (
f"{' ' * (len(parent_stack) + 2)}{normalized!r}, #"
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
f" {value.__class__.__name__}"
)

yield f"{' ' * depth}) # /{node.__class__.__name__}"
yield f"{' ' * len(parent_stack)}) # /{node.__class__.__name__}"
122 changes: 108 additions & 14 deletions tests/test_black.py
Expand Up @@ -46,6 +46,7 @@
from black.debug import DebugVisitor
from black.mode import Mode, Preview
from black.output import color_diff, diff
from black.parsing import ASTSafetyError
from black.report import Report

# Import other test classes
Expand Down Expand Up @@ -1473,10 +1474,6 @@ def test_normalize_line_endings(self) -> None:
ff(test_file, write_back=black.WriteBack.YES)
self.assertEqual(test_file.read_bytes(), expected)

def test_assert_equivalent_different_asts(self) -> None:
with self.assertRaises(AssertionError):
black.assert_equivalent("{}", "None")

def test_root_logger_not_used_directly(self) -> None:
def fail(*args: Any, **kwargs: Any) -> None:
self.fail("Record created with root logger")
Expand Down Expand Up @@ -1962,16 +1959,6 @@ def test_for_handled_unexpected_eof_error(self) -> None:

exc_info.match("Cannot parse: 2:0: EOF in multi-line statement")

def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(AssertionError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")

err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")

def test_line_ranges_with_code_option(self) -> None:
code = textwrap.dedent("""\
if a == b:
Expand Down Expand Up @@ -2822,6 +2809,113 @@ def test_format_file_contents(self) -> None:
black.format_file_contents("x = 1\n", fast=True, mode=black.Mode())


class TestASTSafety(BlackBaseTestCase):
def check_ast_equivalence(
self, source: str, dest: str, *, should_fail: bool = False
) -> None:
# If we get a failure, make sure it's not because the code itself
# is invalid, since that will also cause assert_equivalent() to throw
# ASTSafetyError.
source = textwrap.dedent(source)
dest = textwrap.dedent(dest)
black.parse_ast(source)
black.parse_ast(dest)
if should_fail:
with self.assertRaises(ASTSafetyError):
black.assert_equivalent(source, dest)
else:
black.assert_equivalent(source, dest)

def test_assert_equivalent_basic(self) -> None:
self.check_ast_equivalence("{}", "None", should_fail=True)
self.check_ast_equivalence("1+2", "1 + 2")
self.check_ast_equivalence("hi # comment", "hi")

def test_assert_equivalent_del(self) -> None:
self.check_ast_equivalence("del (a, b)", "del a, b")

def test_assert_equivalent_strings(self) -> None:
self.check_ast_equivalence('x = "x"', 'x = " x "', should_fail=True)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""docstring"""
''',
)
self.check_ast_equivalence(
'''
"""docstring """
''',
'''
"""ddocstring"""
''',
should_fail=True,
)
self.check_ast_equivalence(
'''
class A:
"""

docstring


"""
''',
'''
class A:
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
def f():
" docstring "
""",
'''
def f():
"""docstring"""
''',
)
self.check_ast_equivalence(
"""
async def f():
" docstring "
""",
'''
async def f():
"""docstring"""
''',
)

def test_assert_equivalent_fstring(self) -> None:
major, minor = sys.version_info[:2]
if major < 3 or (major == 3 and minor < 12):
pytest.skip("relies on 3.12+ syntax")
# https://github.com/psf/black/issues/4268
self.check_ast_equivalence(
"""print(f"{"|".join(['a','b','c'])}")""",
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
"""print(f"{" | ".join([a,b,c])}")""",
should_fail=True,
)
self.check_ast_equivalence(
"""print(f"{"|".join(['a','b','c'])}")""",
"""print(f"{" | ".join(['a','b','c'])}")""",
should_fail=True,
)

def test_equivalency_ast_parse_failure_includes_error(self) -> None:
with pytest.raises(ASTSafetyError) as err:
black.assert_equivalent("a«»a = 1", "a«»a = 1")

err.match("--safe")
# Unfortunately the SyntaxError message has changed in newer versions so we
# can't match it directly.
err.match("invalid character")
err.match(r"\(<unknown>, line 1\)")


try:
with open(black.__file__, "r", encoding="utf-8") as _bf:
black_source_lines = _bf.readlines()
Expand Down