Skip to content

Commit

Permalink
[aio types] Fix some grpc.aio python types (#32475)
Browse files Browse the repository at this point in the history
With these, it is actually possible to have typed client stubs where the
return type is correctly inferred.

It's only for the non-streaming calls, because there is
`RequestIterableType` for the streaming ones (but it's just Any with
extra steps and would require much more work).

---------

Co-authored-by: Xuan Wang <xuanwn@google.com>
  • Loading branch information
Tasssadar and XuanWang-Amos committed May 15, 2023
1 parent b8a6b42 commit 5bd38df
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/python/grpcio/grpc/aio/_base_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from abc import ABCMeta
from abc import abstractmethod
from typing import AsyncIterator, Awaitable, Generic, Optional, Union
from typing import Any, AsyncIterator, Generator, Generic, Optional, Union

import grpc

Expand Down Expand Up @@ -141,7 +141,7 @@ class UnaryUnaryCall(Generic[RequestType, ResponseType],
"""The abstract base class of an unary-unary RPC on the client-side."""

@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready.
Returns:
Expand Down Expand Up @@ -197,7 +197,7 @@ async def done_writing(self) -> None:
"""

@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready.
Returns:
Expand Down
16 changes: 9 additions & 7 deletions src/python/grpcio/grpc/aio/_base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,33 @@
"""Abstract base classes for Channel objects and Multicallable objects."""

import abc
from typing import Any, Optional
from typing import Generic, Optional

import grpc

from . import _base_call
from ._typing import DeserializingFunction
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction


class UnaryUnaryMultiCallable(abc.ABC):
class UnaryUnaryMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a unary-call RPC."""

@abc.abstractmethod
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC.
Args:
Expand All @@ -63,20 +65,20 @@ def __call__(
"""


class UnaryStreamMultiCallable(abc.ABC):
class UnaryStreamMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a server-streaming RPC."""

@abc.abstractmethod
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC.
Args:
Expand Down
7 changes: 4 additions & 3 deletions src/python/grpcio/grpc/aio/_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import inspect
import logging
import traceback
from typing import AsyncIterator, Optional, Tuple
from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple

import grpc
from grpc import _common
Expand Down Expand Up @@ -252,7 +252,7 @@ class _APIStyle(enum.IntEnum):
READER_WRITER = 2


class _UnaryResponseMixin(Call):
class _UnaryResponseMixin(Call, Generic[ResponseType]):
_call_response: asyncio.Task

def _init_unary_response_mixin(self, response_task: asyncio.Task):
Expand All @@ -265,7 +265,7 @@ def cancel(self) -> bool:
else:
return False

def __await__(self) -> ResponseType:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_response
Expand Down Expand Up @@ -573,6 +573,7 @@ async def wait_for_connection(self) -> None:
await self._raise_for_status()


# pylint: disable=too-many-ancestors
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
Expand Down
10 changes: 6 additions & 4 deletions src/python/grpcio/grpc/aio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline

Expand Down Expand Up @@ -121,14 +123,14 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable,

def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:

metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
Expand All @@ -152,14 +154,14 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,

def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:

metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
Expand Down

0 comments on commit 5bd38df

Please sign in to comment.