Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved typing for builtins #127

Merged
merged 4 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
280 changes: 4 additions & 276 deletions asyncstdlib/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
Optional,
Dict,
Any,
overload,
)
import builtins as _sync_builtins

from ._typing import T, T1, T2, T3, T4, T5, R, HK, LT, ADD, AnyIterable
from ._typing import T, R, HK, LT, AnyIterable
from ._core import (
aiter,
ScopedIter,
Expand All @@ -27,14 +26,6 @@
__ANEXT_DEFAULT = Sentinel("<no default>")


@overload
async def anext(iterator: AsyncIterator[T]) -> T: ...


@overload
async def anext(iterator: AsyncIterator[T], default: T) -> T: ...


async def anext(
iterator: AsyncIterator[T], default: Union[Sentinel, T] = __ANEXT_DEFAULT
) -> T:
Expand Down Expand Up @@ -63,16 +54,6 @@ async def anext(
__ITER_DEFAULT = Sentinel("<no default>")


@overload
def iter(subject: AnyIterable[T]) -> AsyncIterator[T]:
pass


@overload
def iter(subject: Callable[[], Awaitable[T]], sentinel: T) -> AsyncIterator[T]:
pass


def iter(
subject: Union[AnyIterable[T], Callable[[], Awaitable[T]]],
sentinel: Union[Sentinel, T] = __ITER_DEFAULT,
Expand Down Expand Up @@ -116,7 +97,7 @@ async def acallable_iterator(
value = await subject()


async def all(iterable: AnyIterable[T]) -> bool:
async def all(iterable: AnyIterable[Any]) -> bool:
"""
Return :py:data:`True` if none of the elements of the (async) ``iterable`` are false
"""
Expand All @@ -127,7 +108,7 @@ async def all(iterable: AnyIterable[T]) -> bool:
return True


async def any(iterable: AnyIterable[T]) -> bool:
async def any(iterable: AnyIterable[Any]) -> bool:
"""
Return :py:data:`False` if none of the elements of the (async) ``iterable`` are true
"""
Expand All @@ -138,68 +119,6 @@ async def any(iterable: AnyIterable[T]) -> bool:
return False


@overload
def zip(
__it1: AnyIterable[T1],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2, T3]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2, T3, T4]]: ...


@overload
def zip(
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
*,
strict: bool = ...,
) -> AsyncIterator[Tuple[T1, T2, T3, T4, T5]]: ...


@overload
def zip(
__it1: AnyIterable[Any],
__it2: AnyIterable[Any],
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
*iterables: AnyIterable[Any],
strict: bool = ...,
) -> AsyncIterator[Tuple[Any, ...]]: ...


async def zip(
*iterables: AnyIterable[Any], strict: bool = False
) -> AsyncIterator[Tuple[Any, ...]]:
Expand Down Expand Up @@ -285,118 +204,6 @@ async def _zip_inner_strict(
return


@overload
def map(
function: Callable[[T1], Awaitable[R]],
__it1: AnyIterable[T1],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1], R],
__it1: AnyIterable[T1],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2], R], __it1: AnyIterable[T1], __it2: AnyIterable[T2]
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4, T5], Awaitable[R]],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[[T1, T2, T3, T4, T5], R],
__it1: AnyIterable[T1],
__it2: AnyIterable[T2],
__it3: AnyIterable[T3],
__it4: AnyIterable[T4],
__it5: AnyIterable[T5],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[..., Awaitable[R]],
__it1: AnyIterable[Any],
__it2: AnyIterable[Any],
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
*iterable: AnyIterable[Any],
) -> AsyncIterator[R]: ...


@overload
def map(
function: Callable[..., R],
__it1: AnyIterable[Any],
__it2: AnyIterable[Any],
__it3: AnyIterable[Any],
__it4: AnyIterable[Any],
__it5: AnyIterable[Any],
*iterable: AnyIterable[Any],
) -> AsyncIterator[R]: ...


async def map(
function: Union[Callable[..., R], Callable[..., Awaitable[R]]],
*iterable: AnyIterable[Any],
Expand Down Expand Up @@ -428,26 +235,6 @@ async def map(
__MIN_MAX_DEFAULT = Sentinel("<no default>")


@overload
async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...


@overload
async def max(
iterable: AnyIterable[LT], *, key: None = ..., default: T
) -> Union[LT, T]: ...


@overload
async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...


@overload
async def max(
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
) -> Union[T1, T2]: ...


async def max(
iterable: AnyIterable[Any],
*,
Expand All @@ -474,26 +261,6 @@ async def max(
return await _min_max(iterable, key, True, default)


@overload
async def min(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...


@overload
async def min(
iterable: AnyIterable[LT], *, key: None = ..., default: T
) -> Union[LT, T]: ...


@overload
async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...


@overload
async def min(
iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
) -> Union[T1, T2]: ...


async def min(
iterable: AnyIterable[Any],
*,
Expand Down Expand Up @@ -594,18 +361,6 @@ async def enumerate(
count += 1


@overload
async def sum(iterable: AnyIterable[int]) -> int: ...


@overload
async def sum(iterable: AnyIterable[float]) -> float: ...


@overload
async def sum(iterable: AnyIterable[ADD], start: ADD) -> ADD: ...


async def sum(iterable: AnyIterable[Any], start: Any = 0) -> Any:
"""
Sum of ``start`` and all elements in the (async) iterable
Expand All @@ -632,21 +387,6 @@ async def tuple(iterable: Union[Iterable[T], AsyncIterable[T]] = ()) -> Tuple[T,
return (*[element async for element in aiter(iterable)],)


@overload
async def dict( # noqa: F811
iterable: Union[Iterable[Tuple[HK, T]], AsyncIterable[Tuple[HK, T]]] = (),
) -> Dict[HK, T]:
pass


@overload # noqa: F811
async def dict( # noqa: F811
iterable: Union[Iterable[Tuple[HK, T]], AsyncIterable[Tuple[HK, T]]] = (),
**kwargs: T,
) -> Dict[Union[HK, str], T]:
pass


async def dict( # noqa: F811
iterable: Union[Iterable[Tuple[HK, T]], AsyncIterable[Tuple[HK, T]]] = (),
**kwargs: T,
Expand Down Expand Up @@ -674,18 +414,6 @@ async def set(iterable: Union[Iterable[T], AsyncIterable[T]] = ()) -> Set[T]:
return {element async for element in aiter(iterable)}


@overload
async def sorted(
iterable: AnyIterable[LT], *, key: None = ..., reverse: bool = ...
) -> List[LT]: ...


@overload
async def sorted(
iterable: AnyIterable[T], *, key: Callable[[T], LT], reverse: bool = ...
) -> List[T]: ...


async def sorted(
iterable: AnyIterable[T],
*,
Expand Down Expand Up @@ -716,7 +444,7 @@ async def sorted(
try:
return _sync_builtins.sorted(iterable, reverse=reverse) # type: ignore
except TypeError:
items = [item async for item in aiter(iterable)]
items: "_sync_builtins.list[Any]" = [item async for item in aiter(iterable)]
items.sort(reverse=reverse)
return items
else:
Expand Down