Skip to content

Commit

Permalink
FIX add class_weight="balanced_subsample" to the forests to keep back…
Browse files Browse the repository at this point in the history
…ward compatibility to 0.16
  • Loading branch information
amueller committed May 12, 2015
1 parent 4faa782 commit 6f320fe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
40 changes: 28 additions & 12 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class calls the ``fit`` method of each sub-estimator on random samples

from __future__ import division

import warnings
from warnings import warn

from abc import ABCMeta, abstractmethod

import numpy as np
Expand Down Expand Up @@ -89,6 +91,10 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
curr_sample_weight *= sample_counts

if class_weight == 'subsample':
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
curr_sample_weight *= compute_sample_weight('auto', y, indices)
elif class_weight == 'balanced_subsample':
curr_sample_weight *= compute_sample_weight('balanced', y, indices)

tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
Expand Down Expand Up @@ -414,30 +420,40 @@ def _validate_y_class_weight(self, y):
self.n_classes_.append(classes_k.shape[0])

if self.class_weight is not None:
valid_presets = ('auto', 'balanced', 'subsample', 'auto')
valid_presets = ('auto', 'balanced', 'balanced_subsample', 'subsample', 'auto')
if isinstance(self.class_weight, six.string_types):
if self.class_weight not in valid_presets:
raise ValueError('Valid presets for class_weight include '
'"balanced" and "subsample". Given "%s".'
'"balanced" and "balanced_subsample". Given "%s".'
% self.class_weight)
if self.class_weight == "subsample":
warn("class_weight='subsample' is deprecated and will be removed in 0.18."
" It was replaced by class_weight='balanced_subsample' "
"using the balanced strategy.", DeprecationWarning)
if self.warm_start:
warn('class_weight presets "balanced" or "subsample" are '
warn('class_weight presets "balanced" or "balanced_subsample" are '
'not recommended for warm_start if the fitted data '
'differs from the full dataset. In order to use '
'"auto" weights, use compute_class_weight("balanced", '
'"balanced" weights, use compute_class_weight("balanced", '
'classes, y). In place of y you can use a large '
'enough sample of the full training set target to '
'properly estimate the class frequency '
'distributions. Pass the resulting weights as the '
'class_weight parameter.')

if self.class_weight != 'subsample' or not self.bootstrap:
if (self.class_weight not in ['subsample', 'balanced_subsample'] or
not self.bootstrap):
if self.class_weight == 'subsample':
class_weight = 'balanced'
class_weight = 'auto'
elif self.class_weight == "balanced_subsample":
class_weight = "balanced"
else:
class_weight = self.class_weight
expanded_class_weight = compute_sample_weight(class_weight,
y_original)
with warnings.catch_warnings():
if class_weight == "auto":
warnings.simplefilter('ignore', DeprecationWarning)
expanded_class_weight = compute_sample_weight(class_weight,
y_original)

return y, expanded_class_weight

Expand Down Expand Up @@ -758,7 +774,7 @@ class RandomForestClassifier(ForestClassifier):
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one. For
Expand All @@ -769,7 +785,7 @@ class RandomForestClassifier(ForestClassifier):
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``
The "subsample" mode is the same as "balanced" except that weights are
The "balanced_subsample" mode is the same as "balanced" except that weights are
computed based on the bootstrap sample for every tree grown.
For multi-output, the weights of each column of y will be multiplied.
Expand Down Expand Up @@ -1101,7 +1117,7 @@ class ExtraTreesClassifier(ForestClassifier):
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.
class_weight : dict, list of dicts, "balanced", "subsample" or None, optional
class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one. For
Expand All @@ -1112,7 +1128,7 @@ class ExtraTreesClassifier(ForestClassifier):
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``
The "subsample" mode is the same as "balanced" except that weights are
The "balanced_subsample" mode is the same as "balanced" except that weights are
computed based on the bootstrap sample for every tree grown.
For multi-output, the weights of each column of y will be multiplied.
Expand Down
5 changes: 5 additions & 0 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sklearn.utils.testing import assert_greater_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.testing import ignore_warnings

from sklearn import datasets
Expand Down Expand Up @@ -802,7 +803,11 @@ def check_class_weight_balanced_and_bootstrap_multi_output(name):
clf = ForestClassifier(class_weight=[{-1: 0.5, 1: 1.}, {-2: 1., 2: 1.}],
random_state=0)
clf.fit(X, _y)
# smoke test for subsample and balanced subsample
clf = ForestClassifier(class_weight='balanced_subsample', random_state=0)
clf.fit(X, _y)
clf = ForestClassifier(class_weight='subsample', random_state=0)
#assert_warns_message(DeprecationWarning, "balanced_subsample", clf.fit, X, _y)
clf.fit(X, _y)


Expand Down

0 comments on commit 6f320fe

Please sign in to comment.