From 58d5aad7ba8bec31cf1ff5dc4a758fabc9b7fe16 Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Thu, 10 Aug 2023 03:52:43 -0400 Subject: [PATCH] TYP: Add common type overloads of subplot_mosaic I'll assert, without proof, that passing a single string, a mosaic list of strings, or a mosaic list all of the same type is more common than passing arbitrary unrelated hashables. Thus it is somewhat convenient if the return type stipulates that the resulting dictionary is also keyed with strings or the common type. This also fixes the type of the `per_subplot_kw` argument, which also allows dictionary keys of tuples of the entries. --- doc/missing-references.json | 6 ++++ lib/matplotlib/figure.py | 8 ++--- lib/matplotlib/figure.pyi | 52 +++++++++++++++++++++------- lib/matplotlib/pyplot.py | 68 +++++++++++++++++++++++++++++++++---- lib/matplotlib/typing.py | 5 +-- 5 files changed, 112 insertions(+), 27 deletions(-) diff --git a/doc/missing-references.json b/doc/missing-references.json index 1061b08b7fe6..364885ab3a74 100644 --- a/doc/missing-references.json +++ b/doc/missing-references.json @@ -152,6 +152,9 @@ "HashableList": [ "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.subplot_mosaic:1" ], + "HashableList[_HT]": [ + "doc/docstring of builtins.list:17" + ], "LineStyleType": [ "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.eventplot:1", "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.hlines:1", @@ -701,6 +704,9 @@ "matplotlib.animation.TimedAnimation.to_jshtml": [ "doc/api/_as_gen/matplotlib.animation.TimedAnimation.rst:28::1" ], + "matplotlib.typing._HT": [ + "doc/docstring of builtins.list:17" + ], "mpl_toolkits.axislines.Axes": [ "lib/mpl_toolkits/axisartist/axis_artist.py:docstring of mpl_toolkits.axisartist.axis_artist:7" ], diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 9d929acf688c..d37dda204278 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1749,15 +1749,11 @@ def _norm_per_subplot_kw(per_subplot_kw): if isinstance(k, tuple): for sub_key in k: if sub_key in expanded: - raise ValueError( - f'The key {sub_key!r} appears multiple times.' - ) + raise ValueError(f'The key {sub_key!r} appears multiple times.') expanded[sub_key] = v else: if k in expanded: - raise ValueError( - f'The key {k!r} appears multiple times.' - ) + raise ValueError(f'The key {k!r} appears multiple times.') expanded[k] = v return expanded diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index f53feb78014d..687ae9e500d0 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -1,4 +1,9 @@ +from collections.abc import Callable, Hashable, Iterable import os +from typing import Any, IO, Literal, TypeVar, overload + +import numpy as np +from numpy.typing import ArrayLike from matplotlib.artist import Artist from matplotlib.axes import Axes, SubplotBase @@ -19,14 +24,10 @@ from matplotlib.lines import Line2D from matplotlib.patches import Rectangle, Patch from matplotlib.text import Text from matplotlib.transforms import Affine2D, Bbox, BboxBase, Transform - -import numpy as np -from numpy.typing import ArrayLike - -from collections.abc import Callable, Iterable -from typing import Any, IO, Literal, overload from .typing import ColorType, HashableList +_T = TypeVar("_T") + class FigureBase(Artist): artists: list[Artist] lines: list[Line2D] @@ -200,11 +201,38 @@ class FigureBase(Artist): *, bbox_extra_artists: Iterable[Artist] | None = ..., ) -> Bbox: ... - - # Any in list of list is recursive list[list[Hashable | list[Hashable | ...]]] but that can't really be type checked + @overload def subplot_mosaic( self, - mosaic: str | HashableList, + mosaic: str, + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: str = ..., + subplot_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> dict[str, Axes]: ... + @overload + def subplot_mosaic( + self, + mosaic: list[HashableList[_T]], + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: _T = ..., + subplot_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[_T | tuple[_T, ...], dict[str, Any]] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> dict[_T, Axes]: ... + @overload + def subplot_mosaic( + self, + mosaic: list[HashableList[Hashable]], *, sharex: bool = ..., sharey: bool = ..., @@ -212,9 +240,9 @@ class FigureBase(Artist): height_ratios: ArrayLike | None = ..., empty_sentinel: Any = ..., subplot_kw: dict[str, Any] | None = ..., - per_subplot_kw: dict[Any, dict[str, Any]] | None = ..., - gridspec_kw: dict[str, Any] | None = ... - ) -> dict[Any, Axes]: ... + per_subplot_kw: dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> dict[Hashable, Axes]: ... class SubFigure(FigureBase): figure: Figure diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index cc18b6b21bf0..00e5dea071a4 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -125,6 +125,7 @@ _P = ParamSpec('_P') _R = TypeVar('_R') + _T = TypeVar('_T') # We may not need the following imports here: @@ -1602,8 +1603,56 @@ def subplots( return fig, axs +@overload +def subplot_mosaic( + mosaic: str, + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: str = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | None = ..., + **fig_kw: Any +) -> tuple[Figure, dict[str, matplotlib.axes.Axes]]: ... + + +@overload +def subplot_mosaic( + mosaic: list[HashableList[_T]], + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: _T = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[_T | tuple[_T, ...], dict[str, Any]] | None = ..., + **fig_kw: Any +) -> tuple[Figure, dict[_T, matplotlib.axes.Axes]]: ... + + +@overload +def subplot_mosaic( + mosaic: list[HashableList[Hashable]], + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: Any = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = ..., + **fig_kw: Any +) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: ... + + def subplot_mosaic( - mosaic: str | HashableList, + mosaic: str | list[HashableList[_T]] | list[HashableList[Hashable]], *, sharex: bool = False, sharey: bool = False, @@ -1612,9 +1661,13 @@ def subplot_mosaic( empty_sentinel: Any = '.', subplot_kw: dict[str, Any] | None = None, gridspec_kw: dict[str, Any] | None = None, - per_subplot_kw: dict[Hashable, dict[str, Any]] | None = None, - **fig_kw -) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: + per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | + dict[_T | tuple[_T, ...], dict[str, Any]] | + dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = None, + **fig_kw: Any +) -> tuple[Figure, dict[str, matplotlib.axes.Axes]] | \ + tuple[Figure, dict[_T, matplotlib.axes.Axes]] | \ + tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: """ Build a layout of Axes based on ASCII art or nested lists. @@ -1716,12 +1769,13 @@ def subplot_mosaic( """ fig = figure(**fig_kw) - ax_dict = fig.subplot_mosaic( - mosaic, sharex=sharex, sharey=sharey, + ax_dict = fig.subplot_mosaic( # type: ignore[misc] + mosaic, # type: ignore[arg-type] + sharex=sharex, sharey=sharey, height_ratios=height_ratios, width_ratios=width_ratios, subplot_kw=subplot_kw, gridspec_kw=gridspec_kw, empty_sentinel=empty_sentinel, - per_subplot_kw=per_subplot_kw, + per_subplot_kw=per_subplot_kw, # type: ignore[arg-type] ) return fig, ax_dict diff --git a/lib/matplotlib/typing.py b/lib/matplotlib/typing.py index e6b9ada7d8a2..02059be94ba2 100644 --- a/lib/matplotlib/typing.py +++ b/lib/matplotlib/typing.py @@ -11,7 +11,7 @@ """ from collections.abc import Hashable, Sequence import pathlib -from typing import Any, Literal, Union +from typing import Any, Literal, TypeVar, Union from . import path from ._enums import JoinStyle, CapStyle @@ -55,5 +55,6 @@ Sequence[Union[str, pathlib.Path, dict[str, Any]]], ] -HashableList = list[Union[Hashable, "HashableList"]] +_HT = TypeVar("_HT", bound=Hashable) +HashableList = list[Union[_HT, "HashableList[_HT]"]] """A nested list of Hashable values."""