From c1bdad7100b83f0196364d4e48433d738a5d5b2f Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Sat, 7 Oct 2023 17:08:48 +0200
Subject: [PATCH] split long function parameter type hints without parentheses
---
CHANGES.md | 2 +
src/black/brackets.py | 11 +
src/black/linegen.py | 202 +++++++++++++++++-
src/black/lines.py | 22 ++
src/black/mode.py | 1 +
.../pep604_union_types_line_breaks.py | 150 ++++++++++++-
6 files changed, 372 insertions(+), 16 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 062a195717d..9e5aa7ce76e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -19,6 +19,8 @@
- Long type hints are now wrapped in parentheses and properly indented when split across
multiple lines (#3899)
- Magic trailing commas are now respected in return types. (#3916)
+- Long type hints in parameter lists now get split across multiple lines and properly
+ indented without being wrapped in parentheses (#3930)
### Configuration
diff --git a/src/black/brackets.py b/src/black/brackets.py
index 85dac6edd1e..ec7ad908c69 100644
--- a/src/black/brackets.py
+++ b/src/black/brackets.py
@@ -67,6 +67,17 @@ class BracketTracker:
_lambda_argument_depths: List[int] = field(default_factory=list)
invisible: List[Leaf] = field(default_factory=list)
+ def copy(self) -> "BracketTracker":
+ return BracketTracker(
+ self.depth,
+ self.bracket_match.copy(),
+ self.delimiters.copy(),
+ self.previous,
+ self._for_loop_depths.copy(),
+ self._lambda_argument_depths.copy(),
+ self.invisible.copy(),
+ )
+
def mark(self, leaf: Leaf) -> None:
"""Mark `leaf` with bracket-related metadata. Keep track of delimiters.
diff --git a/src/black/linegen.py b/src/black/linegen.py
index bdc4ee54ab2..ffc673ff81a 100644
--- a/src/black/linegen.py
+++ b/src/black/linegen.py
@@ -5,7 +5,7 @@
from dataclasses import replace
from enum import Enum, auto
from functools import partial, wraps
-from typing import Collection, Iterator, List, Optional, Set, Union, cast
+from typing import Any, Collection, Iterator, List, Optional, Set, Tuple, Union, cast
from black.brackets import (
COMMA_PRIORITY,
@@ -398,18 +398,42 @@ def visit_factor(self, node: Node) -> Iterator[Line]:
yield from self.visit_default(node)
def visit_tname(self, node: Node) -> Iterator[Line]:
- """
- Add potential parentheses around types in function parameter lists to be made
- into real parentheses in case the type hint is too long to fit on a line
+ """Remove unnecessary parentheses around types in PEP604 VBAR-separated
+ parameter lists, and in other cases add potential parentheses around types in
+ function parameter lists to be made into real parentheses in case the type hint
+ is too long to fit on a line.
+
Examples:
def foo(a: int, b: float = 7): ...
+ def bar(c: (int|float)): ...
->
def foo(a: (int), b: (float) = 7): ...
+ def bar(c: int|float): ...
"""
- if Preview.parenthesize_long_type_hints in self.mode:
- assert len(node.children) == 3
+ assert (
+ len(node.children) == 3
+ ), "type hints should always have three children: name, colon, type"
+ typehint = node.children[2].children
+
+ # Ensure PEP604 VBAR-separated typehints are not wrapped in parens so they
+ # get properly handled by func_typehint_split.
+ # Invalid typehints with other operators than `|` parsed as syms.expr are
+ # safe to remove parens from.
+ if Preview.split_long_param_type_without_parens in self.mode and (
+ len(typehint) == 3
+ and typehint[0].type == token.LPAR
+ and typehint[-1].type == token.RPAR
+ and typehint[1].type == syms.expr
+ ):
+ typehint[0].remove()
+ typehint[-1].remove()
+ # ensure trivial and non-PEP604 typehints are wrapped in invisible parens
+ elif Preview.parenthesize_long_type_hints in self.mode and (
+ Preview.split_long_param_type_without_parens not in self.mode
+ or (len(typehint) < 3 or typehint[1].type != token.VBAR)
+ ):
if maybe_make_parens_invisible_in_atom(node.children[2], parent=node):
wrap_in_parentheses(node, node.children[2], visible=False)
@@ -614,6 +638,7 @@ def _rhs(
string_merge,
string_paren_strip,
string_split,
+ func_typehint_split,
delimiter_split,
standalone_comment_split,
string_paren_wrap,
@@ -1056,6 +1081,171 @@ def _safe_add_trailing_comma(safe: bool, delimiter_priority: int, line: Line) ->
return line
+@dont_increase_indentation
+def func_typehint_split(
+ line: Line, features: Collection[Feature], mode: Mode
+) -> Iterator[Line]:
+ """Split a long typehint in a parameter list, indenting subsequent rows.
+ Examples (assuming short line width):
+ def foo(
+ arg: VeryVeryLongType1 | VeryVeryLongType2,
+ arg2: VeryVeryLongType1 = bar()
+ very_very_long_arg_name: VeryVeryLongType | VeryVeryLongType2
+ short_arg: ShortType # very very very long comment
+ | ShortType
+ ): ...
+
+ =>
+ def foo(
+ arg: VeryVeryLongType1
+ | VeryVeryLongType2,
+ arg2: VeryVeryLongType1
+ = bar()
+ very_very_long_arg_name:
+ VeryLongType
+ | VeryLongType2
+ short_arg:
+ ShortType # very very very long comment
+ | ShortType
+ ): ...
+ """
+ if not Preview.split_long_param_type_without_parens:
+ raise CannotSplit("not enabled")
+ if not line.leaves:
+ raise CannotSplit("line empty")
+ if line.leaves[0].parent is None or not line.leaves[0].parent.type == syms.tname:
+ raise CannotSplit("not a tname")
+
+ # split tname into a list of lists of leaves for each `| TYPENAME`
+ # also gives where the colon is in the first line, in case we want to split it
+ result, first_colon_idx = _split_tname(line.leaves)
+
+ # parse the first line, see if COLON TYPENAME is on the same line, with no brackets,
+ # in which case we want to move the TYPENAME to a different line.
+ lines = list(_split_first_typehint(line, result[0]))
+ if lines:
+ yield from lines
+ start_with: int = 1
+ else:
+ start_with = 0
+ result.insert(1, result[0][first_colon_idx + 1 :])
+ result[0] = result[0][: first_colon_idx + 1]
+
+ # parse the leaves and yield the lines
+ for idx, leaves in enumerate(result):
+ if idx < start_with:
+ continue
+ depth = line.depth + (0 if idx == 0 else 1)
+ current_line = Line(
+ mode=line.mode,
+ depth=depth,
+ inside_brackets=line.inside_brackets,
+ )
+ for leaf in leaves:
+ yield from _append_to_line(leaf, current_line, depth=line.depth + 1)
+
+ for comment in line.comments_after(leaf):
+ yield from _append_to_line(comment, current_line, depth=line.depth + 1)
+
+ if current_line:
+ yield current_line
+
+
+def _split_tname(leaves: List[Leaf]) -> Tuple[List[List[Leaf]], int]:
+ result: List[List[Leaf]] = [[]]
+ has_vbar = False
+ has_bracket = False
+ first_colon_idx: Optional[int] = None
+ matching_par: Optional[Leaf] = None
+
+ for idx, leaf in enumerate(leaves):
+ # We only want to trigger on brackets/colon/vbar/equal not inside other brackets
+ if matching_par is None:
+ if leaf.type in OPENING_BRACKETS:
+ matching_par = leaf
+ has_bracket = True
+ elif leaf.type == token.COLON and first_colon_idx is not None:
+ raise CannotSplit(
+ "This function should only be used on a single typehint. It will"
+ " run later after the line has been split by delimiter_split"
+ )
+ elif leaf.type == token.COLON:
+ first_colon_idx = idx
+ elif leaf.type in (token.VBAR, token.EQUAL):
+ result.append([])
+ has_vbar |= leaf.type == token.VBAR
+ elif leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is matching_par:
+ matching_par = None
+ result[-1].append(leaf)
+
+ assert first_colon_idx is not None, "All tnames should have a colon"
+
+ # don't split non-typehints, and non-vbar typehints with brackets are better split
+ # conventionally
+ if not has_vbar and has_bracket:
+ raise CannotSplit("Non-vbar typehint with brackets.")
+
+ return result, first_colon_idx
+
+
+def _split_first_typehint(original_line: Line, leaves: List[Leaf]) -> Iterator[Line]:
+ """
+ parse the first set of leaves from a tname, check if COLON TYPENAME is on the same
+ line, with no brackets, in which case we want to move the TYPENAME to a different
+ line.
+
+ In that case, this function will yield nothing. Otherwise it will yield that first
+ line of code, plus any additional lines caused by comments associated with the line.
+ This is messier than just returning a bool, but avoids having to reparse the line.
+ """
+
+ lines: List[Line] = []
+
+ current_line = Line(
+ original_line.mode,
+ depth=original_line.depth,
+ inside_brackets=original_line.inside_brackets,
+ )
+ for leaf in leaves:
+ lines.extend(_append_to_line(leaf, current_line, depth=original_line.depth + 1))
+ for comment in original_line.comments_after(leaf):
+ lines.extend(
+ _append_to_line(comment, current_line, depth=original_line.depth + 1)
+ )
+ lines.append(current_line)
+
+ for line in lines:
+ has_brackets = False
+ has_colon = False
+ has_type_after_colon = False
+ for leaf in line.leaves:
+ has_brackets |= leaf.type in OPENING_BRACKETS
+ has_type_after_colon |= has_colon and leaf.type == token.NAME
+ has_colon |= leaf.type == token.COLON
+ # this line is too long, has no brackets to split by, and contains `: typename`,
+ # so we should split and reparse it
+ if (
+ has_colon
+ and has_type_after_colon
+ and not has_brackets
+ and not is_line_short_enough(line, mode=original_line.mode)
+ ):
+ return
+ else:
+ yield from lines
+
+
+def _append_to_line(leaf: Leaf, current_line: Line, **kwargs: Any) -> Iterator[Line]:
+ """Append `leaf` to current line or to new line if appending impossible."""
+ try:
+ current_line.append_safe(leaf, preformatted=True)
+ except ValueError:
+ yield current_line.deep_copy()
+
+ current_line.reset(**kwargs)
+ current_line.append(leaf)
+
+
@dont_increase_indentation
def delimiter_split(
line: Line, features: Collection[Feature], mode: Mode
diff --git a/src/black/lines.py b/src/black/lines.py
index 71b657a0654..7317f730496 100644
--- a/src/black/lines.py
+++ b/src/black/lines.py
@@ -59,6 +59,16 @@ class Line:
should_split_rhs: bool = False
magic_trailing_comma: Optional[Leaf] = None
+ def reset(self, depth: Optional[int] = None) -> None:
+ if depth is not None:
+ self.depth = depth
+ self.leaves = []
+ self.comments = {}
+ self.bracket_tracker = BracketTracker()
+ self.inside_brackets = False
+ self.should_split_rhs = False
+ self.magic_trailing_comma = None
+
def append(
self, leaf: Leaf, preformatted: bool = False, track_bracket: bool = False
) -> None:
@@ -456,6 +466,18 @@ def enumerate_with_length(
yield index, leaf, length
+ def deep_copy(self) -> "Line":
+ return Line(
+ mode=self.mode,
+ depth=self.depth,
+ leaves=self.leaves.copy(),
+ comments=self.comments.copy(),
+ bracket_tracker=self.bracket_tracker.copy(),
+ inside_brackets=self.inside_brackets,
+ should_split_rhs=self.should_split_rhs,
+ magic_trailing_comma=self.magic_trailing_comma,
+ )
+
def clone(self) -> "Line":
return Line(
mode=self.mode,
diff --git a/src/black/mode.py b/src/black/mode.py
index 30c5d2f1b2f..dc14f7c2065 100644
--- a/src/black/mode.py
+++ b/src/black/mode.py
@@ -181,6 +181,7 @@ class Preview(Enum):
string_processing = auto()
parenthesize_conditional_expressions = auto()
parenthesize_long_type_hints = auto()
+ split_long_param_type_without_parens = auto()
respect_magic_trailing_comma_in_return_type = auto()
skip_magic_trailing_comma_in_subscript = auto()
wrap_long_dict_values_in_parens = auto()
diff --git a/tests/data/preview_py_310/pep604_union_types_line_breaks.py b/tests/data/preview_py_310/pep604_union_types_line_breaks.py
index 9c4ab870766..9bb172e3047 100644
--- a/tests/data/preview_py_310/pep604_union_types_line_breaks.py
+++ b/tests/data/preview_py_310/pep604_union_types_line_breaks.py
@@ -83,6 +83,54 @@ def f(
...
+# don't lose comments
+def f( # 1
+ looooooooooooooooooooooooooong # 2
+ : # 3
+ Loooooooooooooooooooooooooooooooooooooooooooooooong # 4
+ | # 5
+ Loooooooooooooooooooooooooooooooooooooooooooooooong # 6
+ = # 7
+ 3 # 8
+ ): # 9
+ ... # 10
+
+
+loooooooooooooooooooooooooooooong: (
+ Loooooooooooooooooooooooooooong # aoeuaoeuaoeuaoeuaoeuaoeu
+)
+
+
+def foo(
+ loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong:
+ Loooooooooooooooooooooooooooong,
+): ...
+
+
+def foo(
+ loooooooooooooooong: Tuple[
+ Looooooooooooooooooooooooooooong, Loooooooooooooong
+ ], # aoeuaoeuaoeu
+): ...
+
+
+def foo(a: loooooooooooooong| loooooooooooooooooooooooooooong| looooooooooong| looooooooooong = 3): ...
+def foo(c: (loooooooooooooong| loooooooooooooooooooooooooooong| looooooooooong| looooooooooong) = 4): ...
+def foo(bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb: bbbbbbbbbbbbbbbbbbbbloooooooooooooong| loooooooooooooooooooooooooooong| looooooooooong| looooooooooong = 5): ...
+def foo(not_too_long_param_name: not_too_long_typehint # but there's a reasonably long comment here
+ ): ...
+def foo(not_too_long_param_name: not_too_long_typehint # but there's a reasonably long comment here
+ | loooooooooooooooooooooooooooong| looooooooooong| looooooooooong = 5): ...
+def foo(a: tuple[int, int,] = (7, 3,)): ...
+def foo(a: (
+ # comment_inside_paren
+ loooooooooooooong
+ # comment2
+ |
+ # comment3
+ loooooooooooooong
+ )): ...
+
# output
# This has always worked
z = (
@@ -158,12 +206,10 @@ def foo(i: (int,)) -> None: ...
def foo(
i: int,
- x: (
- Loooooooooooooooooooooooong
+ x: Loooooooooooooooooooooooong
| Looooooooooooooooong
| Looooooooooooooooooooong
- | Looooooong
- ),
+ | Looooooong,
*,
s: str,
) -> None:
@@ -172,16 +218,100 @@ def foo(
@app.get("/path/")
async def foo(
- q: str | None = Query(
- None, title="Some long title", description="Some long description"
- )
+ q: str
+ | None
+ = Query(None, title="Some long title", description="Some long description")
):
pass
def f(
- max_jobs: int | None = Option(
- None, help="Maximum number of jobs to launch. And some additional text."
- ),
+ max_jobs: int
+ | None
+ = Option(
+ None, help="Maximum number of jobs to launch. And some additional text."
+ ),
another_option: bool = False,
): ...
+
+
+# don't lose comments
+def f( # 1
+ looooooooooooooooooooooooooong: # 2 # 3
+ Loooooooooooooooooooooooooooooooooooooooooooooooong # 4
+ | Loooooooooooooooooooooooooooooooooooooooooooooooong # 5 # 6
+ = 3, # 7 # 8
+): # 9
+ ... # 10
+
+
+loooooooooooooooooooooooooooooong: (
+ Loooooooooooooooooooooooooooong # aoeuaoeuaoeuaoeuaoeuaoeu
+)
+
+
+def foo(
+ loooooooooooooooooooooooooooooooooooooooooooooooooooooooooooong:
+ Loooooooooooooooooooooooooooong,
+): ...
+
+
+def foo(
+ loooooooooooooooong: Tuple[
+ Looooooooooooooooooooooooooooong, Loooooooooooooong
+ ], # aoeuaoeuaoeu
+): ...
+
+
+def foo(
+ a: loooooooooooooong
+ | loooooooooooooooooooooooooooong
+ | looooooooooong
+ | looooooooooong
+ = 3,
+): ...
+def foo(
+ c: loooooooooooooong
+ | loooooooooooooooooooooooooooong
+ | looooooooooong
+ | looooooooooong
+ = 4,
+): ...
+def foo(
+ bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb:
+ bbbbbbbbbbbbbbbbbbbbloooooooooooooong
+ | loooooooooooooooooooooooooooong
+ | looooooooooong
+ | looooooooooong
+ = 5,
+): ...
+def foo(
+ not_too_long_param_name:
+ not_too_long_typehint, # but there's a reasonably long comment here
+): ...
+def foo(
+ not_too_long_param_name:
+ not_too_long_typehint # but there's a reasonably long comment here
+ | loooooooooooooooooooooooooooong
+ | looooooooooong
+ | looooooooooong
+ = 5,
+): ...
+def foo(
+ a: tuple[
+ int,
+ int,
+ ] = (
+ 7,
+ 3,
+ )
+): ...
+def foo(
+ a:
+ # comment_inside_paren
+ loooooooooooooong
+ # comment2
+ |
+ # comment3
+ loooooooooooooong,
+): ...