Skip to content

Commit

Permalink
Add invalidation-mode option
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Apr 25, 2024
1 parent ac98ff5 commit 2f12594
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/_pytest/assertion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.known_args_namespace.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None


Expand Down
73 changes: 45 additions & 28 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union

import _imp

from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr
from _pytest._version import version
Expand Down Expand Up @@ -299,23 +302,31 @@ def get_resource_reader(self, name: str) -> TraversableResources:


def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType
fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# 64-bit source file hash
source_hash = source_hash[:8]
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
fp.write(source_hash)
if invalidation_mode == "timestamp":
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)

Check warning on line 326 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L325-L326

Added lines #L325 - L326 were not covered by tests
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash)

Check warning on line 329 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L328-L329

Added lines #L328 - L329 were not covered by tests
fp.write(marshal.dumps(co))


Expand All @@ -329,7 +340,7 @@ def _write_pyc(
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, source_hash, co)
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
Expand Down Expand Up @@ -375,34 +386,40 @@ def _read_pyc(
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(24)
data = fp.read(16)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (24):
if len(data) != (16):
trace("_read_pyc(%s): invalid pyc (too short)" % source)
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
hash = data[16:24]

hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None

Check warning on line 407 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L406-L407

Added lines #L406 - L407 were not covered by tests
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None

Check warning on line 411 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L410-L411

Added lines #L410 - L411 were not covered by tests
elif data[4:8] == b"\x03\x00\x00\x00":
trace("_read_pyc(%s): hash based" % source)
hash = data[8:16]
source_hash = importlib.util.source_hash(source.read_bytes())

Check warning on line 415 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L413-L415

Added lines #L413 - L415 were not covered by tests
if source_hash[:8] == hash:
trace("_read_pyc(%s): source hash match (no change detected)" % source)
else:
if source_hash[:8] != hash:
trace("_read_pyc(%s): hash doesn't match" % source)
return None

Check warning on line 418 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L417-L418

Added lines #L417 - L418 were not covered by tests
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None

try:
co = marshal.load(fp)
except Exception as e:
Expand Down
8 changes: 8 additions & 0 deletions src/_pytest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
)
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)

parser.addini(
"consider_namespace_packages",
type="bool",
Expand Down
96 changes: 82 additions & 14 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from unittest import mock
import zipfile

import _imp

import _pytest._code
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest.assertion import util
Expand Down Expand Up @@ -1128,13 +1130,37 @@ def test_read_pyc_success(self, tmp_path: Path, pytester: Pytester) -> None:
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

# pyc read should still work if only the mtime changed
# Fallback to hash comparison
new_mtime = source_stat.st_mtime + 1.2
os.utime(fn, (new_mtime, new_mtime))
assert source_stat.st_mtime != os.stat(fn).st_mtime
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 0 # timestamp flag set

def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc

config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")

fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")

# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}

fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash

def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

Expand All @@ -1153,35 +1179,77 @@ def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
os.utime(source, (mtime_int, mtime_int))

size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]

code = marshal.dumps(compile(source_bytes, str(source), "exec"))

# Good header.
pyc.write_bytes(magic + flags + mtime + size + hash + code)
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is not None

# Too short.
pyc.write_bytes(magic + flags + mtime)
assert _read_pyc(source, pyc, print) is None

# Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code)
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

# Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code)
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code)
# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
assert _read_pyc(source, pyc, print) is None

# Bad mtime + bad hash.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code)
# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None

def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"

source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)

magic = importlib.util.MAGIC_NUMBER

flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"

mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))

size = len(source_bytes).to_bytes(4, "little")

hash = source_hash(source_bytes)
hash = hash[:8]

code = marshal.dumps(compile(source_bytes, str(source), "exec"))

# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None

# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None

# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None

# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change."""
pytester.makeini(
Expand Down

0 comments on commit 2f12594

Please sign in to comment.