New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
使用zero3_offload+序列并行训练yi-34b的时候出错 #589
Comments
很抱歉给您的使用带来了不便! 我刚刚测试了下 Yi-200K-34B 的 CPU_Offload 全量微调,没有成功复现出您的问题。以下是我序列并行度为2的config文件和训练Config & Log: 理论上序列并行跟CPU Offload是不会互相影响的。所以需要您先将序列并行关掉(通过设置sequence_parallel_size = 1)再测试 CPU Offload 训练,看看是否会报错。如果仍然报错,可能需要您检查您的环境是否安装正确。 有进一步的结果后欢迎联系我们! |
您好,感谢回复,我这边试了一下8k的sp2,但是还是同样的问题,可以提供一下您那边的运行环境吗? 我现在的配置文件是: |
麻烦先尝试把序列并行关掉(通过设置sequence_parallel_size = 1)再测试 CPU Offload 训练,看看是否会报错。我有点担心不是序列并行引入的bug |
您好,我这边确认问题了,我之前不论怎么改序列并行的设置,都会报一样的错误。我后来把deepspeed的版本从0.14.0降到0.12.3,就没问题了,感谢耐心的解答哈! 另外我还有个问题,就是我这边虽然能跑起来了,但是我发现训练的步长有问题,我把的设置如下: |
问下你的序列长度 ( |
长度是4096,这个默认会多个样本拼接到一起吗?我刚刚改成了8192,总步长现在还是32。。。但是显存占用的确是增加了 |
把 https://github.com/InternLM/xtuner/blob/main/xtuner/dataset/huggingface.py#L96 这行代码改成
再试试呢? |
改了之后还是没有变化,您能简单介绍一下这个32是怎么计算来的吗? 谢谢~ |
在数据预处理的时候,我们通过 Huggingface datasets 的 map_fn 接口实现数据拼接功能, 默认使用32个进程同时预处理。某个进程默认一次输入1000条数据,并将其拼接为多条长数据,最后余下的部分会被舍弃。 如果 max_length 较大,舍弃的部分就越多。但我感觉8192不算太长,应该不会导致数据集大幅度下降。 我这边建议先把Huggingface datasets 的缓存清掉,默认是在 ~/.cache/huggingface/datasets/ 目录下,之后把 https://github.com/InternLM/xtuner/blob/main/xtuner/dataset/huggingface.py#L96 这行代码改成 |
好的,我试一下,我看您发的那个log里面,好像也是32个步数,感觉不是我自己的个例 |
我才想起来,我这个 log 里是32步是因为我设置了:
之前为了测速,只跑了前 32 个 iters。 这是我更新后的 config 和 log: 为了检查您训练 iter 只有 32 的问题,可能麻烦您:
如果仍然有问题,欢迎继续讨论! |
max_iters的确是32,我把她改成max_epochs的值就可以了,现在已经可以正常跑起来了,非常感谢您的耐心解答! |
@HIT-cwh 您好,我这边又遇到新的问题了。。。就是在一个epoch结束之后,就会oom,而且这个问题是稳定出现的。就是在一个epoch结束后的第7个step,我配置的序列并行数量是8。这是应该是没有在保存模型,就是比较正常的一次迭代,是不是不同数据的前后拼接时候出了问题呢? |
问下你设置的梯度累积值是多少呢?另外问下您的显卡显存是多少呢,我看log里打印的是小于8G的显存占用 |
accumulative_counts 和sequence_parallel_size是一样的值,我8和4都试过了,都是在一个epoch结束的第accumulative_counts-1个step的时候,必然oom。 |
我确认下哈,您这边用的配置是 Yi34B + 32k seq length + sequence parallel size 4 (8) + deepspeed zero3 offload 吧? 我这边尝试复现下您的问题 |
嗯好的,我这边是 Yi34B + 24k seq length(12k也试过) + sequence parallel size 4 (8) + deepspeed zero3 offload,即使数据集很小也能复现,辛苦~ |
好的,辛苦您帮忙排查一下~ 我这边目前还没有16卡可以用,只能先等待您的进展了。。 |
我只用github上提供的配置文件yi_34b_200k_full_alpaca_enzh_32k_sp8,运行时的deepspeed选项是zero3_offload
但是出现如下错误,请问现在序列并行是不支持offload吗,还是有别的原因? 谢谢。
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
File "/opt/ml/job/xtuner/tools/train.py", line 342, in
File "/opt/ml/job/xtuner/tools/train.py", line 342, in
Traceback (most recent call last):
File "/opt/ml/job/xtuner/tools/train.py", line 342, in
File "/opt/ml/job/xtuner/tools/train.py", line 342, in
main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
main()main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
runner.train()runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 135, in train_step
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 135, in train_step
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 135, in train_step
optim_wrapper.update_params(parsed_loss)
outputs = self.runner.model.train_step( File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 135, in train_step
optim_wrapper.update_params(parsed_loss)
optim_wrapper.update_params(parsed_loss)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
self.step()
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
optim_wrapper.update_params(parsed_loss)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
self.step()
self.step()
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
return wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
self.step()
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
return wrapped(*args, **kwargs)self.model.step()
Traceback (most recent call last):
return wrapped(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
File "/opt/ml/job/xtuner/tools/train.py", line 342, in
Traceback (most recent call last):
return wrapped(*args, **kwargs) File "/opt/ml/job/xtuner/tools/train.py", line 342, in
self.model.step() File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
File "/opt/ml/job/xtuner/tools/train.py", line 342, in
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
self.model.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
self.model.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
self._take_model_step(lr_kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
self._take_model_step(lr_kwargs)
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
self._take_model_step(lr_kwargs) File "/opt/ml/job/xtuner/tools/train.py", line 342, in
self._take_model_step(lr_kwargs)runner.train()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
runner.train() File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
self.optimizer.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
self.run_iter(data_batch)
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.optimizer.step()model = self.train_loop.run() # type: ignore
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
main()
File "/opt/ml/job/xtuner/tools/train.py", line 338, in main
model = self.train_loop.run() # type: ignoreself.optimizer.step()
self.optimizer.step() ret_val = func(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
runner.train()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/_flexible_runner.py", line 1200, in train
optim_wrapper.update_params(parsed_loss)
self.run_iter(data_batch)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 135, in train_step
self.step()
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
optim_wrapper.update_params(parsed_loss)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 135, in train_step
return wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
self.step()
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
self.model.step()optim_wrapper.update_params(parsed_loss)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
return wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
model = self.train_loop.run() # type: ignore
self.step()
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 286, in run
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)self.model.step()
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
return wrapped(*args, **kwargs)
ret_val = func(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 95, in step
ret_val = func(*args, **kwargs)ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/runner/loops.py", line 309, in run_iter
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
self.model.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
outputs = self.runner.model.train_step(
File "/usr/local/lib/python3.10/dist-packages/mmengine/strategy/deepspeed.py", line 135, in train_step
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cpu!
optim_wrapper.update_params(parsed_loss)
File "/usr/local/lib/python3.10/dist-packages/mmengine/_strategy/deepspeed.py", line 83, in update_params
self._take_model_step(lr_kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
self.step()
File "/usr/local/lib/python3.10/dist-packages/mmengine/optim/scheduler/param_scheduler.py", line 115, in wrapper
self._take_model_step(lr_kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in take_model_step
return wrapped(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mmengine/strategy/deepspeed.py", line 95, in step
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul(1. / combined_scale)
RuntimeErrorself.model.step():
Expected all tensors to be on the same device, but found at least two devices, cuda:5 and cpu! self.fp32_partitioned_groups_flat[sub_group_id].grad.mul(1. / combined_scale) File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2169, in step
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cpu!Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu! File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in _take_model_step
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.optimizer.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.take_model_step(lr_kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py", line 2075, in take_model_step
self.optimizer.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
self.optimizer.step()
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2047, in step
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cpu!
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu!
self.unscale_and_clip_grads(sub_group_id, scaled_global_grad_norm)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/stage3.py", line 2117, in unscale_and_clip_grads
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1336) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in
sys.exit(main())
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 794, in main
run(args)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 785, in run
elastic_launch(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 134, in call
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
The text was updated successfully, but these errors were encountered: