Skip to content

Commit

Permalink
[Bug] fix internlm2 flash attn (#693)
Browse files Browse the repository at this point in the history
* fix internlm2 flash attn

* fix SUPPORT_FLASH2
  • Loading branch information
HIT-cwh committed May 17, 2024
1 parent ed844be commit b1099fe
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 41 deletions.
11 changes: 2 additions & 9 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mmengine import print_log
from mmengine.utils import digit_version
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_flash_attn_2_available

from .baichuan import (baichuan2_norm_head_forward, baichuan_7b_attn_forward,
baichuan_13b_attn_forward)
Expand All @@ -33,15 +34,7 @@
# Transformers requires torch version >= 2.1.1 when using Torch SDPA.
# Refer to https://github.com/huggingface/transformers/blob/caa5c65db1f4db617cdac2ad667ba62edf94dd98/src/transformers/modeling_utils.py#L1611 # noqa: E501
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.1.1')
SUPPORT_FLASH2 = False

try:
from flash_attn import flash_attn_func # pre-check # noqa: F401

SUPPORT_FLASH2 = True
except ImportError:
pass

SUPPORT_FLASH2 = is_flash_attn_2_available()
SUPPORT_FLASH = SUPPORT_FLASH1 or SUPPORT_FLASH2

USE_TRITON_KERNEL = bool(os.getenv('USE_TRITON_KERNEL', default=0))
Expand Down
48 changes: 23 additions & 25 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,27 @@ def init_weights(self):
@staticmethod
def _prepare_for_long_context_training(cfg, llm_cfg,
max_position_embeddings):
if not hasattr(llm_cfg, 'rope_scaling'):
print_log('Current model does not support RoPE scaling.',
'current')
return

current_max_length = getattr(llm_cfg, 'max_position_embeddings', None)
if current_max_length and max_position_embeddings > current_max_length:
print_log(
f'Enlarge max model length from {current_max_length} '
f'to {max_position_embeddings}.', 'current')
scaling_factor = float(
math.ceil(max_position_embeddings / current_max_length))
else:
print_log(
'The input `max_position_embeddings` is smaller than '
'origin max length. Consider increase input length.',
'current')
scaling_factor = 1.0
cfg.rope_scaling = {'type': 'linear', 'factor': scaling_factor}

orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
if orig_rope_scaling is None:
orig_rope_scaling = {'factor': 1}

orig_rope_scaling_factor = orig_rope_scaling[
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if max_position_embeddings > orig_ctx_len:
scaling_factor = float(
math.ceil(max_position_embeddings / orig_ctx_len))
llm_cfg.rope_scaling = {
'type': 'linear',
'factor': scaling_factor
}

# hardcode for internlm2
llm_cfg.attn_implementation = 'flash_attention_2'
cfg.config = llm_cfg

return cfg, llm_cfg
return cfg

@staticmethod
def _prepare_for_flash_attn(cfg, llm_cfg):
Expand Down Expand Up @@ -201,7 +199,7 @@ def _prepare_for_flash_attn(cfg, llm_cfg):
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'

return cfg, llm_cfg
return cfg

@staticmethod
def _prepare_for_qlora_zero3(cfg):
Expand All @@ -225,9 +223,9 @@ def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
llm_cfg = AutoConfig.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=True)
cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
if max_position_embeddings is not None:
cfg, llm_cfg = self._prepare_for_long_context_training(
cfg = self._prepare_for_long_context_training(
cfg, llm_cfg, max_position_embeddings)
return cfg

Expand Down
22 changes: 15 additions & 7 deletions xtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional

import torch
from mmengine import print_log
from mmengine.utils.misc import get_object_from_string
from peft import PeftType
from torch import nn
Expand All @@ -18,19 +17,28 @@ def set_obj_dtype(d):
d[key] = getattr(torch, value.split('.')[-1])


def try_build_module(cfg):
builder = cfg['type']
if isinstance(builder, str):
builder = get_object_from_string(builder)
if builder is None:
# support handling cfg with key 'type' can not be built, such as
# {'rope_scaling': {'type': 'linear', 'factor': 2.0}}
return cfg
cfg.pop('type')
module_built = builder(**cfg)
return module_built


def traverse_dict(d):
if isinstance(d, dict):
set_obj_dtype(d)
for key, value in d.items():
if isinstance(value, dict):
traverse_dict(value)
if 'type' in value:
builder = value.pop('type')
if isinstance(builder, str):
builder = get_object_from_string(builder)
new_value = builder(**value)
d[key] = new_value
print_log(f'{key} convert to {builder}')
module_built = try_build_module(value)
d[key] = module_built
elif isinstance(d, list):
for element in d:
traverse_dict(element)
Expand Down

0 comments on commit b1099fe

Please sign in to comment.