Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGDClassifier under/overflow #3040

Closed
worldveil opened this issue Apr 4, 2014 · 16 comments
Closed

SGDClassifier under/overflow #3040

worldveil opened this issue Apr 4, 2014 · 16 comments

Comments

@worldveil
Copy link

Example code:

from sklearn.linear_model import SGDClassifier
from sklearn.datasets import load_iris
from sklearn import cross_validation

iris = load_iris()

hyperparameter_choices = [ 

    # some examples, by no means exhausive
    {u'loss': 'modified_huber', u'shuffle': True, u'n_iter': 25.0, 
    u'l1_ratio': 0.5, u'learning_rate': 'constant', u'fit_intercept': 0.0, 
    u'penalty': 'l2', u'alpha': 1000.0, u'eta0': 0.1, u'class_weight': None},

    {u'loss': 'squared_hinge', u'shuffle': True, u'n_iter': 25.0, u'l1_ratio': 0.5, 
    u'learning_rate': 'optimal', u'fit_intercept': 0.0, u'penalty': 'elasticnet', 
    u'alpha': 0.001, u'eta0': 0.1, u'class_weight': None},

    {u'loss': 'squared_hinge', u'shuffle': True, u'n_iter': 100.0, u'l1_ratio': 0.5, 
    u'learning_rate': 'optimal', u'fit_intercept': 0.0, u'penalty': 'l2', u'alpha': 0.001, 
    u'eta0': 0.001, u'class_weight': None}
]

for params in hyperparameter_choices:
    try:
        clf = SGDClassifier(**params)
        scores = cross_validation.cross_val_score(clf, iris.data, iris.target, cv=5)
    except ValueError as ve:
        print "ValueError: %s" % ve

I'm not sure if these are just faulty hyperparameters for an SGD in general. Otherwise it seems to be a numerical stability bug.

The above under/overflow happens when the data is scaled first as well.

@pprett
Copy link
Member

pprett commented Apr 4, 2014

I assume it fails on alpha=1000 ?

@worldveil
Copy link
Author

2/3 examples fail when alpha = 0.001, so not exclusively, no.

@larsmans
Copy link
Member

larsmans commented Apr 6, 2014

Failures:

ValueError('Floating-point under-/overflow occurred at epoch #2. Scaling input data with StandardScaler or MinMaxScaler might help.',)
{u'alpha': 1000.0,
 u'class_weight': None,
 u'eta0': 0.1,
 u'fit_intercept': 0.0,
 u'l1_ratio': 0.5,
 u'learning_rate': 'constant',
 u'loss': 'modified_huber',
 u'n_iter': 25.0,
 u'penalty': 'l2',
 u'shuffle': True}

ValueError('Floating-point under-/overflow occurred at epoch #3. Scaling input data with StandardScaler or MinMaxScaler might help.',)
{u'alpha': 0.001,
 u'class_weight': None,
 u'eta0': 0.1,
 u'fit_intercept': 0.0,
 u'l1_ratio': 0.5,
 u'learning_rate': 'optimal',
 u'loss': 'squared_hinge',
 u'n_iter': 25.0,
 u'penalty': 'elasticnet',
 u'shuffle': True}

ValueError('Floating-point under-/overflow occurred at epoch #3. Scaling input data with StandardScaler or MinMaxScaler might help.',)
{u'alpha': 0.001,
 u'class_weight': None,
 u'eta0': 0.001,
 u'fit_intercept': 0.0,
 u'l1_ratio': 0.5,
 u'learning_rate': 'optimal',
 u'loss': 'squared_hinge',
 u'n_iter': 100.0,
 u'penalty': 'l2',
 u'shuffle': True}

@worldveil
Copy link
Author

MinMaxScaler on the (0,1) range does help, but not always:

from sklearn.linear_model import SGDClassifier
from sklearn.datasets import load_iris
from sklearn import cross_validation
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler

iris = load_iris()

hyperparameter_choices = [ 

    # some examples, by no means exhausive
    {u'loss': 'modified_huber', u'shuffle': True, u'n_iter': 25.0, 
    u'l1_ratio': 0.5, u'learning_rate': 'constant', u'fit_intercept': 0.0, 
    u'penalty': 'l2', u'alpha': 1000.0, u'eta0': 0.1, u'class_weight': None},

    {u'loss': 'squared_hinge', u'shuffle': True, u'n_iter': 25.0, u'l1_ratio': 0.5, 
    u'learning_rate': 'optimal', u'fit_intercept': 0.0, u'penalty': 'elasticnet', 
    u'alpha': 0.001, u'eta0': 0.1, u'class_weight': None},

    {u'loss': 'squared_hinge', u'shuffle': True, u'n_iter': 100.0, u'l1_ratio': 0.5, 
    u'learning_rate': 'optimal', u'fit_intercept': 0.0, u'penalty': 'l2', u'alpha': 0.001, 
    u'eta0': 0.001, u'class_weight': None}
]

print "\nWith Standard scaling..."
for params in hyperparameter_choices:
    try:
        pipeline = Pipeline([
            ("standard_scaler", StandardScaler()),
            ("sgd", SGDClassifier(**params))
        ])
        scores = cross_validation.cross_val_score(pipeline, iris.data, iris.target, cv=5)
    except ValueError as ve:
        print "ValueError: %s" % ve

print "\nWith MinMax scaling..."     
for params in hyperparameter_choices:
    try:
        pipeline = Pipeline([
            ("minmax_scaler", MinMaxScaler()),
            ("sgd", SGDClassifier(**params))
        ])
        scores = cross_validation.cross_val_score(pipeline, iris.data, iris.target, cv=5)
        print "Success with: %s" % params
    except ValueError as ve:
        print "ValueError: %s" % ve

@larsmans
Copy link
Member

I just did some testing with this example and I think the parameters are just not good. With too high a learning rate, gradient descent will overshoot its target; that's an inherent risk with this algorithm. Scaling with StandardScaler and using the optimal learning rate fixes the first problem.

What remains are the other two. Here, the squared hinge loss goes off into infinity and its dloss becomes huge; so do the updates. Regularizing more or using the vanilla hinge loss solve this issue.

@ogrisel
Copy link
Member

ogrisel commented Nov 24, 2014

I think many SGD practitioners from the deep learning community clip the gradient norms (or the norm of the weights) to e.g. [-100, 100] to avoid such numerical stability issues in practice. That might be worth trying.

@pprett
Copy link
Member

pprett commented Nov 24, 2014

true -- we should experiment with this -- its a major annoyance during grid
search

2014-11-24 14:59 GMT+01:00 Olivier Grisel notifications@github.com:

I think many SGD practitioners from the deep learning community clip the
gradient norms or coefficients to [-100, 100] to avoid such numerical
stability issues in practice. That might be worth trying.


Reply to this email directly or view it on GitHub
#3040 (comment)
.

Peter Prettenhofer

@agramfort
Copy link
Member

we could add a param max_grad ? or clip_grad ?

@ogrisel
Copy link
Member

ogrisel commented Nov 24, 2014

true -- we should experiment with this -- its a major annoyance during grid search

I am hacking my copy of sgd_fast to clip dloss to [-100, 100], it helps for some cases on @worldveil's script but not all: the huber loss case is still unstable. Needs more investigation.

@ogrisel
Copy link
Member

ogrisel commented Nov 24, 2014

Note that the problem disappears when clipping dloss to [-100, 100 ] and preventing alpha to be larger than 10..

@amueller
Copy link
Member

Do you need to check the gradient in every step then? Doesn't that impact performance quite a bit?

@jnothman
Copy link
Member

true -- we should experiment with this -- its a major annoyance during
grid
search

But at least in dev version it's possible to ask grid search to catch the
error and return 0 score...

On 25 November 2014 at 09:13, Andreas Mueller notifications@github.com
wrote:

Do you need to check the gradient in every step then? Doesn't that
impact performance quite a bit?


Reply to this email directly or view it on GitHub
#3040 (comment)
.

@pprett
Copy link
Member

pprett commented Nov 24, 2014

But at least in dev version it's possible to ask grid search to catch the
error and return 0 score...

completely missed that -- that's great!

2014-11-24 23:23 GMT+01:00 jnothman notifications@github.com:

true -- we should experiment with this -- its a major annoyance during
grid
search

But at least in dev version it's possible to ask grid search to catch the
error and return 0 score...

On 25 November 2014 at 09:13, Andreas Mueller notifications@github.com
wrote:

Do you need to check the gradient in every step then? Doesn't that
impact performance quite a bit?


Reply to this email directly or view it on GitHub
<
https://github.com/scikit-learn/scikit-learn/issues/3040#issuecomment-64274690>

.


Reply to this email directly or view it on GitHub
#3040 (comment)
.

Peter Prettenhofer

@ogrisel
Copy link
Member

ogrisel commented Nov 25, 2014

#3883 has an implementations that seems to work. Please have a look. I have not yet run benchmarks to see the computational overhead vs master but I have to go now so I cannot run them right now.

@ogrisel
Copy link
Member

ogrisel commented Nov 25, 2014

@worldveil #3883 seems to fix all the problems you reported. Please feel free to test on more cases on your own data and report any remaining issues.

@larsmans
Copy link
Member

Should be fixed by f5e0ea0, closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants