Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [Feature]Ensure Full Conversation Data #287

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions xtuner/dataset/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def process(dataset,
pack_to_max_length=True,
input_ids_with_output=True,
with_image_token=False,
map_num_proc=32):
map_num_proc=1):
"""Post-process the dataset loaded from the Hugging Face Hub, or a local
dataset.

Expand Down Expand Up @@ -60,8 +60,8 @@ def process(dataset,
of VLM.
map_num_proc: Max number of processes when mapping the dataset.
"""

if isinstance(dataset, DatasetDict):
print("Available splits in the dataset:", dataset.keys())
dataset = dataset[split]
elif isinstance(dataset, dict) or isinstance(
dataset, Config) or isinstance(dataset, ConfigDict):
Expand Down
49 changes: 49 additions & 0 deletions xtuner/dataset/packer_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from transformers import AutoTokenizer
from utils import Packer, encode_fn

# 使用 AutoTokenizer 创建一个 Tokenizer 实例
tokenizer = AutoTokenizer.from_pretrained("/public/home/lvshuhang/model_space/workspace/THUDM_chatglm3-6b-base/", trust_remote_code=True)

# 假设的输入数据
example = {
'conversation': [
{
'input': '今天的天气怎么样?',
'output': '今天的天气非常不错,是晴天。'
},
{
'input': '你今天吃饭了吗?',
'output': '我今天吃饭了,吃的蔬菜沙拉。'
},
{
'input':'这是最后一句话,需要足够长以达到最大长度限制,以触发特殊处理逻辑。',
'output':'确实,这句话确实足够长了。'
}

],
'conversation': [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}


]
}


# 假设的最大长度
max_length = 10 # 这个值可以根据实际情况调整

# 编码数据
encoded_example = encode_fn(example, tokenizer, max_length, input_ids_with_output=True, with_image_token=False)

# 打印编码后的结果
print("Encoded example:", encoded_example,len(encoded_example['input_ids']))

# 使用 Packer 处理编码后的数据
packer = Packer(chunk_size=max_length)
packed_data = packer([encoded_example])

# 打印处理后的结果
print("Packed data:", packed_data)
227 changes: 159 additions & 68 deletions xtuner/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,32 @@
from io import BytesIO
from itertools import chain

import numpy as np
import requests
from PIL import Image

from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX,DEFAULT_PAD_TOKEN_INDEX


def get_bos_eos_token_ids(tokenizer):
if tokenizer.__class__.__name__ in [
'QWenTokenizer', 'QWen2Tokenizer', 'Qwen2TokenizerFast'
]:
bos_token_id = []
eos_token_id = tokenizer.eos_token_id
assert eos_token_id is not None, \
'Please set eos_token for Qwen tokenizer!'
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
bos_token_id = [64790, 64792]
eos_token_id = tokenizer.eos_token_id
else:
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
if isinstance(bos_token_id, int):
bos_token_id = [bos_token_id]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
return bos_token_id, eos_token_id


