Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hash comparison for pyc cache files #11418

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/11418.improvement.rst
@@ -0,0 +1 @@
Added hash comparison for pyc cache files.
1 change: 1 addition & 0 deletions src/_pytest/assertion/__init__.py
Expand Up @@ -94,6 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.known_args_namespace.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None


Expand Down
74 changes: 52 additions & 22 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -23,13 +23,16 @@
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union

import _imp

from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr
from _pytest._version import version
Expand Down Expand Up @@ -166,11 +169,11 @@
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace(f"rewriting {fn!r}")
source_stat, co = _rewrite_test(fn, self.config)
source_stat, source_hash, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
_write_pyc(state, co, source_stat, pyc)
_write_pyc(state, co, source_stat, source_hash, pyc)
finally:
self._writing_pyc = False
else:
Expand Down Expand Up @@ -299,33 +302,45 @@


def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
if invalidation_mode == "timestamp":
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)

Check warning on line 326 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L325-L326

Added lines #L325 - L326 were not covered by tests
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash)

Check warning on line 329 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L328-L329

Added lines #L328 - L329 were not covered by tests
fp.write(marshal.dumps(co))


def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
source_hash: bytes,
pyc: Path,
) -> bool:
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, co)
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
Expand All @@ -341,15 +356,18 @@
return True


def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
def _rewrite_test(
fn: Path, config: Config
) -> Tuple[os.stat_result, bytes, types.CodeType]:
"""Read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
source = fn.read_bytes()
source_hash = importlib.util.source_hash(source)
strfn = str(fn)
tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True)
return stat, co
return stat, source_hash, co


def _read_pyc(
Expand Down Expand Up @@ -379,17 +397,29 @@
if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None
if data[4:8] != b"\x00\x00\x00\x00":

hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None

Check warning on line 407 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L406-L407

Added lines #L406 - L407 were not covered by tests
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None

Check warning on line 411 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L410-L411

Added lines #L410 - L411 were not covered by tests
elif data[4:8] == b"\x03\x00\x00\x00":
trace("_read_pyc(%s): hash based" % source)
hash = data[8:16]
source_hash = importlib.util.source_hash(source.read_bytes())

Check warning on line 415 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L413-L415

Added lines #L413 - L415 were not covered by tests
if source_hash[:8] != hash:
trace("_read_pyc(%s): hash doesn't match" % source)
return None

Check warning on line 418 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L417-L418

Added lines #L417 - L418 were not covered by tests
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None

try:
co = marshal.load(fp)
except Exception as e:
Expand Down
8 changes: 8 additions & 0 deletions src/_pytest/main.py
Expand Up @@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
)
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)

parser.addini(
"consider_namespace_packages",
type="bool",
Expand Down
88 changes: 84 additions & 4 deletions testing/test_assertrewrite.py
Expand Up @@ -4,6 +4,7 @@
from functools import partial
import glob
import importlib
from importlib.util import source_hash
import marshal
import os
from pathlib import Path
Expand All @@ -21,6 +22,8 @@
from unittest import mock
import zipfile

import _imp

import _pytest._code
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest.assertion import util
Expand Down Expand Up @@ -1043,12 +1046,14 @@ def test_write_pyc(self, pytester: Pytester, tmp_path) -> None:
state = AssertionState(config, "rewrite")
tmp_path.joinpath("source.py").touch()
source_path = str(tmp_path)
source_bytes = tmp_path.joinpath("source.py").read_bytes()
pycpath = tmp_path.joinpath("pyc")
co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
hash = source_hash(source_bytes)
assert _write_pyc(state, co, os.stat(source_path), hash, pycpath)

with mock.patch.object(os, "replace", side_effect=OSError):
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath)

def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
"""
Expand Down Expand Up @@ -1121,10 +1126,41 @@ def test_read_pyc_success(self, tmp_path: Path, pytester: Pytester) -> None:

fn.write_text("def test(): assert True", encoding="utf-8")

source_stat, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, pyc)
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 0 # timestamp flag set

def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc

config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")

fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")

# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}

fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash

def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

Expand Down Expand Up @@ -1170,6 +1206,50 @@ def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None

def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"

source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)

magic = importlib.util.MAGIC_NUMBER

flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"

mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))

size = len(source_bytes).to_bytes(4, "little")

hash = source_hash(source_bytes)
hash = hash[:8]

code = marshal.dumps(compile(source_bytes, str(source), "exec"))

# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None

# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None

# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None

# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None

def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change."""
pytester.makeini(
Expand Down