Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: llm_engine_example.py (more requests) get stuck #4904

Open
CsRic opened this issue May 19, 2024 · 1 comment
Open

[Bug]: llm_engine_example.py (more requests) get stuck #4904

CsRic opened this issue May 19, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@CsRic
Copy link

CsRic commented May 19, 2024

Your current environment

Collecting environment information...                                                                                 
PyTorch version: 2.3.0+cu121                                                                                          
Is debug build: False                                                                                                 
CUDA used to build PyTorch: 12.1                                                                                      
ROCM used to build PyTorch: N/A                                                                                       
                                                                                                                      
OS: Debian GNU/Linux 11 (bullseye) (x86_64)                                                                           
GCC version: (Debian 10.2.1-6) 10.2.1 20210110                                                                        
Clang version: Could not collect                                                                                      
CMake version: version 3.29.2                                                                                         
Libc version: glibc-2.31                                                                                              
                                                                                                                      
Python version: 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0] (64-bit runtime)                                    
Python platform: Linux-5.10.0-23-amd64-x86_64-with-glibc2.31                                                          
Is CUDA available: True                                                                                               
CUDA runtime version: 12.1.66                                                                                         
CUDA_MODULE_LOADING set to: LAZY                                                                                      
GPU models and configuration:                                                                                         
GPU 0: NVIDIA RTX A5000                                                                                               
GPU 1: NVIDIA RTX A5000                                                                                               
GPU 2: NVIDIA RTX A5000                                                                                               
GPU 3: NVIDIA RTX A5000                                                                                               
                                                                                                                      
Nvidia driver version: 545.23.08                                                                                      
cuDNN version: Could not collect                                                                                      
HIP runtime version: N/A                                                                                              
MIOpen runtime version: N/A                                                                                           
Is XNNPACK available: True                                                                                            
                                                                                                                      
CPU:                                                                                                                  
Architecture:                    x86_64                                                                               
CPU op-mode(s):                  32-bit, 64-bit                                                                       
Byte Order:                      Little Endian                                                                        
Address sizes:                   43 bits physical, 48 bits virtual                                                    
CPU(s):                          32                                                                                   
On-line CPU(s) list:             0-31                                                                                 
Thread(s) per core:              1                                                                                    
Core(s) per socket:              16                                                                                   
Socket(s):                       2                                                                                    
NUMA node(s):                    8                                                                                    
Vendor ID:                       AuthenticAMD                            
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7302 16-Core Processor
Stepping:                        0
Frequency boost:                 enabled
CPU MHz:                         1730.206
CPU max MHz:                     3310.5459
CPU min MHz:                     1500.0000
BogoMIPS:                        5988.84
Virtualization:                  AMD-V
L1d cache:                       1 MiB
L1i cache:                       1 MiB
L2 cache:                        16 MiB
L3 cache:                        256 MiB
NUMA node0 CPU(s):               0-3
NUMA node1 CPU(s):               4-7
NUMA node2 CPU(s):               8-11
NUMA node3 CPU(s):               12-15
NUMA node4 CPU(s):               16-19
NUMA node5 CPU(s):               20-23
NUMA node6 CPU(s):               24-27
NUMA node7 CPU(s):               28-31
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Mitigation; untrained return thunk; SMT disabled
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS No
t affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx f
xsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid 
aperfmperf pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legac
y svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb b
pext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate sme ssbd mba sev ibrs ibpb stibp vmmcall sev_es fsgsbase bmi1 avx2
 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_tota
l cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbya
sid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] triton==2.3.0
[pip3] vllm-nccl-cu12==2.18.1.0.4.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] torch                     2.3.0                    pypi_0    pypi
[conda] triton                    2.3.0                    pypi_0    pypi
[conda] vllm-nccl-cu12            2.18.1.0.4.0             pypi_0    pypiROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.2
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    NIC0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      SYS     SYS     SYS     SYS     12-15           N/A             N/A
GPU1    SYS      X      SYS     SYS     SYS     8-11    2               N/A
GPU2    SYS     SYS      X      SYS     SYS     28-31           N/A             N/A
GPU3    SYS     SYS     SYS      X      SYS     20-23   5               N/A
NIC0    SYS     SYS     SYS     SYS      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0

