Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to force unions (and options) to be rendered with bars #418

Merged
merged 9 commits into from
Feb 8, 2024
10 changes: 9 additions & 1 deletion src/sphinx_autodoc_typehints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
args_format = "\\[{}]"
formatted_args: str | None = ""

always_use_bars_union: bool = getattr(config, "always_use_bars_union", True)
is_bars_union = full_name == "types.UnionType" or (
always_use_bars_union and type(annotation).__qualname__ == "_UnionGenericAlias"
)
if is_bars_union:
full_name = ""

# Some types require special handling
if full_name == "typing.NewType":
args_format = f"\\(``{annotation.__name__}``, {{}})"
Expand Down Expand Up @@ -248,7 +255,7 @@ def format_annotation(annotation: Any, config: Config) -> str: # noqa: C901, PL
formatted_args = f"\\[\\[{', '.join(fmt[:-1])}], {fmt[-1]}]"
elif full_name == "typing.Literal":
formatted_args = f"\\[{', '.join(f'``{arg!r}``' for arg in args)}]"
elif full_name == "types.UnionType":
elif is_bars_union:
return " | ".join([format_annotation(arg, config) for arg in args])

if args and not formatted_args:
Expand Down Expand Up @@ -929,6 +936,7 @@ def setup(app: Sphinx) -> dict[str, bool]:
app.add_config_value("typehints_use_rtype", True, "env") # noqa: FBT003
app.add_config_value("typehints_defaults", None, "env")
app.add_config_value("simplify_optional_unions", True, "env") # noqa: FBT003
app.add_config_value("always_use_bars_union", False, "env") # noqa: FBT003
app.add_config_value("typehints_formatter", None, "env")
app.add_config_value("typehints_use_signature", False, "env") # noqa: FBT003
app.add_config_value("typehints_use_signature_return", False, "env") # noqa: FBT003
Expand Down
46 changes: 39 additions & 7 deletions tests/test_sphinx_autodoc_typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_parse_annotation(annotation: Any, module: str, class_name: str, args: t

@pytest.mark.parametrize(("annotation", "expected_result"), _CASES)
def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str) -> None:
conf = create_autospec(Config, _annotation_globals=globals())
conf = create_autospec(Config, _annotation_globals=globals(), always_use_bars_union=False)
result = format_annotation(annotation, conf)
assert result == expected_result

Expand All @@ -377,7 +377,12 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
# encapsulate Union in typing.Optional
expected_result_not_simplified = ":py:data:`~typing.Optional`\\ \\[" + expected_result_not_simplified
expected_result_not_simplified += "]"
conf = create_autospec(Config, simplify_optional_unions=False, _annotation_globals=globals())
conf = create_autospec(
Config,
simplify_optional_unions=False,
_annotation_globals=globals(),
always_use_bars_union=False,
)
assert format_annotation(annotation, conf) == expected_result_not_simplified

# Test with the "fully_qualified" flag turned on
Expand All @@ -397,7 +402,12 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
expected_result = expected_result.replace("~nptyping", "nptyping")
expected_result = expected_result.replace("~numpy", "numpy")
expected_result = expected_result.replace("~" + __name__, __name__)
conf = create_autospec(Config, typehints_fully_qualified=True, _annotation_globals=globals())
conf = create_autospec(
Config,
typehints_fully_qualified=True,
_annotation_globals=globals(),
always_use_bars_union=False,
)
assert format_annotation(annotation, conf) == expected_result

# Test for the correct role (class vs data) using the official Sphinx inventory
Expand All @@ -413,6 +423,26 @@ def test_format_annotation(inv: Inventory, annotation: Any, expected_result: str
assert m.group("role") == expected_role


@pytest.mark.parametrize(
("annotation", "expected_result"),
[
("int | float", ":py:class:`int` | :py:class:`float`"),
("int | float | None", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Union[int, float]", ":py:class:`int` | :py:class:`float`"),
("Union[int, float, None]", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Optional[int | float]", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Optional[Union[int, float]]", ":py:class:`int` | :py:class:`float` | :py:obj:`None`"),
("Union[int | float, str]", ":py:class:`int` | :py:class:`float` | :py:class:`str`"),
("Union[int, float] | str", ":py:class:`int` | :py:class:`float` | :py:class:`str`"),
],
)
@pytest.mark.skipif(not PY310_PLUS, reason="| union doesn't work before py310")
def test_always_use_bars_union(annotation: str, expected_result: str) -> None:
conf = create_autospec(Config, always_use_bars_union=True)
result = format_annotation(eval(annotation), conf) # noqa: S307
assert result == expected_result


@pytest.mark.parametrize("library", [typing, typing_extensions], ids=["typing", "typing_extensions"])
@pytest.mark.parametrize(
("annotation", "params", "expected_result"),
Expand Down Expand Up @@ -519,12 +549,13 @@ class dummy_module.DataClass(x)


def maybe_fix_py310(expected_contents: str) -> str:
if sys.version_info >= (3, 11):
return expected_contents
if not PY310_PLUS:
return expected_contents.replace('"', "")

for old, new in [
("bool | None", '"Optional"["bool"]'),
("str | None", '"Optional"["str"]'),
('"str" | "None"', '"Optional"["str"]'),
]:
expected_contents = expected_contents.replace(old, new)
return expected_contents
Expand All @@ -550,15 +581,16 @@ def test_sphinx_output_future_annotations(app: SphinxTestApp, status: StringIO)
Method docstring.

Parameters:
* **x** (bool | None) -- foo
* **x** ("bool" | "None") -- foo

* **y** ("int" | "str" | "float") -- bar

* **z** (str | None) -- baz
* **z** ("str" | "None") -- baz

Return type:
"str"
"""
expected_contents = dedent(expected_contents)
expected_contents = maybe_fix_py310(dedent(expected_contents))
assert contents == expected_contents

Expand Down