Skip to content

Commit

Permalink
[Enhancement] Optimizing Memory Usage during ZeRO Checkpoint Convert (#…
Browse files Browse the repository at this point in the history
…582)

optimize memory usage

Co-authored-by: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>
  • Loading branch information
pppppM and LZHgrla committed May 12, 2024
1 parent aab528c commit e745a0e
Show file tree
Hide file tree
Showing 3 changed files with 741 additions and 5 deletions.
6 changes: 3 additions & 3 deletions xtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,16 @@ def guess_load_checkpoint(pth_model):
state_dict = state_dict['state_dict']
elif osp.isdir(pth_model):
try:
from deepspeed.utils.zero_to_fp32 import \
get_fp32_state_dict_from_zero_checkpoint
from xtuner.utils.zero_to_any_dtype import \
get_state_dict_from_zero_checkpoint
except ImportError:
raise ImportError(
'The provided PTH model appears to be a DeepSpeed checkpoint. '
'However, DeepSpeed library is not detected in current '
'environment. This suggests that DeepSpeed may not be '
'installed or is incorrectly configured. Please verify your '
'setup.')
state_dict = get_fp32_state_dict_from_zero_checkpoint(
state_dict = get_state_dict_from_zero_checkpoint(
osp.dirname(pth_model), osp.basename(pth_model))
else:
raise FileNotFoundError(f'Cannot find {pth_model}')
Expand Down
15 changes: 13 additions & 2 deletions xtuner/tools/model_converters/pth_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import argparse
import os.path as osp
import shutil
import warnings

import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from mmengine.config import Config, DictAction
from mmengine.fileio import PetrelBackend, get_file_backend
from tqdm import tqdm

from xtuner.configs import cfgs_name_path
from xtuner.model.utils import guess_load_checkpoint
Expand Down Expand Up @@ -66,7 +71,11 @@ def main():
if 'LLaVAModel' in model_name:
cfg.model.pretrained_pth = None

model = BUILDER.build(cfg.model)
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = BUILDER.build(cfg.model)

backend = get_file_backend(args.pth_model)
if isinstance(backend, PetrelBackend):
Expand All @@ -76,7 +85,9 @@ def main():
else:
state_dict = guess_load_checkpoint(args.pth_model)

model.load_state_dict(state_dict, strict=False)
for name, param in tqdm(state_dict.items(), desc='Load State Dict'):
set_module_tensor_to_device(model, name, 'cpu', param, torch.float16)

model.llm.config.use_cache = True

print(f'Load PTH model from {args.pth_model}')
Expand Down

0 comments on commit e745a0e

Please sign in to comment.