Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Long context: Decoding performance degradation #1608

Closed
1 of 2 tasks
DayDayupupupup opened this issue May 17, 2024 · 9 comments
Closed
1 of 2 tasks

[Bug] Long context: Decoding performance degradation #1608

DayDayupupupup opened this issue May 17, 2024 · 9 comments

Comments

@DayDayupupupup
Copy link

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.

Describe the bug

Compared with vllm 0.4.2, when the input length is 200k, the decoding time is significantly increased.

  • prompt = 1k, lmdeploy TPOP=7.3ms, vllm TPOP=8.66ms
  • prompt = 200k, lmdeploy TPOP=63ms, vllm TPOP=30ms

Reproduction

TEST Environment

  • IMAGE: openmmlab/lmdeploy:v0.4.1
  • GPU: H800 Driver Version: 535.129.03 CUDA Version: 12.2
  • TP: 1

Model

  • Meta-Llama-3-8B-Instruct: Change the max_position_embeddings in config.json to 409600

Testing script

  • llama3_lmdeploy.py
import time
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.model import ChatTemplateConfig
# Sample prompts.
# chat template length: 10, 
max_input_len = 200*1024-10
print(f'max input len = {max_input_len//1024}k       ')
model_path = 'Meta-Llama-3-8B-Instruct'
prompts = "hi" * (max_input_len - 1)

engine_config = TurbomindEngineConfig(quant_policy=0,
                                      tp=1,
                                      #cache_max_entry_count=0.8,
                                      session_len=max_input_len+2000)
#chat_template_config = ChatTemplateConfig.from_json("llama3_chat_template_config.json")
pipe = pipeline(model_path=model_path,
                backend_config=engine_config,
                #chat_template_config=chat_template_config,
                log_level='WARNING') 
gen_config = GenerationConfig(max_new_tokens=1023, 
                              temperature=0.6,
                              top_p=0.9,
                              ignore_eos=True,
                              stop_words=['<|end_of_text|>','<|eot_id|>'],
                              skip_special_tokens=False)
# Create an LLM.
start_time = time.perf_counter()
outputs = pipe([prompts], gen_config=gen_config)
end_time = time.perf_counter()
latency = end_time - start_time
print(f'warmup time: {latency*1000:.2f} ms')

start_time = time.perf_counter()
outputs = pipe([prompts], gen_config=gen_config)
end_time = time.perf_counter()
latency = end_time - start_time
print(f'infer time: {latency*1000:.2f} ms')

# Print the outputs.
for output in outputs:
    generated_text = output.text
    #print(f"Generated text: {generated_text!r}")
    print(f'prompt_tokens={output.input_token_len} gen_tokens={output.generate_token_len}')
    print(f'finish_reason: {output.finish_reason}')
  • llama3_vllm.py
from vllm import LLM, SamplingParams
import time
from transformers import AutoTokenizer
import torch
# Sample prompts.
max_input_len = 200*1024-10
print(f'                 max input len = {max_input_len//1024}k       ')
model_path = 'Meta-Llama-3-8B-Instruct'

prompts = "hi" * (max_input_len - 1)
chat_prompt = '''<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n''' + prompts + '''<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'''
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.6,
                                 top_p=0.9,
                                 ignore_eos=True,
                                 max_tokens=1024,
                                 )
                                 #stop=['<|end_of_text|>','<|eot_id|>'])
# Create an LLM.
llm = LLM(model=model_path,
          tensor_parallel_size=1,
          trust_remote_code=True,
          dtype='bfloat16',
          max_model_len=max_input_len+2000,
          #kv_cache_dtype="fp8",
          )

start_time = time.perf_counter()
outputs = llm.generate(chat_prompt, sampling_params)
end_time = time.perf_counter()
latency = end_time - start_time
print(f'warmup time: {latency*1000:.2f} ms')

start_time = time.perf_counter()
outputs = llm.generate(chat_prompt, sampling_params=sampling_params)
end_time = time.perf_counter()
latency = end_time - start_time
print(f'infer time: {latency*1000:.2f} ms')
# Print the outputs. 
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    gen_length = len(output.outputs[0].token_ids)
    print(f'finish_reason={output.outputs[0].finish_reason}')
    print(f'prompt_tokens={len(output.prompt_token_ids)} gen_tokens={gen_length}')

