Skip to content

Commit

Permalink
adjust some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Mar 9, 2015
1 parent 74fde88 commit b2cdc55
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions sklearn/utils/tests/test_class_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def test_compute_class_weight():
y = np.asarray([2, 2, 2, 3, 3, 4])
classes = np.unique(y)
cw = compute_class_weight("auto", classes, y)
assert_almost_equal(cw.sum(), classes.shape)
class_counts = np.bincount(y)[2:]
# total effect of samples is preserved
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
assert_true(cw[0] < cw[1] < cw[2])


Expand Down Expand Up @@ -63,19 +65,21 @@ def test_compute_class_weight_auto_negative():
# Test with unbalanced class labels.
y = np.asarray([-1, 0, 0, -2, -2, -2])
cw = compute_class_weight("auto", classes, y)
assert_almost_equal(cw.sum(), classes.shape)
class_counts = np.bincount(y + 2)
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
assert_equal(len(cw), len(classes))
assert_array_almost_equal(cw, np.array([0.545, 1.636, 0.818]), decimal=3)
assert_array_almost_equal(cw, [2. / 3, 2., 1.])


def test_compute_class_weight_auto_unordered():
"""Test compute_class_weight when classes are unordered"""
classes = np.array([1, 0, 3])
y = np.asarray([1, 0, 0, 3, 3, 3])
cw = compute_class_weight("auto", classes, y)
assert_almost_equal(cw.sum(), classes.shape)
class_counts = np.bincount(y)[classes]
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
assert_equal(len(cw), len(classes))
assert_array_almost_equal(cw, np.array([1.636, 0.818, 0.545]), decimal=3)
assert_array_almost_equal(cw, [2., 1., 2. / 3])


def test_compute_sample_weight():
Expand All @@ -97,8 +101,8 @@ def test_compute_sample_weight():
# Test with unbalanced classes
y = np.asarray([1, 1, 1, 2, 2, 2, 3])
sample_weight = compute_sample_weight("auto", y)
expected = np.asarray([.6, .6, .6, .6, .6, .6, 1.8])
assert_array_almost_equal(sample_weight, expected)
expected = np.array([0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 2.3333])
assert_array_almost_equal(sample_weight, expected, decimal=4)

# Test with `None` weights
sample_weight = compute_sample_weight(None, y)
Expand All @@ -117,7 +121,7 @@ def test_compute_sample_weight():
# Test with multi-output of unbalanced classes
y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [3, -1]])
sample_weight = compute_sample_weight("auto", y)
assert_array_almost_equal(sample_weight, expected ** 2)
assert_array_almost_equal(sample_weight, expected ** 2, decimal=3)


def test_compute_sample_weight_with_subsample():
Expand All @@ -135,12 +139,13 @@ def test_compute_sample_weight_with_subsample():
# Test with a subsample
y = np.asarray([1, 1, 1, 2, 2, 2])
sample_weight = compute_sample_weight("auto", y, range(4))
assert_array_almost_equal(sample_weight, [.5, .5, .5, 1.5, 1.5, 1.5])
assert_array_almost_equal(sample_weight, [2. / 3, 2. / 3,
2. / 3, 2., 2., 2.])

# Test with a bootstrap subsample
y = np.asarray([1, 1, 1, 2, 2, 2])
sample_weight = compute_sample_weight("auto", y, [0, 1, 1, 2, 2, 3])
expected = np.asarray([1 / 3., 1 / 3., 1 / 3., 5 / 3., 5 / 3., 5 / 3.])
expected = np.asarray([0.6, 0.6, 0.6, 3., 3., 3.])
assert_array_almost_equal(sample_weight, expected)

# Test with a bootstrap subsample for multi-output
Expand Down

0 comments on commit b2cdc55

Please sign in to comment.