Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: Preserve simple defaults in function signatures #15355

Merged
merged 9 commits into from
Nov 27, 2023
74 changes: 71 additions & 3 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
NameExpr,
OpExpr,
OverloadedFuncDef,
SetExpr,
Statement,
StrExpr,
TempNode,
Expand Down Expand Up @@ -762,14 +763,17 @@ def visit_func_def(self, o: FuncDef) -> None:
args.append("*")

if arg_.initializer:
default, valid = self.get_str_default_of_node(arg_.initializer)
if not valid or len(default) > 200:
default = "..."
if not annotation:
typename = self.get_str_type_of_node(arg_.initializer, True, False)
if typename == "":
annotation = "=..."
annotation = f"={default}"
else:
annotation = f": {typename} = ..."
annotation = f": {typename} = {default}"
else:
annotation += " = ..."
annotation += f" = {default}"
arg = name + annotation
elif kind == ARG_STAR:
arg = f"*{name}{annotation}"
Expand Down Expand Up @@ -1524,6 +1528,70 @@ def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
# This is some other unary expr, we cannot do anything with it (yet?).
return expr

def get_str_default_of_node(self, rvalue: Expression) -> tuple[str, bool]:
"""Get a string representation of the default value of a node.

Returns a 2-tuple of the default and whether or not it is valid.
"""
if isinstance(rvalue, NameExpr):
if rvalue.name in ("None", "True", "False"):
return rvalue.name, True
elif isinstance(rvalue, (IntExpr, FloatExpr)):
return f"{rvalue.value}", True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work correctly for NaN and infinity?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this work correctly for NaN and infinity?

Do these have a literal syntax? float("nan") is a call expression that is ignored when used as a default value in the current implementation. I'll add a test.

elif isinstance(rvalue, UnaryExpr):
if isinstance(rvalue.expr, (IntExpr, FloatExpr)):
return f"{rvalue.op}{rvalue.expr.value}", True
elif isinstance(rvalue, StrExpr):
return repr(rvalue.value), True
elif isinstance(rvalue, BytesExpr):
return f"b{rvalue.value!r}", True
elif isinstance(rvalue, TupleExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
closing = ",)" if len(items_defaults) == 1 else ")"
default = "(" + ", ".join(items_defaults) + closing
return default, True
elif isinstance(rvalue, ListExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
default = "[" + ", ".join(items_defaults) + "]"
return default, True
elif isinstance(rvalue, SetExpr):
items_defaults = []
for e in rvalue.items:
e_default, valid = self.get_str_default_of_node(e)
if not valid:
break
items_defaults.append(e_default)
else:
if items_defaults:
default = "{" + ", ".join(items_defaults) + "}"
return default, True
elif isinstance(rvalue, DictExpr):
items_defaults = []
for k, v in rvalue.items:
if k is None:
break
k_default, k_valid = self.get_str_default_of_node(k)
v_default, v_valid = self.get_str_default_of_node(v)
if not (k_valid and v_valid):
break
items_defaults.append(f"{k_default}: {v_default}")
else:
default = "{" + ", ".join(items_defaults) + "}"
return default, True
return "...", False

def print_annotation(self, t: Type) -> str:
printer = AnnotationPrinter(self)
return t.accept(printer)
Expand Down
87 changes: 67 additions & 20 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,35 @@ def g(arg) -> None: ...
def f(a, b=2): ...
def g(b=-1, c=0): ...
[out]
def f(a, b: int = ...) -> None: ...
def g(b: int = ..., c: int = ...) -> None: ...
def f(a, b: int = 2) -> None: ...
def g(b: int = -1, c: int = 0) -> None: ...

[case testDefaultArgNone]
def f(x=None): ...
[out]
from _typeshed import Incomplete

def f(x: Incomplete | None = ...) -> None: ...
def f(x: Incomplete | None = None) -> None: ...

[case testDefaultArgBool]
def f(x=True, y=False): ...
[out]
def f(x: bool = ..., y: bool = ...) -> None: ...
def f(x: bool = True, y: bool = False) -> None: ...

[case testDefaultArgStr]
def f(x='foo'): ...
def f(x='foo',y="how's quotes"): ...
[out]
def f(x: str = ...) -> None: ...
def f(x: str = 'foo', y: str = "how's quotes") -> None: ...

[case testDefaultArgBytes]
def f(x=b'foo'): ...
def f(x=b'foo',y=b"what's up"): ...
[out]
def f(x: bytes = ...) -> None: ...
def f(x: bytes = b'foo', y: bytes = b"what's up") -> None: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for bytes containing non-ASCII characters? Not convinced that would be handled correctly.


[case testDefaultArgFloat]
def f(x=1.2): ...
def f(x=1.2,y=1e-6,z=0.0,w=-0.0,v=+1.0): ...
[out]
def f(x: float = ...) -> None: ...
def f(x: float = 1.2, y: float = 1e-06, z: float = 0.0, w: float = -0.0, v: float = +1.0) -> None: ...

[case testDefaultArgOther]
def f(x=ord): ...
Expand Down Expand Up @@ -111,10 +111,10 @@ def i(a, *, b=1): ...
def j(a, *, b=1, **c): ...
[out]
def f(a, *b, **c) -> None: ...
def g(a, *b, c: int = ...) -> None: ...
def h(a, *b, c: int = ..., **d) -> None: ...
def i(a, *, b: int = ...) -> None: ...
def j(a, *, b: int = ..., **c) -> None: ...
def g(a, *b, c: int = 1) -> None: ...
def h(a, *b, c: int = 1, **d) -> None: ...
def i(a, *, b: int = 1) -> None: ...
def j(a, *, b: int = 1, **c) -> None: ...

[case testClass]
class A:
Expand Down Expand Up @@ -340,8 +340,8 @@ y: Incomplete
def f(x, *, y=1): ...
def g(x, *, y=1, z=2): ...
[out]
def f(x, *, y: int = ...) -> None: ...
def g(x, *, y: int = ..., z: int = ...) -> None: ...
def f(x, *, y: int = 1) -> None: ...
def g(x, *, y: int = 1, z: int = 2) -> None: ...

[case testProperty]
class A:
Expand Down Expand Up @@ -1081,8 +1081,8 @@ from _typeshed import Incomplete

class A:
x: Incomplete
def __init__(self, a: Incomplete | None = ...) -> None: ...
def method(self, a: Incomplete | None = ...) -> None: ...
def __init__(self, a: Incomplete | None = None) -> None: ...
def method(self, a: Incomplete | None = None) -> None: ...

[case testAnnotationImportsFrom]
import foo
Expand Down Expand Up @@ -2258,7 +2258,7 @@ from _typeshed import Incomplete as _Incomplete

Y: _Incomplete

def g(x: _Incomplete | None = ...) -> None: ...
def g(x: _Incomplete | None = None) -> None: ...

x: _Incomplete

Expand Down Expand Up @@ -3168,7 +3168,7 @@ class P(Protocol):
[case testNonDefaultKeywordOnlyArgAfterAsterisk]
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
[out]
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...

[case testNestedGenerator]
def f1():
Expand Down Expand Up @@ -3513,6 +3513,53 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
class X(_Incomplete): ...
class Y(_Incomplete): ...

[case testIgnoreLongDefaults]
def f(x='abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def g(x=b'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...

def h(x=123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890\
123456789012345678901234567890123456789012345678901234567890): ...

[out]
def f(x: str = ...) -> None: ...
def g(x: bytes = ...) -> None: ...
def h(x: int = ...) -> None: ...

[case testDefaultsOfBuiltinContainers]
def f(x=(), y=(1,), z=(1, 2)): ...
def g(x=[], y=[1, 2]): ...
def h(x={}, y={1: 2, 3: 4}): ...
def i(x={1, 2, 3}): ...
def j(x=[(1,"a"), (2,"b")]): ...

[out]
def f(x=(), y=(1,), z=(1, 2)) -> None: ...
def g(x=[], y=[1, 2]) -> None: ...
def h(x={}, y={1: 2, 3: 4}) -> None: ...
def i(x={1, 2, 3}) -> None: ...
def j(x=[(1, 'a'), (2, 'b')]) -> None: ...

[case testDefaultsOfBuiltinContainersWithNonTrivialContent]
def f(x=(1, u.v), y=(k(),), z=(w,)): ...
def g(x=[1, u.v], y=[k()], z=[w]): ...
def h(x={1: u.v}, y={k(): 2}, z={m: m}, w={**n}): ...
def i(x={u.v, 2}, y={3, k()}, z={w}): ...

[out]
def f(x=..., y=..., z=...) -> None: ...
def g(x=..., y=..., z=...) -> None: ...
def h(x=..., y=..., z=..., w=...) -> None: ...
def i(x=..., y=..., z=...) -> None: ...

[case testDataclass]
import dataclasses
import dataclasses as dcs
Expand Down