Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bnb] Fix bnb skip modules #24043

Merged

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jun 6, 2023

What does this PR do?

Fixes #24037

#23479 removed by mistake the logic introduced in #21579 to deal with modules that are not needed to be converted

The PR also adds a nice test to make sure this will never happen again

@younesbelkada younesbelkada changed the title [bnb] Fix bnbskip modules [bnb] Fix bnb skip modules Jun 6, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 6, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
self.assertTrue(isinstance(seq_classification_model.classifier.out_proj, nn.Linear))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also check at least one other layer not in llm_int8_skip_modules is loaded in 8bit. Ideally one which will effectively check the recursion logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome yes agreed! Will add that now

seq_classification_model = AutoModelForSequenceClassification.from_pretrained(
"roberta-large-mnli", quantization_config=quantization_config
)
self.assertTrue(isinstance(seq_classification_model.classifier.dense, nn.Linear))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding (not a comment to address), here we're checking the layers of the classifier are nn.Linear. In test_linear_are_8bit, we check that the layers are nn.Linear too and that their dtype is torch.int8 (I didn't know this was possible!). Are we certain that this means these layers are loaded in correctly? Do we need a dtype check on the weights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, we also need a dtype check on the weights! Linear8bitLt has nn.Linear as a super class. Adding new tests!

@younesbelkada younesbelkada merged commit 4795219 into huggingface:main Jun 7, 2023
22 checks passed
@younesbelkada younesbelkada deleted the fix-bnb-skip-llm-modules branch June 7, 2023 13:27
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* fix skip modules test

* oops

* address comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BitsAndBytesConfig llm_int8_skip_modules does not work in the new version
3 participants