Skip to content

Commit

Permalink
Support formatting specified lines (#4020)
Browse files Browse the repository at this point in the history
  • Loading branch information
yilei committed Nov 7, 2023
1 parent ecbd9e8 commit 46be1f8
Show file tree
Hide file tree
Showing 20 changed files with 1,358 additions and 28 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Expand Up @@ -6,6 +6,9 @@

<!-- Include any especially major or disruptive changes here -->

- Support formatting ranges of lines with the new `--line-ranges` command-line option
(#4020).

### Stable style

- Fix crash on formatting bytes strings that look like docstrings (#4003)
Expand Down
17 changes: 17 additions & 0 deletions docs/usage_and_configuration/the_basics.md
Expand Up @@ -175,6 +175,23 @@ All done! ✨ 🍰 ✨
1 file would be reformatted.
```

### `--line-ranges`

When specified, _Black_ will try its best to only format these lines.

This option can be specified multiple times, and a union of the lines will be formatted.
Each range must be specified as two integers connected by a `-`: `<START>-<END>`. The
`<START>` and `<END>` integer indices are 1-based and inclusive on both ends.

_Black_ may still format lines outside of the ranges for multi-line statements.
Formatting more than one file or any ipynb files with this option is not supported. This
option cannot be specified in the `pyproject.toml` config.

Example: `black --line-ranges=1-10 --line-ranges=21-30 test.py` will format lines from
`1` to `10` and `21` to `30`.

This option is mainly for editor integrations, such as "Format Selection".

#### `--color` / `--no-color`

Show (or do not show) colored diff. Only applies when `--diff` is given.
Expand Down
130 changes: 109 additions & 21 deletions src/black/__init__.py
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
from typing import (
Any,
Collection,
Dict,
Generator,
Iterator,
Expand Down Expand Up @@ -77,6 +78,7 @@
from black.output import color_diff, diff, dump_to_file, err, ipynb_diff, out
from black.parsing import InvalidInput # noqa F401
from black.parsing import lib2to3_parse, parse_ast, stringify_ast
from black.ranges import adjusted_lines, convert_unchanged_lines, parse_line_ranges
from black.report import Changed, NothingChanged, Report
from black.trans import iter_fexpr_spans
from blib2to3.pgen2 import token
Expand Down Expand Up @@ -163,6 +165,12 @@ def read_pyproject_toml(
"extend-exclude", "Config key extend-exclude must be a string"
)

line_ranges = config.get("line_ranges")
if line_ranges is not None:
raise click.BadOptionUsage(
"line-ranges", "Cannot use line-ranges in the pyproject.toml file."
)

default_map: Dict[str, Any] = {}
if ctx.default_map:
default_map.update(ctx.default_map)
Expand Down Expand Up @@ -304,6 +312,19 @@ def validate_regex(
is_flag=True,
help="Don't write the files back, just output a diff for each file on stdout.",
)
@click.option(
"--line-ranges",
multiple=True,
metavar="START-END",
help=(
"When specified, _Black_ will try its best to only format these lines. This"
" option can be specified multiple times, and a union of the lines will be"
" formatted. Each range must be specified as two integers connected by a `-`:"
" `<START>-<END>`. The `<START>` and `<END>` integer indices are 1-based and"
" inclusive on both ends."
),
default=(),
)
@click.option(
"--color/--no-color",
is_flag=True,
Expand Down Expand Up @@ -443,6 +464,7 @@ def main( # noqa: C901
target_version: List[TargetVersion],
check: bool,
diff: bool,
line_ranges: Sequence[str],
color: bool,
fast: bool,
pyi: bool,
Expand Down Expand Up @@ -544,6 +566,18 @@ def main( # noqa: C901
python_cell_magics=set(python_cell_magics),
)

lines: List[Tuple[int, int]] = []
if line_ranges:
if ipynb:
err("Cannot use --line-ranges with ipynb files.")
ctx.exit(1)

try:
lines = parse_line_ranges(line_ranges)
except ValueError as e:
err(str(e))
ctx.exit(1)

if code is not None:
# Run in quiet mode by default with -c; the extra output isn't useful.
# You can still pass -v to get verbose output.
Expand All @@ -553,7 +587,12 @@ def main( # noqa: C901

if code is not None:
reformat_code(
content=code, fast=fast, write_back=write_back, mode=mode, report=report
content=code,
fast=fast,
write_back=write_back,
mode=mode,
report=report,
lines=lines,
)
else:
assert root is not None # root is only None if code is not None
Expand Down Expand Up @@ -588,10 +627,14 @@ def main( # noqa: C901
write_back=write_back,
mode=mode,
report=report,
lines=lines,
)
else:
from black.concurrency import reformat_many

if lines:
err("Cannot use --line-ranges to format multiple files.")
ctx.exit(1)
reformat_many(
sources=sources,
fast=fast,
Expand Down Expand Up @@ -714,7 +757,13 @@ def path_empty(


def reformat_code(
content: str, fast: bool, write_back: WriteBack, mode: Mode, report: Report
content: str,
fast: bool,
write_back: WriteBack,
mode: Mode,
report: Report,
*,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""
Reformat and print out `content` without spawning child processes.
Expand All @@ -727,7 +776,7 @@ def reformat_code(
try:
changed = Changed.NO
if format_stdin_to_stdout(
content=content, fast=fast, write_back=write_back, mode=mode
content=content, fast=fast, write_back=write_back, mode=mode, lines=lines
):
changed = Changed.YES
report.done(path, changed)
Expand All @@ -741,7 +790,13 @@ def reformat_code(
# not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26
@mypyc_attr(patchable=True)
def reformat_one(
src: Path, fast: bool, write_back: WriteBack, mode: Mode, report: "Report"
src: Path,
fast: bool,
write_back: WriteBack,
mode: Mode,
report: "Report",
*,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""Reformat a single file under `src` without spawning child processes.
Expand All @@ -766,15 +821,17 @@ def reformat_one(
mode = replace(mode, is_pyi=True)
elif src.suffix == ".ipynb":
mode = replace(mode, is_ipynb=True)
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
if format_stdin_to_stdout(
fast=fast, write_back=write_back, mode=mode, lines=lines
):
changed = Changed.YES
else:
cache = Cache.read(mode)
if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
if not cache.is_changed(src):
changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place(
src, fast=fast, write_back=write_back, mode=mode
src, fast=fast, write_back=write_back, mode=mode, lines=lines
):
changed = Changed.YES
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
Expand All @@ -794,6 +851,8 @@ def format_file_in_place(
mode: Mode,
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
*,
lines: Collection[Tuple[int, int]] = (),
) -> bool:
"""Format file under `src` path. Return True if changed.
Expand All @@ -813,7 +872,9 @@ def format_file_in_place(
header = buf.readline()
src_contents, encoding, newline = decode_bytes(buf.read())
try:
dst_contents = format_file_contents(src_contents, fast=fast, mode=mode)
dst_contents = format_file_contents(
src_contents, fast=fast, mode=mode, lines=lines
)
except NothingChanged:
return False
except JSONDecodeError:
Expand Down Expand Up @@ -858,6 +919,7 @@ def format_stdin_to_stdout(
content: Optional[str] = None,
write_back: WriteBack = WriteBack.NO,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
) -> bool:
"""Format file on stdin. Return True if changed.
Expand All @@ -876,7 +938,7 @@ def format_stdin_to_stdout(

dst = src
try:
dst = format_file_contents(src, fast=fast, mode=mode)
dst = format_file_contents(src, fast=fast, mode=mode, lines=lines)
return True

except NothingChanged:
Expand Down Expand Up @@ -904,7 +966,11 @@ def format_stdin_to_stdout(


def check_stability_and_equivalence(
src_contents: str, dst_contents: str, *, mode: Mode
src_contents: str,
dst_contents: str,
*,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
) -> None:
"""Perform stability and equivalence checks.
Expand All @@ -913,10 +979,16 @@ def check_stability_and_equivalence(
content differently.
"""
assert_equivalent(src_contents, dst_contents)
assert_stable(src_contents, dst_contents, mode=mode)
assert_stable(src_contents, dst_contents, mode=mode, lines=lines)


def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileContent:
def format_file_contents(
src_contents: str,
*,
fast: bool,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
) -> FileContent:
"""Reformat contents of a file and return new contents.
If `fast` is False, additionally confirm that the reformatted code is
Expand All @@ -926,13 +998,15 @@ def format_file_contents(src_contents: str, *, fast: bool, mode: Mode) -> FileCo
if mode.is_ipynb:
dst_contents = format_ipynb_string(src_contents, fast=fast, mode=mode)
else:
dst_contents = format_str(src_contents, mode=mode)
dst_contents = format_str(src_contents, mode=mode, lines=lines)
if src_contents == dst_contents:
raise NothingChanged

if not fast and not mode.is_ipynb:
# Jupyter notebooks will already have been checked above.
check_stability_and_equivalence(src_contents, dst_contents, mode=mode)
check_stability_and_equivalence(
src_contents, dst_contents, mode=mode, lines=lines
)
return dst_contents


Expand Down Expand Up @@ -1043,7 +1117,9 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon
raise NothingChanged


def format_str(src_contents: str, *, mode: Mode) -> str:
def format_str(
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
) -> str:
"""Reformat a string and return new contents.
`mode` determines formatting options, such as how many characters per line are
Expand Down Expand Up @@ -1073,16 +1149,20 @@ def f(
hey
"""
dst_contents = _format_str_once(src_contents, mode=mode)
dst_contents = _format_str_once(src_contents, mode=mode, lines=lines)
# Forced second pass to work around optional trailing commas (becoming
# forced trailing commas on pass 2) interacting differently with optional
# parentheses. Admittedly ugly.
if src_contents != dst_contents:
return _format_str_once(dst_contents, mode=mode)
if lines:
lines = adjusted_lines(lines, src_contents, dst_contents)
return _format_str_once(dst_contents, mode=mode, lines=lines)
return dst_contents


def _format_str_once(src_contents: str, *, mode: Mode) -> str:
def _format_str_once(
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_blocks: List[LinesBlock] = []
if mode.target_versions:
Expand All @@ -1097,15 +1177,19 @@ def _format_str_once(src_contents: str, *, mode: Mode) -> str:
if supports_feature(versions, feature)
}
normalize_fmt_off(src_node, mode)
lines = LineGenerator(mode=mode, features=context_manager_features)
if lines:
# This should be called after normalize_fmt_off.
convert_unchanged_lines(src_node, lines)

line_generator = LineGenerator(mode=mode, features=context_manager_features)
elt = EmptyLineTracker(mode=mode)
split_line_features = {
feature
for feature in {Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF}
if supports_feature(versions, feature)
}
block: Optional[LinesBlock] = None
for current_line in lines.visit(src_node):
for current_line in line_generator.visit(src_node):
block = elt.maybe_empty_lines(current_line)
dst_blocks.append(block)
for line in transform_line(
Expand Down Expand Up @@ -1373,12 +1457,16 @@ def assert_equivalent(src: str, dst: str) -> None:
) from None


def assert_stable(src: str, dst: str, mode: Mode) -> None:
def assert_stable(
src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = ()
) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
# We shouldn't call format_str() here, because that formats the string
# twice and may hide a bug where we bounce back and forth between two
# versions.
newdst = _format_str_once(dst, mode=mode)
if lines:
lines = adjusted_lines(lines, src, dst)
newdst = _format_str_once(dst, mode=mode, lines=lines)
if dst != newdst:
log = dump_to_file(
str(mode),
Expand Down
28 changes: 28 additions & 0 deletions src/black/nodes.py
Expand Up @@ -935,3 +935,31 @@ def is_part_of_annotation(leaf: Leaf) -> bool:
return True
ancestor = ancestor.parent
return False


def first_leaf(node: LN) -> Optional[Leaf]:
"""Returns the first leaf of the ancestor node."""
if isinstance(node, Leaf):
return node
elif not node.children:
return None
else:
return first_leaf(node.children[0])


def last_leaf(node: LN) -> Optional[Leaf]:
"""Returns the last leaf of the ancestor node."""
if isinstance(node, Leaf):
return node
elif not node.children:
return None
else:
return last_leaf(node.children[-1])


def furthest_ancestor_with_last_leaf(leaf: Leaf) -> LN:
"""Returns the furthest ancestor that has this leaf node as the last leaf."""
node: LN = leaf
while node.parent and node.parent.children and node is node.parent.children[-1]:
node = node.parent
return node

0 comments on commit 46be1f8

Please sign in to comment.