Skip to content

Commit

Permalink
Begin mypy type-checking (#229)
Browse files Browse the repository at this point in the history
* mypy config
* narrow initial type-checking
* add type stubs
* allow any generics
* compat typing
* make types
* add type check step to `make test`
* pre-commit fixes
* remove explicit override

this would require depending on typing_extensions

* exceptions
* implicity rexport
* type ignores
* remove unused
* formatting

---------

Co-authored-by: Giorgio Salluzzo <giorgio.salluzzo@gmail.com>
  • Loading branch information
Kilo59 and mindflayer committed May 8, 2024
1 parent 389d95e commit c434799
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 25 deletions.
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ setup: develop

develop: install-dev-requirements install-test-requirements

test:
types:
@echo "Type checking Python files"
.venv/bin/mypy --pretty
@echo ""

test: types
@echo "Running Python tests"
export VIRTUAL_ENV=.venv; .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest tests/ || exit 1
@echo ""
Expand Down
11 changes: 7 additions & 4 deletions mocket/compat.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
from __future__ import annotations

import codecs
import os
import shlex
from typing import Any, Final

ENCODING = os.getenv("MOCKET_ENCODING", "utf-8")
ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8")

text_type = str
byte_type = bytes
basestring = (str,)


def encode_to_bytes(s, encoding=ENCODING):
def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes:
if isinstance(s, text_type):
s = s.encode(encoding)
return byte_type(s)


def decode_from_bytes(s, encoding=ENCODING):
def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str:
if isinstance(s, byte_type):
s = codecs.decode(s, encoding, "ignore")
return text_type(s)


def shsplit(s):
def shsplit(s: str | bytes) -> list[str]:
s = decode_from_bytes(s)
return shlex.split(s)

Expand Down
50 changes: 30 additions & 20 deletions mocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
from __future__ import annotations

import binascii
import io
import os
import ssl
from typing import Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar

from .compat import decode_from_bytes, encode_to_bytes
from .exceptions import StrictMocketException

if TYPE_CHECKING:
from _typeshed import ReadableBuffer
from typing_extensions import NoReturn

SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2


class MocketSocketCore(io.BytesIO):
def write(self, content):
def write( # type: ignore[override] # BytesIO returns int
self,
content: ReadableBuffer,
) -> None:
super(MocketSocketCore, self).write(content)

from mocket import Mocket
Expand All @@ -20,7 +29,7 @@ def write(self, content):
os.write(Mocket.w_fd, content)


def hexdump(binary_string):
def hexdump(binary_string: bytes) -> str:
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
True
Expand All @@ -29,7 +38,7 @@ def hexdump(binary_string):
return " ".join(a + b for a, b in zip(bs[::2], bs[1::2]))


def hexload(string):
def hexload(string: str) -> bytes:
r"""
>>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo")
True
Expand All @@ -38,39 +47,40 @@ def hexload(string):
return encode_to_bytes(binascii.unhexlify(string_no_spaces))


def get_mocketize(wrapper_):
def get_mocketize(wrapper_: Callable) -> Callable:
import decorator

if decorator.__version__ < "5": # pragma: no cover
if decorator.__version__ < "5": # type: ignore[attr-defined] # pragma: no cover
return decorator.decorator(wrapper_)
return decorator.decorator(wrapper_, kwsyntax=True)
return decorator.decorator( # type: ignore[call-arg] # kwsyntax
wrapper_,
kwsyntax=True,
)


class MocketMode:
__shared_state = {}
STRICT = None
STRICT_ALLOWED = None
__shared_state: ClassVar[dict[str, Any]] = {}
STRICT: ClassVar = None
STRICT_ALLOWED: ClassVar = None

def __init__(self):
def __init__(self) -> None:
self.__dict__ = self.__shared_state

def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool:
def is_allowed(self, location: str | tuple[str, int]) -> bool:
"""
Checks if (`host`, `port`) or at least `host`
are allowed locations to perform real `socket` calls
"""
if not self.STRICT:
return True
try:
host, _ = location
except ValueError:
host = None
return location in self.STRICT_ALLOWED or (
host is not None and host in self.STRICT_ALLOWED
)

host_allowed = False
if isinstance(location, tuple):
host_allowed = location[0] in self.STRICT_ALLOWED
return host_allowed or location in self.STRICT_ALLOWED

@staticmethod
def raise_not_allowed():
def raise_not_allowed() -> NoReturn:
from .mocket import Mocket

current_entries = [
Expand Down
26 changes: 26 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ test = [
"twine",
"fastapi",
"wait-for-it",
"mypy",
"types-decorator",
]
speedups = [
"xxhash;platform_python_implementation=='CPython'",
Expand All @@ -81,3 +83,27 @@ include = [
exclude = [
".*",
]

[tool.mypy]
python_version = "3.8"
files = [
"mocket/exceptions.py",
"mocket/compat.py",
"mocket/utils.py",
# "tests/"
]
strict = true
warn_unused_configs = true
ignore_missing_imports = true
warn_redundant_casts = true
warn_unused_ignores = true
show_error_codes = true
implicit_reexport = true
disallow_any_generics = false
follow_imports = "silent" # enable this once majority is typed
enable_error_code = ['ignore-without-code']
disable_error_code = ["no-untyped-def"] # enable this once full type-coverage is reached

[[tool.mypy.overrides]]
module = "tests.*"
disable_error_code = ['type-arg', 'no-untyped-def']

0 comments on commit c434799

Please sign in to comment.