Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve caching by comparing file hashes as fallback for mtime and size #3821

Merged
merged 8 commits into from Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -31,6 +31,7 @@
<!-- Changes that improve Black's performance. -->

- Avoid importing `IPython` if notebook cells do not contain magics (#3782)
- Improve caching by comparing file hashes as fallback for mtime and size. (#3821)

### Output

Expand Down
7 changes: 7 additions & 0 deletions docs/contributing/reference/reference_classes.rst
Expand Up @@ -186,6 +186,13 @@ Black Classes
:show-inheritance:
:members:

:class:`Cache`
------------------------

.. autoclass:: black.cache.Cache
:show-inheritance:
:members:

Enum Classes
~~~~~~~~~~~~~

Expand Down
8 changes: 0 additions & 8 deletions docs/contributing/reference/reference_functions.rst
Expand Up @@ -94,18 +94,10 @@ Split functions
Caching
-------

.. autofunction:: black.cache.filter_cached

.. autofunction:: black.cache.get_cache_dir

.. autofunction:: black.cache.get_cache_file

.. autofunction:: black.cache.get_cache_info

.. autofunction:: black.cache.read_cache

.. autofunction:: black.cache.write_cache

Utilities
---------

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -67,7 +67,7 @@ dependencies = [
"pathspec>=0.9.0",
"platformdirs>=2",
"tomli>=1.1.0; python_version < '3.11'",
"typing_extensions>=3.10.0.0; python_version < '3.10'",
"typing_extensions>=4.0.1; python_version < '3.11'",
]
dynamic = ["readme", "version"]

Expand Down
11 changes: 4 additions & 7 deletions src/black/__init__.py
Expand Up @@ -34,7 +34,7 @@
from pathspec.patterns.gitwildmatch import GitWildMatchPatternError

from _black_version import version as __version__
from black.cache import Cache, get_cache_info, read_cache, write_cache
from black.cache import Cache
from black.comments import normalize_fmt_off
from black.const import (
DEFAULT_EXCLUDES,
Expand Down Expand Up @@ -775,12 +775,9 @@ def reformat_one(
if format_stdin_to_stdout(fast=fast, write_back=write_back, mode=mode):
changed = Changed.YES
else:
cache: Cache = {}
cache = Cache.read(mode)
if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
cache = read_cache(mode)
res_src = src.resolve()
res_src_s = str(res_src)
if res_src_s in cache and cache[res_src_s] == get_cache_info(res_src):
if not cache.is_changed(src):
changed = Changed.CACHED
if changed is not Changed.CACHED and format_file_in_place(
src, fast=fast, write_back=write_back, mode=mode
Expand All @@ -789,7 +786,7 @@ def reformat_one(
if (write_back is WriteBack.YES and changed is not Changed.CACHED) or (
write_back is WriteBack.CHECK and changed is Changed.NO
):
write_cache(cache, [src], mode)
cache.write([src])
report.done(src, changed)
except Exception as exc:
if report.verbose:
Expand Down
161 changes: 100 additions & 61 deletions src/black/cache.py
@@ -1,21 +1,28 @@
"""Caching of formatted files with feature-based invalidation."""

import hashlib
import os
import pickle
import sys
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, Set, Tuple
from typing import Dict, Iterable, NamedTuple, Set, Tuple

from platformdirs import user_cache_dir

from _black_version import version as __version__
from black.mode import Mode

# types
Timestamp = float
FileSize = int
CacheInfo = Tuple[Timestamp, FileSize]
Cache = Dict[str, CacheInfo]
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class FileData(NamedTuple):
st_mtime: float
st_size: int
hash: str


def get_cache_dir() -> Path:
Expand All @@ -37,61 +44,93 @@ def get_cache_dir() -> Path:
CACHE_DIR = get_cache_dir()


def read_cache(mode: Mode) -> Cache:
"""Read the cache if it exists and is well formed.

If it is not well formed, the call to write_cache later should resolve the issue.
"""
cache_file = get_cache_file(mode)
if not cache_file.exists():
return {}

with cache_file.open("rb") as fobj:
try:
cache: Cache = pickle.load(fobj)
except (pickle.UnpicklingError, ValueError, IndexError):
return {}

return cache


def get_cache_file(mode: Mode) -> Path:
return CACHE_DIR / f"cache.{mode.get_cache_key()}.pickle"


def get_cache_info(path: Path) -> CacheInfo:
"""Return the information used to check if a file is already formatted or not."""
stat = path.stat()
return stat.st_mtime, stat.st_size


def filter_cached(cache: Cache, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
"""Split an iterable of paths in `sources` into two sets.

The first contains paths of files that modified on disk or are not in the
cache. The other contains paths to non-modified files.
"""
todo, done = set(), set()
for src in sources:
res_src = src.resolve()
if cache.get(str(res_src)) != get_cache_info(res_src):
todo.add(src)
else:
done.add(src)
return todo, done


def write_cache(cache: Cache, sources: Iterable[Path], mode: Mode) -> None:
"""Update the cache file."""
cache_file = get_cache_file(mode)
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
new_cache = {
**cache,
**{str(src.resolve()): get_cache_info(src) for src in sources},
}
with tempfile.NamedTemporaryFile(dir=str(cache_file.parent), delete=False) as f:
pickle.dump(new_cache, f, protocol=4)
os.replace(f.name, cache_file)
except OSError:
pass
@dataclass
class Cache:
mode: Mode
cache_file: Path
file_data: Dict[str, FileData] = field(default_factory=dict)

@classmethod
def read(cls, mode: Mode) -> Self:
"""Read the cache if it exists and is well formed.

If it is not well formed, the call to write later should
resolve the issue.
"""
cache_file = get_cache_file(mode)
if not cache_file.exists():
return cls(mode, cache_file)

with cache_file.open("rb") as fobj:
try:
file_data: Dict[str, FileData] = pickle.load(fobj)
except (pickle.UnpicklingError, ValueError, IndexError):
return cls(mode, cache_file)

return cls(mode, cache_file, file_data)

@staticmethod
def hash_digest(path: Path) -> str:
"""Return hash digest for path."""

with open(path, "rb") as fp:
data = fp.read()
cdce8p marked this conversation as resolved.
Show resolved Hide resolved
return hashlib.sha256(data).hexdigest()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it should be fine to use sha1 here, which is about 1.4x faster for me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sha256 works well for mypy. I would stick with it in this case.
https://github.com/python/mypy/blob/v1.5.1/mypy/util.py#L501-L510


@staticmethod
def get_file_data(path: Path) -> FileData:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a global function? staticmethods often feel a bit useless.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like it here as it helps to group these methods nicely. Obviously personal preference. Though, if you want me to change it, I can do that too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just leave it as is, thanks!

"""Return file data for path."""

stat = path.stat()
hash = Cache.hash_digest(path)
return FileData(stat.st_mtime, stat.st_size, hash)

def is_changed(self, source: Path) -> bool:
"""Check if source has changed compared to cached version."""
res_src = source.resolve()
old = self.file_data.get(str(res_src))
if old is None:
return True

st = res_src.stat()
if st.st_size != old.st_size:
return True
if int(st.st_mtime) != int(old.st_mtime):
new_hash = Cache.hash_digest(res_src)
if new_hash != old.hash:
return True
return False

def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
"""Split an iterable of paths in `sources` into two sets.

The first contains paths of files that modified on disk or are not in the
cache. The other contains paths to non-modified files.
"""
changed: Set[Path] = set()
done: Set[Path] = set()
for src in sources:
cdce8p marked this conversation as resolved.
Show resolved Hide resolved
if self.is_changed(src):
changed.add(src)
else:
done.add(src)
return changed, done

def write(self, sources: Iterable[Path]) -> None:
"""Update the cache file data and write a new cache file."""
self.file_data.update(
**{str(src.resolve()): Cache.get_file_data(src) for src in sources}
)
try:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(
dir=str(self.cache_file.parent), delete=False
) as f:
pickle.dump(self.file_data, f, protocol=4)
os.replace(f.name, self.cache_file)
except OSError:
pass
9 changes: 4 additions & 5 deletions src/black/concurrency.py
Expand Up @@ -17,7 +17,7 @@
from mypy_extensions import mypyc_attr

from black import WriteBack, format_file_in_place
from black.cache import Cache, filter_cached, read_cache, write_cache
from black.cache import Cache
from black.mode import Mode
from black.output import err
from black.report import Changed, Report
Expand Down Expand Up @@ -133,10 +133,9 @@ async def schedule_formatting(
`write_back`, `fast`, and `mode` options are passed to
:func:`format_file_in_place`.
"""
cache: Cache = {}
cache = Cache.read(mode)
if write_back not in (WriteBack.DIFF, WriteBack.COLOR_DIFF):
cache = read_cache(mode)
sources, cached = filter_cached(cache, sources)
sources, cached = cache.filtered_cached(sources)
for src in sorted(cached):
report.done(src, Changed.CACHED)
if not sources:
Expand Down Expand Up @@ -185,4 +184,4 @@ async def schedule_formatting(
if cancelled:
await asyncio.gather(*cancelled, return_exceptions=True)
if sources_to_cache:
write_cache(cache, sources_to_cache, mode)
cache.write(sources_to_cache)