-
-
Notifications
You must be signed in to change notification settings - Fork 6.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Improve MLA on V1 #14540
[Perf] Improve MLA on V1 #14540
Conversation
Signed-off-by: simon-mo <simon.mo@hey.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM left 1 nit. Thanks for working on this! (sorry this fell on your plate)
good catch on number 2! my bad for not catching this! I was wondering if it would be better compute on the CPU in V1 but didn't really keep pushing on that, ill try to be more careful about reviewing CPU->GPU transfers in the future
|
||
decode_q_pe_input = (decode_q_pe.clone().contiguous() | ||
if not decode_q_pe.is_contiguous() else | ||
decode_q_pe) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we need clone here? my understanding is .continuous()
will implicitly do a clone if its not contiguous and no-op if it already is:
>>> x1 = torch.rand((4,4))
>>> x2 = x1.t()
>>> x1.is_contiguous()
True
>>> x2.is_contiguous()
False
>>> x1.data_ptr()
94306274798528
>>> x1.contiguous().data_ptr()
94306274798528
>>> x2.data_ptr()
94306274798528
>>> x2.contiguous().data_ptr()
94306363886080
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. I think we can drop this line and just do:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup great point and i verified the perf. clone was a left over from previous debugging but your solution is great!
Signed-off-by: simon-mo <simon.mo@hey.com>
hi @simon-mo Thx for ur great work! Speaking of D2H operation, I notice that |
good find. Fix welcomed! |
This PR helps V1 to mostly match and exceed (in most cases) V0's performance for MLA. Mostly by two things
rotary_emb
specialization ([Perf] Reduce MLA CPU overheads in V1 #14384, Revert "[Perf] Reduce MLA CPU overheads in V1 (#14384)" #14471, [Bugfix] DeepSeek Accuracy #14476) to reduce CPU overhead.build
function, which ended up costing quite a bit overhead in my timing (p99 tail latency up to 1ms)All the following ran in 8xH200.
Performance Test (R1)
We are still a bit worse on the short range but we became significantly better on longer range. 64% boost for 6k input.
VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 1.09 requests/s, 4342.27 total tokens/s, 1085.57 output tokens/s
VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 1.13 requests/s, 4536.67 total tokens/s, 1134.17 output tokens/s
VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 6000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 0.87 requests/s, 6060.61 total tokens/s, 865.80 output tokens/s
VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --load-format dummy --trust-remote-code --input-len 6000 --output-len 1000 --num-prompts 50 --tensor-parallel-size 8
Throughput: 0.53 requests/s, 3692.82 total tokens/s, 527.55 output tokens/s
Performance Test (Small)
We are 15% better for small model for 3k input.
VLLM_USE_V1=1 python benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50
Throughput: 3.84 requests/s, 15364.27 total tokens/s, 3841.07 output tokens/s
VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50
Throughput: 3.32 requests/s, 13275.67 total tokens/s, 3318.92 output tokens/s
VLLM_USE_V1=0 python benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V2-Lite --load-format dummy --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 50 --enable-chunked-prefill false
Throughput: 3.32 requests/s, 13264.68 total tokens/s, 3316.17 output tokens/s
Accuracy Test
No regression.