Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7.4.x] Fix assert rewriting with assignment expressions (#11414) #11419

Merged
merged 2 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ Maho
Maik Figura
Mandeep Bhutani
Manuel Krebber
Marc Mueller
Marc Schlaich
Marcelo Duarte Trevisani
Marcin Bachry
Expand Down
1 change: 1 addition & 0 deletions changelog/11239.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ``:=`` in asserts impacting unrelated test cases.
54 changes: 40 additions & 14 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys
import tokenize
import types
from collections import defaultdict
from pathlib import Path
from pathlib import PurePath
from typing import Callable
Expand Down Expand Up @@ -56,13 +57,20 @@
astNum = ast.Num


class Sentinel:
pass


assertstate_key = StashKey["AssertionState"]()

# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + (__debug__ and "c" or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT

# Special marker that denotes we have just left a scope definition
_SCOPE_END_MARKER = Sentinel()


class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""
Expand Down Expand Up @@ -645,6 +653,8 @@
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.

:scope: A tuple containing the current scope used for variables_overwrite.

:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator
Expand All @@ -666,7 +676,10 @@
else:
self.enable_assertion_pass_hook = False
self.source = source
self.variables_overwrite: Dict[str, str] = {}
self.scope: tuple[ast.AST, ...] = ()
self.variables_overwrite: defaultdict[
tuple[ast.AST, ...], Dict[str, str]
] = defaultdict(dict)

def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -732,9 +745,17 @@
mod.body[pos:pos] = imports

# Collect asserts.
nodes: List[ast.AST] = [mod]
self.scope = (mod,)
nodes: List[Union[ast.AST, Sentinel]] = [mod]
while nodes:
node = nodes.pop()
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
self.scope = tuple((*self.scope, node))
nodes.append(_SCOPE_END_MARKER)
if node == _SCOPE_END_MARKER:
self.scope = self.scope[:-1]
continue
assert isinstance(node, ast.AST)
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new: List[ast.AST] = []
Expand Down Expand Up @@ -1005,7 +1026,7 @@
]
):
pytest_temp = self.variable()
self.variables_overwrite[
self.variables_overwrite[self.scope][

Check warning on line 1029 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1029

Added line #L1029 was not covered by tests
v.left.target.id
] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
Expand Down Expand Up @@ -1048,17 +1069,20 @@
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
self.scope, {}
):
arg = self.variables_overwrite[self.scope][

Check warning on line 1075 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1075

Added line #L1075 was not covered by tests
arg.id
] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
if (
isinstance(keyword.value, ast.Name)
and keyword.value.id in self.variables_overwrite
):
keyword.value = self.variables_overwrite[
if isinstance(
keyword.value, ast.Name
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
keyword.value = self.variables_overwrite[self.scope][

Check warning on line 1085 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1085

Added line #L1085 was not covered by tests
keyword.value.id
] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
Expand Down Expand Up @@ -1094,12 +1118,14 @@
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left = self.variables_overwrite[
if isinstance(
comp.left, ast.Name
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
comp.left = self.variables_overwrite[self.scope][

Check warning on line 1124 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1124

Added line #L1124 was not covered by tests
comp.left.id
] # type:ignore[assignment]
if isinstance(comp.left, namedExpr):
self.variables_overwrite[
self.variables_overwrite[self.scope][

Check warning on line 1128 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1128

Added line #L1128 was not covered by tests
comp.left.target.id
] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left)
Expand All @@ -1119,7 +1145,7 @@
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[
self.variables_overwrite[self.scope][

Check warning on line 1148 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1148

Added line #L1148 was not covered by tests
left_res.id
] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand)
Expand Down
22 changes: 22 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,28 @@ def test_gt():
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])


class TestIssue11239:
@pytest.mark.skipif(sys.version_info[:2] <= (3, 7), reason="Only Python 3.8+")
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
"""Regression for (#11239)

Walrus operator rewriting would leak to separate test cases if they used the same variables.
"""
pytester.makepyfile(
"""
def test_1():
state = {"x": 2}.get("x")
assert state is not None

def test_2():
db = {"x": 2}
assert (state := db.get("x")) is not None
"""
)
result = pytester.runpytest()
assert result.ret == 0


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down