Skip to content

Commit

Permalink
Add invalidation-mode option
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Sep 10, 2023
1 parent 6e814c8 commit 109346a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 43 deletions.
1 change: 1 addition & 0 deletions src/_pytest/assertion/__init__.py
Expand Up @@ -83,6 +83,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.option.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None


Expand Down
72 changes: 44 additions & 28 deletions src/_pytest/assertion/rewrite.py
@@ -1,4 +1,5 @@
"""Rewrite assertion AST to produce nice error messages."""
import _imp
import ast
import errno
import functools
Expand All @@ -21,6 +22,7 @@
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Set
Expand Down Expand Up @@ -290,23 +292,31 @@ def get_resource_reader(self, name: str) -> TraversableResources: # type: ignor


def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType
fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# 64-bit source file hash
source_hash = source_hash[:8]
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
fp.write(source_hash)
if invalidation_mode == "timestamp":
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash)
fp.write(marshal.dumps(co))


Expand All @@ -320,7 +330,7 @@ def _write_pyc(
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, source_hash, co)
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
Expand Down Expand Up @@ -367,35 +377,41 @@ def _read_pyc(
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(24)
data = fp.read(16)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (24):
if len(data) != (16):
trace("_read_pyc(%s): invalid pyc (too short)" % source)
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
hash = data[16:24]

Check warning on line 391 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L391

Added line #L391 was not covered by tests
hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
elif data[4:8] == b"\x03\x00\x00\x00":
trace("_read_pyc(%s): hash based" % source)
hash = data[8:16]
# source_hash returns bytes not int: https://github.com/python/typeshed/pull/10686
source_hash: bytes = importlib.util.source_hash(source.read_bytes()) # type: ignore[assignment]
if source_hash[:8] == hash:
trace("_read_pyc(%s): source hash match (no change detected)" % source)
else:
if source_hash[:8] != hash:
trace("_read_pyc(%s): hash doesn't match" % source)
return None
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None

try:
co = marshal.load(fp)
except Exception as e:
Expand Down
7 changes: 7 additions & 0 deletions src/_pytest/main.py
Expand Up @@ -215,6 +215,13 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
)
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)

group = parser.getgroup("debugconfig", "test session debugging and configuration")
group.addoption(
Expand Down
97 changes: 82 additions & 15 deletions testing/test_assertrewrite.py
@@ -1,3 +1,4 @@
import _imp
import ast
import errno
import glob
Expand Down Expand Up @@ -1124,13 +1125,37 @@ def test_read_pyc_success(self, tmp_path: Path, pytester: Pytester) -> None:
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

# pyc read should still work if only the mtime changed
# Fallback to hash comparison
new_mtime = source_stat.st_mtime + 1.2
os.utime(fn, (new_mtime, new_mtime))
assert source_stat.st_mtime != os.stat(fn).st_mtime
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 0 # timestamp flag set

def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc

config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")

fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")

# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}

fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash

def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

Expand All @@ -1149,36 +1174,78 @@ def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
os.utime(source, (mtime_int, mtime_int))

size = len(source_bytes).to_bytes(4, "little")
# source_hash returns bytes not int: https://github.com/python/typeshed/pull/10686
hash: bytes = source_hash(source_bytes) # type: ignore[assignment]
hash = hash[:8]

code = marshal.dumps(compile(source_bytes, str(source), "exec"))

# Good header.
pyc.write_bytes(magic + flags + mtime + size + hash + code)
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is not None

# Too short.
pyc.write_bytes(magic + flags + mtime)
assert _read_pyc(source, pyc, print) is None

# Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code)
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

# Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code)
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code)
# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
assert _read_pyc(source, pyc, print) is None

# Bad mtime + bad hash.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code)
# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None

def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"

source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)

magic = importlib.util.MAGIC_NUMBER

flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"

mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))

size = len(source_bytes).to_bytes(4, "little")

# source_hash returns bytes not int: https://github.com/python/typeshed/pull/10686
hash: bytes = source_hash(source_bytes) # type: ignore[assignment]
hash = hash[:8]

code = marshal.dumps(compile(source_bytes, str(source), "exec"))

# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None

# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None

# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None

# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change."""
pytester.makeini(
Expand Down

0 comments on commit 109346a

Please sign in to comment.