From 6fcee0401c0897c5c3e318ac7fc432245ce917b8 Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Thu, 10 Aug 2023 03:52:43 -0400 Subject: [PATCH] TYP: Add common type overloads of subplot_mosaic I'll assert, without proof, that passing a single string, a mosaic list of strings, or a mosaic list all of the same type is more common than passing arbitrary unrelated hashables. Thus it is somewhat convenient if the return type stipulates that the resulting dictionary is also keyed with strings or the common type. This also fixes the type of the `per_subplot_kw` argument, which also allows dictionary keys of tuples of the entries. --- lib/matplotlib/figure.py | 8 ++--- lib/matplotlib/figure.pyi | 52 +++++++++++++++++++++++------- lib/matplotlib/pyplot.py | 68 +++++++++++++++++++++++++++++++++++---- lib/matplotlib/typing.py | 5 +-- 4 files changed, 106 insertions(+), 27 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index ce263c3d8d1c..d4661f8b5195 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1798,15 +1798,11 @@ def _norm_per_subplot_kw(per_subplot_kw): if isinstance(k, tuple): for sub_key in k: if sub_key in expanded: - raise ValueError( - f'The key {sub_key!r} appears multiple times.' - ) + raise ValueError(f'The key {sub_key!r} appears multiple times.') expanded[sub_key] = v else: if k in expanded: - raise ValueError( - f'The key {k!r} appears multiple times.' - ) + raise ValueError(f'The key {k!r} appears multiple times.') expanded[k] = v return expanded diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index 40e8fce0321f..1de0adc58a78 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -1,4 +1,9 @@ +from collections.abc import Callable, Hashable, Iterable import os +from typing import Any, IO, Literal, TypeVar, overload + +import numpy as np +from numpy.typing import ArrayLike from matplotlib.artist import Artist from matplotlib.axes import Axes, SubplotBase @@ -19,14 +24,10 @@ from matplotlib.lines import Line2D from matplotlib.patches import Rectangle, Patch from matplotlib.text import Text from matplotlib.transforms import Affine2D, Bbox, BboxBase, Transform - -import numpy as np -from numpy.typing import ArrayLike - -from collections.abc import Callable, Iterable -from typing import Any, IO, Literal, overload from .typing import ColorType, HashableList +_T = TypeVar("_T") + class SubplotParams: def __init__( self, @@ -226,11 +227,38 @@ class FigureBase(Artist): *, bbox_extra_artists: Iterable[Artist] | None = ..., ) -> Bbox: ... - - # Any in list of list is recursive list[list[Hashable | list[Hashable | ...]]] but that can't really be type checked + @overload def subplot_mosaic( self, - mosaic: str | HashableList, + mosaic: str, + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: str = ..., + subplot_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[str | tuple[str], dict[str, Any]] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> dict[str, Axes]: ... + @overload + def subplot_mosaic( + self, + mosaic: list[HashableList[_T]], + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: _T = ..., + subplot_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[_T | tuple[_T], dict[str, Any]] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> dict[_T, Axes]: ... + @overload + def subplot_mosaic( + self, + mosaic: list[HashableList[Hashable]], *, sharex: bool = ..., sharey: bool = ..., @@ -238,9 +266,9 @@ class FigureBase(Artist): height_ratios: ArrayLike | None = ..., empty_sentinel: Any = ..., subplot_kw: dict[str, Any] | None = ..., - per_subplot_kw: dict[Any, dict[str, Any]] | None = ..., - gridspec_kw: dict[str, Any] | None = ... - ) -> dict[Any, Axes]: ... + per_subplot_kw: dict[Hashable | tuple[Hashable], dict[str, Any]] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> dict[Hashable, Axes]: ... class SubFigure(FigureBase): figure: Figure diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 11a998334b6e..dac019a13aff 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -125,6 +125,7 @@ _P = ParamSpec('_P') _R = TypeVar('_R') + _T = TypeVar('_T') # We may not need the following imports here: @@ -1600,8 +1601,56 @@ def subplots( return fig, axs +@overload +def subplot_mosaic( + mosaic: str, + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: str = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[str | tuple[str], dict[str, Any]] | None = ..., + **fig_kw: Any +) -> tuple[Figure, dict[str, matplotlib.axes.Axes]]: ... + + +@overload +def subplot_mosaic( + mosaic: list[HashableList[_T]], + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: _T = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[_T | tuple[_T], dict[str, Any]] | None = ..., + **fig_kw: Any +) -> tuple[Figure, dict[_T, matplotlib.axes.Axes]]: ... + + +@overload +def subplot_mosaic( + mosaic: list[HashableList[Hashable]], + *, + sharex: bool = ..., + sharey: bool = ..., + width_ratios: ArrayLike | None = ..., + height_ratios: ArrayLike | None = ..., + empty_sentinel: Any = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + per_subplot_kw: dict[Hashable | tuple[Hashable], dict[str, Any]] | None = ..., + **fig_kw: Any +) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: ... + + def subplot_mosaic( - mosaic: str | HashableList, + mosaic: str | list[HashableList[_T]] | list[HashableList[Hashable]], *, sharex: bool = False, sharey: bool = False, @@ -1610,9 +1659,13 @@ def subplot_mosaic( empty_sentinel: Any = '.', subplot_kw: dict[str, Any] | None = None, gridspec_kw: dict[str, Any] | None = None, - per_subplot_kw: dict[Hashable, dict[str, Any]] | None = None, - **fig_kw -) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: + per_subplot_kw: dict[str | tuple[str], dict[str, Any]] | + dict[_T | tuple[_T], dict[str, Any]] | + dict[Hashable | tuple[Hashable], dict[str, Any]] | None = None, + **fig_kw: Any +) -> tuple[Figure, dict[str, matplotlib.axes.Axes]] | \ + tuple[Figure, dict[_T, matplotlib.axes.Axes]] | \ + tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: """ Build a layout of Axes based on ASCII art or nested lists. @@ -1714,12 +1767,13 @@ def subplot_mosaic( """ fig = figure(**fig_kw) - ax_dict = fig.subplot_mosaic( - mosaic, sharex=sharex, sharey=sharey, + ax_dict = fig.subplot_mosaic( # type: ignore[misc] + mosaic, # type: ignore[arg-type] + sharex=sharex, sharey=sharey, height_ratios=height_ratios, width_ratios=width_ratios, subplot_kw=subplot_kw, gridspec_kw=gridspec_kw, empty_sentinel=empty_sentinel, - per_subplot_kw=per_subplot_kw, + per_subplot_kw=per_subplot_kw, # type: ignore[arg-type] ) return fig, ax_dict diff --git a/lib/matplotlib/typing.py b/lib/matplotlib/typing.py index e6b9ada7d8a2..02059be94ba2 100644 --- a/lib/matplotlib/typing.py +++ b/lib/matplotlib/typing.py @@ -11,7 +11,7 @@ """ from collections.abc import Hashable, Sequence import pathlib -from typing import Any, Literal, Union +from typing import Any, Literal, TypeVar, Union from . import path from ._enums import JoinStyle, CapStyle @@ -55,5 +55,6 @@ Sequence[Union[str, pathlib.Path, dict[str, Any]]], ] -HashableList = list[Union[Hashable, "HashableList"]] +_HT = TypeVar("_HT", bound=Hashable) +HashableList = list[Union[_HT, "HashableList[_HT]"]] """A nested list of Hashable values."""