def encode_fn(example,
Expand Down Expand Up @@ -46,47 +68,60 @@ def encode_fn(example,
}
]
"""
if tokenizer.__class__.__name__ == 'QWenTokenizer':
bos_token_id = []
eos_token_id = tokenizer.eos_token_id
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer':
bos_token_id = [64790, 64792]
eos_token_id = tokenizer.eos_token_id
else:
bos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id
if isinstance(bos_token_id, int):
bos_token_id = [bos_token_id]
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output

input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
if DEFAULT_IMAGE_TOKEN in input and with_image_token:
chunk_encode = [
tokenizer(chunk, add_special_tokens=False)
for chunk in input.split('<image>')
tokenizer.encode(chunk, add_special_tokens=False)
for chunk in input.split(DEFAULT_IMAGE_TOKEN)
]
assert len(chunk_encode) == 2
input_encode = {'input_ids': []}
input_encode = []
for idx, cur_chunk_encode in enumerate(chunk_encode):
input_encode['input_ids'].extend(cur_chunk_encode['input_ids'])
input_encode.extend(cur_chunk_encode)
if idx != len(chunk_encode) - 1:
input_encode['input_ids'].append(IMAGE_TOKEN_INDEX)
input_encode.append(IMAGE_TOKEN_INDEX)
else:
input_encode = tokenizer(f'{input}', add_special_tokens=False)
input_ids += bos_token_id + input_encode['input_ids']
labels += [IGNORE_INDEX] * (
len(bos_token_id + input_encode['input_ids']))
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
output_encode = tokenizer(f'{output}', add_special_tokens=False)
input_ids += output_encode['input_ids'] + eos_token_id
labels += copy.deepcopy(output_encode['input_ids'] + eos_token_id)
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)

if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
Expand All @@ -95,69 +130,125 @@ def encode_fn(example,


class Packer:
# modified from
# https://github.com/facebookresearch/llama-recipes/blob/main/ft_datasets/utils.py
"""Pack multiple pieces of data into one."""

def __init__(self, chunk_size=2048):
def __init__(self,
chunk_size=2048,
use_varlen_attn=False,
drop_last=False):
self.chunk_size = chunk_size
self.residual = {'input_ids': [], 'labels': []}
self.use_varlen_attn = use_varlen_attn
self.drop_last = drop_last
if use_varlen_attn:
self.residual_cumulative_len = [0]

def get_cumulative_len(self, chunk_num):
ptr_l = 0
cumulative_len = []
for chunk_idx in range(chunk_num):
length_train = (chunk_idx + 1) * self.chunk_size
ptr_r = np.searchsorted(
self.residual_cumulative_len, length_train, side='left')
if self.residual_cumulative_len[ptr_r] == length_train:
cumulative_len_cur = \
self.residual_cumulative_len[ptr_l:ptr_r + 1]
ptr_l = ptr_r + 1
else:
cumulative_len_cur = self.residual_cumulative_len[
ptr_l:ptr_r] + [length_train]
ptr_l = ptr_r
cumulative_len_cur = [
num - chunk_idx * self.chunk_size for num in cumulative_len_cur
]
if cumulative_len_cur[0] != 0:
cumulative_len_cur = [0] + cumulative_len_cur

cumulative_len.append(cumulative_len_cur)

self.residual_cumulative_len = [
num - length_train for num in self.residual_cumulative_len[ptr_l:]
]
if len(self.residual_cumulative_len) == 0:
self.residual_cumulative_len = [0]
elif self.residual_cumulative_len[0] != 0:
self.residual_cumulative_len = [0] + self.residual_cumulative_len

return cumulative_len

def get_indexes(self, cumulative_len):
indexes = []
for cumulative_len_cur in cumulative_len:
index_cur = []
for i in range(len(cumulative_len_cur) - 1):
index_cur.extend(
list(
range(cumulative_len_cur[i + 1] - # noqa: W504
cumulative_len_cur[i])))
indexes.append(index_cur)
return indexes

def __call__(self, batch):
concatenated_samples = {
k: v + list(chain(*batch[k]))
for k, v in self.residual.items()
}

if self.use_varlen_attn:
for input_id in batch['input_ids']:
self.residual_cumulative_len.append(
self.residual_cumulative_len[-1] + len(input_id))

total_length = len(concatenated_samples[list(
concatenated_samples.keys())[0]])

if total_length >= self.chunk_size:
chunk_num = total_length // self.chunk_size
result = {
k: [
v[i:i + self.chunk_size]
for i in range(0, chunk_num *
self.chunk_size, self.chunk_size)
]
for k, v in concatenated_samples.items()
k: [] for k in concatenated_samples.keys()
}
self.residual = {
k: v[(chunk_num * self.chunk_size):]
for k, v in concatenated_samples.items()
}
else:
result = {k: [v] for k, v in concatenated_samples.items()}
self.residual = {k: [] for k in concatenated_samples.keys()}

return result

# 遍历每个chunk
for i in range(chunk_num):
start_idx = i * self.chunk_size
end_idx = start_idx + self.chunk_size

# 判断切割点是否位于input_ids部分
if end_idx > total_length - len(concatenated_samples['input_ids'][-1]) and end_idx < total_length:
# 如果是,且input_ids部分达到max_length,则在该点前插入padding
padding_length = self.max_length - (end_idx - (total_length - len(concatenated_samples['input_ids'][-1])))
padding = [DEFAULT_PAD_TOKEN_INDEX] * padding_length
result['input_ids'].append(padding + concatenated_samples['input_ids'][start_idx - padding_length:end_idx])
result['labels'].append(padding + concatenated_samples['labels'][start_idx - padding_length:end_idx])
else:
result['input_ids'].append(concatenated_samples['input_ids'][start_idx:end_idx])
result['labels'].append(concatenated_samples['labels'][start_idx:end_idx])

class InternRepoPacker:
"""Only used for packing data in InternLM repo
(https://github.com/InternLM/InternLM) format."""

def __init__(self, chunk_size=2048):
self.chunk_size = chunk_size
self.residual = []
# 处理残余部分
self.residual = {
'input_ids': concatenated_samples['input_ids'][chunk_num * self.chunk_size:],
'labels': concatenated_samples['labels'][chunk_num * self.chunk_size:]
}

def __call__(self, batch):
concatenated_samples = self.residual + list(chain(*batch['input_ids']))
if self.use_varlen_attn:
cumulative_len = self.get_cumulative_len(chunk_num)
result['cumulative_len'] = cumulative_len
result['indexes'] = self.get_indexes(cumulative_len)
else:
if self.drop_last:
result = {k: [] for k, v in concatenated_samples.items()}
else:
result = {k: [v] for k, v in concatenated_samples.items()}

total_length = len(concatenated_samples)
self.residual = {k: [] for k in concatenated_samples.keys()}

if total_length >= self.chunk_size:
chunk_num = total_length // self.chunk_size
input_ids = [
concatenated_samples[i:i + self.chunk_size]
for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
]
result = {'input_ids': input_ids}
self.residual = concatenated_samples[(chunk_num *
self.chunk_size):]
else:
input_ids = [concatenated_samples]
result = {'input_ids': input_ids}
self.residual = []
if self.use_varlen_attn:
result['cumulative_len'] = [] if self.drop_last else [
self.residual_cumulative_len
]
result['indexes'] = [] if self.drop_last else self.get_indexes(
[self.residual_cumulative_len])
self.residual_cumulative_len = [0]

return result

Expand Down Expand Up @@ -188,4 +279,4 @@ def load_image(image_file):
def decode_base64_to_image(base64_string):
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image
return image