Skip to content

Commit

Permalink
llama-3 bug fixes (#429)
Browse files Browse the repository at this point in the history
* Fix prompt

* Update chat_templates.py

* fix_untrained_tokens

* Update llama.py

* add tokens

* Update _utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* pad_token

* Update chat_templates.py

* Update chat_templates.py

* tokenizer

* Update save.py

* Update chat_templates.py

* Update chat_templates.py

* patch tokenizer padding

* Update tokenizer_utils.py

* Update save.py

* Fix: loading models with resized vocabulary (#377)

* new: vocab resize on load

* new: gitignore

* GGUF fix

* Readme (#390)

* Update README.md

* Update README.md

---------

Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>

* Update README.md

* Delete .gitignore

* Phi-3

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Fix reserved tokens

* Update save.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update chat_templates.py

* Update save.py

* Update _utils.py

* Update chat_templates.py

---------

Co-authored-by: Igor Kilbas <whitemarsstudios@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
  • Loading branch information
3 people committed May 7, 2024
1 parent a93a885 commit 4211cc0
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 18 deletions.
22 changes: 22 additions & 0 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,20 @@
CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token,)


# Phi-3
phi3_template = \
"{{ bos_token }}"\
"{% for message in messages %}"\
"{% if (message['role'] == 'user') %}"\
"{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}"\
"{% elif (message['role'] == 'assistant') %}"\
"{{message['content'] + '<|end|>' + '\n'}}"\
"{% endif %}"\
"{% endfor %}"
phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token,)


def get_chat_template(
tokenizer,
chat_template = "chatml",
Expand Down Expand Up @@ -595,4 +609,12 @@ def test_chat_templates():
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)

# Phi-3
template = phi3_template
correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
correct_tokenizer.chat_template = template
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
assert(correct_prompt == our_prompt)
pass
64 changes: 50 additions & 14 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,60 @@ def make_inputs_require_grad(module, input, output):


def patch_tokenizer(model, tokenizer):
"""
Phi3's pad_token isn't set. We set it to <|placeholder...
Llama-3 is <|reserved...
Llama-2 is <unk>
Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
Fixes https://github.com/unslothai/unsloth/issues/5
"""
possible_reserved_tokens = ("<|reserved", "<|placeholder",)

if model is not None:
model.config.update({"unsloth_version" : __version__})
if not hasattr(tokenizer, "pad_token") or tokenizer.pad_token is None:
# Fixes https://github.com/unslothai/unsloth/issues/5
if hasattr(tokenizer, "unk_token") and tokenizer.unk_token is not None:
tokenizer.add_special_tokens({"pad_token" : tokenizer.unk_token})
tokenizer.pad_token = tokenizer.unk_token
else:
name = model.config._name_or_path if model is not None else "Model"
logger.warning_once(
f"{name} does not have a padding or unknown token!\n"\
f"Will use the EOS token of id {tokenizer.eos_token_id} as padding."

bad_pad_token = False
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None:
# Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
bad_pad_token = tokenizer.eos_token == tokenizer.pad_token
elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
bad_pad_token = True
else:
bad_pad_token = False
pass

if bad_pad_token:
# Find a better pad token
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
possible_pad_token = None
for added_token in added_tokens[::-1]:
if added_token.startswith(possible_reserved_tokens):
possible_pad_token = added_token
break
pass
pass
if possible_pad_token is None:
# Try unk_token
possible_pad_token = tokenizer.unk_token
pass
if possible_pad_token is None:
# Failure!!
raise RuntimeError(
"Unsloth: Tokenizer's pad_token cannot be = eos_token, and we couldn't find a\n"\
"replacement of either <|reserved... or <|placeholder..."
)
assert(hasattr(tokenizer, "eos_token"))
tokenizer.add_special_tokens({"pad_token" : tokenizer.eos_token})
tokenizer.pad_token = tokenizer.eos_token
pass

name = model.config._name_or_path if model is not None else "Model"
logger.warning_once(
f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}."
)

# Edit pad_token
tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
tokenizer.pad_token = possible_pad_token
if model is not None:
config = model.config.update({"pad_token_id" : tokenizer.eos_token_id})
config = model.config.update({"pad_token_id" : tokenizer.pad_token_id})
pass
return model, tokenizer
pass
Expand Down
35 changes: 31 additions & 4 deletions unsloth/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, Callable, Union, List
import torch
import os
import shutil
import pickle
import gc
from transformers.models.llama.modeling_llama import logger
Expand Down Expand Up @@ -87,6 +88,24 @@ def print_quantization_methods():
pass


def check_if_sentencepiece_model(model, temporary_location = "_unsloth_sentencepiece_temp"):
if not hasattr(model, "_saved_temp_tokenizer"): return False

temp_tokenizer = model._saved_temp_tokenizer
sentencepiece_model = False
file_location = f"{temporary_location}/{temp_tokenizer.name_or_path}"
if not os.path.exists(file_location):
os.makedirs(file_location)
pass
temp_tokenizer.save_pretrained(file_location)
if os.path.isfile(f"{file_location}/tokenizer.model"):
sentencepiece_model = True
pass
shutil.rmtree(file_location)
return sentencepiece_model
pass


def _free_cached_model(model):
from huggingface_hub import scan_cache_dir
cached_repos = list(scan_cache_dir().repos)
Expand Down Expand Up @@ -840,6 +859,7 @@ def _fix_gemma_gguf():

def save_to_gguf(
model_type : str,
is_sentencepiece : bool = False,
model_directory : str = "unsloth_finetuned_model",
quantization_method : str = "fast_quantized",
first_conversion : str = "f16",
Expand All @@ -856,7 +876,8 @@ def save_to_gguf(

# Careful convert.py is only for Llama / Mistral based archs
use_fast_convert = False
if model_type == "llama": use_fast_convert = True
if not is_sentencepiece: use_fast_convert = False # Llama-3
elif model_type == "llama": use_fast_convert = True
elif model_type == "mistral": use_fast_convert = True
pass
logger.warning_once(f"Unsloth: Converting {model_type} model. Can use fast conversion = {use_fast_convert}.")
Expand Down Expand Up @@ -951,7 +972,7 @@ def save_to_gguf(
f"--outtype {first_conversion} --concurrency {n_cpus}"
else:
# Need to fix convert-hf-to-gguf.py for some models!
_fix_gemma_gguf()
# _fix_gemma_gguf()

command = f"python llama.cpp/convert-hf-to-gguf.py {model_directory} "\
f"--outfile {final_location} "\
Expand Down Expand Up @@ -1353,7 +1374,10 @@ def unsloth_save_pretrained_gguf(
gc.collect()

model_type = self.config.model_type
file_location = save_to_gguf(model_type, new_save_directory, quantization_method, first_conversion, makefile)
is_sentencepiece_model = check_if_sentencepiece_model(self)
file_location = save_to_gguf(model_type, is_sentencepiece_model,
new_save_directory, quantization_method, first_conversion, makefile,
)

if push_to_hub:
print("Unsloth: Uploading GGUF to Huggingface Hub...")
Expand Down Expand Up @@ -1473,7 +1497,10 @@ def unsloth_push_to_hub_gguf(
gc.collect()

model_type = self.config.model_type
file_location = save_to_gguf(model_type, new_save_directory, quantization_method, first_conversion, makefile)
is_sentencepiece_model = check_if_sentencepiece_model(self)
file_location = save_to_gguf(model_type, is_sentencepiece_model,
new_save_directory, quantization_method, first_conversion, makefile,
)

print("Unsloth: Uploading GGUF to Huggingface Hub...")
username = upload_to_huggingface(
Expand Down

0 comments on commit 4211cc0

Please sign in to comment.