Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Grouped-Query Attention #1901

Merged
merged 7 commits into from
Jul 22, 2024
Merged

Conversation

ttw1018
Copy link
Contributor

@ttw1018 ttw1018 commented Jul 2, 2024

To support Grouped-Query Attention, change token_dim to token_dim // num_attention_heads * num_key_value_head and num_attention_heads to num_key_value_head.

@BenjaminBossan
Copy link
Member

Thanks for this PR, do you have an example that uses GQA and currently fails, but would succeed with this change?

@ttw1018
Copy link
Contributor Author

ttw1018 commented Jul 3, 2024

Thanks for this PR, do you have an example that uses GQA and currently fails, but would succeed with this change?

I use Qwen2-0.5B which uses GQA, and tune it with prefix tuning. I use the following code and get the following error.

test code

from peft import PrefixTuningConfig, get_peft_model, TaskType, PeftType
from transformers import Qwen2ForCausalLM
import torch

path = "/public/home/wlchen/twtang/peft-dst/PTM/Qwen2-0.5B"
base = Qwen2ForCausalLM.from_pretrained(path)

print("load base")

peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type=TaskType.CAUSAL_LM)

print(peft_config)
model = get_peft_model(base, peft_config)
print("load peft model")

x = torch.tensor([[1, 2, 3]])
model(x)

error

(DST) [wlchen@gpu01 src]$ python test.py
load base
PrefixTuningConfig(peft_type=<PeftType.PREFIX_TUNING: 'PREFIX_TUNING'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, num_virtual_tokens=10, token_dim=None, num_transformer_submodules=None, num_attention_heads=None, num_layers=None, encoder_hidden_size=None, prefix_projection=False)
load peft model
Traceback (most recent call last):
  File "/public/home/wlchen/twtang/code/peft/src/test.py", line 18, in <module>
    model(x)
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/twtang/code/peft/src/peft/peft_model.py", line 1554, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1169, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1054, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 768, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 277, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/public/home/wlchen/miniconda3/envs/DST/lib/python3.11/site-packages/transformers/cache_utils.py", line 146, in update
    self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 14 but got size 2 for tensor number 1 in the list.

I found that this error is cause by the following code.

def _prepare_prompt_learning_config(peft_config, model_config):
    if peft_config.num_layers is None:
        if "num_hidden_layers" in model_config:
            num_layers = model_config["num_hidden_layers"]
        elif "num_layers" in model_config:
            num_layers = model_config["num_layers"]
        elif "n_layer" in model_config:
            num_layers = model_config["n_layer"]
        else:
            raise ValueError("Please specify `num_layers` in `peft_config`")
        peft_config.num_layers = num_layers

    if peft_config.token_dim is None:
        if "hidden_size" in model_config:
            token_dim = model_config["hidden_size"]
        elif "n_embd" in model_config:
            token_dim = model_config["n_embd"]
        elif "d_model" in model_config:
            token_dim = model_config["d_model"]
        else:
            raise ValueError("Please specify `token_dim` in `peft_config`")
        peft_config.token_dim = token_dim

    if peft_config.num_attention_heads is None:
        if "num_attention_heads" in model_config:
            num_attention_heads = model_config["num_attention_heads"]
        elif "n_head" in model_config:
            num_attention_heads = model_config["n_head"]
        elif "num_heads" in model_config:
            num_attention_heads = model_config["num_heads"]
        elif "encoder_attention_heads" in model_config:
            num_attention_heads = model_config["encoder_attention_heads"]
        else:
            raise ValueError("Please specify `num_attention_heads` in `peft_config`")
        peft_config.num_attention_heads = num_attention_heads

    if getattr(peft_config, "encoder_hidden_size", None) is None:
        setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)

    return peft_config

If the model uses GQA, the number of key and value heads is different from the number of query heads. I add the following code to correct this error and the test code works now.

    if "num_key_value_heads" in model_config:
        num_key_value_heads = model_config["num_key_value_heads"]
        peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads 
        peft_config.num_attention_heads = num_key_value_heads 

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for providing the example and the fix. Based on your example, let's add a unit test too. I have a suggestion:

def test_prompt_learning_with_grouped_query_attention():
    # See 1901, fixes a bug with handling GQA
    model_id = "peft-internal-testing/tiny-dummy-qwen2"
    base_model = AutoModelForCausalLM.from_pretrained(model_id)
    peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
    model = get_peft_model(base_model, peft_config)
    x = torch.tensor([[1, 2, 3]])
    # does not raise
    model(x)

This can be added to the bottom of tests/test_decoder_models.py. PrefixTuningConfig needs to be imported at the top.

Also, could you please run make style on your PR?

tiwent added 2 commits July 4, 2024 19:17

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
@BenjaminBossan
Copy link
Member

I have done some more testing and added qwen2 to our unit test suite, see #1906. There are 4 failing tests but otherwise qwen2 works. However, after applying your fix, there are 13 failing tests for me:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃  File                          ┃  Function                                                                                                             ┃  Function Line  ┃  Error Line  ┃  Error         ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_disable_adapter_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder            │  314            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_disable_adapter_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning             │  314            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder                   │  226            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning                    │  226            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_pos_args_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder          │  230            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_pos_args_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning           │  230            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_inference_safetensors_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder      │  259            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_inference_safetensors_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning       │  259            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_passing_input_embeds_works_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encod…  │  341            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_passing_input_embeds_works_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning  │  341            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_training_prompt_learning_tasks_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_e…  │  310            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_training_prompt_learning_tasks_68_test_peft_internal_testing_tiny_dummy_qwen2_prompt_t…  │  310            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora     │  296            │  620         │  RuntimeError  │
└────────────────────────────────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴─────────────────┴──────────────┴────────────────┘

I didn't have time to dive deeper yet, but this should be addressed before merging the PR.

@ttw1018
Copy link
Contributor Author

ttw1018 commented Jul 4, 2024

I have done some more testing and added qwen2 to our unit test suite, see #1906. There are 4 failing tests but otherwise qwen2 works. However, after applying your fix, there are 13 failing tests for me:

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃  File                          ┃  Function                                                                                                             ┃  Function Line  ┃  Error Line  ┃  Error         ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_disable_adapter_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder            │  314            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_disable_adapter_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning             │  314            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder                   │  226            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning                    │  226            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_pos_args_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder          │  230            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_generate_pos_args_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning           │  230            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_inference_safetensors_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encoder      │  259            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_inference_safetensors_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning       │  259            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_passing_input_embeds_works_66_test_peft_internal_testing_tiny_dummy_qwen2_prompt_encod…  │  341            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_passing_input_embeds_works_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_tuning  │  341            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_training_prompt_learning_tasks_67_test_peft_internal_testing_tiny_dummy_qwen2_prompt_e…  │  310            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_training_prompt_learning_tasks_68_test_peft_internal_testing_tiny_dummy_qwen2_prompt_t…  │  310            │  620         │  RuntimeError  │
│  tests/test_decoder_models.py  │  PeftDecoderModelTester.test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora     │  296            │  620         │  RuntimeError  │
└────────────────────────────────┴───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┴─────────────────┴──────────────┴────────────────┘

I didn't have time to dive deeper yet, but this should be addressed before merging the PR.

Most fails are caused by prompt(prefix) tuning. My commit is to fix the error caused by prompt(prefix) tuning.
I saw the commit history in #1906 but i found it didn't apply my fix.

@BenjaminBossan
Copy link
Member

Most fails are caused by prompt(prefix) tuning. My commit is to fix the error caused by prompt(prefix) tuning.
I saw the commit history in #1906 but i found it didn't apply my fix.

Yes, I applied your fix locally, which resulted in the 13 errors cited above. Without the fix, there are only 4 errors, as in the tests of that PR.

@BenjaminBossan
Copy link
Member

@ttw1018 Do you still plan on working on this?

@ttw1018
Copy link
Contributor Author

ttw1018 commented Jul 18, 2024

I fixed most bugs except test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora, it happens when LORA is applied to k_proj.

There are four errors in test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora.

image

I don't understand when and why the weight of multi_adapter_svd_reweighting changed. It should be an 8x20 matrix but actually, it was an 8x8 matrix. I'm sure this situation caused that.

image image

@ttw1018
Copy link
Contributor Author

ttw1018 commented Jul 18, 2024

I fixed most bugs except test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora, it happens when LORA is applied to k_proj. I think Grouped Query Attention causes that because it works when the target module is q_proj.

There are four errors in test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora.

image

I don't understand when and why the weight of multi_adapter_svd_reweighting changed. It should be an 8x20 matrix but actually, it was an 8x8 matrix. I'm sure this situation caused that.

image image

@ttw1018 ttw1018 closed this Jul 18, 2024
@ttw1018 ttw1018 reopened this Jul 18, 2024
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for testing further and finding a solution. I could verify that your latest commit now fixes the failing Qwen2 tests except for test_weighted_combination_of_adapters_65_test_peft_internal_testing_tiny_dummy_qwen2_lora. Regarding that test, only the SVD-variants fail. It is not clear to me why, but I think we can live with that for now. I'll add some code to skip that test for the time being in my PR.

BenjaminBossan added a commit to BenjaminBossan/peft that referenced this pull request Jul 18, 2024
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

The linter is raising an error. Could you please run make style? Ensure that you have ruff version 0.4.10 installed.

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for providing this fix!

@BenjaminBossan BenjaminBossan merged commit 6472061 into huggingface:main Jul 22, 2024
14 checks passed
sayakpaul pushed a commit that referenced this pull request Jul 23, 2024
* [WIP] ENH Add support for Qwen2

Add Qwen2 to default target modules, use tiny Qwen2 in tests.

* Add target_modules for FourierFT

* Skip Qwen2 + weighted combination test

It fails when SVD is involved. See:
#1901 (comment)

---------

Co-authored-by: BenjaminBossan <b.bossan@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants