Skip to content

Commit

Permalink
TYP: Add common type overloads of subplot_mosaic
Browse files Browse the repository at this point in the history
I'll assert, without proof, that passing a single string, a mosaic list
of strings, or a mosaic list all of the same type is more common than
passing arbitrary unrelated hashables. Thus it is somewhat convenient if
the return type stipulates that the resulting dictionary is also keyed
with strings or the common type.

This also fixes the type of the `per_subplot_kw` argument, which also
allows dictionary keys of tuples of the entries.
  • Loading branch information
QuLogic committed Aug 19, 2023
1 parent d85bc15 commit 4530267
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 27 deletions.
6 changes: 6 additions & 0 deletions doc/missing-references.json
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@
"HashableList": [
"lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.subplot_mosaic:1"
],
"HashableList[_HT]": [
"doc/docstring of builtins.list:17"
],
"LineStyleType": [
"lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.eventplot:1",
"lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.hlines:1",
Expand Down Expand Up @@ -701,6 +704,9 @@
"matplotlib.animation.TimedAnimation.to_jshtml": [
"doc/api/_as_gen/matplotlib.animation.TimedAnimation.rst:28:<autosummary>:1"
],
"matplotlib.typing._HT": [
"doc/docstring of builtins.list:17"
],
"mpl_toolkits.axislines.Axes": [
"lib/mpl_toolkits/axisartist/axis_artist.py:docstring of mpl_toolkits.axisartist.axis_artist:7"
],
Expand Down
8 changes: 2 additions & 6 deletions lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,15 +1798,11 @@ def _norm_per_subplot_kw(per_subplot_kw):
if isinstance(k, tuple):
for sub_key in k:
if sub_key in expanded:
raise ValueError(
f'The key {sub_key!r} appears multiple times.'
)
raise ValueError(f'The key {sub_key!r} appears multiple times.')
expanded[sub_key] = v
else:
if k in expanded:
raise ValueError(
f'The key {k!r} appears multiple times.'
)
raise ValueError(f'The key {k!r} appears multiple times.')
expanded[k] = v
return expanded

Expand Down
52 changes: 40 additions & 12 deletions lib/matplotlib/figure.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from collections.abc import Callable, Hashable, Iterable
import os
from typing import Any, IO, Literal, TypeVar, overload

import numpy as np
from numpy.typing import ArrayLike

from matplotlib.artist import Artist
from matplotlib.axes import Axes, SubplotBase
Expand All @@ -19,14 +24,10 @@ from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle, Patch
from matplotlib.text import Text
from matplotlib.transforms import Affine2D, Bbox, BboxBase, Transform

import numpy as np
from numpy.typing import ArrayLike

from collections.abc import Callable, Iterable
from typing import Any, IO, Literal, overload
from .typing import ColorType, HashableList

_T = TypeVar("_T")

class SubplotParams:
def __init__(
self,
Expand Down Expand Up @@ -226,21 +227,48 @@ class FigureBase(Artist):
*,
bbox_extra_artists: Iterable[Artist] | None = ...,
) -> Bbox: ...

# Any in list of list is recursive list[list[Hashable | list[Hashable | ...]]] but that can't really be type checked
@overload
def subplot_mosaic(
self,
mosaic: str | HashableList,
mosaic: str,
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: str = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[str, Axes]: ...
@overload
def subplot_mosaic(
self,
mosaic: list[HashableList[_T]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: _T = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[_T | tuple[_T, ...], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[_T, Axes]: ...
@overload
def subplot_mosaic(
self,
mosaic: list[HashableList[Hashable]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[Any, dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...
) -> dict[Any, Axes]: ...
per_subplot_kw: dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[Hashable, Axes]: ...

class SubFigure(FigureBase):
figure: Figure
Expand Down
68 changes: 61 additions & 7 deletions lib/matplotlib/pyplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@

_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')


# We may not need the following imports here:
Expand Down Expand Up @@ -1600,8 +1601,56 @@ def subplots(
return fig, axs


@overload
def subplot_mosaic(
mosaic: str,
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: str = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[str, matplotlib.axes.Axes]]: ...


@overload
def subplot_mosaic(
mosaic: list[HashableList[_T]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: _T = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[_T | tuple[_T, ...], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[_T, matplotlib.axes.Axes]]: ...


@overload
def subplot_mosaic(
mosaic: list[HashableList[Hashable]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: ...


def subplot_mosaic(
mosaic: str | HashableList,
mosaic: str | list[HashableList[_T]] | list[HashableList[Hashable]],
*,
sharex: bool = False,
sharey: bool = False,
Expand All @@ -1610,9 +1659,13 @@ def subplot_mosaic(
empty_sentinel: Any = '.',
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,
per_subplot_kw: dict[Hashable, dict[str, Any]] | None = None,
**fig_kw
) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]:
per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] |
dict[_T | tuple[_T, ...], dict[str, Any]] |
dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = None,
**fig_kw: Any
) -> tuple[Figure, dict[str, matplotlib.axes.Axes]] | \
tuple[Figure, dict[_T, matplotlib.axes.Axes]] | \
tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]:
"""
Build a layout of Axes based on ASCII art or nested lists.
Expand Down Expand Up @@ -1714,12 +1767,13 @@ def subplot_mosaic(
"""
fig = figure(**fig_kw)
ax_dict = fig.subplot_mosaic(
mosaic, sharex=sharex, sharey=sharey,
ax_dict = fig.subplot_mosaic( # type: ignore[misc]
mosaic, # type: ignore[arg-type]
sharex=sharex, sharey=sharey,
height_ratios=height_ratios, width_ratios=width_ratios,
subplot_kw=subplot_kw, gridspec_kw=gridspec_kw,
empty_sentinel=empty_sentinel,
per_subplot_kw=per_subplot_kw,
per_subplot_kw=per_subplot_kw, # type: ignore[arg-type]
)
return fig, ax_dict

Expand Down
5 changes: 3 additions & 2 deletions lib/matplotlib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"""
from collections.abc import Hashable, Sequence
import pathlib
from typing import Any, Literal, Union
from typing import Any, Literal, TypeVar, Union

from . import path
from ._enums import JoinStyle, CapStyle
Expand Down Expand Up @@ -55,5 +55,6 @@
Sequence[Union[str, pathlib.Path, dict[str, Any]]],
]

HashableList = list[Union[Hashable, "HashableList"]]
_HT = TypeVar("_HT", bound=Hashable)
HashableList = list[Union[_HT, "HashableList[_HT]"]]
"""A nested list of Hashable values."""

0 comments on commit 4530267

Please sign in to comment.