Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Perf] Improve MLA on V1 #14540

Merged
merged 3 commits into from
Mar 10, 2025
Merged

[Perf] Improve MLA on V1 #14540

merged 3 commits into from
Mar 10, 2025

Conversation

simon-mo
Copy link
Collaborator

@simon-mo simon-mo commented Mar 10, 2025

This PR helps V1 to mostly match and exceed (in most cases) V0's performance for MLA. Mostly by two things

  1. Fix @LucasWilkinson's rotary_emb specialization ([Perf] Reduce MLA CPU overheads in V1 #14384, Revert "[Perf] Reduce MLA CPU overheads in V1 (#14384)" #14471, [Bugfix] DeepSeek Accuracy #14476) to reduce CPU overhead.
  • Identified that the cause of 0 GSM8K score comes from the cuda kernel needs the input to be continuous.
  • Fixed it by make the input contiguous if possible. A better fix will be to change the kernel (help wanted).
  1. Reordered some operation in the build function, which ended up costing quite a bit overhead in my timing (p99 tail latency up to 1ms)
  • This is by ensuring there is not GPU -> CPU communication. CPU -> GPU is fine.

All the following ran in 8xH200.

Performance Test (R1)

We are still a bit worse on the short range but we became significantly better on longer range. 64% boost for 6k input.

VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 1.09 requests/s, 4342.27 total tokens/s, 1085.57 output tokens/s

VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 1.13 requests/s, 4536.67 total tokens/s, 1134.17 output tokens/s

VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 6000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 0.87 requests/s, 6060.61 total tokens/s, 865.80 output tokens/s

VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 6000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 0.53 requests/s, 3692.82 total tokens/s, 527.55 output tokens/s

Performance Test (Small)

We are 15% better for small model for 3k input.

VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50
Throughput: 3.84 requests/s, 15364.27 total tokens/s, 3841.07 output tokens/s

VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50
Throughput: 3.32 requests/s, 13275.67 total tokens/s, 3318.92 output tokens/s

VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50 --enable-chunked-prefill false
Throughput: 3.32 requests/s, 13264.68 total tokens/s, 3316.17 output tokens/s

Accuracy Test

No regression.

VLLM_USE_V1="1" lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot=5 --limit 100 --log_samples --output_path lmeval-results

vllm (pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.65|±  |0.0479|
|     |       |strict-match    |     5|exact_match|↑  | 0.64|±  |0.0482|


VLLM_USE_V1="0" lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot=5 --limit 100 --log_samples --output_path lmeval-results

vllm (pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.66|±  |0.0476|
|     |       |strict-match    |     5|exact_match|↑  | 0.66|±  |0.0476|

Signed-off-by: simon-mo <simon.mo@hey.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 10, 2025
@simon-mo simon-mo requested a review from LucasWilkinson March 10, 2025 05:51
@simon-mo simon-mo added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2025
fix lint
Signed-off-by: simon-mo <simon.mo@hey.com>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM left 1 nit. Thanks for working on this! (sorry this fell on your plate)

good catch on number 2! my bad for not catching this! I was wondering if it would be better compute on the CPU in V1 but didn't really keep pushing on that, ill try to be more careful about reviewing CPU->GPU transfers in the future


decode_q_pe_input = (decode_q_pe.clone().contiguous()
if not decode_q_pe.is_contiguous() else
decode_q_pe)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need clone here? my understanding is .continuous() will implicitly do a clone if its not contiguous and no-op if it already is:

>>> x1 = torch.rand((4,4))
>>> x2 = x1.t()
>>> x1.is_contiguous()
True
>>> x2.is_contiguous()
False
>>> x1.data_ptr()
94306274798528
>>> x1.contiguous().data_ptr()
94306274798528
>>> x2.data_ptr()
94306274798528
>>> x2.contiguous().data_ptr()
94306363886080

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. I think we can drop this line and just do:

            decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
                 attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
                 decode_k_pe)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup great point and i verified the perf. clone was a left over from previous debugging but your solution is great!

Signed-off-by: simon-mo <simon.mo@hey.com>
@simon-mo simon-mo enabled auto-merge (squash) March 10, 2025 16:13
@simon-mo simon-mo disabled auto-merge March 10, 2025 19:06
@simon-mo simon-mo merged commit fb0acb6 into vllm-project:main Mar 10, 2025
29 of 31 checks passed
@ZhongYingMatrix
Copy link
Contributor

hi @simon-mo Thx for ur great work! Speaking of D2H operation, I notice that has_context on here would be a single element bool tensor, which incur H2D in following condition operation. Would it has an impact on performance?
cc @LucasWilkinson

@simon-mo
Copy link
Collaborator Author

good find. Fix welcomed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants