-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(op): support varlen npu flash attention #209
base: develop
Are you sure you want to change the base?
Conversation
a257888
to
14f0106
Compare
82758c5
to
dbc6869
Compare
@@ -472,11 +604,13 @@ def _qkv_without_cu_seqlens(self, qkv, softmax_scale=None, causal=None, key_padd | |||
return _torch_fixedlen_qkvpacked_attn(qkv, self.dropout, softmax_scale, causal, key_padding_mask) | |||
|
|||
@forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) | |||
def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None): | |||
def _q_kv_without_cu_seqlens( | |||
self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None, use_flash_attn=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不用把use_flash_attn作为attention forward的参数,这意味着model需要感知底层是否为fa的算子,我们不希望为model的开发者引入复杂度,这个东西作为算子选择系统的自动逻辑或者配置的一部分就好。
@@ -341,7 +341,7 @@ def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tup | |||
""" | |||
_emb_dim = 2 # [bsz, seqlen, emb_dim] | |||
|
|||
return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) | |||
return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim), DUMMY_HANDLE_CONST |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥要加DUMMY_HANDLE_CONST哇,他是作为register_forward_hook来用的,下同
@@ -78,18 +77,15 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613 | |||
|
|||
|
|||
def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs): | |||
if gpc.config.model.dtype is torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥要去掉if
|
||
|
||
def _npu_varlen_qkvpacked_attn( | ||
qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 | ||
def __npu_varlen_qkvsplited_attn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__npu -> _npu 一个下划线
Motivation
支持
torch_npu
的 var flash attention.Modification
移除了在pack数据下,给q和k进行padding和unpadding下的操作
BC-breaking (Optional)
Does the modification introduce changes that break the backward compatibility of the downstream repositories?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
Use cases (Optional)
If this PR introduces a new feature, it is better to list some use cases here and update the documentation.
Checklist
Before PR:
After PR: