Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc]: a question about chunked-prefill in flash-attn backends #4863

Open
HarryWu99 opened this issue May 16, 2024 · 2 comments
Open

[Misc]: a question about chunked-prefill in flash-attn backends #4863

HarryWu99 opened this issue May 16, 2024 · 2 comments
Labels

Comments

@HarryWu99
Copy link
Contributor

Anything you want to discuss about vllm.

if prefill_meta := attn_metadata.prefill_metadata:

I noticed that in flash-attn backends. forward_prefix and forward_decode seem to be executed serially. Does forward_decode wait for forward_prefix to finish before running? Can this take advantage of the performance provided by chunked-prefill? I mean the tokens of prefill are in the same batch as the tokens of decode.

if prefill_meta := attn_metadata.prefill_metadata:
    output[:num_prefill_tokens] = PagedAttention.forward_prefix(...)

if decode_meta := attn_metadata.decode_metadata:
    output[num_prefill_tokens:] = PagedAttention.forward_decode(...)
@HarryWu99 HarryWu99 added the misc label May 16, 2024
@rkooo567
Copy link
Collaborator

rkooo567 commented May 18, 2024

I noticed that in flash-attn backends. forward_prefix and forward_decode seem to be executed serially. Does forward_decode wait for forward_prefix to finish before running? Can this take advantage of the performance provided by chunked-prefill? I mean the tokens of prefill are in the same batch as the tokens of decode.

Yeah right now, it is running serially. I think after #4681, it should be possible to run them in the same attn kernel, but based on our past internal benchmark before, it didn't make much difference (we can definitely try to see how much perf improvement it will have). But this could be different now.

Note that this should be done after we re-revert #4820 because we should use prefix kernel to run both in the same attn kernel, and existing prefix kernel is too slow (flash attn varlen has at least 3X faster than this kernel)

@CrimsonDump
Copy link

I noticed that in flash-attn backends. forward_prefix and forward_decode seem to be executed serially. Does forward_decode wait for forward_prefix to finish before running? Can this take advantage of the performance provided by chunked-prefill? I mean the tokens of prefill are in the same batch as the tokens of decode.

Yeah right now, it is running serially. I think after #4681, it should be possible to run them in the same attn kernel, but based on our past internal benchmark before, it didn't make much difference (we can definitely try to see how much perf improvement it will have). But this could be different now.

Note that this should be done after we re-revert #4820 because we should use prefix kernel to run both in the same attn kernel, and existing prefix kernel is too slow (flash attn varlen has at least 3X faster than this kernel)

Is there a Issue/PR to "re-revert #4820 " for us to track?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants