Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG]: Use coordinate_descent_gram when precompute is True | auto #3220

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/linear_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ for another implementation::
>>> clf = linear_model.Lasso(alpha = 0.1)
>>> clf.fit([[0, 0], [1, 1]], [0, 1])
Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,
normalize=False, positive=False, precompute='auto', tol=0.0001,
normalize=False, positive=False, precompute=False, tol=0.0001,
warm_start=False)
>>> clf.predict([[1, 1]])
array([ 0.8])
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorial/statistical_inference/model_selection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ automatically by cross-validation::
>>> lasso.fit(X_diabetes, y_diabetes)
LassoCV(alphas=None, copy_X=True, cv=None, eps=0.001, fit_intercept=True,
max_iter=1000, n_alphas=100, n_jobs=1, normalize=False, positive=False,
precompute='auto', tol=0.0001, verbose=False)
precompute=False, tol=0.0001, verbose=False)
>>> # The estimator chose automatically its lambda:
>>> lasso.alpha_ # doctest: +ELLIPSIS
0.01229...
Expand Down
2 changes: 1 addition & 1 deletion doc/tutorial/statistical_inference/supervised_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ application of Occam's razor: `prefer simpler models`.
>>> regr.alpha = best_alpha
>>> regr.fit(diabetes_X_train, diabetes_y_train)
Lasso(alpha=0.025118864315095794, copy_X=True, fit_intercept=True,
max_iter=1000, normalize=False, positive=False, precompute='auto',
max_iter=1000, normalize=False, positive=False, precompute=False,
tol=0.0001, warm_start=False)
>>> print(regr.coef_)
[ 0. -212.43764548 517.19478111 313.77959962 -160.8303982 -0.
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,13 @@ API changes summary
Previously it was broken for input of non-int dtype and the weighted
array that was returned was wrong. By `Manoj Kumar`_.

- Change default value of precompute in :class:`ElasticNet`
:class:`ElasticNetCV`, :class:`Lasso` and :class:`LassoCV` from "auto"
to False. Setting precompute to "auto" was found to be slower since
the computation of the Gram matrix is computationally expensive and
outweighs the benefit of fitting the Gram.
By `Manoj Kumar`_.

.. _changes_0_14:

0.14
Expand Down
53 changes: 27 additions & 26 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,19 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
coef_, l1_reg, l2_reg, X.data, X.indices,
X.indptr, y, X_sparse_scaling,
max_iter, tol, positive)
elif not multi_output:
elif multi_output:
model = cd_fast.enet_coordinate_descent_multi_task(
coef_, l1_reg, l2_reg, X, y, max_iter, tol)
elif isinstance(precompute, np.ndarray):
model = cd_fast.enet_coordinate_descent_gram(
coef_, l1_reg, l2_reg, precompute, Xy, y, max_iter,
tol, positive)
elif precompute is False:
model = cd_fast.enet_coordinate_descent(
coef_, l1_reg, l2_reg, X, y, max_iter, tol, positive)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

do you confirm a speed up with n_samples >> n_features?

else:
model = cd_fast.enet_coordinate_descent_multi_task(
coef_, l1_reg, l2_reg, X, y, max_iter, tol)
raise ValueError("Precompute should be one of True, False, "
"'auto' or array-like")
coef_, dual_gap_, eps_ = model
coefs[..., i] = coef_
dual_gaps[i] = dual_gap_
Expand Down Expand Up @@ -615,7 +622,7 @@ class ElasticNet(LinearModel, RegressorMixin):
path = staticmethod(enet_path)

def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
normalize=False, precompute='auto', max_iter=1000,
normalize=False, precompute=False, max_iter=1000,
copy_X=True, tol=1e-4, warm_start=False, positive=False):
self.alpha = alpha
self.l1_ratio = l1_ratio
Expand Down Expand Up @@ -802,7 +809,7 @@ class Lasso(ElasticNet):
>>> clf = linear_model.Lasso(alpha=0.1)
>>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2])
Lasso(alpha=0.1, copy_X=True, fit_intercept=True, max_iter=1000,
normalize=False, positive=False, precompute='auto', tol=0.0001,
normalize=False, positive=False, precompute=False, tol=0.0001,
warm_start=False)
>>> print(clf.coef_)
[ 0.85 0. ]
Expand All @@ -828,7 +835,7 @@ class Lasso(ElasticNet):
path = staticmethod(enet_path)

def __init__(self, alpha=1.0, fit_intercept=True, normalize=False,
precompute='auto', copy_X=True, max_iter=1000,
precompute=False, copy_X=True, max_iter=1000,
tol=1e-4, warm_start=False, positive=False):
super(Lasso, self).__init__(
alpha=alpha, l1_ratio=1.0, fit_intercept=fit_intercept,
Expand Down Expand Up @@ -889,7 +896,13 @@ def _path_residuals(X, y, train, test, path, path_params, alphas=None,
y_test = y[test]
fit_intercept = path_params['fit_intercept']
normalize = path_params['normalize']
precompute = path_params['precompute']

if y.ndim == 1:
precompute = path_params['precompute']
else:
# No Gram variant of multi-task exists right now.
# Fall back to default enet_multitask
precompute = False

X_train, y_train, X_mean, y_mean, X_std, precompute, Xy = \
_pre_fit(X_train, y_train, None, precompute, normalize, fit_intercept,
Expand Down Expand Up @@ -1218,7 +1231,7 @@ class LassoCV(LinearModelCV, RegressorMixin):
path = staticmethod(lasso_path)

def __init__(self, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True,
normalize=False, precompute='auto', max_iter=1000, tol=1e-4,
normalize=False, precompute=False, max_iter=1000, tol=1e-4,
copy_X=True, cv=None, verbose=False, n_jobs=1,
positive=False):
super(LassoCV, self).__init__(
Expand Down Expand Up @@ -1345,7 +1358,7 @@ class ElasticNetCV(LinearModelCV, RegressorMixin):
path = staticmethod(enet_path)

def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
fit_intercept=True, normalize=False, precompute='auto',
fit_intercept=True, normalize=False, precompute=False,
max_iter=1000, tol=1e-4, cv=None, copy_X=True,
verbose=0, n_jobs=1, positive=False):
self.l1_ratio = l1_ratio
Expand Down Expand Up @@ -1638,11 +1651,6 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
List of alphas where to compute the models.
If not provided, set automatically.

precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.

n_alphas : int, optional
Number of alphas along the regularization path

Expand Down Expand Up @@ -1716,8 +1724,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
... #doctest: +NORMALIZE_WHITESPACE
MultiTaskElasticNetCV(alphas=None, copy_X=True, cv=None, eps=0.001,
fit_intercept=True, l1_ratio=0.5, max_iter=1000, n_alphas=100,
n_jobs=1, normalize=False, precompute='auto', tol=0.0001,
verbose=0)
n_jobs=1, normalize=False, tol=0.0001, verbose=0)
>>> print(clf.coef_)
[[ 0.52875032 0.46958558]
[ 0.52875032 0.46958558]]
Expand All @@ -1740,7 +1747,7 @@ class MultiTaskElasticNetCV(LinearModelCV, RegressorMixin):
path = staticmethod(enet_path)

def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
fit_intercept=True, normalize=False, precompute='auto',
fit_intercept=True, normalize=False,
max_iter=1000, tol=1e-4, cv=None, copy_X=True,
verbose=0, n_jobs=1):
self.l1_ratio = l1_ratio
Expand All @@ -1749,7 +1756,6 @@ def __init__(self, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
self.alphas = alphas
self.fit_intercept = fit_intercept
self.normalize = normalize
self.precompute = precompute
self.max_iter = max_iter
self.tol = tol
self.cv = cv
Expand Down Expand Up @@ -1781,11 +1787,6 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
List of alphas where to compute the models.
If not provided, set automaticlly.

precompute : True | False | 'auto' | array-like
Whether to use a precomputed Gram matrix to speed up
calculations. If set to ``'auto'`` let us decide. The Gram
matrix can also be passed as argument.

n_alphas : int, optional
Number of alphas along the regularization path

Expand Down Expand Up @@ -1856,10 +1857,10 @@ class MultiTaskLassoCV(LinearModelCV, RegressorMixin):
path = staticmethod(lasso_path)

def __init__(self, eps=1e-3, n_alphas=100, alphas=None, fit_intercept=True,
normalize=False, precompute='auto', max_iter=1000, tol=1e-4,
copy_X=True, cv=None, verbose=False, n_jobs=1):
normalize=False, max_iter=1000, tol=1e-4, copy_X=True,
cv=None, verbose=False, n_jobs=1):
super(MultiTaskLassoCV, self).__init__(
eps=eps, n_alphas=n_alphas, alphas=alphas,
fit_intercept=fit_intercept, normalize=normalize,
precompute=precompute, max_iter=max_iter, tol=tol, copy_X=copy_X,
max_iter=max_iter, tol=tol, copy_X=copy_X,
cv=cv, verbose=verbose, n_jobs=n_jobs)
7 changes: 7 additions & 0 deletions sklearn/linear_model/tests/test_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,13 @@ def test_sparse_input_dtype_enet_and_lassocv():
assert_almost_equal(clf.coef_, clf1.coef_, decimal=6)


def test_precompute_invalid_argument():
X, y, _, _ = build_dataset()
for clf in [ElasticNetCV(precompute="invalid"),
LassoCV(precompute="invalid")]:
assert_raises(ValueError, clf.fit, X, y)


if __name__ == '__main__':
import nose
nose.runmodule()