Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(op): support varlen npu flash attention #209

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

SolenoidWGT
Copy link
Collaborator

@SolenoidWGT SolenoidWGT commented Apr 19, 2024

Motivation

支持torch_npu的 var flash attention.

Modification

移除了在pack数据下,给q和k进行padding和unpadding下的操作

BC-breaking (Optional)

Does the modification introduce changes that break the backward compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here and update the documentation.

Checklist

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects.
  • CLA has been signed and all committers have signed the CLA in this PR.

@gaoyang07 gaoyang07 requested review from sallyjunjun and removed request for gaoyang07 April 25, 2024 03:31
@gaoyang07 gaoyang07 assigned SolenoidWGT and unassigned ZwwWayne Apr 25, 2024
@gaoyang07 gaoyang07 added the enhancement New feature or request label Apr 25, 2024
@@ -472,11 +604,13 @@ def _qkv_without_cu_seqlens(self, qkv, softmax_scale=None, causal=None, key_padd
return _torch_fixedlen_qkvpacked_attn(qkv, self.dropout, softmax_scale, causal, key_padding_mask)

@forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut)))
def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None):
def _q_kv_without_cu_seqlens(
self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None, use_flash_attn=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用把use_flash_attn作为attention forward的参数,这意味着model需要感知底层是否为fa的算子,我们不希望为model的开发者引入复杂度,这个东西作为算子选择系统的自动逻辑或者配置的一部分就好。

@@ -341,7 +341,7 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup
"""
_emb_dim = 2 # [bsz, seqlen, emb_dim]

return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim)
return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim), DUMMY_HANDLE_CONST
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥要加DUMMY_HANDLE_CONST哇,他是作为register_forward_hook来用的,下同

@@ -78,18 +77,15 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613


def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs):
if gpc.config.model.dtype is torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥要去掉if



def _npu_varlen_qkvpacked_attn(
qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613
def __npu_varlen_qkvsplited_attn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__npu -> _npu 一个下划线

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants