Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX get config from dispatcher thread in delayed by default #25242

Closed
wants to merge 16 commits into from
Closed
15 changes: 15 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,24 @@ Version 1.2.1

**In Development**

Changes impacting all modules
-----------------------------

- |Fix| Fix a bug that was ignoring the global configuration in estimators using
`n_jobs > 1`. This bug was triggered for the tasks not dispatch by the main
thread in `joblib` since :func:`sklearn.get_config` uses thread local configuration.
:pr:`25242` by :user:`Guillaume Lemaitre <glemaitre>`.

Changelog
---------

:mod:`sklearn`
..............

- |Enhancement| :func:`sklearn.get_config` takes a parameter `thread` allowing to
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
retrieve the local configuration of this specific `thread`.
:pr:`25242` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.base`
...................

Expand Down
36 changes: 23 additions & 13 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Global configuration state and functions for management
"""
import os
from contextlib import contextmanager as contextmanager
import threading

_global_config = {
from contextlib import contextmanager as contextmanager
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
from typing import Dict # noqa
from weakref import WeakKeyDictionary

_global_config_default = {
"assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
Expand All @@ -16,15 +19,22 @@
"array_api_dispatch": False,
"transform_output": "default",
}
_threadlocal = threading.local()
_thread_config = WeakKeyDictionary() # type: WeakKeyDictionary[threading.Thread, Dict]


def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration. If the configuration
does not exist, copy the default global configuration."""
if not hasattr(_threadlocal, "global_config"):
_threadlocal.global_config = _global_config.copy()
return _threadlocal.global_config
def _get_thread_config(thread=None):
"""Get a thread **mutable** configuration.

If the configuration does not exist, copy the default global configuration.
The configuration is also registered to a global dictionary where the keys
are weak references to the thread objects.
"""
if thread is None:
thread = threading.current_thread()

if thread not in _thread_config:
_thread_config[thread] = _global_config_default.copy()
return _thread_config[thread]


def get_config():
Expand All @@ -40,9 +50,9 @@ def get_config():
config_context : Context manager for global scikit-learn configuration.
set_config : Set global scikit-learn configuration.
"""
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()
# Return a copy of the configuration so that users will not be able to
# modify the configuration with the returned dict.
return _get_thread_config().copy()


def set_config(
Expand Down Expand Up @@ -139,7 +149,7 @@ def set_config(
config_context : Context manager for global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
"""
local_config = _get_threadlocal_config()
local_config = _get_thread_config()

if assume_finite is not None:
local_config["assume_finite"] = assume_finite
Expand Down
35 changes: 28 additions & 7 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import time
import threading
from concurrent.futures import ThreadPoolExecutor

from joblib import Parallel
import pytest

from sklearn._config import _get_thread_config

from sklearn import get_config, set_config, config_context
from sklearn.utils.fixes import delayed

Expand Down Expand Up @@ -120,29 +123,47 @@ def test_config_threadsafe_joblib(backend):
should be the same as the value passed to the function. In other words,
it is not influenced by the other job setting assume_finite to True.
"""
assume_finites = [False, True]
sleep_durations = [0.1, 0.2]
assume_finites = [False, True, False, True]
sleep_durations = [0.1, 0.2, 0.1, 0.2]

items = Parallel(backend=backend, n_jobs=2)(
items = Parallel(backend=backend, n_jobs=2, pre_dispatch=2)(
delayed(set_assume_finite)(assume_finite, sleep_dur)
for assume_finite, sleep_dur in zip(assume_finites, sleep_durations)
)

assert items == [False, True]
assert items == [False, True, False, True]


def test_config_threadsafe():
"""Uses threads directly to test that the global config does not change
between threads. Same test as `test_config_threadsafe_joblib` but with
`ThreadPoolExecutor`."""

assume_finites = [False, True]
sleep_durations = [0.1, 0.2]
assume_finites = [False, True, False, True]
sleep_durations = [0.1, 0.2, 0.1, 0.2]

with ThreadPoolExecutor(max_workers=2) as e:
items = [
output
for output in e.map(set_assume_finite, assume_finites, sleep_durations)
]

assert items == [False, True]
assert items == [False, True, False, True]


def test_get_thread_config():
"""Check that we can retrieve the config file from a specific thread."""
glemaitre marked this conversation as resolved.
Show resolved Hide resolved

def set_definitive_assume_finite(assume_finite, sleep_duration):
set_config(assume_finite=assume_finite)
time.sleep(sleep_duration)
return _get_thread_config()["assume_finite"]

thread = threading.Thread(target=set_definitive_assume_finite, args=(True, 0.1))
glemaitre marked this conversation as resolved.
Show resolved Hide resolved
thread.start()
thread.join()

thread_specific_config = _get_thread_config(thread=thread)
assert thread_specific_config["assume_finite"] is True
main_thread_config = _get_thread_config()
assert main_thread_config["assume_finite"] is False
25 changes: 17 additions & 8 deletions sklearn/utils/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
from importlib import resources
import functools
import sys
import threading

import sklearn
import numpy as np
import scipy
import scipy.stats
import threadpoolctl
from .._config import config_context, get_config
from .._config import config_context, _get_thread_config
from ..externals._packaging.version import parse as parse_version


Expand Down Expand Up @@ -107,22 +108,30 @@ def _eigh(*args, **kwargs):


# remove when https://github.com/joblib/joblib/issues/1071 is fixed
def delayed(function):
def delayed(func):
"""Decorator used to capture the arguments of a function."""
return _delayed(func)

@functools.wraps(function)
def delayed_function(*args, **kwargs):
return _FuncWrapper(function), args, kwargs

return delayed_function
def _delayed(func, thread=threading.current_thread()):
Copy link
Member

@ogrisel ogrisel Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Capturing the current thread here is problematic because it makes the behavior of scikit-learn dependent on which thread first imported scikit-learn and scikit-learn's behavior is no longer thread-symmetric.

I tried changing this to:

def _delayed(func, thread=None):
    if thread is None:
        thread = threading.current_thread()

but this does not work either. Thread inspection does not work as intended (too late) when calling Parallel on a generator expression which is the canonical way to use joblib. Instead we should capture the state of the config of the thread that calls Parallel just before the call happens and ship it to all the dispatched tasks.

We just had a live pair-debugging / programming session with @glemaitre on discord and I think we came up with a better solution that is a bit more verbose but also much more explicit (and correct ;). He will open a PR soon and we will be able to have a more informed technical discussion there.

For the longer term we could expose a hook in joblib to better handle this kind of configuration propagation but having a stopgap fix in scikit-learn makes it possible to decouple scikit-learn from the version of joblib.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to #25290 for the better solution

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I was afraid it would come to this kind of more verbose solution. Maybe at the same time this is merge to enable the fix, a separate issue could be opened to discuss the in and outs of the per-thread config ? unless the behavior that is enforced and supported is clear already but that didn't seem to be (the PR where the behavior was enabled does not discuss much)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true that we did not discuss this here but we did during the pair programming session.

@ogrisel agreed that we should keep the current behavior where you don't want a thread modifying the config during that other threads may use it. It is a bit counter-intuitive if we rely on the fact that threads should share memory but the side-effect within scikit-learn would be potentially bad. For instance, you can potentially get different random errors that is not reproducible because it would depend on the config state at a particular moment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should discuss the two options at the next dev meeting:

  • option 1: revert sklearn.get/set_config to always use a single shared config without per-thread isolation as was the case prior in 1.0.
  • option 2: make thread isolation work correctly at the cost of a bit of verbosity as implemented in FIX pass explicit configuration to delayed #25290

I think I am in favor of option 2 but I think it worth discussing this with others.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 1 would resolve the issue for multi-threading, but I think the issue will remain for multiprocessing or loky.

I am okay with Option 2. Most of my concern is how third party developers using joblib need to update their code to use utils.fixes.delayed to work with scikit-learn's config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the developer call, joblib uses another thread to get jobs from a generator, which means Option 1 with a thread local configuration would resolve the current issue.

"""Private function to expose the thread argument."""

def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return _FuncWrapper(func, thread=thread), args, kwargs

return wrapper

return decorate(func)


class _FuncWrapper:
""" "Load the global configuration before calling the function."""

def __init__(self, function):
def __init__(self, function, thread):
self.function = function
self.config = get_config()
self.config = _get_thread_config(thread=thread)
update_wrapper(self, self.function)

def __call__(self, *args, **kwargs):
Expand Down
41 changes: 39 additions & 2 deletions sklearn/utils/tests/test_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
# License: BSD 3 clause

import math
import threading

import numpy as np
import pytest
import scipy.stats

from joblib import Parallel

import sklearn
from sklearn.utils._testing import assert_array_equal

from sklearn.utils.fixes import _object_dtype_isnan
from sklearn.utils.fixes import loguniform
from sklearn.utils.fixes import _delayed, _object_dtype_isnan, loguniform


@pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1]))
Expand Down Expand Up @@ -46,3 +49,37 @@ def test_loguniform(low, high, base):
assert loguniform(base**low, base**high).rvs(random_state=0) == loguniform(
base**low, base**high
).rvs(random_state=0)


def test_delayed_fetching_right_config():
"""Check that `_delayed` function fetches the right config associated to
the main thread.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/25239
"""

def get_working_memory():
return sklearn.get_config()["working_memory"]

n_iter = 10

# by default, we register the main thread and we should retrieve the
# parameters defined within the context manager
with sklearn.config_context(working_memory=123):
results = Parallel(n_jobs=2, pre_dispatch=4)(
_delayed(get_working_memory)() for _ in range(n_iter)
)

assert results == [123] * n_iter

# simulate that we refer to another thread
local_thread = threading.Thread(target=sklearn.get_config)
local_thread.start()
local_thread.join()
with sklearn.config_context(working_memory=123):
results = Parallel(n_jobs=2, pre_dispatch=4)(
_delayed(get_working_memory, thread=local_thread)() for _ in range(n_iter)
)

assert results == [get_working_memory()] * n_iter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the default working_memory on the main thread were to change this test would fail:

sklearn.set_config(working_memory=140)
# the following fails
assert results == [get_working_memory()] * n_iter

The less fragile assertion would be check that the local_thread uses the default global config:

from sklearn._config import _global_config_default
assert results == [_global_config_default["working_memory"]] * n_iter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewing this test, it seems that I have found a bug in the way the default value of the thread argument of delayed is defined. I am working on a PR against this PR.