Skip to content

Commit

Permalink
Use more natural class_weight="auto" heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Mar 5, 2015
1 parent 8dbe3f8 commit 518ca64
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
5 changes: 3 additions & 2 deletions sklearn/utils/class_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def compute_class_weight(class_weight, classes, y):
raise ValueError("classes should have valid labels that are in y")

# inversely proportional to the number of samples in the class
recip_freq = 1. / bincount(y_ind)
weight = recip_freq[le.transform(classes)] / np.mean(recip_freq)
recip_freq = len(y) / (len(le.classes_) *
bincount(y_ind).astype(np.float64))
weight = recip_freq[le.transform(classes)]
else:
# user-defined dictionary
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
Expand Down
5 changes: 2 additions & 3 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,10 +905,9 @@ def check_class_weight_auto_linear_classifier(name, Classifier):
coef_auto = classifier.fit(X, y).coef_.copy()

# Count each label occurrence to reweight manually
mean_weight = (1. / 3 + 1. / 2) / 2
class_weight = {
1: 1. / 3 / mean_weight,
-1: 1. / 2 / mean_weight,
1: 5. / (2 * 3),
-1: 5. / (2 * 2)
}
classifier.set_params(class_weight=class_weight)
coef_manual = classifier.fit(X, y).coef_.copy()
Expand Down

0 comments on commit 518ca64

Please sign in to comment.