From 5e9bb108c66f37fa04638d5c369d4b4371b4038a Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 3 Oct 2023 13:05:45 +0200
Subject: [PATCH 1/3] respect magic trailing commas in return types
---
CHANGES.md | 1 +
src/black/linegen.py | 36 ++-
.../return_annotation_brackets_string.py | 11 +
.../funcdef_return_type_trailing_comma.py | 300 ++++++++++++++++++
.../return_annotation_brackets.py | 13 +
5 files changed, 360 insertions(+), 1 deletion(-)
create mode 100644 tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
diff --git a/CHANGES.md b/CHANGES.md
index 5e518497c92..ce928b7611b 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -18,6 +18,7 @@
- Long type hints are now wrapped in parentheses and properly indented when split across
multiple lines (#3899)
+- Magic trailing commas are now respected in return types. (#3018)
### Configuration
diff --git a/src/black/linegen.py b/src/black/linegen.py
index 9ddd4619f69..a0237001344 100644
--- a/src/black/linegen.py
+++ b/src/black/linegen.py
@@ -573,7 +573,7 @@ def transform_line(
transformers = [string_merge, string_paren_strip]
else:
transformers = []
- elif line.is_def:
+ elif line.is_def and not should_split_funcdef_with_rhs(line, mode):
transformers = [left_hand_split]
else:
@@ -652,6 +652,40 @@ def _rhs(
yield line
+def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
+ """If a funcdef has a magic trailing comma in the return type, then we should first
+ split the line with rhs to respect the comma.
+ """
+ if Preview.string_processing not in mode:
+ return False
+
+ return_type_leaves: List[Leaf] = []
+ in_return_type = False
+
+ for leaf in line.leaves:
+ if leaf.type == token.COLON:
+ in_return_type = False
+ if in_return_type:
+ return_type_leaves.append(leaf)
+ if leaf.type == token.RARROW:
+ in_return_type = True
+
+ # using `bracket_split_build_line` will mess with whitespace, so we duplicate a
+ # couple lines from it.
+ result = Line(mode=line.mode, depth=line.depth)
+ leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves)
+ for leaf in return_type_leaves:
+ result.append(
+ leaf,
+ preformatted=True,
+ track_bracket=id(leaf) in leaves_to_track,
+ )
+
+ # we could also return true if the line is too long, and the return type is longer
+ # than the param list. Or if `should_split_rhs` returns True.
+ return result.magic_trailing_comma is not None
+
+
class _BracketSplitComponent(Enum):
head = auto()
body = auto()
diff --git a/tests/data/preview/return_annotation_brackets_string.py b/tests/data/preview/return_annotation_brackets_string.py
index 6978829fd5c..9148bd045bc 100644
--- a/tests/data/preview/return_annotation_brackets_string.py
+++ b/tests/data/preview/return_annotation_brackets_string.py
@@ -2,6 +2,10 @@
def frobnicate() -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
pass
+# splitting the string breaks if there's any parameters
+def frobnicate(a) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
+ pass
+
# output
# Long string example
@@ -10,3 +14,10 @@ def frobnicate() -> (
" list[ThisIsTrulyUnreasonablyExtremelyLongClassName]"
):
pass
+
+
+# splitting the string breaks if there's any parameters
+def frobnicate(
+ a,
+) -> "ThisIsTrulyUnreasonablyExtremelyLongClassName | list[ThisIsTrulyUnreasonablyExtremelyLongClassName]":
+ pass
diff --git a/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py b/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
new file mode 100644
index 00000000000..15db772f01e
--- /dev/null
+++ b/tests/data/preview_py_310/funcdef_return_type_trailing_comma.py
@@ -0,0 +1,300 @@
+# normal, short, function definition
+def foo(a, b) -> tuple[int, float]: ...
+
+
+# normal, short, function definition w/o return type
+def foo(a, b): ...
+
+
+# no splitting
+def foo(a: A, b: B) -> list[p, q]:
+ pass
+
+
+# magic trailing comma in param list
+def foo(a, b,): ...
+
+
+# magic trailing comma in nested params in param list
+def foo(a, b: tuple[int, float,]): ...
+
+
+# magic trailing comma in return type, no params
+def a() -> tuple[
+ a,
+ b,
+]: ...
+
+
+# magic trailing comma in return type, params
+def foo(a: A, b: B) -> list[
+ p,
+ q,
+]:
+ pass
+
+
+# magic trailing comma in param list and in return type
+def foo(
+ a: a,
+ b: b,
+) -> list[
+ a,
+ a,
+]:
+ pass
+
+
+# long function definition, param list is longer
+def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
+ bbbbbbbbbbbbbbbbbb,
+) -> cccccccccccccccccccccccccccccc: ...
+
+
+# long function definition, return type is longer
+# this should maybe split on rhs?
+def aaaaaaaaaaaaaaaaa(bbbbbbbbbbbbbbbbbb) -> list[
+ Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd
+]: ...
+
+
+# long return type, no param list
+def foo() -> list[
+ Loooooooooooooooooooooooooooooooooooong,
+ Loooooooooooooooooooong,
+ Looooooooooooong,
+]: ...
+
+
+# long function name, no param list, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
+ pass
+
+
+# long function name, no param list
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
+ list[int, float]
+): ...
+
+
+# long function name, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
+ a, b
+): ...
+
+
+# unskippable type hint (??)
+def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
+ pass
+
+
+def foo(a) -> list[
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+]: # abpedeifnore
+ pass
+
+def foo(a, b: list[Bad],): ... # type: ignore
+
+# don't lose any comments (no magic)
+def foo( # 1
+ a, # 2
+ b) -> list[ # 3
+ a, # 4
+ b]: # 5
+ ... # 6
+
+
+# don't lose any comments (param list magic)
+def foo( # 1
+ a, # 2
+ b,) -> list[ # 3
+ a, # 4
+ b]: # 5
+ ... # 6
+
+
+# don't lose any comments (return type magic)
+def foo( # 1
+ a, # 2
+ b) -> list[ # 3
+ a, # 4
+ b,]: # 5
+ ... # 6
+
+
+# don't lose any comments (both magic)
+def foo( # 1
+ a, # 2
+ b,) -> list[ # 3
+ a, # 4
+ b,]: # 5
+ ... # 6
+
+# real life example
+def SimplePyFn(
+ context: hl.GeneratorContext,
+ buffer_input: Buffer[UInt8, 2],
+ func_input: Buffer[Int32, 2],
+ float_arg: Scalar[Float32],
+ offset: int = 0,
+) -> tuple[
+ Buffer[UInt8, 2],
+ Buffer[UInt8, 2],
+]: ...
+# output
+# normal, short, function definition
+def foo(a, b) -> tuple[int, float]: ...
+
+
+# normal, short, function definition w/o return type
+def foo(a, b): ...
+
+
+# no splitting
+def foo(a: A, b: B) -> list[p, q]:
+ pass
+
+
+# magic trailing comma in param list
+def foo(
+ a,
+ b,
+): ...
+
+
+# magic trailing comma in nested params in param list
+def foo(
+ a,
+ b: tuple[
+ int,
+ float,
+ ],
+): ...
+
+
+# magic trailing comma in return type, no params
+def a() -> tuple[
+ a,
+ b,
+]: ...
+
+
+# magic trailing comma in return type, params
+def foo(a: A, b: B) -> list[
+ p,
+ q,
+]:
+ pass
+
+
+# magic trailing comma in param list and in return type
+def foo(
+ a: a,
+ b: b,
+) -> list[
+ a,
+ a,
+]:
+ pass
+
+
+# long function definition, param list is longer
+def aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(
+ bbbbbbbbbbbbbbbbbb,
+) -> cccccccccccccccccccccccccccccc: ...
+
+
+# long function definition, return type is longer
+# this should maybe split on rhs?
+def aaaaaaaaaaaaaaaaa(
+ bbbbbbbbbbbbbbbbbb,
+) -> list[Ccccccccccccccccccccccccccccccccccccccccccccccccccc, Dddddd]: ...
+
+
+# long return type, no param list
+def foo() -> list[
+ Loooooooooooooooooooooooooooooooooooong,
+ Loooooooooooooooooooong,
+ Looooooooooooong,
+]: ...
+
+
+# long function name, no param list, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong():
+ pass
+
+
+# long function name, no param list
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong() -> (
+ list[int, float]
+): ...
+
+
+# long function name, no return value
+def thiiiiiiiiiiiiiiiiiis_iiiiiiiiiiiiiiiiiiiiiiiiiiiiiis_veeeeeeeeeeeeeeeeeeeeeeery_looooooong(
+ a, b
+): ...
+
+
+# unskippable type hint (??)
+def foo(a) -> list[aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa]: # type: ignore
+ pass
+
+
+def foo(
+ a,
+) -> list[
+ aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa, aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
+]: # abpedeifnore
+ pass
+
+
+def foo(
+ a,
+ b: list[Bad],
+): ... # type: ignore
+
+
+# don't lose any comments (no magic)
+def foo(a, b) -> list[a, b]: # 1 # 2 # 3 # 4 # 5
+ ... # 6
+
+
+# don't lose any comments (param list magic)
+def foo( # 1
+ a, # 2
+ b,
+) -> list[a, b]: # 3 # 4 # 5
+ ... # 6
+
+
+# don't lose any comments (return type magic)
+def foo(a, b) -> list[ # 1 # 2 # 3
+ a, # 4
+ b,
+]: # 5
+ ... # 6
+
+
+# don't lose any comments (both magic)
+def foo( # 1
+ a, # 2
+ b,
+) -> list[ # 3
+ a, # 4
+ b,
+]: # 5
+ ... # 6
+
+
+# real life example
+def SimplePyFn(
+ context: hl.GeneratorContext,
+ buffer_input: Buffer[UInt8, 2],
+ func_input: Buffer[Int32, 2],
+ float_arg: Scalar[Float32],
+ offset: int = 0,
+) -> tuple[
+ Buffer[UInt8, 2],
+ Buffer[UInt8, 2],
+]: ...
diff --git a/tests/data/simple_cases/return_annotation_brackets.py b/tests/data/simple_cases/return_annotation_brackets.py
index 265c30220d8..8509ecdb92c 100644
--- a/tests/data/simple_cases/return_annotation_brackets.py
+++ b/tests/data/simple_cases/return_annotation_brackets.py
@@ -87,6 +87,11 @@ def foo() -> tuple[loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooo
def foo() -> tuple[int, int, int,]:
return 2
+# Magic trailing comma example, with params
+# this is broken - the trailing comma is transferred to the param list. Fixed in preview
+def foo(a,b) -> tuple[int, int, int,]:
+ return 2
+
# output
# Control
def double(a: int) -> int:
@@ -208,3 +213,11 @@ def foo() -> (
]
):
return 2
+
+
+# Magic trailing comma example, with params
+# this is broken - the trailing comma is transferred to the param list. Fixed in preview
+def foo(
+ a, b
+) -> tuple[int, int, int,]:
+ return 2
From 1978872a571ee8bc4dc063fd0bea5e1019a429bd Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 3 Oct 2023 13:17:56 +0200
Subject: [PATCH 2/3] changelog reference PR instead of issue
---
CHANGES.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/CHANGES.md b/CHANGES.md
index ce928b7611b..888824ee055 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -18,7 +18,7 @@
- Long type hints are now wrapped in parentheses and properly indented when split across
multiple lines (#3899)
-- Magic trailing commas are now respected in return types. (#3018)
+- Magic trailing commas are now respected in return types. (#3916)
### Configuration
From be48f7bee25efccdf6957a3c37dcce8a82301819 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Tue, 3 Oct 2023 14:12:27 +0200
Subject: [PATCH 3/3] add preview style
'respect_magic_trailing_comma_in_return_type'
---
src/black/linegen.py | 2 +-
src/black/mode.py | 1 +
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/black/linegen.py b/src/black/linegen.py
index a0237001344..bdc4ee54ab2 100644
--- a/src/black/linegen.py
+++ b/src/black/linegen.py
@@ -656,7 +656,7 @@ def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool:
"""If a funcdef has a magic trailing comma in the return type, then we should first
split the line with rhs to respect the comma.
"""
- if Preview.string_processing not in mode:
+ if Preview.respect_magic_trailing_comma_in_return_type not in mode:
return False
return_type_leaves: List[Leaf] = []
diff --git a/src/black/mode.py b/src/black/mode.py
index f44a821bcd0..30c5d2f1b2f 100644
--- a/src/black/mode.py
+++ b/src/black/mode.py
@@ -181,6 +181,7 @@ class Preview(Enum):
string_processing = auto()
parenthesize_conditional_expressions = auto()
parenthesize_long_type_hints = auto()
+ respect_magic_trailing_comma_in_return_type = auto()
skip_magic_trailing_comma_in_subscript = auto()
wrap_long_dict_values_in_parens = auto()
wrap_multiple_context_managers_in_parens = auto()