🐛 Describe the bug

I modified examples/llm_engine_example.py to test a large number of requests. With 200 requests of 32 random tokens, the engine get stuck and never produce a full answer.

llm_engine_example_heavy.py:

import argparse
from typing import List, Tuple
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
import numpy as np



def create_test_prompt_nonsense(args: argparse.Namespace) -> List[Tuple[str, SamplingParams]]:
    length = args.input_len
    return (np.random.randint(10000, size=(length)).tolist(),
            SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1))


def process_requests(engine: LLMEngine,
                     test_prompts: List[Tuple[str, SamplingParams]]):
    """Continuously process a list of prompts and handle the outputs."""
    request_id = 0

    while test_prompts or engine.has_unfinished_requests():
        if test_prompts:
            prompt_token_ids, sampling_params = test_prompts.pop(0)
            engine.add_request(str(request_id), None, sampling_params, prompt_token_ids)
            request_id += 1

        request_outputs: List[RequestOutput] = engine.step()

        for request_output in request_outputs:
            if request_output.finished:
                print(request_output)


def initialize_engine(args: argparse.Namespace) -> LLMEngine:
    """Initialize the LLMEngine from the command line arguments."""
    engine_args = EngineArgs.from_cli_args(args)
    return LLMEngine.from_engine_args(engine_args)


def main(args: argparse.Namespace):
    """Main function that sets up and runs the prompt processing."""
    engine = initialize_engine(args)
    test_prompts = []
    for i in range(args.test_num):
        test_prompts.append(create_test_prompt_nonsense(args))
    process_requests(engine, test_prompts)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Demo on using the LLMEngine class directly')
    parser = EngineArgs.add_cli_args(parser)
    parser.add_argument('--input-len', type=int, default=128)
    parser.add_argument('--test-num', type=int, default=100)
    args = parser.parse_args()
    main(args)

A successful run:

python llm_engine_example_heavy.py --model facebook/opt-125m \
 --input-len 32 \
 --test-num 10

output:

INFO 05-18 22:55:28 llm_engine.py:103] Initializing an LLM engine (v0.4.2) with config: model='facebook/opt-125m', spe
culative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, to
kenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=
LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_c
ache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_back
end='outlines'), seed=0, served_model_name=facebook/opt-125m)
INFO 05-18 22:55:31 selector.py:37] Using FlashAttention-2 backend.
INFO 05-18 22:55:31 weight_utils.py:199] Using model weights format ['*.bin']
INFO 05-18 22:55:31 model_runner.py:145] Loading model weights took 0.2389 GB
INFO 05-18 22:55:32 gpu_executor.py:83] # GPU blocks: 36865, # CPU blocks: 7281
INFO 05-18 22:55:34 model_runner.py:824] Capturing the model for CUDA graphs. This may lead to unexpected consequences
 if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the 
