Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]adding fsdp-qlora in progress #471

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/llama_recipes/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config
from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config, qlora_config
from llama_recipes.configs.fsdp import fsdp_config
from llama_recipes.configs.training import train_config
from llama_recipes.configs.wandb import wandb_config
17 changes: 16 additions & 1 deletion src/llama_recipes/configs/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,19 @@ class llama_adapter_config:
@dataclass
class prefix_config:
num_virtual_tokens: int=30
task_type: str= "CAUSAL_LM"
task_type: str= "CAUSAL_LM"


@dataclass
class qlora_config:
r: int=8
lora_alpha: int=32
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
bias= "none"
task_type: str= "CAUSAL_LM"
lora_dropout: float=0.05
inference_mode: bool = False
bnb_4bit_quant_type: str = "bf16"
bnb_4bit_compute_dtype: str = "bf16"
bnb_4bit_quant_storage: str = "bf16"
use_nested_quant: bool = False
9 changes: 8 additions & 1 deletion src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ class train_config:
mixed_precision: bool=True
val_batch_size: int=1
dataset = "samsum_dataset"
peft_method: str = "lora" # None , llama_adapter, prefix
peft_method: str = "lora" # None , qlora, llama_adapter, prefix
use_peft: bool=False
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
num_freeze_layers: int = 1
quantization: bool = False
use_4bit_quantization: bool = False
use_8bit_quantization: bool = False
one_gpu: bool = False
save_model: bool = True
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
Expand All @@ -43,3 +45,8 @@ class train_config:
use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
use_wandb: bool = False # Enable wandb for experient tracking
save_metrics: bool = False # saves training metrics to a json file for later plotting
inference_mode: bool = False
bnb_4bit_quant_type: str = "bfloat16"
bnb_4bit_compute_dtype: str = "bfloat16"
bnb_4bit_quant_storage: str = "bfloat16"
use_nested_quant: bool = True
8 changes: 6 additions & 2 deletions src/llama_recipes/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
clear_gpu_cache,
print_model_size,
get_policies,
set_quantization_settings
)
from accelerate.utils import is_xpu_available

Expand Down Expand Up @@ -100,6 +101,9 @@ def main(**kwargs):

# Load the pre-trained model and setup its configuration
use_cache = False if train_config.enable_fsdp else None
if train_config.quantization:
bnb_config = set_quantization_settings(train_config)

if train_config.enable_fsdp and train_config.low_cpu_fsdp:
"""
for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
Expand All @@ -110,7 +114,7 @@ def main(**kwargs):
if rank == 0:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
quantization_config=bnb_config,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
Expand All @@ -124,7 +128,7 @@ def main(**kwargs):
else:
model = LlamaForCausalLM.from_pretrained(
train_config.model_name,
load_in_8bit=True if train_config.quantization else None,
quantization_config=bnb_config,
device_map="auto" if train_config.quantization else None,
use_cache=use_cache,
attn_implementation="sdpa" if train_config.use_fast_kernels else None,
Expand Down
4 changes: 2 additions & 2 deletions src/llama_recipes/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers import default_data_collator
from transformers.data import DataCollatorForSeq2Seq

from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from llama_recipes.configs import datasets, lora_config, qlora_config, llama_adapter_config, prefix_config, train_config
from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
from llama_recipes.utils.dataset_utils import DATASET_PREPROC

Expand All @@ -41,7 +41,7 @@ def update_config(config, **kwargs):


def generate_peft_config(train_config, kwargs):
configs = (lora_config, llama_adapter_config, prefix_config)
configs = (lora_config, qlora_config, llama_adapter_config, prefix_config)
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
names = tuple(c.__name__.rstrip("_config") for c in configs)

Expand Down
32 changes: 32 additions & 0 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,35 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
}
with open(output_filename, "w") as f:
json.dump(metrics_data, f)


def set_quantization_settings(train_config):

"""
Configures and returns quantization settings based on training and PEFT configuration.

Parameters:
- train_config: training config object, expected to include settings for 4-bit "use_4bit_quantization" or 8-bit "use_8bit_quantization" along with "quantization",
and "qlora" as the peft_method.
- peft_config: peft configs that include qlora settings such as "compute_dtype", "quant_storage_dtype", "use_nested_quant", and "bnb_4bit_quant_type".
Returns:
- A BitsAndBytesConfig object configured with the specified settings.
"""
from transformers import BitsAndBytesConfig

if train_config.use_4bit_quantization:
compute_dtype = getattr(torch, train_config.bnb_4bit_compute_dtype)
quant_storage_dtype = getattr(torch, train_config.bnb_4bit_quant_storage)

# Initialize BitsAndBytesConfig with 4-bit quantization settings.
bnb_config = BitsAndBytesConfig(
load_in_4bit= train_config.use_4bit_quantization,
bnb_4bit_quant_type= train_config.bnb_4bit_quant_type,
bnb_4bit_compute_dtype= compute_dtype,
bnb_4bit_use_double_quant= train_config.use_nested_quant,
bnb_4bit_quant_storage= quant_storage_dtype,
)
# Initialize BitsAndBytesConfig with 8-bit quantization flag;
elif train_config.use_8bit_quantization:
bnb_config = BitsAndBytesConfig(load_in_8bit=train_config.use_8bit_quantization)
return bnb_config