Skip to content

Commit

Permalink
Test multilabel classifier on random dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
mblondel committed Dec 20, 2011
1 parent 1648c21 commit 117be7e
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions sklearn/tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_equal
from nose.tools import assert_equal
from nose.tools import assert_almost_equal
from nose.tools import assert_true
from nose.tools import assert_raises

Expand All @@ -23,6 +24,30 @@
n_classes = 3


# FIXME: - should use sets
# - should move to metrics module
def multilabel_precision(Y_true, Y_pred):
n_predictions = 0
n_correct = 0
for i in range(len(Y_true)):
n_predictions += len(Y_pred[i])
for label in Y_pred[i]:
if label in Y_true[i]:
n_correct += 1
return float(n_correct) / n_predictions


def multilabel_recall(Y_true, Y_pred):
n_labels = 0
n_correct = 0
for i in range(len(Y_true)):
n_labels += len(Y_true[i])
for label in Y_pred[i]:
if label in Y_true[i]:
n_correct += 1
return float(n_correct) / n_labels


def test_ovr_exceptions():
ovr = OneVsRestClassifier(LinearSVC())
assert_raises(ValueError, ovr.predict, [])
Expand Down Expand Up @@ -64,6 +89,22 @@ def test_ovr_multilabel():
assert_array_equal(y_pred, [0, 1, 1])


def test_ovr_multilabel_dataset():
base_clf = MultinomialNB()
X, Y = datasets.make_multilabel_classification(n_samples=100,
n_features=20,
n_classes=5,
n_labels=2,
length=50,
random_state=0)
X_train, Y_train = X[:80], Y[:80]
X_test, Y_test = X[80:], Y[80:]
clf = OneVsRestClassifier(base_clf).fit(X_train, Y_train)
Y_pred = clf.predict(X_test)
assert_almost_equal(multilabel_precision(Y_test, Y_pred), 0.80, places=2)
assert_almost_equal(multilabel_recall(Y_test, Y_pred), 0.80, places=2)


def test_ovr_gridsearch():
ovr = OneVsRestClassifier(LinearSVC())
Cs = [0.1, 0.5, 0.8]
Expand Down

0 comments on commit 117be7e

Please sign in to comment.