Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement H2O for long context inference on summarization tasks #411

Open
wants to merge 101 commits into
base: main
Choose a base branch
from

Conversation

Kyriection
Copy link

This is add the implementation of H2O algorithm for efficient long context inference of Llama models.
Current implementations are based on the Huggingface transformers and tests on summarization tasks, including XSUM and CNN-DailyMail

@facebook-github-bot
Copy link

Hi @Kyriection!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Copy link
Contributor

@HamidShojanazeri HamidShojanazeri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @Kyriection for the PR!! just added some quick initial thoughts on the PR. will go deeper on the second round.


The following example runs inference of Llama-2-7b on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2.

```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

running the code it runs into this error seems one arg is missing,

 File "/data/users/hamidnazeri/fbsource/llama-recipe-new/llama-recipes/research/long-context-llama/H2O/utils/llama.py", line 290, in enable_h2ocache_forward
    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
TypeError: LlamaModel._update_causal_mask() takes 3 positional arguments but 4 were given

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also start adding numbers/ visuals how it looks/ what it improves/?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please try moving the folder under recipes/experimental/long-context/H2O.

@@ -0,0 +1,11 @@
## Run Llama with H2O for long context inference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add a link to the paper, blog post ( your name as the contributor/ anyone else you like to add) and also brief summary of

  • what is H2O
  • how it works
  • what are advantage of it
  • which models it supports (llama7, 13b 70B?) it seems we need to work with ecosystem projects like VLLM etc to get it done for multi-gpus?
  • potential limitations

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @HamidShojanazeri. I will updated the PR by adding the implementation of Needle in a Haystack analysis, one example of inference with growing sequence length, as well as results on Llama-2/3. Please check the updated PR. Thanks!


More details please refer to Paper: https://arxiv.org/pdf/2306.14048; Blog: https://allenz.work/?p=11.

### Environments:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it as a "NOTE" to install the right version instead of environments as it seems to be only one package that version is important.


##### **Results**

Expected results on XSUM (Rouge-2 score) from the above scripts on Llama-2/3 models. The sequence length of inputs are ~2k, thus KV cache size larger than 2048 represents the full cache performance.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kyriection can you please add a bit more explanation on what is the baseline, how should someone look at Rouge? What does each kv cache size mean here.

Also it seems for kv cache of 64 llama2-7b is significantly worse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets highlight, that this brings throughput benefits on the inference time.


### Evaluation on "Needle in a Haystack" Analysis

The following example runs inference of Llama-3-8b-instruct on "Needle in a haystack" test. The test is modified from [https://github.com/gkamradt/LLMTest_NeedleInAHaystack](). Please follow the original repository for installing necessary packages. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please add some numbers here vs baseline see how its working?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if for the data, specially Paul Graham. Essays make sense to use a script to download them or its easier this way? specially if there are some postprocessing is involved might be easier to have it as is.

| Llama-2-7B | 0.0439 | 0.1127 | 0.1148 | 0.1182 | 0.1170 | 0.1164 | 0.1164 | 0.1164 |
| Llama-2-13B | 0.1180 | 0.1217 | 0.1243 | 0.1291 | 0.1302 | 0.1332 | 0.1332 | 0.1332 |
| Llama-3-8B | 0.1107 | 0.1189 | 0.1200 | 0.1347 | 0.1290 | 0.1311 | 0.1311 | 0.1311 |

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also would be great to highlight the idea of streaming LLMs here, to showcase the benefits in the inference time?

python utils/needle_test/vis.py \
--input-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096_eval
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add some explanation to clarify specially on the tested length 1M.

### Overview:

Heavy-Hitter Oracle (H2O) is an efficient inference framework of LLMs. During the generative inference of transfomers, the size of KV cache grows linearly with the sequence length (prompt length + generation length) during long context generation. And the size KV cache is usually significantly larger than the model parameters, contrains the inference throughput. H2O identifies the critical KV pairs and evicts other unnecessary ones, maintaining a small cache size thus improving the throughput.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets clearly state the difference between this work and other long context work which are based on training and they publish a checkpoint. Here main idea is about a kv cache policy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants