Skip to content

Commit

Permalink
Add hash comparison for pyc cache files
Browse files Browse the repository at this point in the history
  • Loading branch information
cdce8p committed Sep 8, 2023
1 parent 0a06db0 commit 452c2b9
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 23 deletions.
1 change: 1 addition & 0 deletions changelog/11418.feature.rst
@@ -0,0 +1 @@
Added hash comparison for pyc cache files.
37 changes: 25 additions & 12 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -155,11 +155,11 @@ def exec_module(self, module: types.ModuleType) -> None:
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace(f"rewriting {fn!r}")
source_stat, co = _rewrite_test(fn, self.config)
source_stat, source_hash, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
_write_pyc(state, co, source_stat, pyc)
_write_pyc(state, co, source_stat, source_hash, pyc)
finally:
self._writing_pyc = False
else:
Expand Down Expand Up @@ -290,7 +290,7 @@ def get_resource_reader(self, name: str) -> TraversableResources: # type: ignor


def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
Expand All @@ -302,21 +302,25 @@ def _write_pyc_fp(
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# 64-bit source file hash
source_hash = source_hash[:8]
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
fp.write(source_hash)
fp.write(marshal.dumps(co))


def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
source_hash: bytes,
pyc: Path,
) -> bool:
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, co)
_write_pyc_fp(fp, source_stat, source_hash, co)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
Expand All @@ -332,15 +336,18 @@ def _write_pyc(
return True


def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
def _rewrite_test(
fn: Path, config: Config
) -> Tuple[os.stat_result, bytes, types.CodeType]:
"""Read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
source = fn.read_bytes()
source_hash: bytes = importlib.util.source_hash(source) # type: ignore[assignment]
strfn = str(fn)
tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True)
return stat, co
return stat, source_hash, co


def _read_pyc(
Expand All @@ -359,12 +366,12 @@ def _read_pyc(
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(16)
data = fp.read(24)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (16):
if len(data) != (24):
trace("_read_pyc(%s): invalid pyc (too short)" % source)
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
Expand All @@ -373,14 +380,20 @@ def _read_pyc(
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
hash = data[16:24]
source_hash: bytes = importlib.util.source_hash(source.read_bytes()) # type: ignore[assignment]
if source_hash[:8] == hash:
trace("_read_pyc(%s): source hash match (no change detected)" % source)

Check warning on line 393 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L393

Added line #L393 was not covered by tests
else:
trace("_read_pyc(%s): hash doesn't match" % source)
return None
try:
co = marshal.load(fp)
except Exception as e:
Expand Down
35 changes: 24 additions & 11 deletions testing/test_assertrewrite.py
Expand Up @@ -10,7 +10,9 @@
import textwrap
import zipfile
from functools import partial
from importlib.util import source_hash
from pathlib import Path
from time import sleep
from typing import cast
from typing import Dict
from typing import Generator
Expand Down Expand Up @@ -1039,12 +1041,14 @@ def test_write_pyc(self, pytester: Pytester, tmp_path) -> None:
state = AssertionState(config, "rewrite")
tmp_path.joinpath("source.py").touch()
source_path = str(tmp_path)
source_bytes = tmp_path.joinpath("source.py").read_bytes()
pycpath = tmp_path.joinpath("pyc")
co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
hash: bytes = source_hash(source_bytes) # type: ignore[assignment]
assert _write_pyc(state, co, os.stat(source_path), hash, pycpath)

with mock.patch.object(os, "replace", side_effect=OSError):
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath)

def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
"""
Expand Down Expand Up @@ -1116,8 +1120,15 @@ def test_read_pyc_success(self, tmp_path: Path, pytester: Pytester) -> None:

fn.write_text("def test(): assert True", encoding="utf-8")

source_stat, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, pyc)
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None

# pyc read should still work if only the mtime changed
# Fallback to hash comparison
sleep(0.1)
fn.touch()
assert source_stat.st_mtime != os.stat(fn).st_mtime
assert _read_pyc(fn, pyc, state.trace) is not None

def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
Expand All @@ -1138,31 +1149,33 @@ def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
os.utime(source, (mtime_int, mtime_int))

size = len(source_bytes).to_bytes(4, "little")
hash: bytes = source_hash(source_bytes) # type: ignore[assignment]
hash = hash[:8]

code = marshal.dumps(compile(source_bytes, str(source), "exec"))

# Good header.
pyc.write_bytes(magic + flags + mtime + size + code)
pyc.write_bytes(magic + flags + mtime + size + hash + code)
assert _read_pyc(source, pyc, print) is not None

# Too short.
pyc.write_bytes(magic + flags + mtime)
assert _read_pyc(source, pyc, print) is None

# Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code)
assert _read_pyc(source, pyc, print) is None

# Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code)
assert _read_pyc(source, pyc, print) is None

# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code)
assert _read_pyc(source, pyc, print) is None

# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
# Bad mtime + bad hash.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None

def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
Expand Down

0 comments on commit 452c2b9

Please sign in to comment.