Skip to content

Commit

Permalink
Support Qwen2 (#428)
Browse files Browse the repository at this point in the history
* support Qwen2

* support Qwen2

* Delete README.md

* Revert "Delete README.md"

This reverts commit 026b05f.

* Update README.md

* Qwen2 == Mistral

* Update llama.py

* Update __init__.py

* Update README.md

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
  • Loading branch information
yangjianxin1 and danielhanchen committed May 10, 2024
1 parent 7c53652 commit cf83fe3
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ All notebooks are **beginner friendly**! Add your dataset, click "Run All", and
- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text.

## 🦥 Unsloth.ai News
- 📣 NEW! Qwen1.5-7B, Qwen1.5-14B, Qwen1.5-32B, Qwen1.5-72B now work, courtesy of Firefly's PR [#428](https://github.com/unslothai/unsloth/pull/428)
- 📣 NEW! [Llama-3 8b](https://colab.research.google.com/drive/135ced7oHytdxu3N2DNe1Z0kqjyYIkDXp?usp=sharing) now works! Llama-3 70b also works (change the model name in the notebook).
- 📣 NEW! [ORPO support](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) is here!
- 📣 NEW! [Phi-3 3.8b support](https://colab.research.google.com/drive/1NvkBmkHfucGO3Ve9s1NKZvMNlw5p83ym?usp=sharing) is here!
Expand Down
7 changes: 4 additions & 3 deletions unsloth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .loader import FastLanguageModel
from .llama import FastLlamaModel
from .mistral import FastMistralModel
from .dpo import PatchDPOTrainer
from .qwen2 import FastQwen2Model
from .dpo import PatchDPOTrainer
1 change: 1 addition & 0 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ def patch_peft_model(

if model_type == "llama": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "mistral": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "qwen2": apply_lora_mlp = apply_lora_mlp_swiglu
elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx
else:
raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!")
Expand Down
3 changes: 3 additions & 0 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
Expand Down Expand Up @@ -119,6 +120,8 @@ def from_pretrained(
f"to obtain the latest transformers build, then restart this session."\
)
dispatch_model = FastGemmaModel
elif model_type == "qwen2":
dispatch_model = FastQwen2Model
else:
raise NotImplementedError(
f"Unsloth: {model_name} not supported yet!\n"\
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def from_pretrained(
# Mistral does NOT support RoPE Scaling sadly so we have to error out.
if max_seq_length > model_max_seq_length:
raise RuntimeError(
"Unsloth: Unfortunately Mistral type models do not support RoPE scaling!\n"\
f"Unsloth: Unfortunately {model_patcher.__name__[4:-5]} type models do not support RoPE scaling!\n"\
f"The maximum sequence length supported is {model_max_seq_length}.",
)
pass
Expand Down
91 changes: 91 additions & 0 deletions unsloth/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .llama import *
from .mistral import FastMistralModel
import os
from ._utils import __version__

from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Qwen2Model,
Qwen2ForCausalLM,
)
# For Pytorch 2.1.1
try:
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2SdpaAttention,
Qwen2FlashAttention2,
)
except:
Qwen2SdpaAttention = Qwen2Attention
Qwen2FlashAttention2 = Qwen2Attention
pass


class FastQwen2Model(FastLlamaModel):

@staticmethod
def pre_patch():
Qwen2Attention .forward = LlamaAttention_fast_forward
Qwen2SdpaAttention .forward = LlamaAttention_fast_forward
Qwen2FlashAttention2.forward = LlamaAttention_fast_forward
Qwen2DecoderLayer .forward = LlamaDecoderLayer_fast_forward
Qwen2Model .forward = LlamaModel_fast_forward
Qwen2ForCausalLM .forward = CausalLM_fast_forward(LlamaModel_fast_forward_inference)
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward

# Solves https://github.com/unslothai/unsloth/issues/168
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
# https://github.com/huggingface/transformers/pull/27931
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
import transformers.models.qwen2.modeling_qwen2
transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding = LlamaRotaryEmbedding
return
pass


@staticmethod
def from_pretrained(
model_name = "Qwen/Qwen1.5-7B",
max_seq_length = 4096,
dtype = None,
load_in_4bit = True,
token = None,
device_map = "sequential",
rope_scaling = None, # Qwen2 does not support RoPE scaling
fix_tokenizer = True,
model_patcher = None,
tokenizer_name = None,
trust_remote_code = False,
**kwargs,
):
return FastMistralModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
token = token,
device_map = device_map,
rope_scaling = rope_scaling,
fix_tokenizer = fix_tokenizer,
model_patcher = FastQwen2Model,
tokenizer_name = tokenizer_name,
trust_remote_code = trust_remote_code,
**kwargs,
)
pass
pass

0 comments on commit cf83fe3

Please sign in to comment.