Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix assert rewriting with assignment expressions #11414

Merged
merged 5 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ Maho
Maik Figura
Mandeep Bhutani
Manuel Krebber
Marc Mueller
Marc Schlaich
Marcelo Duarte Trevisani
Marcin Bachry
Expand Down
1 change: 1 addition & 0 deletions changelog/11239.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed ``:=`` in asserts impacting unrelated test cases.
54 changes: 40 additions & 14 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import sys
import tokenize
import types
from collections import defaultdict
from enum import Enum
from pathlib import Path
from pathlib import PurePath
from typing import Callable
Expand Down Expand Up @@ -53,6 +55,12 @@
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT


class ScopeEndMarkerType(Enum):
"""Special marker that denotes we have just left a function or class definition."""

ScopeEndMarker = 1


class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""

Expand Down Expand Up @@ -634,6 +642,8 @@
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.

:scope: A tuple containing the current scope used for variables_overwrite.

:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator
Expand All @@ -655,7 +665,10 @@
else:
self.enable_assertion_pass_hook = False
self.source = source
self.variables_overwrite: Dict[str, str] = {}
self.scope: tuple[ast.AST, ...] = ()
self.variables_overwrite: defaultdict[
tuple[ast.AST, ...], Dict[str, str]
] = defaultdict(dict)

def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -719,9 +732,17 @@
mod.body[pos:pos] = imports

# Collect asserts.
nodes: List[ast.AST] = [mod]
self.scope = (mod,)
nodes: List[Union[ast.AST, ScopeEndMarkerType]] = [mod]
while nodes:
node = nodes.pop()
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
self.scope = tuple((*self.scope, node))
nodes.append(ScopeEndMarkerType.ScopeEndMarker)
if node is ScopeEndMarkerType.ScopeEndMarker:
self.scope = self.scope[:-1]
continue
assert isinstance(node, ast.AST)
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new: List[ast.AST] = []
Expand Down Expand Up @@ -992,7 +1013,7 @@
]
):
pytest_temp = self.variable()
self.variables_overwrite[
self.variables_overwrite[self.scope][

Check warning on line 1016 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1016

Added line #L1016 was not covered by tests
v.left.target.id
] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
Expand Down Expand Up @@ -1035,17 +1056,20 @@
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
self.scope, {}
):
arg = self.variables_overwrite[self.scope][

Check warning on line 1062 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1062

Added line #L1062 was not covered by tests
arg.id
] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
if (
isinstance(keyword.value, ast.Name)
and keyword.value.id in self.variables_overwrite
):
keyword.value = self.variables_overwrite[
if isinstance(
keyword.value, ast.Name
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
keyword.value = self.variables_overwrite[self.scope][

Check warning on line 1072 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1072

Added line #L1072 was not covered by tests
keyword.value.id
] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
Expand Down Expand Up @@ -1081,12 +1105,14 @@
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left = self.variables_overwrite[
if isinstance(
comp.left, ast.Name
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
comp.left = self.variables_overwrite[self.scope][

Check warning on line 1111 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1111

Added line #L1111 was not covered by tests
comp.left.id
] # type:ignore[assignment]
if isinstance(comp.left, ast.NamedExpr):
self.variables_overwrite[
self.variables_overwrite[self.scope][

Check warning on line 1115 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1115

Added line #L1115 was not covered by tests
comp.left.target.id
] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left)
Expand All @@ -1106,7 +1132,7 @@
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[
self.variables_overwrite[self.scope][

Check warning on line 1135 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1135

Added line #L1135 was not covered by tests
left_res.id
] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand)
Expand Down
21 changes: 21 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,27 @@ def test_gt():
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])


class TestIssue11239:
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
cdce8p marked this conversation as resolved.
Show resolved Hide resolved
"""Regression for (#11239)

Walrus operator rewriting would leak to separate test cases if they used the same variables.
"""
pytester.makepyfile(
"""
def test_1():
state = {"x": 2}.get("x")
assert state is not None

def test_2():
db = {"x": 2}
assert (state := db.get("x")) is not None
"""
)
result = pytester.runpytest()
assert result.ret == 0


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down