Skip to content

Commit

Permalink
[Feature] add HFCheckpointHook to auto save hf model after the whole …
Browse files Browse the repository at this point in the history
…training phase (#621)

* add HFCheckpointHook to auto save hf model after the whole training phase

* refinie HFCheckpointHook

* fix lint

* delete useless codes

* fix bugs

* support non-dist training
  • Loading branch information
HIT-cwh committed May 10, 2024
1 parent 648d63a commit fc41943
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
3 changes: 2 additions & 1 deletion xtuner/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dataset_info_hook import DatasetInfoHook
from .evaluate_chat_hook import EvaluateChatHook
from .hf_checkpoint_hook import HFCheckpointHook
from .throughput_hook import ThroughputHook
from .varlen_attn_args_to_messagehub_hook import VarlenAttnArgsToMessageHubHook

__all__ = [
'EvaluateChatHook', 'DatasetInfoHook', 'ThroughputHook',
'VarlenAttnArgsToMessageHubHook'
'VarlenAttnArgsToMessageHubHook', 'HFCheckpointHook'
]
53 changes: 53 additions & 0 deletions xtuner/engine/hooks/hf_checkpoint_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from pathlib import Path
from typing import Optional, Union

import torch.distributed as dist
from mmengine._strategy import DeepSpeedStrategy
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import FlexibleRunner

DATA_BATCH = Optional[Union[dict, tuple, list]]


class HFCheckpointHook(Hook):

priority = 95 # lower than CheckpointHook in MMEngine

def __init__(self, out_dir: Optional[Union[str, Path]] = None) -> None:
self.out_dir = out_dir

def after_run(self, runner) -> None:
assert isinstance(runner,
FlexibleRunner), 'Runner should be `FlexibleRunner`'
assert isinstance(
runner.strategy,
DeepSpeedStrategy), 'Strategy should be `DeepSpeedStrategy`'

if self.out_dir is None:
self.out_dir = osp.join(runner.work_dir, 'hf_model')

wrapped_model = runner.strategy.model
if wrapped_model.zero_optimization_partition_weights():
assert wrapped_model.zero_gather_16bit_weights_on_model_save(), \
('Please set `gather_16bit_weights_on_model_save=True` '
'in your DeepSpeed config.')
state_dict = wrapped_model._zero3_consolidated_16bit_state_dict()
else:
state_dict = wrapped_model.module_state_dict(
exclude_frozen_parameters=runner.strategy.
exclude_frozen_parameters)

model = runner.model
if is_model_wrapper(model):
model = model.module
llm = model.llm
if (not dist.is_initialized()) or dist.get_rank() == 0:
# keys in state_dict are prefixed with 'llm.'
keys = list(state_dict.keys())
for k in keys:
val = state_dict.pop(k)
state_dict[k[4:]] = val
llm.save_pretrained(self.out_dir, state_dict=state_dict)

0 comments on commit fc41943

Please sign in to comment.