Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Addition of Dora #936

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

Prakyathkantharaju
Copy link

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
I have added DORA for the loralinear module. Request was proposed in the issue: #893

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Copy link

pytorch-bot bot commented May 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/936

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 5, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening the PR! I am still not sure about the correctness of the implementation though. Can you run forward on the same input tensor and confirm you get the same results with a known correct implementation (e.g. the one from PEFT referenced in your L138 of lora.py).

There are some other forward-looking considerations as well: specifically how we expose DoRA in our higher-level model builders (could potentially be similar to what we do for QLoRA), how we will merge weights when DoRA is applied, determining to what extent we want to support enabling and disabling DoRA adapters (this functionality is used in e.g. our DPO recipe). But for now the main thing is to make sure the linear component itself is correct and well-tested.

+ (self.alpha / self.rank)
* (self.lora_a.weight.T @ self.lora_b.weight.T).T,
dim=1,
).to(self.weight.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the detach used in the PEFT implementation referenced in your comment?

dim=1,
).to(self.weight.dtype)
mag_norm_scale = (self.m / weight_norm - 1).view(1, -1)
return mag_norm_scale * out + mag_norm_scale * lora_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry maybe I am still missing the point but I don't think this actually matches the version that was ultimately added to PEFT. Ref

weight_norm = torch.linalg.norm(
self.weight
+ (self.alpha / self.rank)
* (self.lora_a.weight.T @ self.lora_b.weight.T).T,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need all the transposes here? Isn't this just self.lora_b.weight @ self.lora_a.weight?

@@ -67,6 +72,7 @@ def __init__(
self.dropout = nn.Dropout(p=dropout)
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
self.m = nn.Parameter(torch.ones(1, out_dim)) if self.use_dora else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may necessitate some extra logic for checkpoint loading

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to @ebsmothers comments, let's have comprehensive testing to ensure we have implementation parity compared to a well-known implementation, such as the one offered by HF PEFT. This will ensure we have the appropriate correctness guarantees in place.

Thanks so much for working on this!

@@ -97,6 +110,12 @@ def test_forward(self, inputs, lora_linear, out_dim) -> None:
assert actual.shape == (BSZ, SEQ_LEN, out_dim)
torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)

def test_dora_forward(self, inputs, dora_linear, out_dim) -> None:
expected = torch.tensor(EXPECTED_VAL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the expected val of DoRA is the same as LoRA? Why is this, and intuitively, I'm not sure if I understand if the result is the exact same, how DoRA results in different training than LoRA? Pretty sure I'm missing something basic here but would be good to clarify.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the explanation:
$$DORA = W*x + (\frac{m}{weight norm} -1)*W *x + \frac{m}{weight norm} * {lora}_b({lora}_a(x)) + {scaling}$$

The m vector is initialized as self.m == weight norm. So the ratio $\frac{m}{weight norm}$ is 1 for the first iteration.
so LORA == DORA for the first pass.

Better explanation from the author: huggingface/peft#1474 (comment)

@@ -32,6 +33,8 @@ class LoRALinear(nn.Module, AdapterModule):
rank (int): rank of the low-rank approximation
alpha (float): scaling factor for the low-rank approximation
dropout (float): dropout probability. Default: 0.0
use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: link to paper.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in the new commit.

@@ -32,6 +33,8 @@ class LoRALinear(nn.Module, AdapterModule):
rank (int): rank of the low-rank approximation
alpha (float): scaling factor for the low-rank approximation
dropout (float): dropout probability. Default: 0.0
use_dora (bool): whether to use DORA (weight-Decomposed Low-Rank Adaptation).
Default: False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we want to expose DoRA? I see following the pattern of how we enabled QLoRA and just passing in a use_dora flag. Curious about the tradeoffs compared to a DoraLinear layer though. @ebsmothers any thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a good question. I do like the analogy with QLoRA and using a flag as in this PR matches what's done in PEFT too. But one place the analogy with QLoRA breaks down is that DoRA is actually introducing a new parameter, which QLoRA does not do. This can potentially make stuff like checkpointing a bit trickier. So we may need to think about this a bit more.

@@ -111,6 +118,19 @@ def adapter_params(self) -> List[str]:
adapter_params = ["lora_a.weight", "lora_b.weight"]
return adapter_params

def dora_init(self) -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @rohan-varma and @ebsmothers, I am wondering if there is a more efficient approach to the problem I am facing. I have followed the PEFT (Parameter Efficient Fine-Tuning) method but I am encountering an issue with the initialization of the self.m (magnitude vector) in torch tune because I require the weights. Currently, I am using the flag (self.dora_initialized) to address this issue, but I believe this may not be the best solution. I would appreciate any suggestions you may have.

@Prakyathkantharaju
Copy link
Author

Thanks for opening the PR! I am still not sure about the correctness of the implementation though. Can you run forward on the same input tensor and confirm you get the same results with a known correct implementation (e.g. the one from PEFT referenced in your L138 of lora.py).

There are some other forward-looking considerations as well: specifically how we expose DoRA in our higher-level model builders (could potentially be similar to what we do for QLoRA), how we will merge weights when DoRA is applied, determining to what extent we want to support enabling and disabling DoRA adapters (this functionality is used in e.g. our DPO recipe). But for now the main thing is to make sure the linear component itself is correct and well-tested.

Hello @ebsmothers , There were some bugs in my initial commit, I have fixed those bugs ( detach missing, too many transposes and -1 in the mag_norm_scale calculation). I also added a dora_init function, which will initialize the self.m vector, similar to PEFT, However, I think this method might not be the best approach. If you have a better idea for initialization, please let me know.

@ebsmothers
Copy link
Contributor

Hi @Prakyathkantharaju sorry for the delay in responding here. I think beyond just exposing the DoRA logic in LoRALinear we probably want to think about the overall design, interaction with other parts of the library, and thorough testing (basically some of the points in my comment here and @rohan-varma's comment here). Since it's a fair amount of effort to do all of this, I am gonna tag in @calvinpelletier to help out here. Let me know if this works for you, we'd love to have your collaboration on design and code reviews.

@Prakyathkantharaju
Copy link
Author

Prakyathkantharaju commented May 25, 2024

Hello and thank you for your response. I apologize for not updating you on this issue for a while. I am currently working on comparing the performance of the Dora implementation with the PEFT. This was an aspect that was missing from the request raised by @ebsmothers and @rohan-varma. As you suggested, I welcome input from @calvinpelletier and anyone else who is willing to help move this forward. Let me know if you have any other requests.

Here are the scripts I am using to generate the PEFT model: https://gist.github.com/Prakyathkantharaju/53777b5997b9fc14ba6f40c9b5788b6a

Here is the comparison between the PEFT Dora and the Torchtune Dora loss: https://api.wandb.ai/links/continuous-optimization/991283uj.

Comment on lines 20 to 25
_component_: torchtune.models.llama3.lora_llama3_8b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8
lora_alpha: 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't using DoRA?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the configs and recipe. Now, there is no specific Dora recipe, and everything is done using the Lora recipe, where I check if use_dora is defined and perform a specific class.

log = utils.get_logger("DEBUG")


class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need an entirely new recipe for DoRA? In my mind this is analogous to QLoRA: an extension of LoRA that shouldn't require fundamental changes to the training loop. Am I missing something here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update. Now, everything is done through Lora recipe now.

Comment on lines 190 to 192
Initialize DORA m to ones.
"""
nn.init.zeros_(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't accurate?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with ones. These weights are updated later in the init_dora code. I wanted to keep the initialization of the new weights very similar to Lora, so I followed this format.

def _dora_weight_norm(self) -> Tensor:
if self._quantize_base:
# Convert NF4Tensor to regular Tensor for computation TODO(prakyath): Fix this.
weight = to_regular_tensor(self.weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't defined?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, now the weight norm is calculated by dequantizing the weights. Please let me know if you have a better/faster way to do this ?

@Prakyathkantharaju
Copy link
Author

Hello everyone,

I apologize for the delayed response, and I appreciate your review of my changes.

I have addressed the comments made by @ebsmothers and updated the structure of how Dora is initialized. Here are the details on how Dora is initialized (I have kept the initialization as similar to QLora as possible):

  1. I added a partial Dora class, similar to QLora, where Dora is initialized by the use_dora option. You can find the link to the change here.
  2. I updated the Lora recipe with Dora initialization. You can find the link to the changes here.
  3. I updated the llama-3 lora initialization with the Dora-specific option. The links to the changed files are here.

Additionally, I would like feedback from @rohan-varma @kostmo and @ebsmothers. I also added a clamp function to avoid making the denominator 0 (link to the code here). This is not done in the peft or other Dora implementations. If you feel this is not necessary, then I can remove it.

Please let me know if you need any changes. I am willing to work on them. Moreover, if you have any additional feedback, I am happy to incorporate it as well.

Copy link
Contributor

@calvinpelletier calvinpelletier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Prakyathkantharaju, thanks for contributing! I left some comments.

One thing that you are currently missing is updating the merging logic at torchtune/module/peft/peft_utils.py::get_merged_lora_ckpt.

I'll work on setting up a direct comparison to the huggingface implementation to verify that the loss graph, memory usage, and training speed are similar.

@@ -106,6 +106,10 @@ class Recipe:
name="gemma/2B_qlora_single_device",
file_path="gemma/2B_qlora_single_device.yaml",
),
Config(
name="llama3/*B_dora_single_device",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*B -> 8B


dora_llama3_8b.__doc__ = """
Builder for creating a Llama3 model with DORA enabled. Base model weights in linear layers
that DORA is applied to are quantized per the Dora paper: https://arxiv.org/abs/2402.09353.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: proper capitalization of DoRA in comments and docstrings

Comment on lines +1 to +14
# Config for single device QLoRA with lora_finetune_single_device.py
# using a Llama3 8b Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-8b-Instruct --output-dir /tmp/Meta-Llama-3-8b-Instruct --hf-token <HF_TOKEN>
#
# To launch on a single device, run the following command from root:
# tune run lora_finetune_single_device --config llama3/8b_qlora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run lora_finetune_single_device --config llama3/8b_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qlora -> dora

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /teamspace/studios/this_studio/models/Meta-Llama-3-8b-Instruct/original/tokenizer.model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no rush on this, but before merging make sure to change the paths to /tmp/, the metric logger to DiskLogger, etc.

Comment on lines +132 to +133
@property
def _dora_weight_norm(self) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a regular function instead of a property IMO

def activate_dora_params(model: nn.Module) -> nn.Module:
for k, v in model.named_modules():
if hasattr(v, "adapter_params") and callable(v.adapter_params):
current_adapter_params = v.adapter_params()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line doesn't do anything

norm = torch.linalg.norm(result, dim=1)

# Clamp the norm to avoid division by zero
# TODO(Prakyath): Check with torchtune team whether this should be a parameter ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this question referring to?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants