Skip to content

Commit

Permalink
respect magic trailing commas in return types (#3916)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Oct 4, 2023
1 parent 947bd38 commit 36078bc
Show file tree
Hide file tree
Showing 6 changed files with 361 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -18,6 +18,7 @@

- Long type hints are now wrapped in parentheses and properly indented when split across
multiple lines (#3899)
- Magic trailing commas are now respected in return types. (#3916)

### Configuration

Expand Down
36 changes: 35 additions & 1 deletion src/black/linegen.py
Expand Up @@ -573,7 +573,7 @@ def transform_line(
transformers = [string_merge, string_paren_strip]
else:
transformers = []
elif line.is_def:
elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
transformers = [left_hand_split]
else:

Expand Down Expand Up @@ -652,6 +652,40 @@ def _rhs(
yield line


def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
"""If a funcdef has a magic trailing comma in the return type, then we should first
split the line with rhs to respect the comma.
"""
if Preview.respect_magic_trailing_comma_in_return_type not in mode:
return False

return_type_leaves: List[Leaf] = []
in_return_type = False

for leaf in line.leaves:
if leaf.type == token.COLON:
in_return_type = False
if in_return_type:
return_type_leaves.append(leaf)
if leaf.type == token.RARROW:
in_return_type = True

# using `bracket_split_build_line` will mess with whitespace, so we duplicate a
# couple lines from it.
result = Line(mode=line.mode, depth=line.depth)
leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
for leaf in return_type_leaves:
result.append(
leaf,
preformatted=True,
track_bracket=id(leaf) in leaves_to_track,
)

# we could also return true if the line is too long, and the return type is longer
# than the param list. Or if `should_split_rhs` returns True.
return result.magic_trailing_comma is not None


class _BracketSplitComponent(Enum):
head = auto()
body = auto()
Expand Down
1 change: 1 addition & 0 deletions src/black/mode.py
Expand Up @@ -181,6 +181,7 @@ class Preview(Enum):
string_processing = auto()
parenthesize_conditional_expressions = auto()
parenthesize_long_type_hints = auto()
respect_magic_trailing_comma_in_return_type = auto()
skip_magic_trailing_comma_in_subscript = auto()
wrap_long_dict_values_in_parens = auto()
wrap_multiple_context_managers_in_parens = auto()
Expand Down
11 changes: 11 additions & 0 deletions tests/data/preview/return_annotation_brackets_string.py
Expand Up @@ -2,6 +2,10 @@
def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass

# splitting the string breaks if there's any parameters
def frobnicate(a) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass

# output

# Long string example
Expand All @@ -10,3 +14,10 @@ def frobnicate() -> (
" list[ThisIsTrulyUnreasonablyExtremelyLongClassName]"
):
pass


# splitting the string breaks if there's any parameters
def frobnicate(
a,
) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass
300 changes: 300 additions & 0 deletions tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
@@ -0,0 +1,300 @@
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...


# normal, short, function definition w/o return type
def foo(a, b): ...


# no splitting
def foo(a: A, b: B) -> list[p, q]:
pass


# magic trailing comma in param list
def foo(a, b,): ...


# magic trailing comma in nested params in param list
def foo(a, b: tuple[int, float,]): ...


# magic trailing comma in return type, no params
def a() -> tuple[
a,
b,
]: ...


# magic trailing comma in return type, params
def foo(a: A, b: B) -> list[
p,
q,
]:
pass


# magic trailing comma in param list and in return type
def foo(
a: a,
b: b,
) -> list[
a,
a,
]:
pass


# long function definition, param list is longer
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
bbbbbbbbbbbbbbbbbb,
) -> cccccccccccccccccccccccccccccc: ...


# long function definition, return type is longer
# this should maybe split on rhs?
def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[
Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd
]: ...


# long return type, no param list
def foo() -> list[
Loooooooooooooooooooooooooooooooooooong,
Loooooooooooooooooooong,
Looooooooooooong,
]: ...


# long function name, no param list, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
pass


# long function name, no param list
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
list[int, float]
): ...


# long function name, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
a, b
): ...


# unskippable type hint (??)
def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
pass


def foo(a) -> list[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
]: # abpedeifnore
pass

def foo(a, b: list[Bad],): ... # type: ignore

# don't lose any comments (no magic)
def foo( # 1
a, # 2
b) -> list[ # 3
a, # 4
b]: # 5
... # 6


# don't lose any comments (param list magic)
def foo( # 1
a, # 2
b,) -> list[ # 3
a, # 4
b]: # 5
... # 6


# don't lose any comments (return type magic)
def foo( # 1
a, # 2
b) -> list[ # 3
a, # 4
b,]: # 5
... # 6


# don't lose any comments (both magic)
def foo( # 1
a, # 2
b,) -> list[ # 3
a, # 4
b,]: # 5
... # 6

# real life example
def SimplePyFn(
context: hl.GeneratorContext,
buffer_input: Buffer[UInt8, 2],
func_input: Buffer[Int32, 2],
float_arg: Scalar[Float32],
offset: int = 0,
) -> tuple[
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...
# output
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...


# normal, short, function definition w/o return type
def foo(a, b): ...


# no splitting
def foo(a: A, b: B) -> list[p, q]:
pass


# magic trailing comma in param list
def foo(
a,
b,
): ...


# magic trailing comma in nested params in param list
def foo(
a,
b: tuple[
int,
float,
],
): ...


# magic trailing comma in return type, no params
def a() -> tuple[
a,
b,
]: ...


# magic trailing comma in return type, params
def foo(a: A, b: B) -> list[
p,
q,
]:
pass


# magic trailing comma in param list and in return type
def foo(
a: a,
b: b,
) -> list[
a,
a,
]:
pass


# long function definition, param list is longer
def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
bbbbbbbbbbbbbbbbbb,
) -> cccccccccccccccccccccccccccccc: ...


# long function definition, return type is longer
# this should maybe split on rhs?
def aaaaaaaaaaaaaaaaa(
bbbbbbbbbbbbbbbbbb,
) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ...


# long return type, no param list
def foo() -> list[
Loooooooooooooooooooooooooooooooooooong,
Loooooooooooooooooooong,
Looooooooooooong,
]: ...


# long function name, no param list, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
pass


# long function name, no param list
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
list[int, float]
): ...


# long function name, no return value
def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
a, b
): ...


# unskippable type hint (??)
def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
pass


def foo(
a,
) -> list[
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
]: # abpedeifnore
pass


def foo(
a,
b: list[Bad],
): ... # type: ignore


# don't lose any comments (no magic)
def foo(a, b) -> list[a, b]: # 1 # 2 # 3 # 4 # 5
... # 6


# don't lose any comments (param list magic)
def foo( # 1
a, # 2
b,
) -> list[a, b]: # 3 # 4 # 5
... # 6


# don't lose any comments (return type magic)
def foo(a, b) -> list[ # 1 # 2 # 3
a, # 4
b,
]: # 5
... # 6


# don't lose any comments (both magic)
def foo( # 1
a, # 2
b,
) -> list[ # 3
a, # 4
b,
]: # 5
... # 6


# real life example
def SimplePyFn(
context: hl.GeneratorContext,
buffer_input: Buffer[UInt8, 2],
func_input: Buffer[Int32, 2],
float_arg: Scalar[Float32],
offset: int = 0,
) -> tuple[
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...

0 comments on commit 36078bc

Please sign in to comment.