Skip to content

Commit

Permalink
Fix generation (#472)
Browse files Browse the repository at this point in the history
* Fix prompt

* Update chat_templates.py

* fix_untrained_tokens

* Update llama.py

* add tokens

* Update _utils.py

* Update tokenizer_utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* pad_token

* Update chat_templates.py

* Update chat_templates.py

* tokenizer

* Update save.py

* Update chat_templates.py

* Update chat_templates.py

* patch tokenizer padding

* Update tokenizer_utils.py

* Update save.py

* Fix: loading models with resized vocabulary (#377)

* new: vocab resize on load

* new: gitignore

* GGUF fix

* Readme (#390)

* Update README.md

* Update README.md

---------

Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>

* Update README.md

* Delete .gitignore

* Phi-3

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Fix reserved tokens

* Update save.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update chat_templates.py

* Update save.py

* Update _utils.py

* Update chat_templates.py

* Adds dependencies and extras for torch 2.3.0 with new xformers versions (#415)

* Adds dependencies and extras for torch 2.3.0 with new xformers versions

* Add 2.3.0 section to readme

* Support Qwen2 (#428)

* support Qwen2

* support Qwen2

* Delete README.md

* Revert "Delete README.md"

This reverts commit 026b05f.

* Update README.md

* Qwen2 == Mistral

* Update llama.py

* Update __init__.py

* Update README.md

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update save.py

* Update save.py

* Update _utils.py

* Update save.py

* Update save.py

* Update save.py

* test_hf_gguf_equivalence

* Update chat_templates.py

* Update chat_templates.py

* --pad-vocab

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Unspecified max_seq_length

* possible_pad_token

* Update tokenizer_utils.py

* past_key_values

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* _wrap_fast_inference

* Update llama.py

* Update llama.py

* flag

---------

Co-authored-by: Igor Kilbas <whitemarsstudios@gmail.com>
Co-authored-by: Michael Han <107991372+shimmyshimmer@users.noreply.github.com>
Co-authored-by: Nathan Azrak <42650258+nathan-az@users.noreply.github.com>
Co-authored-by: Yang JianXin <995462226@qq.com>
  • Loading branch information
5 people committed May 16, 2024
1 parent 47ffd39 commit 25975f9
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 14 deletions.
41 changes: 37 additions & 4 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import torch
from typing import Union, Optional, List, Any, Callable
import warnings
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
import bitsandbytes as bnb
from transformers.models.llama.modeling_llama import logger
from transformers import AutoTokenizer
Expand Down Expand Up @@ -388,3 +389,35 @@ def backward(ctx, dY):
pass
pass


"""
Remove warnings about missing kwargs
"""
try:
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
from inspect import getsource
import re
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
"",
BitsAndBytesConfig__init__,
flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
"__init__",
"_BitsAndBytesConfig__init__",
)
exec(BitsAndBytesConfig__init__, globals())

import transformers.utils.quantization_config
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
except:
logger.warning_once(
"Unsloth unsuccessfully patched bitsandbytes. Please file a bug report.\n"\
"Luckily, your training run will still work in the meantime!"
)
pass
2 changes: 1 addition & 1 deletion unsloth/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def GemmaDecoderLayer_fast_forward(
padding_mask: Optional[torch.LongTensor] = None,
*args, **kwargs,
):
if use_cache: #past_key_value is not None:
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda")

# Self Attention
Expand Down
32 changes: 27 additions & 5 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def LlamaDecoderLayer_fast_forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if use_cache:
if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
Expand Down Expand Up @@ -789,7 +789,7 @@ def _CausalLM_fast_forward(
return_dict: Optional[bool] = None,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:

if past_key_values is not None:
outputs = fast_forward_inference(
self,
Expand Down Expand Up @@ -968,12 +968,34 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
pass


def _wrap_fast_inference(generate, device_type, dtype):
def _wrap_fast_inference(generate, device_type, dtype, model):
# Wraps inference with bfloat16 / float16
@torch.inference_mode
def _fast_generate(*args, **kwargs):

# Set a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
internal_model._flag_for_generation = True
internal_model = internal_model.model
pass
internal_model._flag_for_generation = True

# Autocasted
with torch.autocast(device_type = device_type, dtype = dtype):
return generate(*args, **kwargs)
output = generate(*args, **kwargs)
pass

# Unset a flag for generation!
internal_model = model
while hasattr(internal_model, "model"):
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation
internal_model = internal_model.model
pass
if hasattr(internal_model, "_flag_for_generation"): del internal_model._flag_for_generation

return output
pass
return _fast_generate
pass

Expand Down Expand Up @@ -1787,7 +1809,7 @@ def for_inference(model):

# Wrap model.generate
model._unwrapped_old_generate = model.generate
model.generate = _wrap_fast_inference(model.generate, device_type, dtype)
model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)

# Patch tokenizer to pad to the left
internal_model = model
Expand Down
9 changes: 5 additions & 4 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,10 +369,11 @@ def load_correct_tokenizer(
cache_dir = cache_dir,
)
except:
print(
f"Unsloth: {tokenizer_name} has no tokenizer.model file.\n"\
"Just informing you about this - this is not a critical error."
)
pass
# print(
# f"Unsloth: {tokenizer_name} has no tokenizer.model file.\n"\
# "Just informing you about this - this is not a critical error."
# )
pass

fast_tokenizer = AutoTokenizer.from_pretrained(
Expand Down

0 comments on commit 25975f9

Please sign in to comment.