Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Balance vision model weights on multi gpus #1591

Merged
merged 12 commits into from
May 23, 2024
Merged

Conversation

irexyc
Copy link
Collaborator

@irexyc irexyc commented May 14, 2024

TODO

  • hangs issue when using nccl(turbomind).
  • docs & cli

#1563

@irexyc irexyc changed the title [WIP] Balance vision model weights on multi gpus Balance vision model weights on multi gpus May 14, 2024
lmdeploy/messages.py Outdated Show resolved Hide resolved
lmdeploy/messages.py Outdated Show resolved Hide resolved
@@ -279,7 +279,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh

/// TODO: this stream handle is leaked
cudaStream_t stream{};
ft::check_cuda_error(cudaStreamCreate(&stream));
ft::check_cuda_error(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non blocking, why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是 non blocking 的话同步时会和default stream 同步。怀疑 nccl 和 python 那边的 default stream 同步的时候会有冲突,然后卡主。

lmdeploy/cli/serve.py Outdated Show resolved Hide resolved
lmdeploy/cli/serve.py Outdated Show resolved Hide resolved
Comment on lines -98 to -103
@contextmanager
def cuda_ctx(device_id):
old_device = torch.cuda.current_device()
torch.cuda.set_device(device_id)
yield
torch.cuda.set_device(old_device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, why add and remove it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pybind的那几个函数在c++那边都有做cudaSetDevice,所以感觉不需要了。
之前的话,forward thread c++ 那边没做cudaSetDevice,不过有没有cuda_ctx 我这里都能正常跑。

self.model = load_vl_model(model_path)
self.max_batch_size = max_batch_size
self.max_batch_size = (1 if vicion_config is None else
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the default value from 16 to 1 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为有issue提到,服务可以启动,但是多个请求过来显存会超。另外有一些vision模型的显存对batch比较敏感

@buaadf
Copy link

buaadf commented May 21, 2024

@lzhangzz waiting for your review...

@rTrQqgH74lc2PT5k
Copy link

非常期待这个功能

@RunningLeon RunningLeon mentioned this pull request May 22, 2024
6 tasks
Copy link
Collaborator

@AllentDan AllentDan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在的 tp 相当于只要 CUDA_VISIBLE_DEVICES 可访问的 GPU 都会用吗?即使指定了 tp==2, 也会用四卡,如果四卡均可访问

@irexyc
Copy link
Collaborator Author

irexyc commented May 22, 2024

现在的 tp 相当于只要 CUDA_VISIBLE_DEVICES 可访问的 GPU 都会用吗?即使指定了 tp==2, 也会用四卡,如果四卡均可访问

是的。

lmdeploy/vl/engine.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested OK with vl pipeline on tp=2 for these available models:

  • llava-v1.57b
  • llava-v1.6-vicuna-7b
  • llava-v1.6-34b
  • deepseek-vl-7b-chat
  • Qwen-VL-Chat
  • Yi-VL-6B
  • Mini-Gemini-7B

Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028
Copy link
Collaborator

lvhan028 commented May 23, 2024

runtime.txt 中要明确下 accelerate的最低版本

Copy link
Collaborator

@lzhangzz lzhangzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit 2f28531 into InternLM:main May 23, 2024
5 of 9 checks passed
@covdvoyager
Copy link

大佬,请问在执行python文件时 File "D:\新建文件夹\InternDog-master\app_cli.py", line 3, in
from agent.model import chat
File "D:\新建文件夹\InternDog-master\agent\model.py", line 2, in
from lmdeploy import turbomind as tm
File "C:\Users\86186\anaconda3\envs\pythonProject2\Lib\site-packages\lmdeploy\turbomind_init_.py", line 24, in
from .turbomind import TurboMind # noqa: E402
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\86186\anaconda3\envs\pythonProject2\Lib\site-packages\lmdeploy\turbomind\turbomind.py", line 26, in
from .deploy.converter import (get_model_format, supported_formats,
File "C:\Users\86186\anaconda3\envs\pythonProject2\Lib\site-packages\lmdeploy\turbomind\deploy\converter.py", line 16, in
from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig
File "C:\Users\86186\anaconda3\envs\pythonProject2\Lib\site-packages\lmdeploy\turbomind\deploy\target_model_init_.py", line 3, in
from .w4 import TurbomindW4Model # noqa: F401
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Users\86186\anaconda3\envs\pythonProject2\Lib\site-packages\lmdeploy\turbomind\deploy\target_model\w4.py", line 17, in
import _turbomind as _tm # noqa: E402
^^^^^^^^^^^^^^^^^^^^^^^^
ImportError: DLL load failed while importing _turbomind: 找不到指定的模块。这样报错是有甚麽问题吗

@irexyc
Copy link
Collaborator Author

irexyc commented May 23, 2024

@covdvoyager

可以看下这个是否能帮到你。
#1146 (comment)

@buaadf
Copy link

buaadf commented May 24, 2024

@irexyc 请问多卡并行必须要2的幂次张卡吗,我这里用3张A30跑不起来

@irexyc
Copy link
Collaborator Author

irexyc commented May 24, 2024

@buaadf

backend_config 里面的 tp 需要 2的幂次。

@buaadf
Copy link

buaadf commented May 24, 2024

@buaadf

backend_config 里面的 tp 需要 2的幂次。

请问 tp的设置和卡数有什么关系吗,A30(24G)至少需要几张才能跑起来呀?

@irexyc
Copy link
Collaborator Author

irexyc commented May 24, 2024

@buaadf

LM 模型切分 tp 需要是2的幂次。tp=2就是说LM需要两块卡,会从可见的卡里面选择0,1号卡。

如果你跑的是VLM模型,backend_config tp设置2,CUDA_VISIBLE_DEVICES=“0,1,2”,那么 vision 模型会均分到三块卡上,LM模型会均分到前两块卡上。

能不能跑看你跑的是什么模型。就权重来说(不量化),7b的模型,大概需要14G的显存。20b的模型需要40G的显存。除了模型的显存外,kv cache 也需要显存,会影响 session_len 以及 batch 的大小。可以通过 cache_max_entry_count 来控制大小。

@ysyx2008
Copy link

tp=2的情况下,双4090卡仍然无法运行int8版本的InternVL(25G权重文件),显存占用会爆掉。
求赐教。

(internvl) yushen@YuShen-Work:~/ai/InternVL$ python gradio_InternVL.py
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Exception in thread Thread-3 (_create_weight_func):
Traceback (most recent call last):
File "/home/yushen/micromamba/envs/internvl/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/yushen/micromamba/envs/internvl/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/yushen/micromamba/envs/internvl/lib/python3.10/site-packages/lmdeploy/turbomind/turbomind.py", line 199, in _create_weight_func
model_comm.create_shared_weights(device_id, rank)
RuntimeError: [TM][ERROR] CUDA runtime error: out of memory /lmdeploy/src/turbomind/utils/memory_utils.cu:32

Exception in thread Thread-5 (_get_params):
Traceback (most recent call last):
File "/home/yushen/micromamba/envs/internvl/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/yushen/micromamba/envs/internvl/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/yushen/micromamba/envs/internvl/lib/python3.10/site-packages/lmdeploy/turbomind/turbomind.py", line 215, in _get_params
out = model_comm.get_params(device_id, rank)
RuntimeError: [TM][ERROR] Assertion fail: /lmdeploy/src/turbomind/triton_backend/llama/LlamaTritonModel.cc:417

@irexyc
Copy link
Collaborator Author

irexyc commented May 27, 2024

@ysyx2008

我们不支持加载bnb形式的int8模型。请用我们的量化工具进行量化。这个工具应该在0.4.2中可以使用

这个是针对 LLM 的量化文档,VLM 也是通用的,把DEMO中的模型换成VLM模型就好。

近期也会有一篇针对 VLM量化 的文章发布,可以关注一下。

@ysyx2008
Copy link

@ysyx2008

我们不支持加载bnb形式的int8模型。请用我们的量化工具进行量化。这个工具应该在0.4.2中可以使用

这个是针对 LLM 的量化文档,VLM 也是通用的,把DEMO中的模型换成VLM模型就好。

近期也会有一篇针对 VLM量化 的文章发布,可以关注一下。

非常感谢,之前自行量化报错,刚发现pip默认安装的是0.4.1版本,我再去研究更新到0.4.2再试一次。再次感谢。

@sshuair
Copy link

sshuair commented May 27, 2024

@irexyc error with internlm/internlm-xcomposer2-4khd-7b model

Dummy Resized
Traceback (most recent call last):
  File "/opt/py38/bin/lmdeploy", line 11, in <module>
    load_entry_point('lmdeploy', 'console_scripts', 'lmdeploy')()
  File "/opt/lmdeploy/lmdeploy/cli/entrypoint.py", line 37, in run
    args.run(args)
  File "/opt/lmdeploy/lmdeploy/cli/serve.py", line 303, in api_server
    run_api_server(args.model_path,
  File "/opt/lmdeploy/lmdeploy/serve/openai/api_server.py", line 1191, in serve
    VariableInterface.async_engine = pipeline_class(
  File "/opt/lmdeploy/lmdeploy/serve/vl_async_engine.py", line 20, in __init__
    self.vl_encoder = ImageEncoder(model_path, vision_config)
  File "/opt/lmdeploy/lmdeploy/vl/engine.py", line 69, in __init__
    self.model = load_vl_model(model_path)
  File "/opt/lmdeploy/lmdeploy/vl/model/builder.py", line 40, in load_vl_model
    return Xcomposer2VisionModel(model_path, with_llm)
  File "/opt/lmdeploy/lmdeploy/vl/model/xcomposer2.py", line 42, in __init__
    self.build_model()
  File "/opt/lmdeploy/lmdeploy/vl/model/xcomposer2.py", line 76, in build_model
    max_memory = get_balanced_memory(
UnboundLocalError: local variable 'get_balanced_memory' referenced before assignment

@irexyc
Copy link
Collaborator Author

irexyc commented May 27, 2024

@sshuair

shoud be fixed in #1661

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants