Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX get config from dispatcher thread in delayed by default #25242

Closed
wants to merge 16 commits into from
Closed
15 changes: 15 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,24 @@ Version 1.2.1

**In Development**

Changes impacting all modules
-----------------------------

- |Fix| Fix a bug that was ignoring the global configuration in estimators using
`n_jobs > 1`. This bug was triggered for the tasks not dispatch by the main
thread in `joblib` since :func:`sklearn.get_config` uses thread local configuration.
:pr:`25242` by :user:`Guillaume Lemaitre <glemaitre>`.

Changelog
---------

:mod:`sklearn`
..............

- |Enhancement| :func:`sklearn.get_config` takes a parameter `thread` allowing to
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
retrieve the local configuration of this specific `thread`.
:pr:`25242` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.base`
...................

Expand Down
32 changes: 25 additions & 7 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Global configuration state and functions for management
"""
import os
from contextlib import contextmanager as contextmanager
import threading

_global_config = {
from contextlib import contextmanager as contextmanager
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
from typing import Dict # noqa
from weakref import WeakKeyDictionary

_global_config_default = {
"assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
Expand All @@ -17,19 +20,31 @@
"transform_output": "default",
}
_threadlocal = threading.local()
_thread_config = WeakKeyDictionary() # type: WeakKeyDictionary[threading.Thread, Dict]


def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration. If the configuration
does not exist, copy the default global configuration."""
"""Get a threadlocal **mutable** configuration.

If the configuration does not exist, copy the default global configuration.
The configuration is also registered to a global dictionary where the keys
are weak references to the thread objects.
"""
if not hasattr(_threadlocal, "global_config"):
_threadlocal.global_config = _global_config.copy()
_threadlocal.global_config = _global_config_default.copy()
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
_thread_config[threading.current_thread()] = _threadlocal.global_config
return _threadlocal.global_config


def get_config():
def get_config(thread=None):
"""Retrieve current values for configuration set by :func:`set_config`.

Parameters
----------
thread : Thread, default=None
The thread for which to retrieve the configuration. If None, the
configuration of the current thread is returned.

Returns
-------
config : dict
Expand All @@ -42,7 +57,10 @@ def get_config():
"""
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()
threadlocal_config = _get_threadlocal_config()
if thread is None:
return threadlocal_config.copy()
return _thread_config[thread].copy()


def set_config(
Expand Down
43 changes: 36 additions & 7 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
import threading
from concurrent.futures import ThreadPoolExecutor

from joblib import Parallel
Expand Down Expand Up @@ -120,29 +121,57 @@ def test_config_threadsafe_joblib(backend):
should be the same as the value passed to the function. In other words,
it is not influenced by the other job setting assume_finite to True.
"""
assume_finites = [False, True]
sleep_durations = [0.1, 0.2]
assume_finites = [False, True, False, True]
sleep_durations = [0.1, 0.2, 0.1, 0.2]

items = Parallel(backend=backend, n_jobs=2)(
items = Parallel(backend=backend, n_jobs=2, pre_dispatch=2)(
delayed(set_assume_finite)(assume_finite, sleep_dur)
for assume_finite, sleep_dur in zip(assume_finites, sleep_durations)
)

assert items == [False, True]
assert items == [False, True, False, True]


def test_config_threadsafe():
"""Uses threads directly to test that the global config does not change
between threads. Same test as `test_config_threadsafe_joblib` but with
`ThreadPoolExecutor`."""

assume_finites = [False, True]
sleep_durations = [0.1, 0.2]
assume_finites = [False, True, False, True]
sleep_durations = [0.1, 0.2, 0.1, 0.2]

with ThreadPoolExecutor(max_workers=2) as e:
items = [
output
for output in e.map(set_assume_finite, assume_finites, sleep_durations)
]

assert items == [False, True]
assert items == [False, True, False, True]


def test_get_config_thread_dependent():
"""Check that we can retrieve the config file from a specific thread."""
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

def set_definitive_assume_finite(assume_finite, sleep_duration):
set_config(assume_finite=assume_finite)
time.sleep(sleep_duration)
return get_config()["assume_finite"]

thread = threading.Thread(target=set_definitive_assume_finite, args=(True, 0.1))
thread.start()
thread.join()

thread_specific_config = get_config(thread=thread)
assert thread_specific_config["assume_finite"] is True
main_thread_config = get_config()
assert main_thread_config["assume_finite"] is False

# check that we have 2 threads registered in the thread config dictionary
from sklearn._config import _thread_config

assert len(_thread_config) == 2

# delete the thread and check that the dictionary does keep a reference to it
# since we use a weakref dictionary
del thread
assert len(_thread_config) == 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failure with ARM shows that we register more threads than expected because we run the test in parallel. It is probably best to rely upon that the weakref dictionary does the job when joblib kill the thread.

18 changes: 11 additions & 7 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from importlib import resources
import functools
import sys
import threading

import sklearn
import numpy as np
Expand Down Expand Up @@ -107,22 +108,25 @@ def _eigh(*args, **kwargs):


# remove when https://github.com/joblib/joblib/issues/1071 is fixed
def delayed(function):
def delayed(func, thread=threading.current_thread()):
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
"""Decorator used to capture the arguments of a function."""

@functools.wraps(function)
def delayed_function(*args, **kwargs):
return _FuncWrapper(function), args, kwargs
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return _FuncWrapper(func, thread=thread), args, kwargs

return delayed_function
return wrapper

return decorate(func)


class _FuncWrapper:
""" "Load the global configuration before calling the function."""

def __init__(self, function):
def __init__(self, function, thread):
self.function = function
self.config = get_config()
self.config = get_config(thread=thread)
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
update_wrapper(self, self.function)

def __call__(self, *args, **kwargs):
Expand Down
41 changes: 39 additions & 2 deletions sklearn/utils/tests/test_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
# License: BSD 3 clause

import math
import threading

import numpy as np
import pytest
import scipy.stats

from joblib import Parallel

import sklearn
from sklearn.utils._testing import assert_array_equal

from sklearn.utils.fixes import _object_dtype_isnan
from sklearn.utils.fixes import loguniform
from sklearn.utils.fixes import _object_dtype_isnan, delayed, loguniform


@pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1]))
Expand Down Expand Up @@ -46,3 +49,37 @@ def test_loguniform(low, high, base):
assert loguniform(base**low, base**high).rvs(random_state=0) == loguniform(
base**low, base**high
).rvs(random_state=0)


def test_delayed_fetching_right_config():
"""Check that `delayed` function fetches the right config associated to
the main thread.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/25239
"""

def get_working_memory():
return sklearn.get_config()["working_memory"]

n_iter = 10

# by default, we register the main thread and we should retrieve the
# parameters defined within the context manager
with sklearn.config_context(working_memory=123):
results = Parallel(n_jobs=2, pre_dispatch=4)(
delayed(get_working_memory)() for _ in range(n_iter)
)

assert results == [123] * n_iter

# simulate that we refer to another thread
local_thread = threading.Thread(target=sklearn.get_config)
local_thread.start()
local_thread.join()
with sklearn.config_context(working_memory=123):
results = Parallel(n_jobs=2, pre_dispatch=4)(
delayed(get_working_memory, thread=local_thread)() for _ in range(n_iter)
)

assert results == [get_working_memory()] * n_iter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the default working_memory on the main thread were to change this test would fail:

sklearn.set_config(working_memory=140)
# the following fails
assert results == [get_working_memory()] * n_iter

The less fragile assertion would be check that the local_thread uses the default global config:

from sklearn._config import _global_config_default
assert results == [_global_config_default["working_memory"]] * n_iter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewing this test, it seems that I have found a bug in the way the default value of the thread argument of delayed is defined. I am working on a PR against this PR.