Skip to content

Commit

Permalink
Start rechecking types on tests (#6555)
Browse files Browse the repository at this point in the history
Reenabling type checking in the tests.

This is the first batch, still a bunch of other files currently
disabled.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Dreamsorcerer and pre-commit-ci[bot] committed May 13, 2023
1 parent 0761da6 commit cbbf36c
Show file tree
Hide file tree
Showing 13 changed files with 251 additions and 180 deletions.
3 changes: 0 additions & 3 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,5 @@ ignore_missing_imports = True
[mypy-gunicorn.*]
ignore_missing_imports = True

[mypy-uvloop]
ignore_missing_imports = True

[mypy-python_on_whales]
ignore_missing_imports = True
22 changes: 17 additions & 5 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""WebSocket protocol versions 13 and 8."""

import asyncio
import collections
import functools
import json
import random
Expand All @@ -10,7 +9,18 @@
import zlib
from enum import IntEnum
from struct import Struct
from typing import Any, Callable, List, Optional, Pattern, Set, Tuple, Union, cast
from typing import (
Any,
Callable,
List,
NamedTuple,
Optional,
Pattern,
Set,
Tuple,
Union,
cast,
)

from typing_extensions import Final

Expand Down Expand Up @@ -80,10 +90,12 @@ class WSMsgType(IntEnum):
DEFAULT_LIMIT: Final[int] = 2**16


_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"])

class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: Optional[str]

class WSMessage(_WSMessageBase):
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.
Expand Down
15 changes: 9 additions & 6 deletions aiohttp/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import contextlib
import inspect
import warnings
from typing import Any, Awaitable, Callable, Dict, Generator, Optional, Type, Union
from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union

import pytest

Expand All @@ -22,9 +22,11 @@
try:
import uvloop
except ImportError: # pragma: no cover
uvloop = None
uvloop = None # type: ignore[assignment]

AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]]
AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]]
AiohttpServer = Callable[[Application], Awaitable[TestServer]]


def pytest_addoption(parser): # type: ignore[no-untyped-def]
Expand Down Expand Up @@ -193,6 +195,7 @@ def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def]
return

loops = metafunc.config.option.aiohttp_loop
avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]]
avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy}

if uvloop is not None: # pragma: no cover
Expand Down Expand Up @@ -242,13 +245,13 @@ def proactor_loop(): # type: ignore[no-untyped-def]


@pytest.fixture
def aiohttp_unused_port(): # type: ignore[no-untyped-def]
def aiohttp_unused_port() -> Callable[[], int]:
"""Return a port that is unused on the current host."""
return _unused_port


@pytest.fixture
def aiohttp_server(loop): # type: ignore[no-untyped-def]
def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]:
"""Factory to create a TestServer instance, given an app.
aiohttp_server(app, **kwargs)
Expand All @@ -271,7 +274,7 @@ async def finalize() -> None:


@pytest.fixture
def aiohttp_raw_server(loop): # type: ignore[no-untyped-def]
def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]:
"""Factory to create a RawTestServer instance, given a web handler.
aiohttp_raw_server(handler, **kwargs)
Expand Down Expand Up @@ -323,7 +326,7 @@ def test_login(aiohttp_client):
@pytest.fixture
def aiohttp_client(
loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient]
) -> Generator[AiohttpClient, None, None]:
) -> Iterator[AiohttpClient]:
"""Factory to create a TestClient instance.
aiohttp_client(app, **kwargs)
Expand Down
35 changes: 21 additions & 14 deletions aiohttp/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .client_ws import ClientWebSocketResponse
from .helpers import _SENTINEL, PY_38, sentinel
from .http import HttpVersion, RawRequestMessage
from .typedefs import StrOrURL
from .web import (
Application,
AppRunner,
Expand Down Expand Up @@ -148,14 +149,14 @@ async def start_server(self, **kwargs: Any) -> None:
async def _make_runner(self, **kwargs: Any) -> BaseRunner:
pass

def make_url(self, path: str) -> URL:
def make_url(self, path: StrOrURL) -> URL:
assert self._root is not None
url = URL(path)
if not self.skip_url_asserts:
assert not url.is_absolute()
return self._root.join(url)
else:
return URL(str(self._root) + path)
return URL(str(self._root) + str(path))

@property
def started(self) -> bool:
Expand Down Expand Up @@ -304,16 +305,20 @@ def session(self) -> ClientSession:
"""
return self._session

def make_url(self, path: str) -> URL:
def make_url(self, path: StrOrURL) -> URL:
return self._server.make_url(path)

async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse:
async def _request(
self, method: str, path: StrOrURL, **kwargs: Any
) -> ClientResponse:
resp = await self._session.request(method, self.make_url(path), **kwargs)
# save it to close later
self._responses.append(resp)
return resp

def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager:
def request(
self, method: str, path: StrOrURL, **kwargs: Any
) -> _RequestContextManager:
"""Routes a request to tested http server.
The interface is identical to aiohttp.ClientSession.request,
Expand All @@ -323,43 +328,45 @@ def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManag
"""
return _RequestContextManager(self._request(method, path, **kwargs))

def get(self, path: str, **kwargs: Any) -> _RequestContextManager:
def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP GET request."""
return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs))

def post(self, path: str, **kwargs: Any) -> _RequestContextManager:
def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP POST request."""
return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs))

def options(self, path: str, **kwargs: Any) -> _RequestContextManager:
def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP OPTIONS request."""
return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs))

def head(self, path: str, **kwargs: Any) -> _RequestContextManager:
def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP HEAD request."""
return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs))

def put(self, path: str, **kwargs: Any) -> _RequestContextManager:
def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PUT request."""
return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs))

def patch(self, path: str, **kwargs: Any) -> _RequestContextManager:
def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs))

def delete(self, path: str, **kwargs: Any) -> _RequestContextManager:
def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager:
"""Perform an HTTP PATCH request."""
return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs))

def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager:
def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager:
"""Initiate websocket connection.
The api corresponds to aiohttp.ClientSession.ws_connect.
"""
return _WSRequestContextManager(self._ws_connect(path, **kwargs))

async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse:
async def _ws_connect(
self, path: StrOrURL, **kwargs: Any
) -> ClientWebSocketResponse:
ws = await self._session.ws_connect(self.make_url(path), **kwargs)
self._websockets.append(ws)
return ws
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def init_signals(self) -> None:
# there is no need to reset it.
signal.signal(signal.SIGCHLD, signal.SIG_DFL)

def handle_quit(self, sig: int, frame: FrameType) -> None:
def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None:
self.alive = False

# worker_int callback
Expand All @@ -194,7 +194,7 @@ def handle_quit(self, sig: int, frame: FrameType) -> None:
# wakeup closing process
self._notify_waiter_done()

def handle_abort(self, sig: int, frame: FrameType) -> None:
def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None:
self.alive = False
self.exit_code = 1
self.cfg.worker_abort(self)
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ cchardet==2.1.7; python_version < "3.10" # Unmaintained: aio-libs/aiohttp#6819
charset-normalizer==2.0.12
frozenlist==1.3.1
gunicorn==20.1.0
uvloop==0.14.0; platform_system!="Windows" and implementation_name=="cpython" and python_version<"3.9" # MagicStack/uvloop#14
uvloop==0.17.0; platform_system!="Windows" and implementation_name=="cpython" and python_version<"3.9" # MagicStack/uvloop#14
yarl==1.9.2
2 changes: 1 addition & 1 deletion requirements/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ uritemplate==4.1.1
# via gidgethub
urllib3==1.26.7
# via requests
uvloop==0.14.0 ; platform_system != "Windows" and implementation_name == "cpython" and python_version < "3.9"
uvloop==0.17.0 ; platform_system != "Windows" and implementation_name == "cpython" and python_version < "3.9"
# via -r requirements/base.txt
virtualenv==20.10.0
# via pre-commit
Expand Down
1 change: 1 addition & 0 deletions requirements/lint.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ mypy==0.982; implementation_name=="cpython"
pre-commit==2.17.0
pytest==6.2.5
slotscheck==0.8.0
uvloop==0.17.0; platform_system!="Windows"

0 comments on commit cbbf36c

Please sign in to comment.