New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement H2O for long context inference on summarization tasks #411
base: main
Are you sure you want to change the base?
Conversation
Hi @Kyriection! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @Kyriection for the PR!! just added some quick initial thoughts on the PR. will go deeper on the second round.
|
||
The following example runs inference of Llama-2-7b on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2. | ||
|
||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
running the code it runs into this error seems one arg is missing,
File "/data/users/hamidnazeri/fbsource/llama-recipe-new/llama-recipes/research/long-context-llama/H2O/utils/llama.py", line 290, in enable_h2ocache_forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
TypeError: LlamaModel._update_causal_mask() takes 3 positional arguments but 4 were given
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also start adding numbers/ visuals how it looks/ what it improves/?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please try moving the folder under recipes/experimental/long-context/H2O
.
@@ -0,0 +1,11 @@ | |||
## Run Llama with H2O for long context inference | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please add a link to the paper, blog post ( your name as the contributor/ anyone else you like to add) and also brief summary of
- what is H2O
- how it works
- what are advantage of it
- which models it supports (llama7, 13b 70B?) it seems we need to work with ecosystem projects like VLLM etc to get it done for multi-gpus?
- potential limitations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @HamidShojanazeri. I will updated the PR by adding the implementation of Needle in a Haystack analysis, one example of inference with growing sequence length, as well as results on Llama-2/3. Please check the updated PR. Thanks!
|
||
More details please refer to Paper: https://arxiv.org/pdf/2306.14048; Blog: https://allenz.work/?p=11. | ||
|
||
### Environments: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make it as a "NOTE" to install the right version instead of environments as it seems to be only one package that version is important.
|
||
##### **Results** | ||
|
||
Expected results on XSUM (Rouge-2 score) from the above scripts on Llama-2/3 models. The sequence length of inputs are ~2k, thus KV cache size larger than 2048 represents the full cache performance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Kyriection can you please add a bit more explanation on what is the baseline, how should someone look at Rouge? What does each kv cache size mean here.
Also it seems for kv cache of 64 llama2-7b is significantly worse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets highlight, that this brings throughput benefits on the inference time.
|
||
### Evaluation on "Needle in a Haystack" Analysis | ||
|
||
The following example runs inference of Llama-3-8b-instruct on "Needle in a haystack" test. The test is modified from [https://github.com/gkamradt/LLMTest_NeedleInAHaystack](). Please follow the original repository for installing necessary packages. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we please add some numbers here vs baseline see how its working?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if for the data, specially Paul Graham. Essays make sense to use a script to download them or its easier this way? specially if there are some postprocessing is involved might be easier to have it as is.
| Llama-2-7B | 0.0439 | 0.1127 | 0.1148 | 0.1182 | 0.1170 | 0.1164 | 0.1164 | 0.1164 | | ||
| Llama-2-13B | 0.1180 | 0.1217 | 0.1243 | 0.1291 | 0.1302 | 0.1332 | 0.1332 | 0.1332 | | ||
| Llama-3-8B | 0.1107 | 0.1189 | 0.1200 | 0.1347 | 0.1290 | 0.1311 | 0.1311 | 0.1311 | | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it also would be great to highlight the idea of streaming LLMs here, to showcase the benefits in the inference time?
python utils/needle_test/vis.py \ | ||
--input-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096_eval | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add some explanation to clarify specially on the tested length 1M.
### Overview: | ||
|
||
Heavy-Hitter Oracle (H2O) is an efficient inference framework of LLMs. During the generative inference of transfomers, the size of KV cache grows linearly with the sequence length (prompt length + generation length) during long context generation. And the size KV cache is usually significantly larger than the model parameters, contrains the inference throughput. H2O identifies the critical KV pairs and evicts other unnecessary ones, maintaining a small cache size thus improving the throughput. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets clearly state the difference between this work and other long context work which are based on training and they publish a checkpoint. Here main idea is about a kv cache policy.
This is add the implementation of H2O algorithm for efficient long context inference of Llama models.
Current implementations are based on the Huggingface transformers and tests on summarization tasks, including XSUM and CNN-DailyMail