CLI.
INFO 05-18 22:55:34 model_runner.py:828] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running ou
t of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_se
qs` as needed to decrease memory usage.
INFO 05-18 22:55:37 model_runner.py:894] Graph capturing finished in 3 secs.
RequestOutput(request_id=0, prompt=None, prompt_token_ids=[2732, 9845, 3264, 4859, 9225, 7891, 4373, 5874, 6744, 3468,
 705, 2599, 2222, 7768, 2897, 9893, 537, 6216, 6921, 6036, 2163, 5072, 4851, 7877, 2046, 1871, 7599, 2496, 8291, 755, 
797, 659], prompt_logprobs=[None, {9845: Logprob(logprob=-12.68840217590332, rank=7125, decoded_token=' emerge'), 4: L
ogprob(logprob=-1.9188706874847412, rank=1, decoded_token='.')}, {3264: Logprob(logprob=-14.21910285949707, rank=6125,
 decoded_token=' accept'), 31: Logprob(logprob=-1.1546498537063599, rank=1, decoded_token=' from')}, {4859: Logprob(lo
gprob=-12.61778450012207, rank=4317, decoded_token=' buyers'), 4735: Logprob(logprob=-0.4537220001220703, rank=1, deco
ded_token='ably')}, {9225: Logprob(logprob=-16.543548583984375, rank=25278, decoded_token=' Kurdish'), 

...

I omitted the rest. All answers are printed. The program terminated normally.

A failed run, change --test-num from 10 to 200:

python llm_engine_example_heavy.py --model facebook/opt-125m \
 --input-len 32 \
 --test-num 200

output:

INFO 05-18 22:58:38 llm_engine.py:103] Initializing an LLM engine (v0.4.2) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=facebook/opt-125m)
INFO 05-18 22:58:42 selector.py:37] Using FlashAttention-2 backend.
INFO 05-18 22:58:43 weight_utils.py:199] Using model weights format ['*.bin']
INFO 05-18 22:58:43 model_runner.py:145] Loading model weights took 0.2389 GB
INFO 05-18 22:58:43 gpu_executor.py:83] # GPU blocks: 36865, # CPU blocks: 7281
INFO 05-18 22:58:45 model_runner.py:824] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-18 22:58:45 model_runner.py:828] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-18 22:58:48 model_runner.py:894] Graph capturing finished in 3 secs.
^C[rank0]: Traceback (most recent call last):
[rank0]:   File "/[my workpath]/vllm/examples/llm_engine_example_heavy.py", line 54, in <module>
[rank0]:     main(args)
[rank0]:   File "/[my workpath]/vllm/examples/llm_engine_example_heavy.py", line 44, in main
[rank0]:     process_requests(engine, test_prompts)
[rank0]:   File "/[my workpath]/vllm/examples/llm_engine_example_heavy.py", line 25, in process_requests
[rank0]:     request_outputs: List[RequestOutput] = engine.step()
[rank0]:                                            ^^^^^^^^^^^^^
[rank0]:   File "/[my workpath]/vllm/vllm/engine/llm_engine.py", line 686, in step
[rank0]:     request_outputs = self._process_model_outputs(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/[my workpath]/vllm/vllm/engine/llm_engine.py", line 599, in _process_model_outputs
[rank0]:     self.output_processor.process_prompt_logprob(seq_group, outputs)
[rank0]:   File "/[my workpath]/vllm/vllm/engine/output_processor/single_step.py", line 65, in process_prompt_logprob
[rank0]:     self.detokenizer.decode_prompt_logprobs_inplace(
[rank0]:   File "/[my workpath]/vllm/vllm/transformers_utils/detokenizer.py", line 60, in decode_prompt_logprobs_inplace
[rank0]:     new_read_offset) = detokenize_incrementally(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/[my workpath]/vllm/vllm/transformers_utils/detokenizer.py", line 287, in detokenize_incrementally
[rank0]:     prefix_text = tokenizer.convert_tokens_to_string(
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/[my conda env path]/lib/python3.11/site-packages/transformers/tokenization_utils_fast.py", line 612, in convert_tokens_to_string
[rank0]:     return self.backend_tokenizer.decoder.decode(tokens)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: KeyboardInterrupt

Before I hit ctrl+c, the program stuck for 1 hour. The GPU activiy is 0%.
The traceback always show self.backend_tokenizer.decoder.decode(tokens) as the latest position.

@CsRic CsRic added the bug Something isn't working label May 19, 2024
@zifeitong
Copy link
Contributor

zifeitong commented May 30, 2024

Minimal reproducer:

import numpy as np

from vllm import LLM, SamplingParams


def main():
    llm = LLM(model="facebook/opt-125m")
    test_prompts = np.random.randint(10000, size=(200, 32)).tolist()
    outputs = llm.generate(
        prompt_token_ids=test_prompts,
        sampling_params=SamplingParams(
            temperature=0.0, logprobs=1, prompt_logprobs=1
        ),
    )

    for output in outputs:
        print(output)


if __name__ == "__main__":
    main()

It's affecting using EleutherAI/lm-evaluation-harness with vLLM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants