Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SGD module renaming #25

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/modules/classes.rst
Expand Up @@ -80,8 +80,8 @@ Stochastic Gradient Descent
:toctree: generated/
:template: class.rst

sgd.ClassifierSGD
sgd.RegressorSGD
sgd.SGDClassifier
sgd.SGDRegressor

For sparse data
---------------
Expand All @@ -90,8 +90,8 @@ For sparse data
:toctree: generated/
:template: class.rst

sgd.sparse.ClassifierSGD
sgd.sparse.RegressorSGD
sgd.sparse.SGDClassifier
sgd.sparse.SGDRegressor


Bayesian Regression
Expand Down
26 changes: 13 additions & 13 deletions doc/modules/sgd.rst
Expand Up @@ -2,7 +2,7 @@
Stochastic Gradient Descent
===========================

.. currentmodule:: scikits.learn.sgd
.. currentmodule:: scikits.learn.linear_model.stochastic_gradient

**Stochastic Gradient Descent (SGD)** is a simple yet very efficient approach
to discriminative learning of linear classifiers under convex loss functions
Expand Down Expand Up @@ -33,7 +33,7 @@ The disadvantages of Stochastic Gradient Descent include:
Classification
==============

The class :class:`ClassifierSGD` implements a plain stochastic gradient descent
The class :class:`SGDClassifier` implements a plain stochastic gradient descent
learning routine which supports different loss functions and penalties for
classification.

Expand All @@ -47,12 +47,12 @@ an array X of size [n_samples, n_features] holding the training
samples, and an array Y of size [n_samples] holding the target values
(class labels) for the training samples::

>>> from scikits.learn import sgd
>>> from scikits.learn.linear_model import stochastic_gradient
>>> X = [[0., 0.], [1., 1.]]
>>> y = [0, 1]
>>> clf = sgd.ClassifierSGD(loss="hinge", penalty="l2")
>>> clf = stochastic_gradient.SGDClassifier(loss="hinge", penalty="l2")
>>> clf.fit(X, y)
ClassifierSGD(loss='hinge', n_jobs=1, shuffle=False, verbose=0, n_iter=5,
SGDClassifier(loss='hinge', n_jobs=1, shuffle=False, verbose=0, n_iter=5,
fit_intercept=True, penalty='l2', rho=1.0, alpha=0.0001)

After being fitted, the model can then be used to predict new values::
Expand Down Expand Up @@ -81,7 +81,7 @@ To get the signed distance to the hyperplane use `decision_function`:

.. warning:: Make sure you permute (shuffle) your training data before fitting the model or use `shuffle=True` to shuffle after each iterations.

The concrete loss function can be set via the `loss` parameter. :class:`ClassifierSGD` supports the
The concrete loss function can be set via the `loss` parameter. :class:`SGDClassifier` supports the
following loss functions:

- `loss="hinge"`: (soft-margin) linear Support Vector Machine.
Expand All @@ -95,7 +95,7 @@ Log loss, on the other hand, provides probability estimates.
In the case of binary classification and `loss="log"` you get a probability
estimate P(y=C|x) using `predict_proba`, where `C` is the largest class label:

>>> clf = sgd.ClassifierSGD(loss="log").fit(X, y)
>>> clf = stochastic_gradient.SGDClassifier(loss="log").fit(X, y)
>>> clf.predict_proba([[1., 1.]])
array([ 0.99999949])

Expand All @@ -111,7 +111,7 @@ driving most coefficients to zero. The Elastic Net solves some deficiencies of
the L1 penalty in the presence of highly correlated attributes. The parameter `rho`
has to be specified by the user.

:class:`ClassifierSGD` supports multi-class classification by combining multiple
:class:`SGDClassifier` supports multi-class classification by combining multiple
binary classifiers in a "one versus all" (OVA) scheme. For each of the `K` classes,
a binary classifier is learned that discriminates between that and all other `K-1`
classes. At testing time, we compute the confidence score (i.e. the signed distances
Expand All @@ -138,7 +138,7 @@ class; classes are indexed in ascending order (see member `classes`).
Regression
==========

The class :class:`RegressorSGD` implements a plain stochastic gradient descent learning
The class :class:`SGDRegressor` implements a plain stochastic gradient descent learning
routine which supports different loss functions and penalties to fit linear regression
models.

Expand All @@ -147,7 +147,7 @@ models.
:align: center
:scale: 75

The concrete loss function can be set via the `loss` parameter. :class:`RegressorSGD` supports the
The concrete loss function can be set via the `loss` parameter. :class:`SGDRegressor` supports the
following loss functions:

- `loss="squared_loss"`: Ordinary least squares.
Expand All @@ -157,7 +157,7 @@ following loss functions:

* :ref:`example_sgd_plot_ols_sgd.py`,

.. currentmodule:: scikits.learn.sgd.sparse
.. currentmodule:: scikits.learn.linear_model.stochastic_gradient.sparse


Stochastic Gradient Descent for sparse data
Expand All @@ -173,7 +173,7 @@ For maximum efficiency, use the CSR matrix format as defined in
`scipy.sparse.csr_matrix
<http://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html>`_.

Implemented classes are :class:`ClassifierSGD` and :class:`RegressorSGD`.
Implemented classes are :class:`SGDClassifier` and :class:`SGDRegressor`.

.. topic:: Examples:

Expand Down Expand Up @@ -273,7 +273,7 @@ optimization problems. In contrast to (batch) gradient descent, SGD
approximates the true gradient of :math:`E(w,b)` by considering a
single training example at a time.

The class :class:`ClassifierSGD` implements a first-order SGD learning routine.
The class :class:`SGDClassifier` implements a first-order SGD learning routine.
The algorithm iterates over the training examples and for each example
updates the model parameters according to the update rule given by

Expand Down
6 changes: 3 additions & 3 deletions examples/document_classification_20newsgroups.py
Expand Up @@ -42,7 +42,7 @@
from scikits.learn.datasets import load_files
from scikits.learn.feature_extraction.text.sparse import Vectorizer
from scikits.learn.svm.sparse import LinearSVC
from scikits.learn.sgd.sparse import ClassifierSGD
from scikits.learn.linear_model.stochastic_gradient.sparse import SGDClassifier
from scikits.learn import metrics

# parse commandline arguments
Expand Down Expand Up @@ -162,11 +162,11 @@ def benchmark(clf):
dual=False, eps=1e-3))

