-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement H2O for long context inference on summarization tasks #411
base: main
Are you sure you want to change the base?
Changes from all commits
864f0d9
532e88b
da6f5eb
6ebd289
0216b20
e226672
a34a2e2
f50818a
77427ea
ae85ea9
9320185
4c08c68
4292fdb
53b917b
a346cd6
9788b1e
1245dca
266d5dc
07efbea
bc5d495
8422806
3bdf7db
4cbce19
bc1047e
a220760
b3571b7
eb8b28c
a6691c5
fe82f2a
905e546
e2a4c74
67f3666
71d579e
d3932c8
f5404d8
6b447d6
fb9373b
a0240ea
89103e5
d0eeba3
1a433bb
d9b9857
ad2ed66
5d6dc23
db9e842
1bd6480
9adb645
e024001
ab4eee2
7b619fb
749174d
cd94a39
60faf6f
a3e3c91
3dd3c8f
9b642b8
aba329b
e61b4c1
4676deb
9daddd4
4cbd593
d860109
0ba99ca
603cf5c
66bf383
7f007cb
228d710
dfd56f7
174d1c5
769e93e
89e576b
0993102
0824531
0dc84a4
9e2072f
8f955aa
008238b
36109c8
57d1f6d
036620e
0832c06
5b50ca5
84ddf52
ab07cdc
b33b68c
43e8599
a694fe8
5affb02
38cec86
9441e0c
f2802dd
9fb1080
428e8e8
cedb89b
115e930
525de54
7772734
0892bf4
28c811a
36eaf36
ec8842f
7ef694c
492eac7
61cdf88
636f874
1f11a37
15b6cc1
7084c1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
## Run Llama with H2O for long context inference | ||
|
||
### Overview: | ||
|
||
Heavy-Hitter Oracle (H2O) is an efficient inference framework of LLMs. During the generative inference of transfomers, the size of KV cache grows linearly with the sequence length (prompt length + generation length) during long context generation. And the size KV cache is usually significantly larger than the model parameters, contrains the inference throughput. H2O identifies the critical KV pairs and evicts other unnecessary ones, maintaining a small cache size thus improving the throughput. | ||
|
||
Besides, LLMs usually have poor generation to long sequence during inference. H2O handles this issue by maintaining only heavy-hitter tokens and the most recent tokens. Incorporated with the positional rolling strategy (reassigning the position of each kv with the position in the kv cache instead of the original sequence), H2O can process sequence length much longer than the pretrained context window. Different from other approaches, like [Positional Interpolation](https://arxiv.org/abs/2306.15595), H2O is a KV cache policy and do not involve any training process for long context processing. | ||
|
||
Current implementation supports llama-1/2/3, from 7B to 70B. Since H2O only maintains the most important KV pairs, it might missing some important information in the middle content for some knowlege-intensive tasks. | ||
|
||
More details please refer to Paper: **https://arxiv.org/pdf/2306.14048**; Blog: **https://allenz.work/?p=11**. | ||
|
||
**Note: this implementation is tested with transformers == 4.39.0** | ||
|
||
### Evaluation on Summarization Tasks | ||
|
||
The following example runs inference of Llama-2-7b on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2. | ||
|
||
``` | ||
python run_summarization.py \ | ||
--input-path data/summarization/xsum.jsonl \ | ||
--output-path summarization_output/xsum_h2o.jsonl \ | ||
--model-name meta-llama/Llama-2-7b-hf \ | ||
--enable_h2o_generation | ||
``` | ||
|
||
##### **Results** | ||
|
||
Expected results on XSUM (Rouge-2 score, ther higher the better) from the above scripts on Llama-2/3 models. The sequence length of inputs are ~2k. Here we constrains the size of KV cache, allowing only n KVs to be write/read after the prefilling stage. n ranges from **64** to **full** where we maintain all the KV pairs. With 128 KVs, the performance can be matched as the full baseline (~2k KVs) while performance degradation is observed with 64 KVs. Also, maintaining a smaller KV cache reduces the I/O cost of KVs, thus we can achieve better throughput. | ||
|
||
| KV Cache Size | 64 | 128 | 256 | 512 | 1024 | Full | | ||
| ------------- | ------ | ------ | ------ | ------ | ------ | ------ | | ||
| Llama-2-7B | 0.0439 | 0.1127 | 0.1148 | 0.1182 | 0.1170 | 0.1164 | | ||
| Llama-2-13B | 0.1180 | 0.1217 | 0.1243 | 0.1291 | 0.1302 | 0.1332 | | ||
| Llama-3-8B | 0.1107 | 0.1189 | 0.1200 | 0.1347 | 0.1290 | 0.1311 | | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it also would be great to highlight the idea of streaming LLMs here, to showcase the benefits in the inference time? |
||
### Evaluation on "Needle in a Haystack" Analysis | ||
|
||
The following example runs inference of Llama-3-8b-instruct on "Needle in a haystack" test. The test is modified from [https://github.com/gkamradt/LLMTest_NeedleInAHaystack](). Please follow the original repository for installing necessary packages. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we please add some numbers here vs baseline see how its working? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if for the data, specially Paul Graham. Essays make sense to use a script to download them or its easier this way? specially if there are some postprocessing is involved might be easier to have it as is. |
||
|
||
``` | ||
# step 1: obtain prompts for evaluation | ||
# download the dataset from https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main/needlehaystack/PaulGrahamEssays | ||
# modify the data-path in utils/needle_test/config-prompt.yaml (line 3: haystack_dir: "data/PaulGrahamEssays") | ||
python utils/needle_test/prompt.py --model_name meta-llama/Meta-Llama-3-8B-Instruct | ||
# modify utils/needle_test/config-prompt.yaml to adjust the min/max sequence length for the test | ||
|
||
|
||
# step 2: generation predictions of each prompt | ||
# full model | ||
python run_needle_haystack_test.py \ | ||
--input-path data/needle_test/Huggingface \ | ||
--output-path needle_test_results/huggingface/llama-3-8b-instruct/ \ | ||
--model-name meta-llama/Meta-Llama-3-8B-Instruct | ||
|
||
# h2o with 2048 kv cache | ||
python run_needle_haystack_test.py \ | ||
--input-path data/needle_test/Huggingface \ | ||
--output-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096/ \ | ||
--model-name meta-llama/Meta-Llama-3-8B-Instruct \ | ||
--enable_h2o_generation \ | ||
--num_window_length 4096 \ | ||
--num_heavy_hitter_tokens 2048 | ||
|
||
|
||
# step 3: scoring with gpt4 | ||
export OPENAI_API_KEY=YOUR_API_KEY | ||
python utils/needle_test/eval.py \ | ||
--input-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096\ #path for the prediction results | ||
--output-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096_eval | ||
|
||
|
||
# step 4: visualization | ||
python utils/needle_test/vis.py \ | ||
--input-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096_eval | ||
``` | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we please add some explanation to clarify specially on the tested length 1M. |
||
### One Demo on Streaming to "Infinite" Context Length | ||
|
||
The following example demonstrates the generation process of "infinite" sequence length. We use MT-Bench data and generate the context sample-by-sample. The KV Cache will keep the KV pairs from the previous samples while maintain a fixed size. Results can be found on [Demo](https://allenz.work/?p=11) (Video 1). | ||
|
||
``` | ||
# run with full cache | ||
# expected results: 1) normal generation at the early stage; 2) performance collapse and generation slow down at the middle stage, because the sequence length exceeds the context window and the I/O cost of KV cache contrains the throughput; 3) OOM errors and stop. | ||
bash src/streaming.sh full | ||
|
||
# run with h2o | ||
# expected results: normal generation at all stage. | ||
# adjust the number of heavy-hitter tokens with --num_heavy_hitter_tokens and size of KV cache with --num_window_length in src/streaming.sh | ||
bash src/streaming.sh h2o | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets clearly state the difference between this work and other long context work which are based on training and they publish a checkpoint. Here main idea is about a kv cache policy.