Skip to content

Commit

Permalink
stubgen: Preserve simple defaults in function signatures (#15355)
Browse files Browse the repository at this point in the history
  • Loading branch information
hamdanal committed Nov 27, 2023
1 parent 1200d1d commit e69c5cd
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 24 deletions.
15 changes: 12 additions & 3 deletions mypy/stubdoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@ def is_valid_type(s: str) -> bool:
class ArgSig:
"""Signature info for a single argument."""

def __init__(self, name: str, type: str | None = None, default: bool = False):
def __init__(
self,
name: str,
type: str | None = None,
*,
default: bool = False,
default_value: str = "...",
) -> None:
self.name = name
self.type = type
# Does this argument have a default value?
self.default = default
self.default_value = default_value

def is_star_arg(self) -> bool:
return self.name.startswith("*") and not self.name.startswith("**")
Expand All @@ -59,6 +67,7 @@ def __eq__(self, other: Any) -> bool:
self.name == other.name
and self.type == other.type
and self.default == other.default
and self.default_value == other.default_value
)
return False

Expand Down Expand Up @@ -119,10 +128,10 @@ def format_sig(
if arg_type:
arg_def += ": " + arg_type
if arg.default:
arg_def += " = ..."
arg_def += f" = {arg.default_value}"

elif arg.default:
arg_def += "=..."
arg_def += f"={arg.default_value}"

args.append(arg_def)

Expand Down
73 changes: 72 additions & 1 deletion mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
NameExpr,
OpExpr,
OverloadedFuncDef,
SetExpr,
Statement,
StrExpr,
TempNode,
Expand Down Expand Up @@ -491,15 +492,21 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
if kind.is_named() and not any(arg.name.startswith("*") for arg in args):
args.append(ArgSig("*"))

default = "..."
if arg_.initializer:
if not typename:
typename = self.get_str_type_of_node(arg_.initializer, True, False)
potential_default, valid = self.get_str_default_of_node(arg_.initializer)
if valid and len(potential_default) <= 200:
default = potential_default
elif kind == ARG_STAR:
name = f"*{name}"
elif kind == ARG_STAR2:
name = f"**{name}"

args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
args.append(
ArgSig(name, typename, default=bool(arg_.initializer), default_value=default)
)

if ctx.class_info is not None and all(
arg.type is None and arg.default is False for arg in args
Expand Down Expand Up @@ -1234,6 +1241,70 @@ def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
# This is some other unary expr, we cannot do anything with it (yet?).
return expr

def get_str_default_of_node(self, rvalue: Expression) -> tuple[str, bool]:
"""Get a string representation of the default value of a node.
Returns a 2-tuple of the default and whether or not it is valid.
"""
if isinstance(rvalue, NameExpr):
if rvalue.name in ("None", "True", "False"):
return rvalue.name, True
elif isinstance(rvalue, (IntExpr, FloatExpr)):
return f"{rvalue.value}", True
elif isinstance(rvalue, UnaryExpr):
if isinstance(rvalue.expr, (IntExpr, FloatExpr)):
return f"{rvalue.op}{rvalue.expr.value}", True
elif isinstance(rvalue, StrExpr):
return repr(rvalue.value), True
elif isinstance(rvalue, BytesExpr):
return "b" + repr(rvalue.value).replace("\\\\", "\\"), True
elif isinstance(rvalue, TupleExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
closing = ",)" if len(items_defaults) == 1 else ")"
default = "(" + ", ".join(items_defaults) + closing
return default, True
elif isinstance(rvalue, ListExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
default = "[" + ", ".join(items_defaults) + "]"
return default, True
elif isinstance(rvalue, SetExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
if items_defaults:
default = "{" + ", ".join(items_defaults) + "}"
return default, True
elif isinstance(rvalue, DictExpr):
items_defaults = []
for k, v in rvalue.items:
if k is None:
break
k_default, k_valid = self.get_str_default_of_node(k)
v_default, v_valid = self.get_str_default_of_node(v)
if not (k_valid and v_valid):
break
items_defaults.append(f"{k_default}: {v_default}")
else:
default = "{" + ", ".join(items_defaults) + "}"
return default, True
return "...", False

def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
is_private = self.is_private_name(name, full_module + "." + name)
if (
Expand Down
89 changes: 69 additions & 20 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,47 @@ def g(arg) -> None: ...
def f(a, b=2): ...
def g(b=-1, c=0): ...
[out]
def f(a, b: int = ...) -> None: ...
def g(b: int = ..., c: int = ...) -> None: ...
def f(a, b: int = 2) -> None: ...
def g(b: int = -1, c: int = 0) -> None: ...

[case testDefaultArgNone]
def f(x=None): ...
[out]
from _typeshed import Incomplete

def f(x: Incomplete | None = ...) -> None: ...
def f(x: Incomplete | None = None) -> None: ...

[case testDefaultArgBool]
def f(x=True, y=False): ...
[out]
def f(x: bool = ..., y: bool = ...) -> None: ...
def f(x: bool = True, y: bool = False) -> None: ...

[case testDefaultArgBool_inspect]
def f(x=True, y=False): ...
[out]
def f(x: bool = ..., y: bool = ...): ...

[case testDefaultArgStr]
def f(x='foo'): ...
def f(x='foo',y="how's quotes"): ...
[out]
def f(x: str = ...) -> None: ...
def f(x: str = 'foo', y: str = "how's quotes") -> None: ...

[case testDefaultArgStr_inspect]
def f(x='foo'): ...
[out]
def f(x: str = ...): ...

[case testDefaultArgBytes]
def f(x=b'foo'): ...
def f(x=b'foo',y=b"what's up",z=b'\xc3\xa0 la une'): ...
[out]
def f(x: bytes = ...) -> None: ...
def f(x: bytes = b'foo', y: bytes = b"what's up", z: bytes = b'\xc3\xa0 la une') -> None: ...

[case testDefaultArgFloat]
def f(x=1.2): ...
def f(x=1.2,y=1e-6,z=0.0,w=-0.0,v=+1.0): ...
def g(x=float("nan"), y=float("inf"), z=float("-inf")): ...
[out]
def f(x: float = ...) -> None: ...
def f(x: float = 1.2, y: float = 1e-06, z: float = 0.0, w: float = -0.0, v: float = +1.0) -> None: ...
def g(x=..., y=..., z=...) -> None: ...

[case testDefaultArgOther]
def f(x=ord): ...
Expand Down Expand Up @@ -126,10 +128,10 @@ def i(a, *, b=1): ...
def j(a, *, b=1, **c): ...
[out]
def f(a, *b, **c) -> None: ...
def g(a, *b, c: int = ...) -> None: ...
def h(a, *b, c: int = ..., **d) -> None: ...
def i(a, *, b: int = ...) -> None: ...
def j(a, *, b: int = ..., **c) -> None: ...
def g(a, *b, c: int = 1) -> None: ...
def h(a, *b, c: int = 1, **d) -> None: ...
def i(a, *, b: int = 1) -> None: ...
def j(a, *, b: int = 1, **c) -> None: ...

[case testClass]
class A:
Expand Down Expand Up @@ -356,8 +358,8 @@ y: Incomplete
def f(x, *, y=1): ...
def g(x, *, y=1, z=2): ...
[out]
def f(x, *, y: int = ...) -> None: ...
def g(x, *, y: int = ..., z: int = ...) -> None: ...
def f(x, *, y: int = 1) -> None: ...
def g(x, *, y: int = 1, z: int = 2) -> None: ...

[case testProperty]
class A:
Expand Down Expand Up @@ -1285,8 +1287,8 @@ from _typeshed import Incomplete

class A:
x: Incomplete
def __init__(self, a: Incomplete | None = ...) -> None: ...
def method(self, a: Incomplete | None = ...) -> None: ...
def __init__(self, a: Incomplete | None = None) -> None: ...
def method(self, a: Incomplete | None = None) -> None: ...

[case testAnnotationImportsFrom]
import foo
Expand Down Expand Up @@ -2514,7 +2516,7 @@ from _typeshed import Incomplete as _Incomplete

Y: _Incomplete

def g(x: _Incomplete | None = ...) -> None: ...
def g(x: _Incomplete | None = None) -> None: ...

x: _Incomplete

Expand Down Expand Up @@ -3503,7 +3505,7 @@ class P(Protocol):
[case testNonDefaultKeywordOnlyArgAfterAsterisk]
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
[out]
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...

[case testNestedGenerator]
def f1():
Expand Down Expand Up @@ -3909,6 +3911,53 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
class X(_Incomplete): ...
class Y(_Incomplete): ...

[case testIgnoreLongDefaults]
def f(x='abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def g(x=b'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def h(x=123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890): ...

[out]
def f(x: str = ...) -> None: ...
def g(x: bytes = ...) -> None: ...
def h(x: int = ...) -> None: ...

[case testDefaultsOfBuiltinContainers]
def f(x=(), y=(1,), z=(1, 2)): ...
def g(x=[], y=[1, 2]): ...
def h(x={}, y={1: 2, 3: 4}): ...
def i(x={1, 2, 3}): ...
def j(x=[(1,"a"), (2,"b")]): ...

[out]
def f(x=(), y=(1,), z=(1, 2)) -> None: ...
def g(x=[], y=[1, 2]) -> None: ...
def h(x={}, y={1: 2, 3: 4}) -> None: ...
def i(x={1, 2, 3}) -> None: ...
def j(x=[(1, 'a'), (2, 'b')]) -> None: ...

[case testDefaultsOfBuiltinContainersWithNonTrivialContent]
def f(x=(1, u.v), y=(k(),), z=(w,)): ...
def g(x=[1, u.v], y=[k()], z=[w]): ...
def h(x={1: u.v}, y={k(): 2}, z={m: m}, w={**n}): ...
def i(x={u.v, 2}, y={3, k()}, z={w}): ...

[out]
def f(x=..., y=..., z=...) -> None: ...
def g(x=..., y=..., z=...) -> None: ...
def h(x=..., y=..., z=..., w=...) -> None: ...
def i(x=..., y=..., z=...) -> None: ...

[case testDataclass]
import dataclasses
import dataclasses as dcs
Expand Down

0 comments on commit e69c5cd

Please sign in to comment.