-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support Grouped-Query Attention #1901
Conversation
Thanks for this PR, do you have an example that uses GQA and currently fails, but would succeed with this change? |
I use Qwen2-0.5B which uses GQA, and tune it with prefix tuning. I use the following code and get the following error. test code from peft import PrefixTuningConfig, get_peft_model, TaskType, PeftType
from transformers import Qwen2ForCausalLM
import torch
path = "/public/home/wlchen/twtang/peft-dst/PTM/Qwen2-0.5B"
base = Qwen2ForCausalLM.from_pretrained(path)
print("load base")
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type=TaskType.CAUSAL_LM)
print(peft_config)
model = get_peft_model(base, peft_config)
print("load peft model")
x = torch.tensor([[1, 2, 3]])
model(x) error
I found that this error is cause by the following code. def _prepare_prompt_learning_config(peft_config, model_config):
if peft_config.num_layers is None:
if "num_hidden_layers" in model_config:
num_layers = model_config["num_hidden_layers"]
elif "num_layers" in model_config:
num_layers = model_config["num_layers"]
elif "n_layer" in model_config:
num_layers = model_config["n_layer"]
else:
raise ValueError("Please specify `num_layers` in `peft_config`")
peft_config.num_layers = num_layers
if peft_config.token_dim is None:
if "hidden_size" in model_config:
token_dim = model_config["hidden_size"]
elif "n_embd" in model_config:
token_dim = model_config["n_embd"]
elif "d_model" in model_config:
token_dim = model_config["d_model"]
else:
raise ValueError("Please specify `token_dim` in `peft_config`")
peft_config.token_dim = token_dim
if peft_config.num_attention_heads is None:
if "num_attention_heads" in model_config:
num_attention_heads = model_config["num_attention_heads"]
elif "n_head" in model_config:
num_attention_heads = model_config["n_head"]
elif "num_heads" in model_config:
num_attention_heads = model_config["num_heads"]
elif "encoder_attention_heads" in model_config:
num_attention_heads = model_config["encoder_attention_heads"]
else:
raise ValueError("Please specify `num_attention_heads` in `peft_config`")
peft_config.num_attention_heads = num_attention_heads
if getattr(peft_config, "encoder_hidden_size", None) is None:
setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)
return peft_config If the model uses GQA, the number of key and value heads is different from the number of query heads. I add the following code to correct this error and the test code works now. if "num_key_value_heads" in model_config:
num_key_value_heads = model_config["num_key_value_heads"]
peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads
peft_config.num_attention_heads = num_key_value_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for providing the example and the fix. Based on your example, let's add a unit test too. I have a suggestion:
def test_prompt_learning_with_grouped_query_attention():
# See 1901, fixes a bug with handling GQA
model_id = "peft-internal-testing/tiny-dummy-qwen2"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
model = get_peft_model(base_model, peft_config)
x = torch.tensor([[1, 2, 3]])
# does not raise
model(x)
This can be added to the bottom of tests/test_decoder_models.py
. PrefixTuningConfig
needs to be imported at the top.
Also, could you please run make style
on your PR?
I have done some more testing and added qwen2 to our unit test suite, see #1906. There are 4 failing tests but otherwise qwen2 works. However, after applying your fix, there are 13 failing tests for me:
I didn't have time to dive deeper yet, but this should be addressed before merging the PR. |
Most fails are caused by prompt(prefix) tuning. My commit is to fix the error caused by prompt(prefix) tuning. |
Yes, I applied your fix locally, which resulted in the 13 errors cited above. Without the fix, there are only 4 errors, as in the tests of that PR. |
@ttw1018 Do you still plan on working on this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for testing further and finding a solution. I could verify that your latest commit now fixes the failing Qwen2 tests except for test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora
. Regarding that test, only the SVD-variants fail. It is not clear to me why, but I think we can live with that for now. I'll add some code to skip that test for the time being in my PR.
It fails when SVD is involved. See: huggingface#1901 (comment)
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The linter is raising an error. Could you please run |
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for providing this fix!
* [WIP] ENH Add support for Qwen2 Add Qwen2 to default target modules, use tiny Qwen2 in tests. * Add target_modules for FourierFT * Skip Qwen2 + weighted combination test It fails when SVD is involved. See: #1901 (comment) --------- Co-authored-by: BenjaminBossan <b.bossan@gmail.com>
To support Grouped-Query Attention, change
token_dim
totoken_dim // num_attention_heads * num_key_value_head
andnum_attention_heads
tonum_key_value_head
.