# Train SGD model
sgd_results = benchmark(ClassifierSGD(alpha=.0001, n_iter=50,
sgd_results = benchmark(SGDClassifier(alpha=.0001, n_iter=50,
penalty=penalty))

# Train SGD with Elastic Net penalty
print 80*'='
print "Elastic-Net penalty"
sgd_results = benchmark(ClassifierSGD(alpha=.0001, n_iter=50,
sgd_results = benchmark(SGDClassifier(alpha=.0001, n_iter=50,
penalty="elasticnet"))
4 changes: 2 additions & 2 deletions examples/grid_search_text_feature_extraction.py
Expand Up @@ -54,7 +54,7 @@
from scikits.learn.datasets import load_files
from scikits.learn.feature_extraction.text.sparse import CountVectorizer
from scikits.learn.feature_extraction.text.sparse import TfidfTransformer
from scikits.learn.sgd.sparse import ClassifierSGD
from scikits.learn.linear_model.stochastic_gradient.sparse import SGDClassifier
from scikits.learn.grid_search import GridSearchCV
from scikits.learn.pipeline import Pipeline

Expand Down Expand Up @@ -101,7 +101,7 @@
pipeline = Pipeline([
('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', ClassifierSGD()),
('clf', SGDClassifier()),
])

parameters = {
Expand Down
4 changes: 2 additions & 2 deletions examples/mlcomp_sparse_document_classification.py
Expand Up @@ -46,7 +46,7 @@

from scikits.learn.datasets import load_mlcomp
from scikits.learn.feature_extraction.text.sparse import Vectorizer
from scikits.learn.sgd.sparse import ClassifierSGD
from scikits.learn.linear_model.stochastic_gradient.sparse import SGDClassifier
from scikits.learn.metrics import confusion_matrix
from scikits.learn.metrics import classification_report

Expand Down Expand Up @@ -81,7 +81,7 @@
}
print "parameters:", parameters
t0 = time()
clf = ClassifierSGD(**parameters).fit(X_train, y_train)
clf = SGDClassifier(**parameters).fit(X_train, y_train)
print "done in %fs" % (time() - t0)
print "Percentage of non zeros coef: %f" % (np.mean(clf.coef_ != 0) * 100)

Expand Down
2 changes: 1 addition & 1 deletion examples/sgd/README.txt
Expand Up @@ -3,4 +3,4 @@
Stochastic Gradient Descent
---------------------------

Examples concerning the `scikits.learn.sgd` package.
Examples concerning the `scikits.learn.linear_model.stochastic_gradient` package.
4 changes: 2 additions & 2 deletions examples/sgd/covertype_dense_sgd.py
Expand Up @@ -56,7 +56,7 @@
import numpy as np

from scikits.learn.svm import LinearSVC
from scikits.learn.sgd import ClassifierSGD
from scikits.learn.linear_model.stochastic_gradient import SGDClassifier
from scikits.learn.naive_bayes import GNB
from scikits.learn import metrics

Expand Down Expand Up @@ -167,7 +167,7 @@ def benchmark(clf):
'alpha': 0.001,
'n_iter': 2,
}
sgd_err, sgd_train_time, sgd_test_time = benchmark(ClassifierSGD(
sgd_err, sgd_train_time, sgd_test_time = benchmark(SGDClassifier(
**sgd_parameters))

######################################################################
Expand Down
5 changes: 3 additions & 2 deletions examples/sgd/plot_iris.py
Expand Up @@ -12,7 +12,8 @@

import numpy as np
import pylab as pl
from scikits.learn import sgd, datasets
from scikits.learn import datasets
from scikits.learn.linear_model.stochastic_gradient import SGDClassifier

# import some data to play with
iris = datasets.load_iris()
Expand All @@ -35,7 +36,7 @@

h = .02 # step size in the mesh

clf = sgd.ClassifierSGD(alpha=0.001, n_iter=100).fit(X, y)
clf = SGDClassifier(alpha=0.001, n_iter=100).fit(X, y)

# create a mesh to plot in
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
Expand Down
5 changes: 3 additions & 2 deletions examples/sgd/plot_loss_functions.py
Expand Up @@ -3,13 +3,14 @@
SGD: Convex Loss Functions
==========================

Plot the convex loss functions supported by `scikits.learn.sgd`.
Plot the convex loss functions supported by `scikits.learn.linear_model.stochastic_gradient`.
"""
print __doc__

import numpy as np
import pylab as pl
from scikits.learn.sgd.sgd_fast import Hinge, ModifiedHuber, SquaredLoss
from scikits.learn.linear_model.stochastic_gradient.sgd_fast import Hinge, \
ModifiedHuber, SquaredLoss

###############################################################################
# Define loss funcitons
Expand Down
4 changes: 2 additions & 2 deletions examples/sgd/plot_ols_sgd.py
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import pylab as pl

from scikits.learn import sgd
from scikits.learn.linear_model.stochastic_gradient import SGDRegressor

# this is our test set, it's just a straight line with some
# gaussian noise
Expand All @@ -23,7 +23,7 @@
+ np.random.randn(n_samples, 1).ravel()

# run the classifier
clf = sgd.RegressorSGD(alpha=0.1, n_iter=20)
clf = SGDRegressor(alpha=0.1, n_iter=20)
clf.fit(X, Y)

# and plot the result
Expand Down
2 changes: 1 addition & 1 deletion examples/sgd/plot_penalties.py
Expand Up @@ -3,7 +3,7 @@
SGD: Penalties
==============

Plot the contours of the three penalties supported by `scikits.learn.sgd`.
Plot the contours of the three penalties supported by `scikits.learn.linear_model.stochastic_gradient`.

"""
from __future__ import division
Expand Down
4 changes: 2 additions & 2 deletions examples/sgd/plot_separating_hyperplane.py
Expand Up @@ -11,15 +11,15 @@

import numpy as np
import pylab as pl
from scikits.learn.sgd import ClassifierSGD
from scikits.learn.linear_model.stochastic_gradient import SGDClassifier

# we create 40 separable points
np.random.seed(0)
X = np.r_[np.random.randn(20, 2) - [2,2], np.random.randn(20, 2) + [2, 2]]
Y = [0]*20 + [1]*20

# fit the model
clf = ClassifierSGD(loss="hinge", alpha = 0.01, n_iter=50,
clf = SGDClassifier(loss="hinge", alpha = 0.01, n_iter=50,
fit_intercept=True)
clf.fit(X, Y)

Expand Down
6 changes: 3 additions & 3 deletions examples/sgd/plot_weighted_classes.py
Expand Up @@ -11,7 +11,7 @@

import numpy as np
import pylab as pl
from scikits.learn import sgd
from scikits.learn.linear_model.stochastic_gradient import SGDClassifier

# we create 40 separable points
np.random.seed(0)
Expand All @@ -29,7 +29,7 @@
X = (X - mean) / std

# fit the model and get the separating hyperplane
clf = sgd.ClassifierSGD(n_iter=100, alpha=0.01)
clf = SGDClassifier(n_iter=100, alpha=0.01)
clf.fit(X, y)

w = clf.coef_
Expand All @@ -39,7 +39,7 @@


# get the separating hyperplane using weighted classes
wclf = sgd.ClassifierSGD(n_iter=100, alpha=0.01)
wclf = SGDClassifier(n_iter=100, alpha=0.01)
wclf.fit(X, y, class_weight={1: 10})

ww = wclf.coef_
Expand Down
2 changes: 1 addition & 1 deletion scikits/learn/__init__.py
Expand Up @@ -35,7 +35,7 @@ def test(self, label='fast', verbose=1, extra_argv=['--exe'],

__all__ = ['cross_val', 'ball_tree', 'cluster', 'covariance', 'datasets', 'gmm',
'linear_model', 'logistic', 'lda', 'metrics', 'svm', 'features', 'clone',
'test', 'sgd', 'gaussian_process']
'test', 'gaussian_process']

__version__ = '0.6.git'

1 change: 1 addition & 0 deletions scikits/learn/linear_model/__init__.py
Expand Up @@ -19,3 +19,4 @@
from .logistic import LogisticRegression

from . import sparse
from . import stochastic_gradient
1 change: 1 addition & 0 deletions scikits/learn/linear_model/setup.py
Expand Up @@ -36,6 +36,7 @@ def configuration(parent_package='', top_path=None):
# add other directories
config.add_subpackage('tests')
config.add_subpackage('sparse')
config.add_subpackage('stochastic_gradient')

return config

Expand Down
Expand Up @@ -6,5 +6,5 @@
"""

from . import sparse
from .sgd import ClassifierSGD, RegressorSGD
from .sgd import SGDClassifier, SGDRegressor
from .base import Log, ModifiedHuber, Hinge, SquaredLoss, Huber
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ...base import BaseEstimator, ClassifierMixin, RegressorMixin
from .sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber


Expand Down Expand Up @@ -46,14 +46,14 @@ def _get_penalty_type(self):
raise ValueError("Penalty %s is not supported. " % self.penalty)


class ClassifierBaseSGD(BaseSGD, ClassifierMixin):
class BaseSGDClassifier(BaseSGD, ClassifierMixin):
"""Base class for dense and sparse classification using SGD.
"""

def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
verbose=0, n_jobs=1):
super(ClassifierBaseSGD, self).__init__(loss=loss, penalty=penalty,
super(BaseSGDClassifier, self).__init__(loss=loss, penalty=penalty,
alpha=alpha, rho=rho,
fit_intercept=fit_intercept,
n_iter=n_iter, shuffle=shuffle,
Expand Down Expand Up @@ -132,14 +132,14 @@ def predict_proba(self, X):
"this functionality" % self.loss)


class RegressorBaseSGD(BaseSGD, RegressorMixin):
class BaseSGDRegressor(BaseSGD, RegressorMixin):
"""Base class for dense and sparse regression using SGD.
"""
def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001,
rho=0.85, fit_intercept=True, n_iter=5, shuffle=False,
verbose=0, epsilon=0.1):
self.epsilon=float(epsilon)
super(RegressorBaseSGD, self).__init__(loss=loss, penalty=penalty,
super(BaseSGDRegressor, self).__init__(loss=loss, penalty=penalty,
alpha=alpha, rho=rho,
fit_intercept=fit_intercept,
n_iter=n_iter, shuffle=shuffle,
Expand Down
Expand Up @@ -5,7 +5,7 @@
def configuration(parent_package='', top_path=None):
from numpy.distutils.misc_util import Configuration
from numpy.distutils.system_info import get_standard_file
config = Configuration('sgd', parent_package, top_path)
config = Configuration('stochastic_gradient', parent_package, top_path)

site_cfg = ConfigParser()
site_cfg.read(get_standard_file('site.cfg'))
Expand Down