Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

respect magic trailing commas in return types #3916

Merged
merged 3 commits into from Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -18,6 +18,7 @@

- Long type hints are now wrapped in parentheses and properly indented when split across
multiple lines (#3899)
- Magic trailing commas are now respected in return types. (#3916)

### Configuration

Expand Down
36 changes: 35 additions & 1 deletion src/black/linegen.py
Expand Up @@ -573,7 +573,7 @@ def transform_line(
transformers = [string_merge, string_paren_strip]
else:
transformers = []
elif line.is_def:
elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
transformers = [left_hand_split]
else:

Expand Down Expand Up @@ -652,6 +652,40 @@ def _rhs(
yield line


def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
"""If a funcdef has a magic trailing comma in the return type, then we should first
split the line with rhs to respect the comma.
"""
if Preview.respect_magic_trailing_comma_in_return_type not in mode:
return False

return_type_leaves: List[Leaf] = []
in_return_type = False

for leaf in line.leaves:
if leaf.type == token.COLON:
in_return_type = False
if in_return_type:
return_type_leaves.append(leaf)
if leaf.type == token.RARROW:
in_return_type = True

# using `bracket_split_build_line` will mess with whitespace, so we duplicate a
# couple lines from it.
result = Line(mode=line.mode, depth=line.depth)
leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
for leaf in return_type_leaves:
result.append(
leaf,
preformatted=True,
track_bracket=id(leaf) in leaves_to_track,
)

# we could also return true if the line is too long, and the return type is longer
# than the param list. Or if `should_split_rhs` returns True.
return result.magic_trailing_comma is not None


class _BracketSplitComponent(Enum):
head = auto()
body = auto()
Expand Down
1 change: 1 addition & 0 deletions src/black/mode.py
Expand Up @@ -181,6 +181,7 @@ class Preview(Enum):
string_processing = auto()
parenthesize_conditional_expressions = auto()
parenthesize_long_type_hints = auto()
respect_magic_trailing_comma_in_return_type = auto()
skip_magic_trailing_comma_in_subscript = auto()
wrap_long_dict_values_in_parens = auto()
wrap_multiple_context_managers_in_parens = auto()
Expand Down
11 changes: 11 additions & 0 deletions tests/data/preview/return_annotation_brackets_string.py
Expand Up @@ -2,6 +2,10 @@
def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass

# splitting the string breaks if there's any parameters
def frobnicate(a) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass

# output

# Long string example
Expand All @@ -10,3 +14,10 @@ def frobnicate() -> (
" list[ThisIsTrulyUnreasonablyExtremelyLongClassName]"
):
pass


# splitting the string breaks if there's any parameters
def frobnicate(
a,
) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass
300 changes: 300 additions & 0 deletions tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
@@ -0,0 +1,300 @@
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...


# normal, short, function definition w/o return type
def foo(a, b): ...


# no splitting
def foo(a: A, b: B) -> list[p, q]:
pass


# magic trailing comma in param list
def foo(a, b,): ...


# magic trailing comma in nested params in param list
def foo(a, b: tuple[int, float,]): ...


# magic trailing comma in return type, no params
def a() -> tuple[
a,
b,
]: ...


# magic trailing comma in return type, params
def foo(a: A, b: B) -> list[
p,
q,
]:
pass


# magic trailing comma in param list and in return type
def foo(
a: a,
b: b,
) -> list[
a,
a,
]:
pass


# long function definition, param list is longer
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
bbbbbbbbbbbbbbbbbb,
) -> cccccccccccccccccccccccccccccc: ...


# long function definition, return type is longer
# this should maybe split on rhs?
def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[
Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd
]: ...


# long return type, no param list
def foo() -> list[
Loooooooooooooooooooooooooooooooooooong,
Loooooooooooooooooooong,
Looooooooooooong,
]: ...


# long function name, no param list, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
pass


# long function name, no param list
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
list[int, float]
): ...


# long function name, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
a, b
): ...


# unskippable type hint (??)
def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
pass


def foo(a) -> list[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
]: # abpedeifnore
pass

def foo(a, b: list[Bad],): ... # type: ignore

# don't lose any comments (no magic)
def foo( # 1
a, # 2
b) -> list[ # 3
a, # 4
b]: # 5
... # 6


# don't lose any comments (param list magic)
def foo( # 1
a, # 2
b,) -> list[ # 3
a, # 4
b]: # 5
... # 6


# don't lose any comments (return type magic)
def foo( # 1
a, # 2
b) -> list[ # 3
a, # 4
b,]: # 5
... # 6


# don't lose any comments (both magic)
def foo( # 1
a, # 2
b,) -> list[ # 3
a, # 4
b,]: # 5
... # 6

# real life example
def SimplePyFn(
context: hl.GeneratorContext,
buffer_input: Buffer[UInt8, 2],
func_input: Buffer[Int32, 2],
float_arg: Scalar[Float32],
offset: int = 0,
) -> tuple[
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...
# output
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...


# normal, short, function definition w/o return type
def foo(a, b): ...


# no splitting
def foo(a: A, b: B) -> list[p, q]:
pass


# magic trailing comma in param list
def foo(
a,
b,
): ...


# magic trailing comma in nested params in param list
def foo(
a,
b: tuple[
int,
float,
],
): ...


# magic trailing comma in return type, no params
def a() -> tuple[
a,
b,
]: ...


# magic trailing comma in return type, params
def foo(a: A, b: B) -> list[
p,
q,
]:
pass


# magic trailing comma in param list and in return type
def foo(
a: a,
b: b,
) -> list[
a,
a,
]:
pass


# long function definition, param list is longer
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
bbbbbbbbbbbbbbbbbb,
) -> cccccccccccccccccccccccccccccc: ...


# long function definition, return type is longer
# this should maybe split on rhs?
def aaaaaaaaaaaaaaaaa(
bbbbbbbbbbbbbbbbbb,
) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ...


# long return type, no param list
def foo() -> list[
Loooooooooooooooooooooooooooooooooooong,
Loooooooooooooooooooong,
Looooooooooooong,
]: ...


# long function name, no param list, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
pass


# long function name, no param list
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
list[int, float]
): ...


# long function name, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
a, b
): ...


# unskippable type hint (??)
def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
pass


def foo(
a,
) -> list[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
]: # abpedeifnore
pass


def foo(
a,
b: list[Bad],
): ... # type: ignore


# don't lose any comments (no magic)
def foo(a, b) -> list[a, b]: # 1 # 2 # 3 # 4 # 5
... # 6


# don't lose any comments (param list magic)
def foo( # 1
a, # 2
b,
) -> list[a, b]: # 3 # 4 # 5
... # 6


# don't lose any comments (return type magic)
def foo(a, b) -> list[ # 1 # 2 # 3
a, # 4
b,
]: # 5
... # 6


# don't lose any comments (both magic)
def foo( # 1
a, # 2
b,
) -> list[ # 3
a, # 4
b,
]: # 5
... # 6


# real life example
def SimplePyFn(
context: hl.GeneratorContext,
buffer_input: Buffer[UInt8, 2],
func_input: Buffer[Int32, 2],
float_arg: Scalar[Float32],
offset: int = 0,
) -> tuple[
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...