diff --git a/AUTHORS b/AUTHORS index e8456d92b31..6d860575f37 100644 --- a/AUTHORS +++ b/AUTHORS @@ -231,6 +231,7 @@ Maho Maik Figura Mandeep Bhutani Manuel Krebber +Marc Mueller Marc Schlaich Marcelo Duarte Trevisani Marcin Bachry diff --git a/changelog/11239.bugfix.rst b/changelog/11239.bugfix.rst new file mode 100644 index 00000000000..a486224cdda --- /dev/null +++ b/changelog/11239.bugfix.rst @@ -0,0 +1 @@ +Fixed ``:=`` in asserts impacting unrelated test cases. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index ab83fee32b2..fd23552973e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -13,6 +13,7 @@ import sys import tokenize import types +from collections import defaultdict from pathlib import Path from pathlib import PurePath from typing import Callable @@ -56,6 +57,10 @@ astNum = ast.Num +class Sentinel: + pass + + assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -63,6 +68,9 @@ PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT +# Special marker that denotes we have just left a scope definition +_SCOPE_END_MARKER = Sentinel() + class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -645,6 +653,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. + :scope: A tuple containing the current scope used for variables_overwrite. + :variables_overwrite: A dict filled with references to variables that change value within an assert. This happens when a variable is reassigned with the walrus operator @@ -666,7 +676,10 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.variables_overwrite: Dict[str, str] = {} + self.scope: tuple[ast.AST, ...] = () + self.variables_overwrite: defaultdict[ + tuple[ast.AST, ...], Dict[str, str] + ] = defaultdict(dict) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -732,9 +745,17 @@ def run(self, mod: ast.Module) -> None: mod.body[pos:pos] = imports # Collect asserts. - nodes: List[ast.AST] = [mod] + self.scope = (mod,) + nodes: List[Union[ast.AST, Sentinel]] = [mod] while nodes: node = nodes.pop() + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + self.scope = tuple((*self.scope, node)) + nodes.append(_SCOPE_END_MARKER) + if node == _SCOPE_END_MARKER: + self.scope = self.scope[:-1] + continue + assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): new: List[ast.AST] = [] @@ -1005,7 +1026,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: ] ): pytest_temp = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ v.left.target.id ] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp @@ -1048,17 +1069,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: - arg = self.variables_overwrite[arg.id] # type:ignore[assignment] + if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( + self.scope, {} + ): + arg = self.variables_overwrite[self.scope][ + arg.id + ] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if ( - isinstance(keyword.value, ast.Name) - and keyword.value.id in self.variables_overwrite - ): - keyword.value = self.variables_overwrite[ + if isinstance( + keyword.value, ast.Name + ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): + keyword.value = self.variables_overwrite[self.scope][ keyword.value.id ] # type:ignore[assignment] res, expl = self.visit(keyword.value) @@ -1094,12 +1118,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: - comp.left = self.variables_overwrite[ + if isinstance( + comp.left, ast.Name + ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): + comp.left = self.variables_overwrite[self.scope][ comp.left.id ] # type:ignore[assignment] if isinstance(comp.left, namedExpr): - self.variables_overwrite[ + self.variables_overwrite[self.scope][ comp.left.target.id ] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) @@ -1119,7 +1145,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ left_res.id ] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 778f843e6cf..63353438c95 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1531,6 +1531,27 @@ def test_gt(): result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"]) +class TestIssue11239: + def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None: + """Regression for (#11239) + + Walrus operator rewriting would leak to separate test cases if they used the same variables. + """ + pytester.makepyfile( + """ + def test_1(): + state = {"x": 2}.get("x") + assert state is not None + + def test_2(): + db = {"x": 2} + assert (state := db.get("x")) is not None + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )