Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aio types] Fix some grpc.aio python types #32475

Merged
merged 4 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/python/grpcio/grpc/aio/_base_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from abc import ABCMeta
from abc import abstractmethod
from typing import AsyncIterator, Awaitable, Generic, Optional, Union
from typing import Any, AsyncIterator, Generator, Generic, Optional, Union

import grpc

Expand Down Expand Up @@ -141,7 +141,7 @@ class UnaryUnaryCall(Generic[RequestType, ResponseType],
"""The abstract base class of an unary-unary RPC on the client-side."""

@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready.

Returns:
Expand Down Expand Up @@ -197,7 +197,7 @@ async def done_writing(self) -> None:
"""

@abstractmethod
def __await__(self) -> Awaitable[ResponseType]:
def __await__(self) -> Generator[Any, None, ResponseType]:
"""Await the response message to be ready.

Returns:
Expand Down
16 changes: 9 additions & 7 deletions src/python/grpcio/grpc/aio/_base_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,33 @@
"""Abstract base classes for Channel objects and Multicallable objects."""

import abc
from typing import Any, Optional
from typing import Generic, Optional

import grpc

from . import _base_call
from ._typing import DeserializingFunction
from ._typing import MetadataType
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction


class UnaryUnaryMultiCallable(abc.ABC):
class UnaryUnaryMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a unary-call RPC."""

@abc.abstractmethod
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC.

Args:
Expand All @@ -63,20 +65,20 @@ def __call__(
"""


class UnaryStreamMultiCallable(abc.ABC):
class UnaryStreamMultiCallable(Generic[RequestType, ResponseType], abc.ABC):
"""Enables asynchronous invocation of a server-streaming RPC."""

@abc.abstractmethod
def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[MetadataType] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:
"""Asynchronously invokes the underlying RPC.

Args:
Expand Down
7 changes: 4 additions & 3 deletions src/python/grpcio/grpc/aio/_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import inspect
import logging
import traceback
from typing import AsyncIterator, Optional, Tuple
from typing import Any, AsyncIterator, Generator, Generic, Optional, Tuple

import grpc
from grpc import _common
Expand Down Expand Up @@ -252,7 +252,7 @@ class _APIStyle(enum.IntEnum):
READER_WRITER = 2


class _UnaryResponseMixin(Call):
class _UnaryResponseMixin(Call, Generic[ResponseType]):
_call_response: asyncio.Task

def _init_unary_response_mixin(self, response_task: asyncio.Task):
Expand All @@ -265,7 +265,7 @@ def cancel(self) -> bool:
else:
return False

def __await__(self) -> ResponseType:
def __await__(self) -> Generator[Any, None, ResponseType]:
Tasssadar marked this conversation as resolved.
Show resolved Hide resolved
"""Wait till the ongoing RPC request finishes."""
try:
response = yield from self._call_response
Expand Down Expand Up @@ -573,6 +573,7 @@ async def wait_for_connection(self) -> None:
await self._raise_for_status()


# pylint: disable=too-many-ancestors
class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
_base_call.StreamUnaryCall):
"""Object for managing stream-unary RPC calls.
Expand Down
10 changes: 6 additions & 4 deletions src/python/grpcio/grpc/aio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from ._typing import ChannelArgumentType
from ._typing import DeserializingFunction
from ._typing import RequestIterableType
from ._typing import RequestType
from ._typing import ResponseType
from ._typing import SerializingFunction
from ._utils import _timeout_to_deadline

Expand Down Expand Up @@ -121,14 +123,14 @@ class UnaryUnaryMultiCallable(_BaseMultiCallable,

def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryUnaryCall:
) -> _base_call.UnaryUnaryCall[RequestType, ResponseType]:

metadata = self._init_metadata(metadata, compression)
if not self._interceptors:
Expand All @@ -152,14 +154,14 @@ class UnaryStreamMultiCallable(_BaseMultiCallable,

def __call__(
self,
request: Any,
request: RequestType,
*,
timeout: Optional[float] = None,
metadata: Optional[Metadata] = None,
credentials: Optional[grpc.CallCredentials] = None,
wait_for_ready: Optional[bool] = None,
compression: Optional[grpc.Compression] = None
) -> _base_call.UnaryStreamCall:
) -> _base_call.UnaryStreamCall[RequestType, ResponseType]:

metadata = self._init_metadata(metadata, compression)
deadline = _timeout_to_deadline(timeout)
Expand Down