Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RingFlashAttention for context parallel #8383

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from

Conversation

zhangyuqin1998
Copy link
Contributor

@zhangyuqin1998 zhangyuqin1998 commented May 7, 2024

PR types

New features

PR changes

Models

Description

为fleet的context parallel增加ring flash attention的支持

paddle兼容性:
使用paddle中的sep group,对paddle无改动

收敛性:
将cp和sep做对比。理论上,二者的收敛结果应该完全一致。经过测试,sep和cp的收敛情况近乎一致。绿色为cp,蓝色为sep。
bd35bf29d665142a5f9415f394d41946

性能:
单机8卡小模型测试,序列长度为20k时,性能对比如图。绿色为cp,蓝色为sep。
b92cd10211839d878896bada7bf1c1e9

Copy link

paddle-bot bot commented May 7, 2024

Thanks for your contribution!

@CLAassistant
Copy link

CLAassistant commented May 7, 2024

CLA assistant check
All committers have signed the CLA.

Copy link

codecov bot commented May 9, 2024

Codecov Report

Attention: Patch coverage is 15.98361% with 205 lines in your changes are missing coverage. Please review.

Project coverage is 53.87%. Comparing base (773497e) to head (e7c4b1e).
Report is 8 commits behind head on develop.

Files Patch % Lines
paddlenlp/transformers/ring_flash_attention.py 15.47% 142 Missing ⚠️
paddlenlp/transformers/context_parallel_utils.py 10.00% 27 Missing ⚠️
paddlenlp/transformers/llama/fusion_ops.py 14.28% 12 Missing ⚠️
paddlenlp/trainer/training_args.py 8.33% 11 Missing ⚠️
paddlenlp/transformers/llama/modeling.py 35.71% 9 Missing ⚠️
paddlenlp/trainer/trainer.py 20.00% 4 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8383      +/-   ##
===========================================
- Coverage    54.22%   53.87%   -0.35%     
===========================================
  Files          617      620       +3     
  Lines        96203    97065     +862     
===========================================
+ Hits         52164    52297     +133     
- Misses       44039    44768     +729     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


# if step != cp_size - 1:
# comm_buffer.wait()
paddle.device.synchronize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。

done~

@zhangyuqin1998 zhangyuqin1998 force-pushed the ring_flash_attention branch 2 times, most recently from cf7d334 to 88bc460 Compare May 27, 2024 07:58
block_out, _, block_lse, _ = _C_ops.flash_attn(
local_query,
block_k[:, : local_q_seq_len // 2, :, :],
block_v[:, : local_q_seq_len // 2, :, :],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种方式,性能可能比较慢。看看能否直接使用op的方式调用。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种方式,性能可能比较慢。看看能否直接使用op的方式调用。

done~

if attn_mask is not None:
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3)
if is_causal:
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contiguous ? 可能不需要这个。尽量使用切分的api,不实用运算符重载。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contiguous ? 可能不需要这个。尽量使用切分的api,不实用运算符重载。

done~

paddlenlp/transformers/ring_flash_attention.py Outdated Show resolved Hide resolved
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer)

if is_causal:
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个前向已经计算过了,是否可以优化不计算。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个前向已经计算过了,是否可以优化不计算。

已做优化

def wait(self):
# for req in self._reqs:
# req.wait()
# self._reqs = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成TODO吧。不用注释。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成TODO吧。不用注释。

done~

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

sneaxiy
sneaxiy previously approved these changes May 31, 2024
@@ -583,6 +587,15 @@ class TrainingArguments:
)
},
)
cp_parallel_degree: int = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成 context_parallel_degree

Suggested change
cp_parallel_degree: int = field(
context_parallel_degree: int = field(

@@ -763,6 +764,8 @@ def train(
trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size
if self.args.sep_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.sep_parallel_degree
if self.args.cp_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.cp_parallel_degree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cp_parallel_degree 会切分哪些参数?

@@ -230,6 +230,10 @@ class TrainingArguments:
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy.
)
cp_parallel_degree (`int`, *optional*, defaults to `-1`)(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数在 docs/trainer.md 文档中也加一下吧。

self.tensor_parallel_degree
* self.sep_parallel_degree
* self.cp_parallel_degree
* self.pipeline_parallel_degree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保存相关的考虑了吗?通信组需要额外建吗?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants