Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast to fp32 if using bf16 weights on cpu during merge_and_unload #1978

Merged
merged 2 commits into from
Jul 31, 2024

Conversation

snarayan21
Copy link
Contributor

Should address #1977

Verified

This commit was signed with the committer’s verified signature.
bizob2828 Bob Evans
@snarayan21
Copy link
Contributor Author

@BenjaminBossan Our team found this bug through LoRA runs that were hanging for a very long time on certain CPU types and addressed it with this simple fix. would be great if you could take a look. Thanks!

@BenjaminBossan
Copy link
Member

Thanks for working on this fix. I think this should be okay to merge, but I would like to check if I can replicate the slowness. Do you have an example that I could check?

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jul 30, 2024

It's dependent on the CPU, but instantiating any LoraModel, converting it to fp16 (should be fast) and bf16 (should be slow), and callingmodel.merge_and_unload() should do the trick... @BenjaminBossan

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jul 30, 2024

Actually @BenjaminBossan here's a script which worked for me. Running locally on my mac (current peft v0.12.0), I got:

Merge and unload with dtype torch.float32 took 0.13617515563964844 seconds
Merge and unload with dtype torch.bfloat16 took 8.84065294265747 seconds
Merge and unload with dtype torch.float16 took 0.15977191925048828 seconds

Clearly, bf16 takes way longer due to the lack of fast bf16 matmul support on many cpus.

Script:

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import time
import torch

model_id = "facebook/opt-350m"


def merge_adapters(dtype):
    model = AutoModelForCausalLM.from_pretrained(model_id)
    config = LoraConfig(r=256)
    model = get_peft_model(model, config)

    model = model.to(dtype=dtype)

    start = time.time()
    model = model.merge_and_unload()
    end = time.time()

    print(f'Merge and unload with dtype {dtype} took {end - start} seconds')

if __name__ == '__main__':
    merge_adapters(torch.float32)
    merge_adapters(torch.bfloat16)
    merge_adapters(torch.float16)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Thanks for providing the example. On one machine, I could not see any slowdown with bf16, on the other there was a factor of 100.

Regarding the comment above the code:

        # In case users wants to merge the adapter weights that are in
        # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
        # (b)float16 because the `@` and matmul operation in general is not supported in torch + cpu + fp16/bf16.
        cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)

it is not quite correct anymore, right? Originally, we had to cast to fp32 because there was an error with fp16 on CPU (which I think is fixed in newer PyTorch versions). Would it make sense to add a comment about accelerating the operation on some CPUs?

@snarayan21
Copy link
Contributor Author

Sure, let me change that.

yo

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
@snarayan21
Copy link
Contributor Author

@BenjaminBossan updated!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this improvement, LGTM.

@BenjaminBossan BenjaminBossan merged commit 52a4ac9 into huggingface:main Jul 31, 2024
14 checks passed
@snarayan21
Copy link
Contributor Author

@BenjaminBossan when will the next release be?

@BenjaminBossan
Copy link
Member

when will the next release be?

Sorry, no release soon, as we had a release just last week. You could install from main (optionally fixing the hash) if you need this right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants