Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG + 1] Add class_weight to PA Classifier, remove from PA Regressor #4767

Merged

Conversation

trevorstephens
Copy link
Contributor

Was browsing Landscape.io and noticed this strange one. class_weight is a zombie param for the PassiveAggressiveRegressor, not present for the PassiveAggressiveClassifier O_o

Removed from regressor, didn't think deprecation is necessary since it didn't go anywhere and makes no sense either, and implemented it for the classifier with some tests.

@@ -125,6 +126,77 @@ def test_classifier_undefined_methods():
assert_raises(AttributeError, lambda x: getattr(clf, x), meth)


def test_class_weights():
# Test class weights.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this taken from the SGDClassifier tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of it, with some mods here and there. Some others might have been adapted from d-tree's tests if I recall correctly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to reuse the SGD tests for PA (since they share the implementation). There'd be more work there, so I think this shouldn't be a show stopper for this PR. @amueller what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vene, there's certainly things that may need to be factored over to common tests for the class/sample wt classifiers, seems like a different PR to me though... ping @amueller ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eg, here: #4838 (comment)

@amueller
Copy link
Member

LGTM. Removing from regressor seems fine for me as silently ignoring was a bug and breaking there seems ok.

@amueller amueller changed the title [MRG] Add class_weight to PA Classifier, remove from PA Regressor [MRG + 1] Add class_weight to PA Classifier, remove from PA Regressor May 27, 2015
@trevorstephens
Copy link
Contributor Author

Note that I need to rebase this on top of the newly merged #4347

@trevorstephens
Copy link
Contributor Author

Changed over the class_weight preset from 'auto' to 'balanced' due to #4347 's merge. Assuming @amueller that your +1 stands, anyone else feel like taking a look?

@amueller
Copy link
Member

amueller commented Jun 4, 2015

yeah still lgtm.

@trevorstephens
Copy link
Contributor Author

@amueller ... While I'm at it, should i also add a sample_weight parameter to both PA-Reg/Cls? I'm reading over the reference lit: http://jmlr.csail.mit.edu/papers/volume7/crammer06a/crammer06a.pdf and it states to initialize all sample weights to zero, which wouldn't work with our sgd implementation... Not sure now if the whole class_weight thing is straying too far from the literature as it isn't mentioned anywhere in there, or if the consistency with the rest of the API is a valid reason to implement it. 😕 Your thoughts?

raise ValueError("class_weight 'balanced' is not supported for "
"partial_fit. In order to use 'balanced' "
"weights, use "
"compute_class_weight('balanced', classes, y). "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this error message indicate that compute_class_weight can be found in sklearn.utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be more clear, I'll update this. I'd simply copied from SGD's partial_fit message. I can also update the other instances where this message appears around the code base in a separate PR.

@vene
Copy link
Member

vene commented Jun 4, 2015

it states to initialize all sample weights to zero

I skimmed through it and couldn't find this, could you point it out? They do talk about initializing the linear model's weight vector (the coefficients) to 0. Is there something I missed in the cost-sensitive learning section?

I thought sample weights (and class weights) are a generally applicable concept, which is why people don't really mention them. But I'd leave sample weights for a different PR.

@amueller
Copy link
Member

amueller commented Jun 4, 2015

I thought sample weights (and class weights) are a generally applicable concept, which is why people don't really mention them. But I'd leave sample weights for a different PR.

+1 on both accounts

@trevorstephens
Copy link
Contributor Author

I skimmed through it and couldn't find this, could you point it out? They do talk about initializing the linear model's weight vector (the coefficients) to 0. Is there something I missed in the cost-sensitive learning section?

Ah yes, I think I skipped through it too fast :/

@trevorstephens
Copy link
Contributor Author

OK, error message updated and commits squashed.

@trevorstephens
Copy link
Contributor Author

@vene that message look better to you?

"partial_fit. In order to use 'balanced' "
"weights, from the sklearn.utils module use "
"compute_class_weight('balanced', classes, y). "
"In place of y you can us a large enough sample "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can us -> you can use
I'd say "subset" instead of "sample", because it's less ambiguous. (consider n_samples).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Ha, yeah ok. Cut and paste job, shall fix. And as I mentioned, this error exists several other places, will (once we get this settled) persist around the code base.
  2. Fair call, shall also change that.

@vene
Copy link
Member

vene commented Jun 6, 2015

I really like the new suggestion of using a subset of the labels to estimate class priors.

rebase on top of scikit-learn#4347

improve error message

update error msg
@trevorstephens
Copy link
Contributor Author

@vene , I think I have all your comments incorporated.

@trevorstephens
Copy link
Contributor Author

ping @vene this look good to you now?

assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([1]))

# we give a small weights to class 1
clf = PassiveAggressiveClassifier(C=0.1, n_iter=100,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely convinced it's better, but I can see some reasons for just doing clf.set_params(class_weight={1: 0.001}) here. It makes it explicit that the rest of the parameters shouldn't be changed, in case somebody modifies the test in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it might be slightly clearer @vene but I see this paradigm only very rarely in other tests in git grep... You think it's necessary for merge?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not important, it just seems slightly better to me from a maintenance point of view.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this. Either way would be fine.

@GaelVaroquaux
Copy link
Member

I reviewed the code. The changes are good. They are some pending comments, but it's minor details, and merging right now seems the best way to provide value to users.

Given that there is already a 👍 I am merging.

GaelVaroquaux added a commit that referenced this pull request Aug 30, 2015
[MRG + 1] Add class_weight to PA Classifier, remove from PA Regressor
@GaelVaroquaux GaelVaroquaux merged commit 6e735f5 into scikit-learn:master Aug 30, 2015
@trevorstephens trevorstephens deleted the passive-aggressive_cw branch August 30, 2015 15:31
@trevorstephens
Copy link
Contributor Author

Thanks for the reviews all!

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Aug 30, 2015 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants