Skip to content

Commit

Permalink
Merge pull request #2979 from jdb8/cpu-sched-getaffinity
Browse files Browse the repository at this point in the history
Use os.sched_getaffinity for cpu counts when available
  • Loading branch information
asottile committed Aug 30, 2023
2 parents 9ebda91 + ea8244b commit ac42dc5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
8 changes: 8 additions & 0 deletions pre_commit/xargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@


def cpu_count() -> int:
try:
# On systems that support it, this will return a more accurate count of
# usable CPUs for the current process, which will take into account
# cgroup limits
return len(os.sched_getaffinity(0))
except AttributeError:
pass

try:
return multiprocessing.cpu_count()
except NotImplementedError:
Expand Down
27 changes: 25 additions & 2 deletions tests/lang_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def fake_expanduser(pth):
yield


@pytest.fixture
def no_sched_getaffinity():
# Simulates an OS without os.sched_getaffinity available (mac/windows)
# https://docs.python.org/3/library/os.html#interface-to-the-scheduler
with mock.patch.object(
os,
'sched_getaffinity',
create=True,
side_effect=AttributeError,
):
yield


def test_exe_exists_does_not_exist(find_exe_mck, homedir_mck):
find_exe_mck.return_value = None
assert lang_base.exe_exists('ruby') is False
Expand Down Expand Up @@ -116,7 +129,17 @@ def test_no_env_noop(tmp_path):
assert before == inside == after


def test_target_concurrency_normal():
def test_target_concurrency_sched_getaffinity(no_sched_getaffinity):
with mock.patch.object(
os,
'sched_getaffinity',
return_value=set(range(345)),
):
with mock.patch.dict(os.environ, clear=True):
assert lang_base.target_concurrency() == 345


def test_target_concurrency_without_sched_getaffinity(no_sched_getaffinity):
with mock.patch.object(multiprocessing, 'cpu_count', return_value=123):
with mock.patch.dict(os.environ, {}, clear=True):
assert lang_base.target_concurrency() == 123
Expand All @@ -134,7 +157,7 @@ def test_target_concurrency_on_travis():
assert lang_base.target_concurrency() == 2


def test_target_concurrency_cpu_count_not_implemented():
def test_target_concurrency_cpu_count_not_implemented(no_sched_getaffinity):
with mock.patch.object(
multiprocessing, 'cpu_count', side_effect=NotImplementedError,
):
Expand Down

0 comments on commit ac42dc5

Please sign in to comment.