Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Optimizing Memory Usage during ZeRO Checkpoint Convert #582

Merged
merged 2 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions xtuner/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,16 @@ def guess_load_checkpoint(pth_model):
state_dict = state_dict['state_dict']
elif osp.isdir(pth_model):
try:
from deepspeed.utils.zero_to_fp32 import \
get_fp32_state_dict_from_zero_checkpoint
from xtuner.utils.zero_to_any_dtype import \
get_state_dict_from_zero_checkpoint
except ImportError:
raise ImportError(
'The provided PTH model appears to be a DeepSpeed checkpoint. '
'However, DeepSpeed library is not detected in current '
'environment. This suggests that DeepSpeed may not be '
'installed or is incorrectly configured. Please verify your '
'setup.')
state_dict = get_fp32_state_dict_from_zero_checkpoint(
state_dict = get_state_dict_from_zero_checkpoint(
osp.dirname(pth_model), osp.basename(pth_model))
else:
raise FileNotFoundError(f'Cannot find {pth_model}')
Expand Down
14 changes: 12 additions & 2 deletions xtuner/tools/model_converters/pth_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import argparse
import os.path as osp
import shutil
import warnings

import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from mmengine.config import Config, DictAction
from mmengine.fileio import PetrelBackend, get_file_backend
from tqdm import tqdm

from xtuner.configs import cfgs_name_path
from xtuner.model.utils import guess_load_checkpoint
Expand Down Expand Up @@ -62,7 +67,11 @@ def main():
if 'LLaVAModel' in model_name:
cfg.model.pretrained_pth = None

model = BUILDER.build(cfg.model)
with init_empty_weights():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='.*non-meta.*', category=UserWarning)
model = BUILDER.build(cfg.model)

backend = get_file_backend(args.pth_model)
if isinstance(backend, PetrelBackend):
Expand All @@ -72,7 +81,8 @@ def main():
else:
state_dict = guess_load_checkpoint(args.pth_model)

model.load_state_dict(state_dict, strict=False)
for name, param in tqdm(state_dict.items(), desc='Load State Dict'):
set_module_tensor_to_device(model, name, 'cpu', param, torch.float16)
print(f'Load PTH model from {args.pth_model}')

if 'LLaVAModel' in model_name:
Expand Down