Skip to content

Commit

Permalink
TYP: Add common type overloads of subplot_mosaic
Browse files Browse the repository at this point in the history
I'll assert, without proof, that passing a single string, a mosaic list
of strings, or a mosaic list all of the same type is more common than
passing arbitrary unrelated hashables. Thus it is somewhat convenient if
the return type stipulates that the resulting dictionary is also keyed
with strings or the common type.

This also fixes the type of the `per_subplot_kw` argument, which also
allows dictionary keys of tuples of the entries.
  • Loading branch information
QuLogic committed Aug 10, 2023
1 parent 3ab6e1f commit 6fcee04
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 27 deletions.
8 changes: 2 additions & 6 deletions lib/matplotlib/figure.py
Expand Up @@ -1798,15 +1798,11 @@ def _norm_per_subplot_kw(per_subplot_kw):
if isinstance(k, tuple):
for sub_key in k:
if sub_key in expanded:
raise ValueError(
f'The key {sub_key!r} appears multiple times.'
)
raise ValueError(f'The key {sub_key!r} appears multiple times.')
expanded[sub_key] = v
else:
if k in expanded:
raise ValueError(
f'The key {k!r} appears multiple times.'
)
raise ValueError(f'The key {k!r} appears multiple times.')
expanded[k] = v
return expanded

Expand Down
52 changes: 40 additions & 12 deletions lib/matplotlib/figure.pyi
@@ -1,4 +1,9 @@
from collections.abc import Callable, Hashable, Iterable
import os
from typing import Any, IO, Literal, TypeVar, overload

import numpy as np
from numpy.typing import ArrayLike

from matplotlib.artist import Artist
from matplotlib.axes import Axes, SubplotBase
Expand All @@ -19,14 +24,10 @@ from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle, Patch
from matplotlib.text import Text
from matplotlib.transforms import Affine2D, Bbox, BboxBase, Transform

import numpy as np
from numpy.typing import ArrayLike

from collections.abc import Callable, Iterable
from typing import Any, IO, Literal, overload
from .typing import ColorType, HashableList

_T = TypeVar("_T")

class SubplotParams:
def __init__(
self,
Expand Down Expand Up @@ -226,21 +227,48 @@ class FigureBase(Artist):
*,
bbox_extra_artists: Iterable[Artist] | None = ...,
) -> Bbox: ...

# Any in list of list is recursive list[list[Hashable | list[Hashable | ...]]] but that can't really be type checked
@overload
def subplot_mosaic(
self,
mosaic: str | HashableList,
mosaic: str,
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: str = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[str | tuple[str], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[str, Axes]: ...
@overload
def subplot_mosaic(
self,
mosaic: list[HashableList[_T]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: _T = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[_T | tuple[_T], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[_T, Axes]: ...
@overload
def subplot_mosaic(
self,
mosaic: list[HashableList[Hashable]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[Any, dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...
) -> dict[Any, Axes]: ...
per_subplot_kw: dict[Hashable | tuple[Hashable], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[Hashable, Axes]: ...

class SubFigure(FigureBase):
figure: Figure
Expand Down
68 changes: 61 additions & 7 deletions lib/matplotlib/pyplot.py
Expand Up @@ -125,6 +125,7 @@

_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')


# We may not need the following imports here:
Expand Down Expand Up @@ -1600,8 +1601,56 @@ def subplots(
return fig, axs


@overload
def subplot_mosaic(
mosaic: str,
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: str = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[str | tuple[str], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[str, matplotlib.axes.Axes]]: ...


@overload
def subplot_mosaic(
mosaic: list[HashableList[_T]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: _T = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[_T | tuple[_T], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[_T, matplotlib.axes.Axes]]: ...


@overload
def subplot_mosaic(
mosaic: list[HashableList[Hashable]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[Hashable | tuple[Hashable], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: ...


def subplot_mosaic(
mosaic: str | HashableList,
mosaic: str | list[HashableList[_T]] | list[HashableList[Hashable]],
*,
sharex: bool = False,
sharey: bool = False,
Expand All @@ -1610,9 +1659,13 @@ def subplot_mosaic(
empty_sentinel: Any = '.',
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,
per_subplot_kw: dict[Hashable, dict[str, Any]] | None = None,
**fig_kw
) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]:
per_subplot_kw: dict[str | tuple[str], dict[str, Any]] |
dict[_T | tuple[_T], dict[str, Any]] |
dict[Hashable | tuple[Hashable], dict[str, Any]] | None = None,
**fig_kw: Any
) -> tuple[Figure, dict[str, matplotlib.axes.Axes]] | \
tuple[Figure, dict[_T, matplotlib.axes.Axes]] | \
tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]:
"""
Build a layout of Axes based on ASCII art or nested lists.
Expand Down Expand Up @@ -1714,12 +1767,13 @@ def subplot_mosaic(
"""
fig = figure(**fig_kw)
ax_dict = fig.subplot_mosaic(
mosaic, sharex=sharex, sharey=sharey,
ax_dict = fig.subplot_mosaic( # type: ignore[misc]
mosaic, # type: ignore[arg-type]
sharex=sharex, sharey=sharey,
height_ratios=height_ratios, width_ratios=width_ratios,
subplot_kw=subplot_kw, gridspec_kw=gridspec_kw,
empty_sentinel=empty_sentinel,
per_subplot_kw=per_subplot_kw,
per_subplot_kw=per_subplot_kw, # type: ignore[arg-type]
)
return fig, ax_dict

Expand Down
5 changes: 3 additions & 2 deletions lib/matplotlib/typing.py
Expand Up @@ -11,7 +11,7 @@
"""
from collections.abc import Hashable, Sequence
import pathlib
from typing import Any, Literal, Union
from typing import Any, Literal, TypeVar, Union

from . import path
from ._enums import JoinStyle, CapStyle
Expand Down Expand Up @@ -55,5 +55,6 @@
Sequence[Union[str, pathlib.Path, dict[str, Any]]],
]

HashableList = list[Union[Hashable, "HashableList"]]
_HT = TypeVar("_HT", bound=Hashable)
HashableList = list[Union[_HT, "HashableList[_HT]"]]
"""A nested list of Hashable values."""

0 comments on commit 6fcee04

Please sign in to comment.