Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable prefix tuning and limit llama adapter #482

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipes/finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ It lets us specify the training settings for everything from `model_name` to `da

* [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.

* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified.
* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified. We currently support LoRA and Llama-Adapter. Please note that LoRA is the only technique which is supported in combination with FSDP.

* [FSDP config file](../../src/llama_recipes/configs/fsdp.py) provides FSDP settings such as:

Expand Down
3 changes: 2 additions & 1 deletion src/llama_recipes/configs/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class llama_adapter_config:
adapter_layers: int= 30
task_type: str= "CAUSAL_LM"

#CAUTION prefix tuning is currently not supported
@dataclass
class prefix_config:
num_virtual_tokens: int=30
task_type: str= "CAUSAL_LM"
task_type: str= "CAUSAL_LM"
2 changes: 1 addition & 1 deletion src/llama_recipes/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class train_config:
mixed_precision: bool=True
val_batch_size: int=1
dataset = "samsum_dataset"
peft_method: str = "lora" # None,llama_adapter, prefix
peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
use_peft: bool=False
output_dir: str = "PATH/to/save/PEFT/model"
freeze_layers: bool = False
Expand Down
9 changes: 8 additions & 1 deletion src/llama_recipes/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ def generate_peft_config(train_config, kwargs):
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
names = tuple(c.__name__.rstrip("_config") for c in configs)

assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
if train_config.peft_method not in names:
raise RuntimeError(f"Peft config not found: {train_config.peft_method}")

if train_config.peft_method == "prefix":
raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)")

if train_config.enable_fsdp and train_config.peft_method == "llama_adapter":
raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)")

config = configs[names.index(train_config.peft_method)]()

Expand Down
8 changes: 0 additions & 8 deletions src/llama_recipes/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):

from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy

from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder

def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
Expand All @@ -23,13 +21,7 @@ def lambda_policy_fn(module):
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
# FullyShardedDataParallelPlugin.get_module_class_from_name(
# model, transformer_layer_name
# ),
),
)

Expand Down
198 changes: 143 additions & 55 deletions tests/test_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import pytest
from pytest import approx
import os
from unittest.mock import patch

import pytest

import torch
from llama_recipes.data.sampler import LengthBasedBatchSampler

from llama_recipes.finetuning import main
from pytest import approx
from torch.optim import AdamW
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import BatchSampler

from llama_recipes.finetuning import main
from llama_recipes.data.sampler import LengthBasedBatchSampler


def get_fake_dataset():
return [{
"input_ids":[1],
"attention_mask":[1],
"labels":[1],
}]

@patch('llama_recipes.finetuning.torch.cuda.is_available')
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
return [
{
"input_ids": [1],
"attention_mask": [1],
"labels": [1],
}
]


@patch("llama_recipes.finetuning.torch.cuda.is_available")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
@pytest.mark.parametrize("cuda_is_available", [True, False])
def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
def test_finetuning_no_validation(
step_lr,
optimizer,
get_dataset,
tokenizer,
get_model,
train,
cuda,
cuda_is_available,
):
kwargs = {"run_validation": False}

get_dataset.return_value = get_fake_dataset()
cuda.return_value = cuda_is_available

get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]

main(**kwargs)

assert train.call_count == 1
Expand All @@ -53,20 +69,31 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
assert get_model.return_value.to.call_count == 0


@patch('llama_recipes.finetuning.torch.cuda.is_available')
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
@patch("llama_recipes.finetuning.torch.cuda.is_available")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
@pytest.mark.parametrize("cuda_is_available", [True, False])
def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
def test_finetuning_with_validation(
step_lr,
optimizer,
get_dataset,
tokenizer,
get_model,
train,
cuda,
cuda_is_available,
):
kwargs = {"run_validation": True}

get_dataset.return_value = get_fake_dataset()
cuda.return_value = cuda_is_available

get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]

main(**kwargs)

assert train.call_count == 1
Expand All @@ -83,22 +110,36 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
else:
assert get_model.return_value.to.call_count == 0

@patch('llama_recipes.finetuning.torch.cuda.is_available')
@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
@patch('llama_recipes.finetuning.generate_peft_config')
@patch('llama_recipes.finetuning.get_peft_model')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')

@patch("llama_recipes.finetuning.torch.cuda.is_available")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.generate_peft_config")
@patch("llama_recipes.finetuning.get_peft_model")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
@pytest.mark.parametrize("cuda_is_available", [True, False])
def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
def test_finetuning_peft_lora(
step_lr,
optimizer,
get_peft_model,
gen_peft_config,
get_dataset,
tokenizer,
get_model,
train,
cuda,
cuda_is_available,
):
kwargs = {"use_peft": True}

get_dataset.return_value = get_fake_dataset()
cuda.return_value = cuda_is_available

get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]

main(**kwargs)

if cuda_is_available:
Expand All @@ -110,21 +151,64 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1


@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
@patch('llama_recipes.finetuning.get_peft_model')
@patch('llama_recipes.finetuning.StepLR')
def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
kwargs = {"weight_decay": 0.01}
@patch("llama_recipes.finetuning.get_peft_model")
@patch("llama_recipes.finetuning.setup")
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
def test_finetuning_peft_llama_adapter(
get_dataset, tokenizer, get_model, train, setup, get_peft_model
):
kwargs = {
"use_peft": True,
"peft_method": "llama_adapter",
"enable_fsdp": True,
}

get_dataset.return_value = get_fake_dataset()

model = mocker.MagicMock(name="Model")
model.parameters.return_value = [torch.ones(1,1)]
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]

os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"

with pytest.raises(
RuntimeError,
match="Llama_adapter is currently not supported in combination with FSDP",
):
main(**kwargs)

GET_ME_OUT = "Get me out of here"
get_peft_model.side_effect = RuntimeError(GET_ME_OUT)

kwargs["enable_fsdp"] = False

with pytest.raises(
RuntimeError,
match=GET_ME_OUT,
):
main(**kwargs)


get_model.return_value = model
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.get_peft_model")
@patch("llama_recipes.finetuning.StepLR")
def test_finetuning_weight_decay(
step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
):
kwargs = {"weight_decay": 0.01}

get_dataset.return_value = get_fake_dataset()

get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]

main(**kwargs)

Expand All @@ -139,17 +223,21 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)


@patch('llama_recipes.finetuning.train')
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
@patch('llama_recipes.finetuning.optim.AdamW')
@patch('llama_recipes.finetuning.StepLR')
def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
@patch("llama_recipes.finetuning.train")
@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
@patch("llama_recipes.finetuning.get_preprocessed_dataset")
@patch("llama_recipes.finetuning.optim.AdamW")
@patch("llama_recipes.finetuning.StepLR")
def test_batching_strategy(
step_lr, optimizer, get_dataset, tokenizer, get_model, train
):
kwargs = {"batching_strategy": "packing"}

get_dataset.return_value = get_fake_dataset()

get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]

main(**kwargs)

assert train.call_count == 1
Expand Down