Skip to content

Commit

Permalink
[Fix] ZeRO2 Checkpoint Convert Bug (#684)
Browse files Browse the repository at this point in the history
fix z2 convert
  • Loading branch information
pppppM committed May 16, 2024
1 parent e745a0e commit 3b14f48
Showing 1 changed file with 21 additions and 50 deletions.
71 changes: 21 additions & 50 deletions xtuner/utils/zero_to_any_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,23 @@
from dataclasses import dataclass

import torch
from deepspeed.checkpoint.constants import (
BUFFER_NAMES, DS_VERSION, FP32_FLAT_GROUPS, FROZEN_PARAM_FRAGMENTS,
FROZEN_PARAM_SHAPES, OPTIMIZER_STATE_DICT, PARAM_SHAPES, PARTITION_COUNT,
SINGLE_PARTITION_OF_FP32_GROUPS, ZERO_STAGE)
# yapf: disable
from deepspeed.checkpoint.constants import (BUFFER_NAMES, DS_VERSION,
FP32_FLAT_GROUPS,
FROZEN_PARAM_FRAGMENTS,
FROZEN_PARAM_SHAPES,
OPTIMIZER_STATE_DICT, PARAM_SHAPES,
PARTITION_COUNT,
SINGLE_PARTITION_OF_FP32_GROUPS,
ZERO_STAGE)
# while this script doesn't use deepspeed to recover data, since the
# checkpoints are pickled with DeepSpeed data structures it has to be
# available in the current python environment.
from deepspeed.utils import logger
from tqdm import tqdm

# yapf: enable


@dataclass
class zero_model_state:
Expand Down Expand Up @@ -150,7 +157,7 @@ def parse_model_states(files, dtype=DEFAULT_DTYPE):
return zero_model_states


@torch.no_grad
@torch.no_grad()
def parse_optim_states(files, ds_checkpoint_dir, dtype=DEFAULT_DTYPE):

zero_stage = None
Expand Down Expand Up @@ -179,7 +186,7 @@ def parse_optim_states(files, ds_checkpoint_dir, dtype=DEFAULT_DTYPE):
state_dict['optimizer_state_dict'].pop('optimizer_state_dict', None)
fp32_groups = state_dict['optimizer_state_dict'].pop(fp32_groups_key)
if zero_stage <= 2:
flat_groups.append(fp32_groups.to(dtype))
flat_groups.append([param.to(dtype) for param in fp32_groups])
elif zero_stage == 3:
# if there is more than one param group, there will be multiple
# flattened tensors - one flattened tensor per group - for
Expand Down Expand Up @@ -435,9 +442,9 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
state_dict[name] = torch.cat(param_frags, 0).narrow(
0, 0, unpartitioned_numel).view(shape) # noqa: E501

partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(
unpartitioned_numel, world_size)

_partitioned = zero3_partitioned_param_info(unpartitioned_numel,
world_size)
partitioned_numel, partitioned_padding_numel = _partitioned
if debug:
print(f'Frozen params: {total_params} {name} full shape: {shape} '
f'partition0 numel={partitioned_numel} '
Expand Down Expand Up @@ -555,6 +562,7 @@ def get_state_dict_from_zero_checkpoint(checkpoint_dir,
tag=None,
exclude_frozen_parameters=False,
dtype=DEFAULT_DTYPE):
# flake8: noqa
"""Convert ZeRO 2 or 3 checkpoint into a single consolidated state_dict
that can be loaded with ``load_state_dict()`` and used for training without
DeepSpeed or shared with others, for example via a model hub.
Expand Down Expand Up @@ -591,6 +599,7 @@ def get_state_dict_from_zero_checkpoint(checkpoint_dir,
If you want it all done for you, use
``load_state_dict_from_zero_checkpoint`` instead.
"""
# flake8: noqa
if tag is None:
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
Expand Down Expand Up @@ -640,8 +649,8 @@ def load_state_dict_from_zero_checkpoint(model,
checkpoint_dir,
tag=None,
dtype=DEFAULT_DTYPE):
# yapf: disable

# flake8: noqa
"""
1. Put the provided model to cpu
2. Convert ZeRO 2 or 3 checkpoint into a single consolidated ``state_dict``
Expand Down Expand Up @@ -675,7 +684,7 @@ def load_state_dict_from_zero_checkpoint(model,
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic
from it.
"""
# yapf: enable
# flake8: noqa
logger.info(f'Extracting {dtype} weights')
state_dict = get_state_dict_from_zero_checkpoint(
checkpoint_dir, tag, dtype=dtype)
Expand All @@ -685,41 +694,3 @@ def load_state_dict_from_zero_checkpoint(model,
model.load_state_dict(state_dict, strict=False)

return model


if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument(
'checkpoint_dir',
type=str,
help='path to the desired checkpoint folder, e.g., path/checkpoint-12')
parser.add_argument(
'output_file',
type=str,
help=
'path to the pytorch state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)' # noqa: E501
)
parser.add_argument(
'-t',
'--tag',
type=str,
default=None,
help=
'checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1' # noqa: E501
)
parser.add_argument(
'--exclude_frozen_parameters',
action='store_true',
help='exclude frozen parameters')
parser.add_argument(
'-d', '--debug', action='store_true', help='enable debug')
args = parser.parse_args()

debug = args.debug

convert_zero_checkpoint_to_state_dict(
args.checkpoint_dir,
args.output_file,
tag=args.tag,
exclude_frozen_parameters=args.exclude_frozen_parameters)

0 comments on commit 3b14f48

Please sign in to comment.