-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LoRA] fix cross_attention_kwargs
problems and tighten tests
#7388
Conversation
Cc: @younesbelkada for viz. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Will also wait for @BenjaminBossan to approve it. And then I will proceed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch ! Thanks ! One could also use get
to avoid copying the kwargs at each forward !
The problem with |
ok makes sense ! thanks for explaining ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this bug, I think the copy solution is solid.
* debugging * let's see the numbers * let's see the numbers * let's see the numbers * restrict tolerance. * increase inference steps. * shallow copy of cross_attentionkwargs * remove print
What does this PR do?
First of all, I would like to apologize for not being rigorous enough with #7338. This was actually breaking:
This is because
pop()
pops the requested key forever from the underlying dictionary (for the first time) and uses the default value throughout the subsequent calls. Sinceunet
within aDiffusionPipeline
is iteratively called this phenomenon creates a lot of unexpected consequences. As a result, the above-mentioned test fails. Here are thelora_scale
values:Notice how it is defaulting to 1.0 after the first round of denoising step.
A simple solution is to create a shallow copy of
cross_attention_kwargs
so that the original one is left untouched. This is what this PR does.Additionally, you may wonder why the below set of tests PASS?
pytest tests/lora/test_lora_layers_peft.py -k "test_simple_inference_with_text_unet_lora_and_scale"
My best guess is that because we use a little too few
num_inference_steps
to validate things. To see if my hunch was right, I increased thenum_inference_steps
to 5 here, and run these tests WITHOUT the changes introduced in this PR (i.e., shallow copy). All of those tests failed. With the changes, they pass.Once this PR is merged, I will take care of making another patch release.
Once again, I am genuinely sorry for the oversight on my end.