Skip to content

Commit

Permalink
Support typing via PyRight (#138)
Browse files Browse the repository at this point in the history
* fix mixed named and anonymous arguments

* configure pyright

* reduce old-style types

* run pyright in CI

* use separate typetest

* document typing explicitly
  • Loading branch information
maxfischer2781 committed Mar 23, 2024
1 parent c1c8e06 commit 04e0925
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 25 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/verification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[test]
pip install .[test,typetest]
- name: Lint with flake8
run: |
flake8 asyncstdlib unittests
Expand All @@ -28,3 +28,5 @@ jobs:
- name: Verify with MyPy
run: |
mypy --pretty
- name: Verify with PyRight
uses: jakebailey/pyright-action@v2
9 changes: 4 additions & 5 deletions asyncstdlib/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
AsyncGenerator,
Iterable,
AsyncIterable,
Union,
Generic,
Optional,
Awaitable,
Expand Down Expand Up @@ -100,7 +99,7 @@ def borrow(iterator: AsyncIterator[T]) -> AsyncGenerator[T, None]:


def awaitify(
function: Union[Callable[..., T], Callable[..., Awaitable[T]]]
function: "Callable[..., Awaitable[T]] | Callable[..., T]",
) -> Callable[..., Awaitable[T]]:
"""Ensure that ``function`` can be used in ``await`` expressions"""
if iscoroutinefunction(function):
Expand All @@ -114,16 +113,16 @@ class Awaitify(Generic[T]):

__slots__ = "__wrapped__", "_async_call"

def __init__(self, function: Union[Callable[..., T], Callable[..., Awaitable[T]]]):
def __init__(self, function: "Callable[..., Awaitable[T]] | Callable[..., T]"):
self.__wrapped__ = function
self._async_call: Optional[Callable[..., Awaitable[T]]] = None
self._async_call: "Callable[..., Awaitable[T]] | None" = None

def __call__(self, *args: Any, **kwargs: Any) -> Awaitable[T]:
if (async_call := self._async_call) is None:
value = self.__wrapped__(*args, **kwargs)
if isinstance(value, Awaitable):
self._async_call = self.__wrapped__ # type: ignore
return value
return value # pyright: ignore
else:
self._async_call = force_async(self.__wrapped__) # type: ignore
return await_value(value)
Expand Down
12 changes: 12 additions & 0 deletions asyncstdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -77,37 +77,43 @@ def zip(
def map(
function: Callable[[T1], Awaitable[R]],
__it1: AnyIterable[T1],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1], R],
__it1: AnyIterable[T1],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2, T3], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
function: Callable[[T1, T2, T3], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -116,6 +122,7 @@ def map(
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -124,6 +131,7 @@ def map(
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -133,6 +141,7 @@ def map(
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -142,6 +151,7 @@ def map(
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
/,
) -> AsyncIterator[R]: ...
@overload
def map(
Expand All @@ -151,6 +161,7 @@ def map(
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
/,
*iterable: AnyIterable[Any],
) -> AsyncIterator[R]: ...
@overload
Expand All @@ -161,6 +172,7 @@ def map(
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
/,
*iterable: AnyIterable[Any],
) -> AsyncIterator[R]: ...
@overload
Expand Down
18 changes: 8 additions & 10 deletions asyncstdlib/heapq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from typing import (
Generic,
AsyncIterator,
Tuple,
List,
Optional,
Callable,
Any,
Expand Down Expand Up @@ -53,21 +51,21 @@ def __init__(
@classmethod
def from_iters(
cls,
iterables: Tuple[AnyIterable[T], ...],
iterables: "tuple[AnyIterable[T], ...]",
reverse: bool,
key: Callable[[T], Awaitable[LT]],
) -> "AsyncIterator[_KeyIter[LT]]": ...

@overload
@classmethod
def from_iters(
cls, iterables: Tuple[AnyIterable[LT], ...], reverse: bool, key: None
cls, iterables: "tuple[AnyIterable[LT], ...]", reverse: bool, key: None
) -> "AsyncIterator[_KeyIter[LT]]": ...

@classmethod
async def from_iters(
cls,
iterables: Tuple[AnyIterable[Any], ...],
iterables: "tuple[AnyIterable[Any], ...]",
reverse: bool,
key: Optional[Callable[[Any], Any]],
) -> "AsyncIterator[_KeyIter[Any]]":
Expand Down Expand Up @@ -124,10 +122,10 @@ async def merge(
"""
a_key = awaitify(key) if key is not None else None
# sortable iterators with (reverse) position to ensure stable sort for ties
iter_heap: List[Tuple[_KeyIter[Any], int]] = [
iter_heap: "list[tuple[_KeyIter[Any], int]]" = [
(itr, idx if not reverse else -idx)
async for idx, itr in a_enumerate(
_KeyIter.from_iters(iterables, reverse, a_key)
_KeyIter[Any].from_iters(iterables, reverse, a_key)
)
]
try:
Expand Down Expand Up @@ -175,7 +173,7 @@ async def _largest(
n: int,
key: Callable[[T], Awaitable[LT]],
reverse: bool,
) -> List[T]:
) -> "list[T]":
ordered: Callable[[LT], LT] = ReverseLT if reverse else lambda x: x # type: ignore
async with ScopedIter(iterable) as iterator:
# assign an ordering to items to solve ties
Expand Down Expand Up @@ -207,7 +205,7 @@ async def nlargest(
iterable: AnyIterable[T],
n: int,
key: Optional[Callable[[Any], Awaitable[Any]]] = None,
) -> List[T]:
) -> "list[T]":
"""
Return a sorted list of the ``n`` largest elements from the (async) iterable
Expand All @@ -229,7 +227,7 @@ async def nsmallest(
iterable: AnyIterable[T],
n: int,
key: Optional[Callable[[Any], Awaitable[Any]]] = None,
) -> List[T]:
) -> "list[T]":
"""
Return a sorted list of the ``n`` smallest elements from the (async) iterable
Expand Down
3 changes: 1 addition & 2 deletions asyncstdlib/itertools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ from typing import (
Iterable,
Callable,
TypeVar,
Self,
overload,
)
from typing_extensions import Literal
from typing_extensions import Literal, Self

from ._typing import AnyIterable, ADD, T, T1, T2, T3, T4, T5

Expand Down
23 changes: 16 additions & 7 deletions docs/source/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ where you can report bugs, request improvements or propose changes.

- For bug reports and feature requests simply `open a new issue`_
and fill in the appropriate template.
- Even for content submissions it is highly recommended to make sure an issue
exists - this allows you to get early feedback and document the development.
- Even for content submissions make sure `an issue exists`_ - this
allows you to get early feedback and document the development.
You can use whatever tooling you like to create the content,
but the next sections give a rough outline on how to proceed.

.. _asyncstdlib GitHub repository: https://github.com/maxfischer2781/asyncstdlib
.. _open a new issue: https://github.com/maxfischer2781/asyncstdlib/issues/new/choose
.. _an issue exists: https://github.com/maxfischer2781/asyncstdlib/issues

Submitting Content
==================
Expand All @@ -32,26 +33,34 @@ the extras ``test`` and ``doc``, respectively.
.. note::

Ideally you develop with the repository checked out locally and a separate `Python venv`_.
If you have the venv active and the current working directory is the repository root,
simply run `python -m pip install -e '.[test,doc]'` to install all dependencies.
If you have the venv active and are at the repository root,
run ``python -m pip install -e '.[test,typetest,doc]'`` to install all dependencies.

.. _`GitHub Fork and Pull Request`: https://guides.github.com/activities/forking/
.. _`Python venv`: https://docs.python.org/3/library/venv.html

Testing Code
------------

Code is verified locally using the tools `flake8`, `black`, `pytest` and `mypy`.
If you do not have your own preferences we recommend the following order:
Code can be verified locally using the tools `flake8`, `black`, `pytest`, `pyright` and `mypy`.
You should always verify that the basic checks pass:

.. code:: bash
python -m black asyncstdlib unittests
python -m flake8 asyncstdlib unittests
python -m pytest
This runs tests from simplest to most advanced and should allow a quick development cycle.

In many cases you can rely on your IDE for type checking.
For major typing related changes, run the full type checking:

.. code:: bash
python -m mypy --pretty
python -m pyright
This runs tests from simplest to most advanced and should allow you quick development.
Note that some additional checks are run on GitHub to check test coverage and code health.

Building Docs
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ test = [
"pytest-cov",
"flake8-2020",
"mypy; implementation_name=='cpython'",
]
typetest = [
"mypy; implementation_name=='cpython'",
"pyright",
"typing-extensions",
]
doc = ["sphinx", "sphinxcontrib-trio"]
Expand Down Expand Up @@ -64,6 +68,13 @@ warn_return_any = true
no_implicit_reexport = true
strict_equality = true

[tool.pyright]
include = ["asyncstdlib", "typetests"]
typeCheckingMode = "strict"
pythonPlatform = "All"
pythonVersion = "3.8"
verboseOutput = true

[tool.pytest.ini_options]
testpaths = [
"unittests",
Expand Down

0 comments on commit 04e0925

Please sign in to comment.