Skip to content

Commit

Permalink
FIX make dataset fetchers accept os.Pathlike for data_home (sciki…
Browse files Browse the repository at this point in the history
…t-learn#27468)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
  • Loading branch information
Charlie-XIAO and glemaitre committed Oct 17, 2023
1 parent d99b728 commit a4803c5
Show file tree
Hide file tree
Showing 11 changed files with 47 additions and 31 deletions.
4 changes: 2 additions & 2 deletions sklearn/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_data_home(data_home=None) -> str:
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/sklearn_learn_data`.
is `~/scikit_learn_data`.
Returns
-------
Expand All @@ -84,7 +84,7 @@ def clear_data_home(data_home=None):
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/sklearn_learn_data`.
is `~/scikit_learn_data`.
"""
data_home = get_data_home(data_home)
shutil.rmtree(data_home)
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_california_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import logging
import tarfile
from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists

import joblib
Expand Down Expand Up @@ -53,7 +53,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"download_if_missing": ["boolean"],
"return_X_y": ["boolean"],
"as_frame": ["boolean"],
Expand All @@ -76,7 +76,7 @@ def fetch_california_housing(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/_covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"download_if_missing": ["boolean"],
"random_state": ["random_state"],
"shuffle": ["boolean"],
Expand Down Expand Up @@ -98,7 +98,7 @@ def fetch_covtype(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/_kddcup99.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
@validate_params(
{
"subset": [StrOptions({"SA", "SF", "http", "smtp"}), None],
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"shuffle": ["boolean"],
"random_state": ["random_state"],
"percent10": ["boolean"],
Expand Down Expand Up @@ -92,7 +92,7 @@ def fetch_kddcup99(
To return the corresponding classical subsets of kddcup 99.
If None, return the entire kddcup 99 dataset.
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
10 changes: 5 additions & 5 deletions sklearn/datasets/_lfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import logging
from numbers import Integral, Real
from os import listdir, makedirs, remove
from os import PathLike, listdir, makedirs, remove
from os.path import exists, isdir, join

import numpy as np
Expand Down Expand Up @@ -234,7 +234,7 @@ def _fetch_lfw_people(

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"funneled": ["boolean"],
"resize": [Interval(Real, 0, None, closed="neither"), None],
"min_faces_per_person": [Interval(Integral, 0, None, closed="left"), None],
Expand Down Expand Up @@ -272,7 +272,7 @@ def fetch_lfw_people(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down Expand Up @@ -431,7 +431,7 @@ def _fetch_lfw_pairs(
@validate_params(
{
"subset": [StrOptions({"train", "test", "10_folds"})],
"data_home": [str, None],
"data_home": [str, PathLike, None],
"funneled": ["boolean"],
"resize": [Interval(Real, 0, None, closed="neither"), None],
"color": ["boolean"],
Expand Down Expand Up @@ -480,7 +480,7 @@ def fetch_lfw_pairs(
official evaluation set that is meant to be used with a 10-folds
cross validation.
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By
default all scikit-learn data is stored in '~/scikit_learn_data'
subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_olivetti_faces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Copyright (c) 2011 David Warde-Farley <wardefar at iro dot umontreal dot ca>
# License: BSD 3 clause

from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists

import joblib
Expand All @@ -36,7 +36,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"shuffle": ["boolean"],
"random_state": ["random_state"],
"download_if_missing": ["boolean"],
Expand Down Expand Up @@ -67,7 +67,7 @@ def fetch_olivetti_faces(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def _valid_data_column_names(features_list, target_columns):
"name": [str, None],
"version": [Interval(Integral, 1, None, closed="left"), StrOptions({"active"})],
"data_id": [Interval(Integral, 1, None, closed="left"), None],
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"target_column": [str, list, None],
"cache": [bool],
"return_X_y": [bool],
Expand All @@ -769,7 +769,7 @@ def fetch_openml(
*,
version: Union[str, int] = "active",
data_id: Optional[int] = None,
data_home: Optional[str] = None,
data_home: Optional[Union[str, os.PathLike]] = None,
target_column: Optional[Union[str, List]] = "default-target",
cache: bool = True,
return_X_y: bool = False,
Expand Down Expand Up @@ -815,7 +815,7 @@ def fetch_openml(
dataset. If data_id is not given, name (and potential version) are
used to obtain a dataset.
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the data sets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_rcv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import logging
from gzip import GzipFile
from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists, join

import joblib
Expand Down Expand Up @@ -74,7 +74,7 @@

@validate_params(
{
"data_home": [str, None],
"data_home": [str, PathLike, None],
"subset": [StrOptions({"train", "test", "all"})],
"download_if_missing": ["boolean"],
"random_state": ["random_state"],
Expand Down Expand Up @@ -111,7 +111,7 @@ def fetch_rcv1(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
6 changes: 3 additions & 3 deletions sklearn/datasets/_species_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import logging
from io import BytesIO
from os import makedirs, remove
from os import PathLike, makedirs, remove
from os.path import exists

import joblib
Expand Down Expand Up @@ -136,7 +136,7 @@ def construct_grids(batch):


@validate_params(
{"data_home": [str, None], "download_if_missing": ["boolean"]},
{"data_home": [str, PathLike, None], "download_if_missing": ["boolean"]},
prefer_skip_nested_validation=True,
)
def fetch_species_distributions(*, data_home=None, download_if_missing=True):
Expand All @@ -146,7 +146,7 @@ def fetch_species_distributions(*, data_home=None, download_if_missing=True):
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify another download and cache folder for the datasets. By default
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
8 changes: 4 additions & 4 deletions sklearn/datasets/_twenty_newsgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def strip_newsgroup_footer(text):

@validate_params(
{
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"subset": [StrOptions({"train", "test", "all"})],
"categories": ["array-like", None],
"shuffle": ["boolean"],
Expand Down Expand Up @@ -191,7 +191,7 @@ def fetch_20newsgroups(
Parameters
----------
data_home : str, default=None
data_home : str or path-like, default=None
Specify a download and cache folder for the datasets. If None,
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down Expand Up @@ -351,7 +351,7 @@ def fetch_20newsgroups(
{
"subset": [StrOptions({"train", "test", "all"})],
"remove": [tuple],
"data_home": [str, None],
"data_home": [str, os.PathLike, None],
"download_if_missing": ["boolean"],
"return_X_y": ["boolean"],
"normalize": ["boolean"],
Expand Down Expand Up @@ -411,7 +411,7 @@ def fetch_20newsgroups_vectorized(
ends of posts that look like signatures, and 'quotes' removes lines
that appear to be quoting another post.
data_home : str, default=None
data_home : str or path-like, default=None
Specify an download and cache folder for the datasets. If None,
all scikit-learn data is stored in '~/scikit_learn_data' subfolders.
Expand Down
18 changes: 17 additions & 1 deletion sklearn/datasets/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import warnings
from functools import partial
from pathlib import Path
from pickle import dumps, loads

import numpy as np
Expand Down Expand Up @@ -31,6 +32,16 @@
from sklearn.utils.fixes import _is_resource


class _DummyPath:
"""Minimal class that implements the os.PathLike interface."""

def __init__(self, path):
self.path = path

def __fspath__(self):
return self.path


def _remove_dir(path):
if os.path.isdir(path):
shutil.rmtree(path)
Expand Down Expand Up @@ -67,13 +78,18 @@ def test_category_dir_2(load_files_root):
_remove_dir(test_category_dir2)


def test_data_home(data_home):
@pytest.mark.parametrize("path_container", [None, Path, _DummyPath])
def test_data_home(path_container, data_home):
# get_data_home will point to a pre-existing folder
if path_container is not None:
data_home = path_container(data_home)
data_home = get_data_home(data_home=data_home)
assert data_home == data_home
assert os.path.exists(data_home)

# clear_data_home will delete both the content and the folder it-self
if path_container is not None:
data_home = path_container(data_home)
clear_data_home(data_home=data_home)
assert not os.path.exists(data_home)

Expand Down

0 comments on commit a4803c5

Please sign in to comment.