TEST MODE

llama3_lmdeploy.py

  1. Set max_new_tokens=0, get TTFT
  2. Set max_new_tokens=1023, TPOP = (infer_time - TTFT) / 1024

llama3_vllm.py

  1. Set max_tokens=1, get TTFT
  2. Set max_tokens=1024, TPOP = (infer_time - TTFT) / 1024

Performance comparison

ledeploy_vs_vllm

Environment

sys.platform: linux
Python: 3.8.10 (default, Nov 22 2023, 10:22:35) [GCC 9.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3,4,5,6,7: NVIDIA H800
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: x86_64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
PyTorch: 2.1.0+cu118
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.16.0+cu118
LMDeploy: 0.4.1+14e9953
transformers: 4.40.2
gradio: 3.50.2
fastapi: 0.111.0
pydantic: 2.7.1
triton: 2.1.0

Error traceback

No response

@DayDayupupupup DayDayupupupup changed the title Long context: Decoding performance degradation [Bug] [Bug] Long context: Decoding performance degradation May 17, 2024
@lvhan028
Copy link
Collaborator

@lzhangzz is working on this issue. The related PR is #1606
@zhulinJulia24 may pay attention to the test case.

@lzhangzz
Copy link
Collaborator

lzhangzz commented May 17, 2024

With the script above, on A100-SXM4-80G

        output=1023     output=1      
before: 152453.82 ms - 66456.73 ms = 85997.09 ms
 #1606:  94043.37 ms - 68165.34 ms = 25878.03 ms

@DayDayupupupup
Copy link
Author

UPDATE Performance comparison

lmdeploy_pr1606

With PR#1606, decoding performance has improved significantly! 👍🏻

              input=200k  output=1024 's TPOP 
before:                 63.1ms
 #1606:                 12.24ms

Thanks! @lvhan028 @lzhangzz

@DayDayupupupup
Copy link
Author

@lzhangzz Can you test the above llama3_lmdeploy.py script again?
In this case, 200K_1024, PR#1606 vs lmdeploy 0.4.2

I noticed that the decoding speed dropped after upgrading to the new version 0.4.2

lmdeploy042

@lzhangzz
Copy link
Collaborator

lzhangzz commented May 30, 2024

@DayDayupupupup I got almost the same result on v0.4.2 compared with #1606

               output=1023  output=1 
- F16  KV cache: 94145.74 - 66536.14 = 27609.6  ms (run 1)
                 94387.54 - 66691.50 = 27696.04 ms (run 2)
- INT8 KV cache: 87000.81 - 65821.41 = 21179.4  ms

@datalee
Copy link

datalee commented Jun 4, 2024

mark

@DayDayupupupup
Copy link
Author

@lzhangzz Thanks for your reply.

In fact, the decoding performance is consistent between 0.4.2 and 0.4.1 with PR1606。

Correct the above data, 0.4.1+PR1606 performance data is wrong. WIth 0.4.1, I made some changes to support head_dim=64,and then merged PR1606, so TPOP=12.24ms at 200k is not accurate.

# src/turbomind/kernels/attention/decoding.cu +32
# I modifide 128 to 64, so I got TPOP=12.24ms(0.4.1 + my changes + PR1606 + LLAMA3)
# Using the original 128,  TPOP=16.56 ms(0.4.1 + PR1606 + LLAMA3)
static constexpr std::integral_constant<int, 128> kHeadDim{};

Another question, when is head_dim=64 expected to be supported?
Although I supported head_dim=64 and the inference is correct, but the performance is poor for long sequences. 😂

@lzhangzz
Copy link
Collaborator

lzhangzz commented Jun 6, 2024

when is head_dim=64 expected to be supported?

Likely in July.

Although I supported head_dim=64 and the inference is correct, but the performance is poor for long sequences

You may try to find the bottleneck using Nsight Compute

@DayDayupupupup
Copy link
Author

Thx, I'll give it a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants