Skip to content

Commit

Permalink
nodes: add Node.iterchain() function
Browse files Browse the repository at this point in the history
This is a useful addition to the existing `listchain`. While `listchain`
returns top-to-bottom, `iterchain` is bottom-to-top and doesn't require
an internal full iteration + `reverse`.
  • Loading branch information
bluetech committed Jan 11, 2024
1 parent b1c4308 commit 471adac
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 25 deletions.
4 changes: 2 additions & 2 deletions src/_pytest/fixtures.py
Expand Up @@ -1513,7 +1513,7 @@ def pytest_plugin_registered(self, plugin: _PluggyPlugin) -> None:

def _getautousenames(self, node: nodes.Node) -> Iterator[str]:
"""Return the names of autouse fixtures applicable to node."""
for parentnode in reversed(list(nodes.iterparentnodes(node))):
for parentnode in node.listchain():
basenames = self._nodeid_autousenames.get(parentnode.nodeid)
if basenames:
yield from basenames
Expand Down Expand Up @@ -1781,7 +1781,7 @@ def getfixturedefs(
def _matchfactories(
self, fixturedefs: Iterable[FixtureDef[Any]], node: nodes.Node
) -> Iterator[FixtureDef[Any]]:
parentnodeids = {n.nodeid for n in nodes.iterparentnodes(node)}
parentnodeids = {n.nodeid for n in node.iterchain()}
for fixturedef in fixturedefs:
if fixturedef.baseid in parentnodeids:
yield fixturedef
38 changes: 18 additions & 20 deletions src/_pytest/nodes.py
Expand Up @@ -49,15 +49,6 @@
tracebackcutdir = Path(_pytest.__file__).parent


def iterparentnodes(node: "Node") -> Iterator["Node"]:
"""Return the parent nodes, including the node itself, from the node
upwards."""
parent: Optional[Node] = node
while parent is not None:
yield parent
parent = parent.parent


_NodeType = TypeVar("_NodeType", bound="Node")


Expand Down Expand Up @@ -265,12 +256,20 @@ def setup(self) -> None:
def teardown(self) -> None:
pass

def listchain(self) -> List["Node"]:
"""Return list of all parent collectors up to self, starting from
the root of collection tree.
def iterchain(self) -> Iterator["Node"]:
"""Iterate over self and all parent collectors up to root of the
collection tree, starting from self.
:returns: The nodes.
.. versionadded:: 8.1
"""
parent: Optional[Node] = self
while parent is not None:
yield parent
parent = parent.parent

def listchain(self) -> List["Node"]:
"""Return a list of all parent collectors up to and including self,
starting from the root of the collection tree."""
chain = []
item: Optional[Node] = self
while item is not None:
Expand Down Expand Up @@ -319,7 +318,7 @@ def iter_markers_with_node(
:param name: If given, filter the results by the name attribute.
:returns: An iterator of (node, mark) tuples.
"""
for node in reversed(self.listchain()):
for node in self.iterchain():
for mark in node.own_markers:
if name is None or getattr(mark, "name", None) == name:
yield node, mark
Expand Down Expand Up @@ -363,17 +362,16 @@ def addfinalizer(self, fin: Callable[[], object]) -> None:
self.session._setupstate.addfinalizer(fin, self)

def getparent(self, cls: Type[_NodeType]) -> Optional[_NodeType]:
"""Get the next parent node (including self) which is an instance of
"""Get the closest parent node (including self) which is an instance of
the given class.
:param cls: The node class to search for.
:returns: The node, if found.
"""
current: Optional[Node] = self
while current and not isinstance(current, cls):
current = current.parent
assert current is None or isinstance(current, cls)
return current
for node in self.iterchain():
if isinstance(node, cls):
return node
return None

def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
return excinfo.traceback
Expand Down
4 changes: 1 addition & 3 deletions src/_pytest/python.py
Expand Up @@ -333,10 +333,8 @@ def _getobj(self):

def getmodpath(self, stopatmodule: bool = True, includemodule: bool = False) -> str:
"""Return Python path relative to the containing module."""
chain = self.listchain()
chain.reverse()
parts = []
for node in chain:
for node in self.iterchain():
name = node.name
if isinstance(node, Module):
name = os.path.splitext(name)[0]
Expand Down

0 comments on commit 471adac

Please sign in to comment.