Skip to content

Commit

Permalink
Add pybind11 tests for insert typing.List
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Nov 16, 2023
1 parent 827361f commit ecc6b2e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 11 deletions.
16 changes: 16 additions & 0 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,22 @@ def __init__(
self.resort_members = self.is_c_module
super().__init__(_all_, include_private, export_less, include_docstrings)
self.module_name = module_name
self.known_imports.update(
{
"typing": [
"Any",
"Callable",
"ClassVar",
"Dict",
"Iterable",
"Iterator",
"List",
"Optional",
"Tuple",
"Union",
]
}
)

def get_default_function_sig(self, func: object, ctx: FunctionContext) -> FunctionSig:
argspec = None
Expand Down
18 changes: 9 additions & 9 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,14 @@ def __init__(
self.sig_generators = self.get_sig_generators()
# populated by visit_mypy_file
self.module_name: str = ""
# These are "soft" imports for objects which might appear in annotations but not have
# a corresponding import statement.
self.known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
}

def get_sig_generators(self) -> list[SignatureGenerator]:
return []
Expand Down Expand Up @@ -667,15 +675,7 @@ def set_defined_names(self, defined_names: set[str]) -> None:
for name in self._all_ or ():
self.import_tracker.reexport(name)

# These are "soft" imports for objects which might appear in annotations but not have
# a corresponding import statement.
known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
}
for pkg, imports in known_imports.items():
for pkg, imports in self.known_imports.items():
for t in imports:
# require=False means that the import won't be added unless require_name() is called
# for the object during generation.
Expand Down
7 changes: 7 additions & 0 deletions test-data/pybind11_mypy_demo/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

#include <cmath>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace py = pybind11;

Expand Down Expand Up @@ -102,6 +103,11 @@ struct Point {
return distance_to(other.x, other.y);
}

std::vector<double> as_vector()
{
return std::vector<double>{x, y};
}

double x, y;
};

Expand Down Expand Up @@ -134,6 +140,7 @@ void bind_basics(py::module& basics) {
.def(py::init<double, double>(), py::arg("x"), py::arg("y"))
.def("distance_to", py::overload_cast<double, double>(&Point::distance_to, py::const_), py::arg("x"), py::arg("y"))
.def("distance_to", py::overload_cast<const Point&>(&Point::distance_to, py::const_), py::arg("other"))
.def("as_vector", &Point::as_vector)
// Note that the trailing newline is required because the generated docstring
// is concatenated to the signature, e.g.:
// 'some docstring\n(self: pybind11_mypy_demo.basics.Point) -> float\n'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, overload
from typing import ClassVar, List, overload

PI: float
__version__: str
Expand Down Expand Up @@ -73,6 +73,8 @@ class Point:
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
"""
def as_vector(self) -> List[float]:
"""as_vector(self: pybind11_mypy_demo.basics.Point) -> List[float]"""
@overload
def distance_to(self, x: float, y: float) -> float:
"""distance_to(*args, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, overload
from typing import ClassVar, List, overload

PI: float
__version__: str
Expand Down Expand Up @@ -47,6 +47,7 @@ class Point:
def __init__(self) -> None: ...
@overload
def __init__(self, x: float, y: float) -> None: ...
def as_vector(self) -> List[float]: ...
@overload
def distance_to(self, x: float, y: float) -> float: ...
@overload
Expand Down

0 comments on commit ecc6b2e

Please sign in to comment.