Skip to content

Commit

Permalink
add chatglm3-6b model support
Browse files Browse the repository at this point in the history
huggingface model: https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
  • Loading branch information
xingxingqiao committed May 15, 2024
1 parent 9f77348 commit 398fecb
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 7 deletions.
162 changes: 161 additions & 1 deletion convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
if not self.is_safetensors:
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
self.hparams = Model.load_hparams(self.dir_model)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
if self.ftype == gguf.LlamaFileType.GUESSED:
Expand Down Expand Up @@ -2388,6 +2388,166 @@ def set_vocab(self, *args, **kwargs):
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

@Model.register("ChatGLMModel")
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM

def set_vocab(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
scores: list[float] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
assert max(tokenizer.get_vocab().values()) < vocab_size

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.get_vocab().items()}

for token_id in range(vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
if token_id == 0:
piece = "<unk>"
elif token_id == 1:
piece = "<bos>"
elif token_id == 2:
piece = "<eos>"

text = piece.encode("utf-8")
score = 0.0
if len(piece) != 0 and token_id < 64789:
score = tokenizer.tokenizer.sp_model.get_score(token_id)

if len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")

if token_id >= 64789:
toktype = SentencePieceTokenTypes.UNKNOWN
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
continue

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.tokenizer.sp_model.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens.append(text)
scores.append(score)
toktypes.append(toktype)

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
self.gguf_writer.add_name("ChatGLM-6b-chat")
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_dimension_count(64)
self.gguf_writer.add_add_bos_token(False)

def write_tensors(self):
block_count = self.hparams["num_layers"]
tensors = dict(self.get_tensors())
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
has_lm_head = True
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))

for name, data_torch in tensors.items():
if name.endswith(".rotary_pos_emb.inv_freq"):
continue

if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys():
has_lm_head = False

name = re.sub(r'transformer\.', '', name)

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
# Map bloom-style qkv_linear to gpt-style qkv_linear
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
data = np.concatenate(
(
qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
),
axis=0,
)
print("re-format attention.linear_qkv.weight")
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
data = np.concatenate(
(
qkv_bias[:, 0, :].reshape((n_embed,)),
qkv_bias[:, 1, :].reshape((n_embed,)),
qkv_bias[:, 2, :].reshape((n_embed,)),
),
axis=0,
)
print("re-format attention.linear_qkv.bias")

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)

if not has_lm_head and name == "word_embeddings.weight":
self.gguf_writer.add_tensor("output.weight", data)
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")


###### CONVERSION LOGIC ######

Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
CHATGLM = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -217,6 +218,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.CHATGLM: "chatglm",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -743,6 +745,18 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.CHATGLM : [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down Expand Up @@ -779,6 +793,9 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS,
],
}

#
Expand Down
18 changes: 14 additions & 4 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TensorNameMap:
"backbone.embedding", # mamba
"backbone.embeddings", # mamba-hf
"transformer.in_out_embed", # Grok
"embedding.word_embeddings", # chatglm
),

# Token type embeddings
Expand Down Expand Up @@ -52,6 +53,7 @@ class TensorNameMap:
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
"output_layer", # chatglm
),

# Output norm
Expand All @@ -68,11 +70,13 @@ class TensorNameMap:
"model.norm_f", # mamba-qbert
"backbone.norm_f", # mamba
"transformer.rms_norm", # Grok
"encoder.final_layernorm", # chatglm
),

# Rope frequencies
MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs", # llama-pth
"rotary_pos_emb.inv_freq", # chatglm
),
}

Expand All @@ -97,6 +101,7 @@ class TensorNameMap:
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
),

# Attention norm 2
Expand All @@ -117,7 +122,8 @@ class TensorNameMap:
"h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
"model.layers.{bid}.self_attn.qkv_proj" # phi3
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
),

# Attention query
Expand All @@ -128,7 +134,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
),

# Attention key
Expand All @@ -140,7 +146,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.k", # refact
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
),

# Attention value
Expand All @@ -152,7 +158,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.v", # refact
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
),

# Attention output
Expand All @@ -175,6 +181,7 @@ class TensorNameMap:
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
),

# Attention output norm
Expand Down Expand Up @@ -206,6 +213,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand Down Expand Up @@ -244,6 +252,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -306,6 +315,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
"model.layers.{bid}.mlp.c_proj", # starcoder2
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
2 changes: 1 addition & 1 deletion gguf-py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gguf"
version = "0.9.0"
version = "0.9.1"
description = "Read and write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [
Expand Down

0 comments on commit 398fecb

Please sign in to comment.