diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index d4f92548ba0ac..a9ea738beca91 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -182,6 +182,10 @@ Changelog which returns a dense numpy ndarray as before. :pr:`27438` by :user:`Yao Xiao `. +- |Fix| All dataset fetchers now accept `data_home` as any object that implements + the :class:`os.PathLike` interface, for instance, :class:`pathlib.Path`. + :pr:`27468` by :user:`Yao Xiao `. + :mod:`sklearn.decomposition` ............................ diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index b2d198ecf8c2f..5675798137824 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -57,7 +57,7 @@ def get_data_home(data_home=None) -> str: ---------- data_home : str or path-like, default=None The path to scikit-learn data directory. If `None`, the default path - is `~/sklearn_learn_data`. + is `~/scikit_learn_data`. Returns ------- @@ -84,7 +84,7 @@ def clear_data_home(data_home=None): ---------- data_home : str or path-like, default=None The path to scikit-learn data directory. If `None`, the default path - is `~/sklearn_learn_data`. + is `~/scikit_learn_data`. """ data_home = get_data_home(data_home) shutil.rmtree(data_home) diff --git a/sklearn/datasets/_california_housing.py b/sklearn/datasets/_california_housing.py index b48e7e10bdc4b..3153f0dd03f72 100644 --- a/sklearn/datasets/_california_housing.py +++ b/sklearn/datasets/_california_housing.py @@ -23,7 +23,7 @@ import logging import tarfile -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists import joblib @@ -53,7 +53,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "download_if_missing": ["boolean"], "return_X_y": ["boolean"], "as_frame": ["boolean"], @@ -76,7 +76,7 @@ def fetch_california_housing( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_covtype.py b/sklearn/datasets/_covtype.py index 557899bc88e97..7620e08c5ec92 100644 --- a/sklearn/datasets/_covtype.py +++ b/sklearn/datasets/_covtype.py @@ -65,7 +65,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "download_if_missing": ["boolean"], "random_state": ["random_state"], "shuffle": ["boolean"], @@ -98,7 +98,7 @@ def fetch_covtype( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_kddcup99.py b/sklearn/datasets/_kddcup99.py index 17c49161c3bc2..444bd01737901 100644 --- a/sklearn/datasets/_kddcup99.py +++ b/sklearn/datasets/_kddcup99.py @@ -50,7 +50,7 @@ @validate_params( { "subset": [StrOptions({"SA", "SF", "http", "smtp"}), None], - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "shuffle": ["boolean"], "random_state": ["random_state"], "percent10": ["boolean"], @@ -92,7 +92,7 @@ def fetch_kddcup99( To return the corresponding classical subsets of kddcup 99. If None, return the entire kddcup 99 dataset. - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index 345f56e89a03b..d06d29f21d0a5 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -10,7 +10,7 @@ import logging from numbers import Integral, Real -from os import listdir, makedirs, remove +from os import PathLike, listdir, makedirs, remove from os.path import exists, isdir, join import numpy as np @@ -234,7 +234,7 @@ def _fetch_lfw_people( @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "funneled": ["boolean"], "resize": [Interval(Real, 0, None, closed="neither"), None], "min_faces_per_person": [Interval(Integral, 0, None, closed="left"), None], @@ -272,7 +272,7 @@ def fetch_lfw_people( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. @@ -431,7 +431,7 @@ def _fetch_lfw_pairs( @validate_params( { "subset": [StrOptions({"train", "test", "10_folds"})], - "data_home": [str, None], + "data_home": [str, PathLike, None], "funneled": ["boolean"], "resize": [Interval(Real, 0, None, closed="neither"), None], "color": ["boolean"], @@ -480,7 +480,7 @@ def fetch_lfw_pairs( official evaluation set that is meant to be used with a 10-folds cross validation. - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_olivetti_faces.py b/sklearn/datasets/_olivetti_faces.py index 51710faccc417..8e1b3c91e254b 100644 --- a/sklearn/datasets/_olivetti_faces.py +++ b/sklearn/datasets/_olivetti_faces.py @@ -13,7 +13,7 @@ # Copyright (c) 2011 David Warde-Farley # License: BSD 3 clause -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists import joblib @@ -36,7 +36,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "shuffle": ["boolean"], "random_state": ["random_state"], "download_if_missing": ["boolean"], @@ -67,7 +67,7 @@ def fetch_olivetti_faces( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_openml.py b/sklearn/datasets/_openml.py index 1c36dc8a25ce1..c9d09dc3ce46a 100644 --- a/sklearn/datasets/_openml.py +++ b/sklearn/datasets/_openml.py @@ -749,7 +749,7 @@ def _valid_data_column_names(features_list, target_columns): "name": [str, None], "version": [Interval(Integral, 1, None, closed="left"), StrOptions({"active"})], "data_id": [Interval(Integral, 1, None, closed="left"), None], - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "target_column": [str, list, None], "cache": [bool], "return_X_y": [bool], @@ -769,7 +769,7 @@ def fetch_openml( *, version: Union[str, int] = "active", data_id: Optional[int] = None, - data_home: Optional[str] = None, + data_home: Optional[Union[str, os.PathLike]] = None, target_column: Optional[Union[str, List]] = "default-target", cache: bool = True, return_X_y: bool = False, @@ -815,7 +815,7 @@ def fetch_openml( dataset. If data_id is not given, name (and potential version) are used to obtain a dataset. - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the data sets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_rcv1.py b/sklearn/datasets/_rcv1.py index a807d8e311466..d9f392d872216 100644 --- a/sklearn/datasets/_rcv1.py +++ b/sklearn/datasets/_rcv1.py @@ -10,7 +10,7 @@ import logging from gzip import GzipFile -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists, join import joblib @@ -74,7 +74,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "subset": [StrOptions({"train", "test", "all"})], "download_if_missing": ["boolean"], "random_state": ["random_state"], @@ -111,7 +111,7 @@ def fetch_rcv1( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_species_distributions.py b/sklearn/datasets/_species_distributions.py index 0bfc4bb0fdaf5..a1e654d41e071 100644 --- a/sklearn/datasets/_species_distributions.py +++ b/sklearn/datasets/_species_distributions.py @@ -39,7 +39,7 @@ import logging from io import BytesIO -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists import joblib @@ -136,7 +136,7 @@ def construct_grids(batch): @validate_params( - {"data_home": [str, None], "download_if_missing": ["boolean"]}, + {"data_home": [str, PathLike, None], "download_if_missing": ["boolean"]}, prefer_skip_nested_validation=True, ) def fetch_species_distributions(*, data_home=None, download_if_missing=True): @@ -146,7 +146,7 @@ def fetch_species_distributions(*, data_home=None, download_if_missing=True): Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_twenty_newsgroups.py b/sklearn/datasets/_twenty_newsgroups.py index 637cf8e4fc8d4..5973e998c34b9 100644 --- a/sklearn/datasets/_twenty_newsgroups.py +++ b/sklearn/datasets/_twenty_newsgroups.py @@ -153,7 +153,7 @@ def strip_newsgroup_footer(text): @validate_params( { - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "subset": [StrOptions({"train", "test", "all"})], "categories": ["array-like", None], "shuffle": ["boolean"], @@ -191,7 +191,7 @@ def fetch_20newsgroups( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify a download and cache folder for the datasets. If None, all scikit-learn data is stored in '~/scikit_learn_data' subfolders. @@ -351,7 +351,7 @@ def fetch_20newsgroups( { "subset": [StrOptions({"train", "test", "all"})], "remove": [tuple], - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "download_if_missing": ["boolean"], "return_X_y": ["boolean"], "normalize": ["boolean"], @@ -411,7 +411,7 @@ def fetch_20newsgroups_vectorized( ends of posts that look like signatures, and 'quotes' removes lines that appear to be quoting another post. - data_home : str, default=None + data_home : str or path-like, default=None Specify an download and cache folder for the datasets. If None, all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index f31f20636c0c1..f84c275d67cf9 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -3,6 +3,7 @@ import tempfile import warnings from functools import partial +from pathlib import Path from pickle import dumps, loads import numpy as np @@ -31,6 +32,16 @@ from sklearn.utils.fixes import _is_resource +class _DummyPath: + """Minimal class that implements the os.PathLike interface.""" + + def __init__(self, path): + self.path = path + + def __fspath__(self): + return self.path + + def _remove_dir(path): if os.path.isdir(path): shutil.rmtree(path) @@ -67,13 +78,18 @@ def test_category_dir_2(load_files_root): _remove_dir(test_category_dir2) -def test_data_home(data_home): +@pytest.mark.parametrize("path_container", [None, Path, _DummyPath]) +def test_data_home(path_container, data_home): # get_data_home will point to a pre-existing folder + if path_container is not None: + data_home = path_container(data_home) data_home = get_data_home(data_home=data_home) assert data_home == data_home assert os.path.exists(data_home) # clear_data_home will delete both the content and the folder it-self + if path_container is not None: + data_home = path_container(data_home) clear_data_home(data_home=data_home) assert not os.path.exists(data_home)