Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support vl models quantization #1553

Merged
merged 29 commits into from
May 24, 2024
Merged

Conversation

AllentDan
Copy link
Collaborator

@AllentDan AllentDan commented May 7, 2024

  • deepseek vl
  • llava
  • internvl
  • xcomposer (did not quant plora)
  • minigemini
  • yi
  • qwen
  • internvl-llava

@AllentDan AllentDan changed the title [WIP] Support vl models quantization [Feature] Support vl models quantization May 13, 2024
@AllentDan AllentDan removed the WIP label May 13, 2024
@lvhan028 lvhan028 added the enhancement New feature or request label May 13, 2024
@lvhan028 lvhan028 requested review from irexyc and pppppM May 13, 2024 03:15
@irexyc
Copy link
Collaborator

irexyc commented May 14, 2024

xcomposer2 量化的时候,weight_type 是int4,LlamaLinear.h 是需要改的,不然只会经过forwardInt4,不会经过plora

Comment on lines +47 to +50
if (n < N) {
C[n * M + m] += ((half2&)data).x;
C[n * M + m + 1] += ((half2&)data).y;
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is weird since the following implementation failed:

        if (n < N) {
            (half2&)C[n * M + m] += (half2&)data;
        }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my test

if (n < N) { 
    (half2&)C[n * M + m] += (half2&)data; 
}

works fine.

@lvhan028
Copy link
Collaborator

如果 xcompose2 的量化比较复杂,建议使用另外的PR单独处理。不然,可能会影响review速度,也可能和其他PR有冲突

@AllentDan
Copy link
Collaborator Author

如果 xcompose2 的量化比较复杂,建议使用另外的PR单独处理。不然,可能会影响review速度,也可能和其他PR有冲突

可以 work 了,但是遇到个怪事,就上面的 comment

@lvhan028
Copy link
Collaborator

Pls resolve the conflicts

@lvhan028 lvhan028 requested a review from lzhangzz May 19, 2024 11:37
max_shard_size='2GB',
safe_serialization=False)
if vl_model:
save_vl_model(vl_model, model_path, work_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vl_model 能不能复用 model.save_pretrained?
vl_model 量化后,还能通过 transformers 加载么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vl_model 里除了 llm 部分,还有视觉部分也要一起保存。所以不能直接用 model.save_pretrained.
理论上可以,因为只量化了 llm 的部分,这部分本来是兼容 transformers 的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能这里需要更多的工作。得要先了解 transformers 对 vlm 量化支持的程度,这个放在TODO list中,后续来跟进吧。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

试了下,transformers 跑不了我们的 vl 量化模型,transformers 的 awq 是所有层的量化,包括 vision 的部分,我们的 vision 的部分不量化。只能跑 llm 的部分,有图片对话就会有问题。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llava awq 的那个没跑起来。qwen vl 的用的 gptq,不太一样。

@@ -51,6 +51,20 @@ def load_model_from_weight_files(model: nn.Module, folder: str) -> None:
model.load_state_dict(state_dict, strict=False)


def buffers_aware_empty(model: nn.Module, device: str = 'cpu'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

motivation是什么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

buffer 形式的 tensor,在 to_empty 后会被清零,他们不在权重里面,不会被 save 或者 load,只能初始化产生。所以只能先备份,然后 to_empty 后再拷贝回去。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我也遇到类似问题,可以直接写个to_empty回调函数,无需拷贝,参考如下:

def to_empty(m, device='cpu', include_buffers=False):
    
    fn = lambda t: torch.empty_like(t, device=device) 
    for key, param in m._parameters.items():
        if param is None:
            continue
        m._parameters[key] = fn(m._parameters[key])
    
    for key, buf in m._buffers.items():
        if buf is not None:
            if include_buffers:
                m._buffers[key] = fn(buf)
            else:
                m._buffers[key] = m._buffers[key].to(device)

model = model.apply(to_empty)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前的初始化的逻辑是 meta model -> cpu empty -> cuda model

我在balanced vision model 那里已经不用to_empty了。这里改成这样可以么?

from accelerate import load_checkpoint_and_dispatch
load_checkpoint_and_dispatch(
                model=model,
                checkpoint=self.model_path,
                device_map='auto',
                dtype=torch.halfstrict=False)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

量化是只能在单卡上做么?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果 load_checkpoint_and_dispatch 可以完全满足需求,我们就不用造轮子了 @AllentDan

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,我本以为是两个 PR 一个合入后另一个会处理掉。是要在这个 PR 里就改好吗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我建议在这个PR中改掉。@irexyc 你觉得合适不?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我建议在这个PR中改掉。@irexyc 你觉得合适不?

因为改变了模型的加载方式,我也觉得这里改一下,测一下比较好,不然后面还要测这个地方。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要不先合 tp 的 PR 吧,单靠将 buffers_aware_empty 换成

from accelerate import load_checkpoint_and_dispatch
load_checkpoint_and_dispatch(
                model=model,
                checkpoint=self.model_path,
                device_map='auto',
                dtype=torch.halfstrict=False)

不够

@@ -289,6 +289,8 @@ def main(model_name: str,
if inferred_model_format.find('awq') != -1:
cfg.weight_type = 'int4'
output_format = 'w4'
if 'xcomposer2' in inferred_model_format:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉 output_format 变得有些难维护了。
converter.py, turbomind.py 都有类似的修改和判断
这个PR中,暂时放过。但后续需要好好优化这部分的可维护性。工作 assigned 到我这边

@RunningLeon RunningLeon mentioned this pull request May 22, 2024
6 tasks
int* lora_mask)
{
FT_CHECK(type == kGemm);
// output = lora(x) * scale
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora(x) scale, mask(), xw + output 和 FpLora的计算是一样的。
可否把这部分提前到 forward 里先算好,然后根据不同的data type dispatch到不同的fuse操作?

Copy link
Collaborator Author

@AllentDan AllentDan May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要放到 forword 的 swtich type 前吗,data type dispatch 不够,forwardFpLoraforwardFp 最后的 Gemm 不一样,forwardInt4LoraforwardInt4 最后的 gemm_s4_f16_.Run 也不一样

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I meant

void forward(T*                         output_data,
                 const T*                   input_data,
                 int                        batch_size,
                 const LlamaDenseWeight<T>& weight,
                 Type                       type      = kGemm,
                 int*                       lora_mask = nullptr)
    {
        if (weight.lora.r == 0) {
            case WeightType::kFP16:
            case WeightType::kFP32:
            case WeightType::kBF16:
                forwardFp(output_data, input_data, batch_size, weight, type);
                break;
            case WeightType::kINT4:
                forwardInt4(output_data, input_data, batch_size, weight, type);
                break;
            default:
                FT_CHECK(0);
        } else if (lora_mask != nullptr && weight.lora.r > 0) {
            FT_CHECK(type == kGemm);
            // output = lora(x) * scale
            // output = mask(output)
            // output = x*W + output
            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  weight.lora.r,                                  // m
                                  batch_size,                                     // n
                                  weight.input_dims,                              // k
                                  (const T*)weight.lora.a,                        // A
                                  weight.lora.r,                                  // lda
                                  input_data,                                     // B
                                  weight.input_dims,                              // ldb
                                  output_data + batch_size * weight.output_dims,  // C
                                  weight.lora.r);                                 // ldc

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  weight.output_dims,                             // m
                                  batch_size,                                     // n
                                  weight.lora.r,                                  // k
                                  (const T*)weight.lora.b,                        // A
                                  weight.output_dims,                             // lda
                                  output_data + batch_size * weight.output_dims,  // B
                                  weight.lora.r,                                  // ldb
                                  output_data,                                    // C
                                  weight.output_dims,                             // ldc
                                  weight.lora.scale,                              // alpha
                                  0.0f);                                          // beta

            invokeMask(output_data, lora_mask, batch_size, weight.output_dims, stream_);
            switch (weight.type) {
                case WeightType::kFP16:
                case WeightType::kFP32:
                case WeightType::kBF16:
                    forwardFpLora(output_data, input_data, batch_size, weight, type);
                    break;
                case WeightType::kINT4:
                    forwardInt4Lora(output_data, input_data, batch_size, weight, type);
                    break;
           }
        } else {
          FT_CHECK(0);
        }
    }

Let forward do the merging lora, and forwardFpLora and forwardInt4Lora do the gemm part
// output = lora(x) * scale
// output = mask(output)
// output = x*W + output

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lzhangzz what's your opinion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do this, can we merge forwardInt4Lora and forwardInt4, forwardFpLora and forwardFp respectively?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@lzhangzz lzhangzz May 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, adpaters are independent of the dispatching of mixed precision GEMMs.

    void forward(T*                         output_data,
                 const T*                   input_data,
                 int                        batch_size,
                 const LlamaDenseWeight<T>& weight,
                 Type                       type      = kGemm,
                 int*                       lora_mask = nullptr)
    {

        if (lora_mask && weight.lora.r > 0) {
            FT_CHECK(type == kGemm);
            // output = lora(x) * scale
            // output = mask(output)
            // output = x*W + output
            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  weight.lora.r,                                  // m
                                  batch_size,                                     // n
                                  weight.input_dims,                              // k
                                  (const T*)weight.lora.a,                        // A
                                  weight.lora.r,                                  // lda
                                  input_data,                                     // B
                                  weight.input_dims,                              // ldb
                                  output_data + batch_size * weight.output_dims,  // C
                                  weight.lora.r);                                 // ldc

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  weight.output_dims,                             // m
                                  batch_size,                                     // n
                                  weight.lora.r,                                  // k
                                  (const T*)weight.lora.b,                        // A
                                  weight.output_dims,                             // lda
                                  output_data + batch_size * weight.output_dims,  // B
                                  weight.lora.r,                                  // ldb
                                  output_data,                                    // C
                                  weight.output_dims,                             // ldc
                                  weight.lora.scale,                              // alpha
                                  0.0f);                                          // beta

            invokeMask(output_data, lora_mask, batch_size, weight.output_dims, stream_);

            type = kAdd;
        }

        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
            case WeightType::kBF16:
                forwardFp(output_data, input_data, batch_size, weight, type);
                break;
            case WeightType::kINT4:
                forwardInt4(output_data, input_data, batch_size, weight, type);
                break;
            default:
                FT_CHECK(0);
        }
    }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得先判断是不是有lora比之前好。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, adpaters are independent of the dispatching of mixed precision GEMMs.

    void forward(T*                         output_data,
                 const T*                   input_data,
                 int                        batch_size,
                 const LlamaDenseWeight<T>& weight,
                 Type                       type      = kGemm,
                 int*                       lora_mask = nullptr)
    {

        if (lora_mask && weight.lora.r > 0) {
            FT_CHECK(type == kGemm);
            // output = lora(x) * scale
            // output = mask(output)
            // output = x*W + output
            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  weight.lora.r,                                  // m
                                  batch_size,                                     // n
                                  weight.input_dims,                              // k
                                  (const T*)weight.lora.a,                        // A
                                  weight.lora.r,                                  // lda
                                  input_data,                                     // B
                                  weight.input_dims,                              // ldb
                                  output_data + batch_size * weight.output_dims,  // C
                                  weight.lora.r);                                 // ldc

            cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                  CUBLAS_OP_N,
                                  weight.output_dims,                             // m
                                  batch_size,                                     // n
                                  weight.lora.r,                                  // k
                                  (const T*)weight.lora.b,                        // A
                                  weight.output_dims,                             // lda
                                  output_data + batch_size * weight.output_dims,  // B
                                  weight.lora.r,                                  // ldb
                                  output_data,                                    // C
                                  weight.output_dims,                             // ldc
                                  weight.lora.scale,                              // alpha
                                  0.0f);                                          // beta

            invokeMask(output_data, lora_mask, batch_size, weight.output_dims, stream_);

            type = kAdd;
        }

        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
            case WeightType::kBF16:
                forwardFp(output_data, input_data, batch_size, weight, type);
                break;
            case WeightType::kINT4:
                forwardInt4(output_data, input_data, batch_size, weight, type);
                break;
            default:
                FT_CHECK(0);
        }
    }

OK. According to the comments of @lzhangzz and @irexyc, let's respect @lzhangzz 's suggestion. @AllentDan

Conflicts:
	lmdeploy/vl/model/deepseek.py
	lmdeploy/vl/model/internvl.py
	lmdeploy/vl/model/internvl_llava.py
	lmdeploy/vl/model/llava.py
	lmdeploy/vl/model/mini_gemeni.py
	lmdeploy/vl/model/qwen.py
	lmdeploy/vl/model/xcomposer2.py
	lmdeploy/vl/model/yi.py
@AllentDan
Copy link
Collaborator Author

Performance tested OK.

@lvhan028 lvhan028 merged commit b940995 into InternLM:main May 24, 2024
9 checks passed
@sshuair
Copy link

sshuair commented May 27, 2024

@AllentDan use the latest code to quant model internlm/internlm-xcomposer2-4khd-7b with following command got this error

  • command : lmdeploy lite auto_awq internlm/internlm-xcomposer2-4khd-7b --work-dir /data/quant/internlm-xcomposer2-4khd-7b-4bit
can't find model from local_path internlm/internlm-xcomposer2-4khd-7b, try to download from remote                                                                                                           
Fetching 22 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:00<00:00, 102641.48it/s]
You are using a model of type internlmxcomposer2 to instantiate a model of type internlm2. This is not supported for all configurations of models and can yield errors.                                      
Set max length to 16384                                                                                                                                                                                      
Dummy Resized                                                                                                                                                                                                
^[[1;3BMove model.tok_embeddings to GPU.                                                                                                                                                                     
Move model.layers.0 to CPU.                                                                                                                                                                                  
Move model.layers.1 to CPU.                                                                                                                                                                                  
Move model.layers.2 to CPU.                                                                                                                                                                                  
Move model.layers.3 to CPU.                                                                                                                                                                                  
Move model.layers.4 to CPU.                                                                                                                                                                                  
Move model.layers.5 to CPU.                                                                                                                                                                                  
Move model.layers.6 to CPU.                                                                                                                                                                                  
Move model.layers.7 to CPU.                                                                                                                                                                                  
Move model.layers.8 to CPU.                                                                                                                                                                                  
Move model.layers.9 to CPU.                                                                                                                                                                                  
Move model.layers.10 to CPU.                                                                                                                                                                                 
Move model.layers.11 to CPU.                                                                                                                                                                                 
Move model.layers.12 to CPU.                                                                                                                                                                                 
Move model.layers.13 to CPU.                                                                                                                                                                                 
Move model.layers.14 to CPU.                                                                                                                                                                                 
Move model.layers.15 to CPU.                                                                                                                                                                                 
Move model.layers.16 to CPU.                                                                                                                                                                                 
Move model.layers.17 to CPU.                                                                                                                                                                                 
Move model.layers.18 to CPU.                                                                                                                                                                                 
Move model.layers.19 to CPU.                                                                                                                                                                                 
Move model.layers.20 to CPU.                                                                                                                                                                                 
Move model.layers.21 to CPU.                                                                                                                                                                                 
Move model.layers.22 to CPU.                                                                                                                                                                                 
Move model.layers.23 to CPU.                                                                                                                                                                                 
Move model.layers.24 to CPU.                                                                                                                                                                                 
Move model.layers.25 to CPU.                                                                                                                                                                                 
Move model.layers.26 to CPU.                                                                                                                                                                                 
Move model.layers.27 to CPU.                                    
Move model.layers.28 to CPU.
Move model.layers.29 to CPU.
Move model.layers.30 to CPU.
Move model.layers.31 to CPU.
Move model.norm to GPU.
Move output to CPU.
Move vit to GPU.
Move vision_proj to GPU.
Loading calibrate dataset ...
Traceback (most recent call last):
  File "/opt/py38/bin/lmdeploy", line 11, in <module>
    load_entry_point('lmdeploy', 'console_scripts', 'lmdeploy')()
  File "/opt/lmdeploy/lmdeploy/cli/entrypoint.py", line 37, in run
    args.run(args)
  File "/opt/lmdeploy/lmdeploy/cli/lite.py", line 137, in auto_awq
    auto_awq(**kwargs)
  File "/opt/lmdeploy/lmdeploy/lite/apis/auto_awq.py", line 96, in auto_awq
    vl_model, model, tokenizer, work_dir = calibrate(model,
  File "/opt/lmdeploy/lmdeploy/lite/apis/calibrate.py", line 235, in calibrate
    calib_ctx.calibrate(all_data)
  File "/opt/lmdeploy/lmdeploy/lite/quantization/calibration.py", line 315, in calibrate
    _ = model(data.to(self.device))
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/a2c222ebd3a723c3dff00232e4f5cc6429f472d1/modeling_internlm2.py", line 958, in forward
    layer_outputs = decoder_layer(
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/lite/quantization/calibration.py", line 195, in _forward
    out = self._ori_forwards[mod](*batch_args[i],
  File "/root/.cache/huggingface/modules/transformers_modules/a2c222ebd3a723c3dff00232e4f5cc6429f472d1/modeling_internlm2.py", line 659, in forward
    hidden_states, self_attn_weights, present_key_value = self.attention(
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/a2c222ebd3a723c3dff00232e4f5cc6429f472d1/modeling_internlm2.py", line 361, in forward
    qkv_states = self.wqkv(hidden_states, im_mask)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/.cache/huggingface/modules/transformers_modules/a2c222ebd3a723c3dff00232e4f5cc6429f472d1/build_mlp.py", line 204, in forward
    res[:1] += self.Plora_B(self.Plora_A(
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1557, in _call_impl
    args_result = hook(self, args)
  File "/opt/lmdeploy/lmdeploy/lite/quantization/calibration.py", line 125, in _input_hook
    obs.observe(inp[0])
  File "/opt/py38/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/lmdeploy/lmdeploy/lite/quantization/activation/observer.py", line 104, in observe
    assert len(x.shape) == 3
AssertionError

@AllentDan
Copy link
Collaborator Author

Seems we did not test this model yet. I will support it afterwards. @sshuair

@AllentDan
Copy link
Collaborator Author

@sshuair I added a new PR #1666 to support it. You may give it a try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants