Skip to content

Commit

Permalink
Add typing to pylint.pyreverse.inspector (#6614)
Browse files Browse the repository at this point in the history
* Add typing to `pylint.pyreverse.inspector`

* Import `Callable` from `typing` to fix compatibility with Python 3.7 and 3.8

* Fix return type of `interfaces` function
  • Loading branch information
DudeNr33 committed May 15, 2022
1 parent e72fe66 commit c6d5cfe
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
67 changes: 44 additions & 23 deletions pylint/pyreverse/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,31 @@
Try to resolve definitions (namespace) dictionary, relationship...
"""

from __future__ import annotations

import collections
import os
import traceback
from collections.abc import Generator
from typing import Any, Callable, Optional

import astroid
from astroid import nodes

from pylint.pyreverse import utils

_WrapperFuncT = Callable[[Callable[[str], nodes.Module], str], Optional[nodes.Module]]


def _iface_hdlr(_):
def _iface_hdlr(_: nodes.NodeNG | Any) -> bool:
"""Handler used by interfaces to handle suspicious interface nodes."""
return True


def _astroid_wrapper(func, modname):
def _astroid_wrapper(
func: Callable[[str], nodes.Module], modname: str
) -> nodes.Module | None:
print(f"parsing {modname}...")
try:
return func(modname)
Expand All @@ -32,7 +41,11 @@ def _astroid_wrapper(func, modname):
return None


def interfaces(node, herited=True, handler_func=_iface_hdlr):
def interfaces(
node: nodes.ClassDef,
herited: bool = True,
handler_func: Callable[[nodes.NodeNG | Any], bool] = _iface_hdlr,
) -> Generator[Any, None, None]:
"""Return an iterator on interfaces implemented by the given class node."""
try:
implements = astroid.bases.Instance(node).getattr("__implements__")[0]
Expand All @@ -56,14 +69,14 @@ def interfaces(node, herited=True, handler_func=_iface_hdlr):
class IdGeneratorMixIn:
"""Mixin adding the ability to generate integer uid."""

def __init__(self, start_value=0):
def __init__(self, start_value: int = 0) -> None:
self.id_count = start_value

def init_counter(self, start_value=0):
def init_counter(self, start_value: int = 0) -> None:
"""Init the id counter."""
self.id_count = start_value

def generate_id(self):
def generate_id(self) -> int:
"""Generate a new identifier."""
self.id_count += 1
return self.id_count
Expand All @@ -72,29 +85,29 @@ def generate_id(self):
class Project:
"""A project handle a set of modules / packages."""

def __init__(self, name=""):
def __init__(self, name: str = ""):
self.name = name
self.uid = None
self.path = None
self.modules = []
self.locals = {}
self.uid: int | None = None
self.path: str = ""
self.modules: list[nodes.Module] = []
self.locals: dict[str, nodes.Module] = {}
self.__getitem__ = self.locals.__getitem__
self.__iter__ = self.locals.__iter__
self.values = self.locals.values
self.keys = self.locals.keys
self.items = self.locals.items

def add_module(self, node):
def add_module(self, node: nodes.Module) -> None:
self.locals[node.name] = node
self.modules.append(node)

def get_module(self, name):
def get_module(self, name: str) -> nodes.Module:
return self.locals[name]

def get_children(self):
def get_children(self) -> list[nodes.Module]:
return self.modules

def __repr__(self):
def __repr__(self) -> str:
return f"<Project {self.name!r} at {id(self)} ({len(self.modules)} modules)>"


Expand All @@ -121,7 +134,9 @@ class Linker(IdGeneratorMixIn, utils.LocalsVisitor):
list of implemented interface _objects_ (only on astroid.Class nodes)
"""

def __init__(self, project, inherited_interfaces=0, tag=False):
def __init__(
self, project: Project, inherited_interfaces: bool = False, tag: bool = False
) -> None:
IdGeneratorMixIn.__init__(self)
utils.LocalsVisitor.__init__(self)
# take inherited interface in consideration or not
Expand Down Expand Up @@ -180,7 +195,8 @@ def visit_classdef(self, node: nodes.ClassDef) -> None:
self.handle_assignattr_type(assignattr, node)
# resolve implemented interface
try:
node.implements = list(interfaces(node, self.inherited_interfaces))
ifaces = interfaces(node, self.inherited_interfaces)
node.implements = list(ifaces) if ifaces is not None else []
except astroid.InferenceError:
node.implements = []

Expand Down Expand Up @@ -232,7 +248,7 @@ def visit_assignname(self, node: nodes.AssignName) -> None:
frame.locals_type[node.name] = list(set(current) | utils.infer_node(node))

@staticmethod
def handle_assignattr_type(node, parent):
def handle_assignattr_type(node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
"""Handle an astroid.assignattr node.
handle instance_attrs_type
Expand Down Expand Up @@ -276,7 +292,7 @@ def visit_importfrom(self, node: nodes.ImportFrom) -> None:
if fullname != basename:
self._imported_module(node, fullname, relative)

def compute_module(self, context_name, mod_path):
def compute_module(self, context_name: str, mod_path: str) -> int:
"""Return true if the module should be added to dependencies."""
package_dir = os.path.dirname(self.project.path)
if context_name == mod_path:
Expand All @@ -285,7 +301,9 @@ def compute_module(self, context_name, mod_path):
return 1
return 0

def _imported_module(self, node, mod_path, relative):
def _imported_module(
self, node: nodes.Import | nodes.ImportFrom, mod_path: str, relative: bool
) -> None:
"""Notify an imported module, used to analyze dependencies."""
module = node.root()
context_name = module.name
Expand All @@ -301,11 +319,14 @@ def _imported_module(self, node, mod_path, relative):


def project_from_files(
files, func_wrapper=_astroid_wrapper, project_name="no name", black_list=("CVS",)
):
files: list[str],
func_wrapper: _WrapperFuncT = _astroid_wrapper,
project_name: str = "no name",
black_list: tuple[str, ...] = ("CVS",),
) -> Project:
"""Return a Project from a list of files or modules."""
# build the project representation
astroid_manager = astroid.manager.AstroidManager()
astroid_manager = astroid.MANAGER
project = Project(project_name)
for something in files:
if not os.path.exists(something):
Expand Down
9 changes: 5 additions & 4 deletions pylint/pyreverse/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import argparse
import itertools
import os
from collections.abc import Iterable

from astroid import modutils, nodes

Expand Down Expand Up @@ -56,18 +57,18 @@ def __init__(self, config: argparse.Namespace) -> None:
)
self.used_colors: dict[str, str] = {}

def write(self, diadefs):
def write(self, diadefs: Iterable[ClassDiagram | PackageDiagram]) -> None:
"""Write files for <project> according to <diadefs>."""
for diagram in diadefs:
basename = diagram.title.strip().replace(" ", "_")
file_name = f"{basename}.{self.config.output_format}"
if os.path.exists(self.config.output_directory):
file_name = os.path.join(self.config.output_directory, file_name)
self.set_printer(file_name, basename)
if diagram.TYPE == "class":
self.write_classes(diagram)
else:
if isinstance(diagram, PackageDiagram):
self.write_packages(diagram)
else:
self.write_classes(diagram)
self.save()

def write_packages(self, diagram: PackageDiagram) -> None:
Expand Down

0 comments on commit c6d5cfe

Please sign in to comment.