From 5bd5b80afdec77349a50d3c95b0d85d50f1ec0a9 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Thu, 11 Jan 2024 11:30:22 +0200 Subject: [PATCH] nodes: add `Node.iterparents()` function This is a useful addition to the existing `listchain`. While `listchain` returns top-to-bottom, `iterparents` is bottom-to-top and doesn't require an internal full iteration + `reverse`. --- changelog/11801.improvement.rst | 2 ++ src/_pytest/fixtures.py | 24 ++++++++------------- src/_pytest/nodes.py | 38 ++++++++++++++++----------------- src/_pytest/python.py | 4 +--- 4 files changed, 30 insertions(+), 38 deletions(-) create mode 100644 changelog/11801.improvement.rst diff --git a/changelog/11801.improvement.rst b/changelog/11801.improvement.rst new file mode 100644 index 00000000000..3046edca0c2 --- /dev/null +++ b/changelog/11801.improvement.rst @@ -0,0 +1,2 @@ +Added the :func:`iterparents() <_pytest.nodes.Node.iterparents>` helper method on nodes. +It is similar to :func:`listchain <_pytest.nodes.Node.listchain>`, but goes from bottom to top, and returns an iterator, not a list. diff --git a/src/_pytest/fixtures.py b/src/_pytest/fixtures.py index 13c1790bea1..dd37f8ec3e3 100644 --- a/src/_pytest/fixtures.py +++ b/src/_pytest/fixtures.py @@ -116,22 +116,16 @@ def pytest_sessionstart(session: "Session") -> None: def get_scope_package( node: nodes.Item, fixturedef: "FixtureDef[object]", -) -> Optional[Union[nodes.Item, nodes.Collector]]: +) -> Optional[nodes.Node]: from _pytest.python import Package - current: Optional[Union[nodes.Item, nodes.Collector]] = node - while current and ( - not isinstance(current, Package) or current.nodeid != fixturedef.baseid - ): - current = current.parent # type: ignore[assignment] - if current is None: - return node.session - return current + for parent in node.iterparents(): + if isinstance(parent, Package) and parent.nodeid == fixturedef.baseid: + return parent + return node.session -def get_scope_node( - node: nodes.Node, scope: Scope -) -> Optional[Union[nodes.Item, nodes.Collector]]: +def get_scope_node(node: nodes.Node, scope: Scope) -> Optional[nodes.Node]: import _pytest.python if scope is Scope.Function: @@ -738,7 +732,7 @@ def node(self): scope = self._scope if scope is Scope.Function: # This might also be a non-function Item despite its attribute name. - node: Optional[Union[nodes.Item, nodes.Collector]] = self._pyfuncitem + node: Optional[nodes.Node] = self._pyfuncitem elif scope is Scope.Package: node = get_scope_package(self._pyfuncitem, self._fixturedef) else: @@ -1513,7 +1507,7 @@ def pytest_plugin_registered(self, plugin: _PluggyPlugin) -> None: def _getautousenames(self, node: nodes.Node) -> Iterator[str]: """Return the names of autouse fixtures applicable to node.""" - for parentnode in reversed(list(nodes.iterparentnodes(node))): + for parentnode in node.listchain(): basenames = self._nodeid_autousenames.get(parentnode.nodeid) if basenames: yield from basenames @@ -1781,7 +1775,7 @@ def getfixturedefs( def _matchfactories( self, fixturedefs: Iterable[FixtureDef[Any]], node: nodes.Node ) -> Iterator[FixtureDef[Any]]: - parentnodeids = {n.nodeid for n in nodes.iterparentnodes(node)} + parentnodeids = {n.nodeid for n in node.iterparents()} for fixturedef in fixturedefs: if fixturedef.baseid in parentnodeids: yield fixturedef diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index bc6d6f4dd50..e45a515b0ae 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -49,15 +49,6 @@ tracebackcutdir = Path(_pytest.__file__).parent -def iterparentnodes(node: "Node") -> Iterator["Node"]: - """Return the parent nodes, including the node itself, from the node - upwards.""" - parent: Optional[Node] = node - while parent is not None: - yield parent - parent = parent.parent - - _NodeType = TypeVar("_NodeType", bound="Node") @@ -265,12 +256,20 @@ def setup(self) -> None: def teardown(self) -> None: pass - def listchain(self) -> List["Node"]: - """Return list of all parent collectors up to self, starting from - the root of collection tree. + def iterparents(self) -> Iterator["Node"]: + """Iterate over all parent collectors starting from and including self + up to the root of the collection tree. - :returns: The nodes. + .. versionadded:: 8.1 """ + parent: Optional[Node] = self + while parent is not None: + yield parent + parent = parent.parent + + def listchain(self) -> List["Node"]: + """Return a list of all parent collectors starting from the root of the + collection tree down to and including self.""" chain = [] item: Optional[Node] = self while item is not None: @@ -319,7 +318,7 @@ def iter_markers_with_node( :param name: If given, filter the results by the name attribute. :returns: An iterator of (node, mark) tuples. """ - for node in reversed(self.listchain()): + for node in self.iterparents(): for mark in node.own_markers: if name is None or getattr(mark, "name", None) == name: yield node, mark @@ -363,17 +362,16 @@ def addfinalizer(self, fin: Callable[[], object]) -> None: self.session._setupstate.addfinalizer(fin, self) def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]: - """Get the next parent node (including self) which is an instance of + """Get the closest parent node (including self) which is an instance of the given class. :param cls: The node class to search for. :returns: The node, if found. """ - current: Optional[Node] = self - while current and not isinstance(current, cls): - current = current.parent - assert current is None or isinstance(current, cls) - return current + for node in self.iterparents(): + if isinstance(node, cls): + return node + return None def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback: return excinfo.traceback diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 36d2eba0323..7623a97489c 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -333,10 +333,8 @@ def _getobj(self): def getmodpath(self, stopatmodule: bool = True, includemodule: bool = False) -> str: """Return Python path relative to the containing module.""" - chain = self.listchain() - chain.reverse() parts = [] - for node in chain: + for node in self.iterparents(): name = node.name if isinstance(node, Module): name = os.path.splitext(name)[0]