Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add alibi position embedding and support baichuan #54

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

qyccc
Copy link

@qyccc qyccc commented Dec 16, 2023

This adds the ALiBi method and its flash attention version (using triton) for positional information. And it supports baichuan model trainig by porting over the implementation from baichuan-inc/Baichuan2-13B-Base .

@CLAassistant
Copy link

CLAassistant commented Dec 16, 2023

CLA assistant check
All committers have signed the CLA.

@li-yi-dong
Copy link
Collaborator

Cool! It may take some time to review 🙃

@li-yi-dong li-yi-dong self-assigned this Dec 19, 2023
Copy link
Collaborator

@li-yi-dong li-yi-dong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll review again once you solve the comments.

examples/Baichuan_13_standalone.sh Outdated Show resolved Hide resolved
megatron/arguments.py Outdated Show resolved Hide resolved
megatron/fused_kernels/__init__.py Outdated Show resolved Hide resolved
megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/model/transformer.py Outdated Show resolved Hide resolved
@@ -471,10 +542,24 @@ def __init__(self, init_method,
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
world_size = mpu.get_tensor_model_parallel_world_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_parallel_size

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_parallel_size

sorry, I didn't get it. Do you mean the variance name should be tensor_parallel_size?

megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/model/transformer.py Outdated Show resolved Hide resolved
megatron/tokenizer/tokenizer.py Outdated Show resolved Hide resolved
@qyccc
Copy link
Author

qyccc commented Dec 20, 2023

@li-yi-dong Thanks for your time and cautious review. I have made the necessary changes and addressed the comments you mentioned. Please take another look at the updated version at your convenience.

megatron/training.py Outdated Show resolved Hide resolved
megatron/training.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@li-yi-dong li-yi-dong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big thanks to your efforts and patience.
I added some comments to resolve.

megatron/model/transformer.py Outdated Show resolved Hide resolved
@@ -1222,11 +1286,106 @@ def set_input_tensor(self, input_tensor):
forward_step_func"""
self.input_tensor = input_tensor

def _build_alibi_tensor(self, tensor, max_seq_len, num_attention_heads):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placing this func together with alibi_mask_func

Copy link
Author

@qyccc qyccc Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This func requires the internal variable first_run, so it cannot be placed in the utils.

@qyccc qyccc requested a review from li-yi-dong January 3, 2024 03:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants