Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX make dataset fetchers accept os.Pathlike for data_home #27468

Merged
merged 9 commits into from Sep 28, 2023
4 changes: 4 additions & 0 deletions doc/whats_new/v1.4.rst
Expand Up @@ -182,6 +182,10 @@ Changelog
which returns a dense numpy ndarray as before.
:pr:`27438` by :user:`Yao Xiao <Charlie-XIAO>`.

- |Fix| All dataset fetchers now accept `data_home` as any object that implements
the :class:`os.PathLike` interface, for instance, :class:`pathlib.Path`.
:pr:`27468` by :user:`Yao Xiao <Charlie-XIAO>`.

:mod:`sklearn.decomposition`
............................

Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/_base.py
Expand Up @@ -57,7 +57,7 @@ def get_data_home(data_home=None) -> str:
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/sklearn_learn_data`.
is `~/scikit_learn_data`.
Returns
-------
Expand All @@ -84,7 +84,7 @@ def clear_data_home(data_home=None):
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/sklearn_learn_data`.
is `~/scikit_learn_data`.
"""
data_home = get_data_home(data_home)
shutil.rmtree(data_home)
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_california_housing.py
Expand Up @@ -23,7 +23,7 @@

import logging
import tarfile
from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists

import joblib
Expand Down Expand Up @@ -53,7 +53,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"download_if_missing": ["boolean"],
"return_X_y": ["boolean"],
"as_frame": ["boolean"],
Expand All @@ -76,7 +76,7 @@ def fetch_california_housing(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/_covtype.py
Expand Up @@ -65,7 +65,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"download_if_missing": ["boolean"],
"random_state": ["random_state"],
"shuffle": ["boolean"],
Expand Down Expand Up @@ -98,7 +98,7 @@ def fetch_covtype(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/_kddcup99.py
Expand Up @@ -50,7 +50,7 @@
@validate_params(
{
"subset": [StrOptions({"SA", "SF", "http", "smtp"}), None],
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"shuffle": ["boolean"],
"random_state": ["random_state"],
"percent10": ["boolean"],
Expand Down Expand Up @@ -92,7 +92,7 @@ def fetch_kddcup99(
To return the corresponding classical subsets of kddcup 99.
If None, return the entire kddcup 99 dataset.
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
10 changes: 5 additions & 5 deletions sklearn/datasets/_lfw.py
Expand Up @@ -10,7 +10,7 @@

import logging
from numbers import Integral, Real
from os import listdir, makedirs, remove
from os import PathLike, listdir, makedirs, remove
from os.path import exists, isdir, join

import numpy as np
Expand Down Expand Up @@ -234,7 +234,7 @@ def _fetch_lfw_people(

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"funneled": ["boolean"],
"resize": [Interval(Real, 0, None, closed="neither"), None],
"min_faces_per_person": [Interval(Integral, 0, None, closed="left"), None],
Expand Down Expand Up @@ -272,7 +272,7 @@ def fetch_lfw_people(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down Expand Up @@ -431,7 +431,7 @@ def _fetch_lfw_pairs(
@validate_params(
{
"subset": [StrOptions({"train", "test", "10_folds"})],
"data_home": [str, None],
"data_home": [str, PathLike, None],
"funneled": ["boolean"],
"resize": [Interval(Real, 0, None, closed="neither"), None],
"color": ["boolean"],
Expand Down Expand Up @@ -480,7 +480,7 @@ def fetch_lfw_pairs(
official evaluation set that is meant to be used with a 10-folds
cross validation.
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By
default all scikit-learn data is stored in '~/scikit_learn_data'
subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_olivetti_faces.py
Expand Up @@ -13,7 +13,7 @@
# Copyright (c) 2011 David Warde-Farley <wardefar at iro dot umontreal dot ca>
# License: BSD 3 clause

from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists

import joblib
Expand All @@ -36,7 +36,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"shuffle": ["boolean"],
"random_state": ["random_state"],
"download_if_missing": ["boolean"],
Expand Down Expand Up @@ -67,7 +67,7 @@ def fetch_olivetti_faces(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_openml.py
Expand Up @@ -749,7 +749,7 @@ def _valid_data_column_names(features_list, target_columns):
"name": [str, None],
"version": [Interval(Integral, 1, None, closed="left"), StrOptions({"active"})],
"data_id": [Interval(Integral, 1, None, closed="left"), None],
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"target_column": [str, list, None],
"cache": [bool],
"return_X_y": [bool],
Expand All @@ -769,7 +769,7 @@ def fetch_openml(
*,
version: Union[str, int] = "active",
data_id: Optional[int] = None,
data_home: Optional[str] = None,
data_home: Optional[Union[str, os.PathLike]] = None,
target_column: Optional[Union[str, List]] = "default-target",
cache: bool = True,
return_X_y: bool = False,
Expand Down Expand Up @@ -815,7 +815,7 @@ def fetch_openml(
dataset. If data_id is not given, name (and potential version) are
used to obtain a dataset.
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the data sets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_rcv1.py
Expand Up @@ -10,7 +10,7 @@

import logging
from gzip import GzipFile
from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists, join

import joblib
Expand Down Expand Up @@ -74,7 +74,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"subset": [StrOptions({"train", "test", "all"})],
"download_if_missing": ["boolean"],
"random_state": ["random_state"],
Expand Down Expand Up @@ -111,7 +111,7 @@ def fetch_rcv1(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_species_distributions.py
Expand Up @@ -39,7 +39,7 @@

import logging
from io import BytesIO
from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists

import joblib
Expand Down Expand Up @@ -136,7 +136,7 @@ def construct_grids(batch):


@validate_params(
{"data_home": [str, None], "download_if_missing": ["boolean"]},
{"data_home": [str, PathLike, None], "download_if_missing": ["boolean"]},
prefer_skip_nested_validation=True,
)
def fetch_species_distributions(*, data_home=None, download_if_missing=True):
Expand All @@ -146,7 +146,7 @@ def fetch_species_distributions(*, data_home=None, download_if_missing=True):
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
8 changes: 4 additions & 4 deletions sklearn/datasets/_twenty_newsgroups.py
Expand Up @@ -153,7 +153,7 @@ def strip_newsgroup_footer(text):

@validate_params(
{
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"subset": [StrOptions({"train", "test", "all"})],
"categories": ["array-like", None],
"shuffle": ["boolean"],
Expand Down Expand Up @@ -191,7 +191,7 @@ def fetch_20newsgroups(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify a download and cache folder for the datasets. If None,
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down Expand Up @@ -351,7 +351,7 @@ def fetch_20newsgroups(
{
"subset": [StrOptions({"train", "test", "all"})],
"remove": [tuple],
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"download_if_missing": ["boolean"],
"return_X_y": ["boolean"],
"normalize": ["boolean"],
Expand Down Expand Up @@ -411,7 +411,7 @@ def fetch_20newsgroups_vectorized(
ends of posts that look like signatures, and 'quotes' removes lines
that appear to be quoting another post.
data_home : str, default=None
data_home : str or path-like, default=None
Specify an download and cache folder for the datasets. If None,
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
18 changes: 17 additions & 1 deletion sklearn/datasets/tests/test_base.py
Expand Up @@ -3,6 +3,7 @@
import tempfile
import warnings
from functools import partial
from pathlib import Path
from pickle import dumps, loads

import numpy as np
Expand Down Expand Up @@ -31,6 +32,16 @@
from sklearn.utils.fixes import _is_resource


class _DummyPath:
"""Minimal class that implements the os.PathLike interface."""

def __init__(self, path):
self.path = path

def __fspath__(self):
return self.path


def _remove_dir(path):
if os.path.isdir(path):
shutil.rmtree(path)
Expand Down Expand Up @@ -67,13 +78,18 @@ def test_category_dir_2(load_files_root):
_remove_dir(test_category_dir2)


def test_data_home(data_home):
@pytest.mark.parametrize("path_container", [None, Path, _DummyPath])
def test_data_home(path_container, data_home):
# get_data_home will point to a pre-existing folder
if path_container is not None:
data_home = path_container(data_home)
data_home = get_data_home(data_home=data_home)
assert data_home == data_home
assert os.path.exists(data_home)

# clear_data_home will delete both the content and the folder it-self
if path_container is not None:
data_home = path_container(data_home)
clear_data_home(data_home=data_home)
assert not os.path.exists(data_home)

Expand Down