Skip to content

Commit

Permalink
Backport PR matplotlib#26491: TYP: Add common-type overloads of subpl…
Browse files Browse the repository at this point in the history
…ot_mosaic
  • Loading branch information
greglucas authored and QuLogic committed Sep 7, 2023
1 parent 70f4990 commit 0421fde
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 27 deletions.
6 changes: 6 additions & 0 deletions doc/missing-references.json
Expand Up @@ -152,6 +152,9 @@
"HashableList": [
"lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.subplot_mosaic:1"
],
"HashableList[_HT]": [
"doc/docstring of builtins.list:17"
],
"LineStyleType": [
"lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.eventplot:1",
"lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.hlines:1",
Expand Down Expand Up @@ -701,6 +704,9 @@
"matplotlib.animation.TimedAnimation.to_jshtml": [
"doc/api/_as_gen/matplotlib.animation.TimedAnimation.rst:28:<autosummary>:1"
],
"matplotlib.typing._HT": [
"doc/docstring of builtins.list:17"
],
"mpl_toolkits.axislines.Axes": [
"lib/mpl_toolkits/axisartist/axis_artist.py:docstring of mpl_toolkits.axisartist.axis_artist:7"
],
Expand Down
8 changes: 2 additions & 6 deletions lib/matplotlib/figure.py
Expand Up @@ -1812,15 +1812,11 @@ def _norm_per_subplot_kw(per_subplot_kw):
if isinstance(k, tuple):
for sub_key in k:
if sub_key in expanded:
raise ValueError(
f'The key {sub_key!r} appears multiple times.'
)
raise ValueError(f'The key {sub_key!r} appears multiple times.')
expanded[sub_key] = v
else:
if k in expanded:
raise ValueError(
f'The key {k!r} appears multiple times.'
)
raise ValueError(f'The key {k!r} appears multiple times.')
expanded[k] = v
return expanded

Expand Down
52 changes: 40 additions & 12 deletions lib/matplotlib/figure.pyi
@@ -1,4 +1,9 @@
from collections.abc import Callable, Hashable, Iterable
import os
from typing import Any, IO, Literal, TypeVar, overload

import numpy as np
from numpy.typing import ArrayLike

from matplotlib.artist import Artist
from matplotlib.axes import Axes, SubplotBase
Expand All @@ -19,14 +24,10 @@ from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle, Patch
from matplotlib.text import Text
from matplotlib.transforms import Affine2D, Bbox, BboxBase, Transform

import numpy as np
from numpy.typing import ArrayLike

from collections.abc import Callable, Iterable
from typing import Any, IO, Literal, overload
from .typing import ColorType, HashableList

_T = TypeVar("_T")

class SubplotParams:
def __init__(
self,
Expand Down Expand Up @@ -226,21 +227,48 @@ class FigureBase(Artist):
*,
bbox_extra_artists: Iterable[Artist] | None = ...,
) -> Bbox: ...

# Any in list of list is recursive list[list[Hashable | list[Hashable | ...]]] but that can't really be type checked
@overload
def subplot_mosaic(
self,
mosaic: str | HashableList,
mosaic: str,
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: str = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[str, Axes]: ...
@overload
def subplot_mosaic(
self,
mosaic: list[HashableList[_T]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: _T = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[_T | tuple[_T, ...], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[_T, Axes]: ...
@overload
def subplot_mosaic(
self,
mosaic: list[HashableList[Hashable]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
subplot_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[Any, dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...
) -> dict[Any, Axes]: ...
per_subplot_kw: dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
) -> dict[Hashable, Axes]: ...

class SubFigure(FigureBase):
figure: Figure
Expand Down
68 changes: 61 additions & 7 deletions lib/matplotlib/pyplot.py
Expand Up @@ -125,6 +125,7 @@

_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')


# We may not need the following imports here:
Expand Down Expand Up @@ -1602,8 +1603,56 @@ def subplots(
return fig, axs


@overload
def subplot_mosaic(
mosaic: str,
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: str = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[str, matplotlib.axes.Axes]]: ...


@overload
def subplot_mosaic(
mosaic: list[HashableList[_T]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: _T = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[_T | tuple[_T, ...], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[_T, matplotlib.axes.Axes]]: ...


@overload
def subplot_mosaic(
mosaic: list[HashableList[Hashable]],
*,
sharex: bool = ...,
sharey: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
empty_sentinel: Any = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...,
per_subplot_kw: dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = ...,
**fig_kw: Any
) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]: ...


def subplot_mosaic(
mosaic: str | HashableList,
mosaic: str | list[HashableList[_T]] | list[HashableList[Hashable]],
*,
sharex: bool = False,
sharey: bool = False,
Expand All @@ -1612,9 +1661,13 @@ def subplot_mosaic(
empty_sentinel: Any = '.',
subplot_kw: dict[str, Any] | None = None,
gridspec_kw: dict[str, Any] | None = None,
per_subplot_kw: dict[Hashable, dict[str, Any]] | None = None,
**fig_kw
) -> tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]:
per_subplot_kw: dict[str | tuple[str, ...], dict[str, Any]] |
dict[_T | tuple[_T, ...], dict[str, Any]] |
dict[Hashable | tuple[Hashable, ...], dict[str, Any]] | None = None,
**fig_kw: Any
) -> tuple[Figure, dict[str, matplotlib.axes.Axes]] | \
tuple[Figure, dict[_T, matplotlib.axes.Axes]] | \
tuple[Figure, dict[Hashable, matplotlib.axes.Axes]]:
"""
Build a layout of Axes based on ASCII art or nested lists.
Expand Down Expand Up @@ -1716,12 +1769,13 @@ def subplot_mosaic(
"""
fig = figure(**fig_kw)
ax_dict = fig.subplot_mosaic(
mosaic, sharex=sharex, sharey=sharey,
ax_dict = fig.subplot_mosaic( # type: ignore[misc]
mosaic, # type: ignore[arg-type]
sharex=sharex, sharey=sharey,
height_ratios=height_ratios, width_ratios=width_ratios,
subplot_kw=subplot_kw, gridspec_kw=gridspec_kw,
empty_sentinel=empty_sentinel,
per_subplot_kw=per_subplot_kw,
per_subplot_kw=per_subplot_kw, # type: ignore[arg-type]
)
return fig, ax_dict

Expand Down
5 changes: 3 additions & 2 deletions lib/matplotlib/typing.py
Expand Up @@ -11,7 +11,7 @@
"""
from collections.abc import Hashable, Sequence
import pathlib
from typing import Any, Literal, Union
from typing import Any, Literal, TypeVar, Union

from . import path
from ._enums import JoinStyle, CapStyle
Expand Down Expand Up @@ -55,5 +55,6 @@
Sequence[Union[str, pathlib.Path, dict[str, Any]]],
]

HashableList = list[Union[Hashable, "HashableList"]]
_HT = TypeVar("_HT", bound=Hashable)
HashableList = list[Union[_HT, "HashableList[_HT]"]]
"""A nested list of Hashable values."""

0 comments on commit 0421fde

Please sign in to comment.