Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement H2O for long context inference on summarization tasks #411

Open
wants to merge 108 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
108 commits
Select commit Hold shift + click to select a range
864f0d9
benchmark_summarization
Kyriection Mar 18, 2024
532e88b
Update exp.sh
Kyriection Mar 18, 2024
da6f5eb
Update exp.sh
Kyriection Mar 18, 2024
6ebd289
Update exp.sh
Kyriection Mar 18, 2024
0216b20
Update generation.py
Kyriection Mar 18, 2024
e226672
Update exp.sh
Kyriection Mar 18, 2024
a34a2e2
Update generation.py
Kyriection Mar 18, 2024
f50818a
Update generation.py
Kyriection Mar 18, 2024
77427ea
test
Kyriection Mar 18, 2024
ae85ea9
Update utils_llama.py
Kyriection Mar 18, 2024
9320185
test
Kyriection Mar 23, 2024
4c08c68
Create setup.sh
Kyriection Mar 23, 2024
4292fdb
Update utils_llama.py
Kyriection Mar 23, 2024
53b917b
Update utils_llama.py
Kyriection Mar 23, 2024
a346cd6
Update cache_utils.py
Kyriection Mar 23, 2024
9788b1e
x
Kyriection Mar 23, 2024
1245dca
Update generation.py
Kyriection Mar 23, 2024
266d5dc
x
Kyriection Mar 23, 2024
07efbea
Update utils_llama.py
Kyriection Mar 23, 2024
bc5d495
Update cache_utils.py
Kyriection Mar 23, 2024
8422806
Update utils_llama.py
Kyriection Mar 23, 2024
3bdf7db
Update utils_llama.py
Kyriection Mar 23, 2024
4cbce19
Update utils_llama.py
Kyriection Mar 23, 2024
bc1047e
Update utils_llama.py
Kyriection Mar 23, 2024
a220760
Update utils_llama.py
Kyriection Mar 23, 2024
b3571b7
Update utils_llama.py
Kyriection Mar 23, 2024
eb8b28c
Update utils_llama.py
Kyriection Mar 23, 2024
a6691c5
Update utils_llama.py
Kyriection Mar 23, 2024
fe82f2a
Update cache_utils.py
Kyriection Mar 23, 2024
905e546
Update cache_utils.py
Kyriection Mar 23, 2024
e2a4c74
Update utils_llama.py
Kyriection Mar 23, 2024
67f3666
test
Kyriection Mar 23, 2024
71d579e
Update exp.sh
Kyriection Mar 23, 2024
d3932c8
test
Kyriection Mar 23, 2024
f5404d8
Update cache_utils.py
Kyriection Mar 23, 2024
6b447d6
test
Kyriection Mar 24, 2024
fb9373b
Update cache_utils.py
Kyriection Mar 24, 2024
a0240ea
Update cache_utils.py
Kyriection Mar 24, 2024
89103e5
Update cache_utils.py
Kyriection Mar 24, 2024
d0eeba3
Update cache_utils.py
Kyriection Mar 24, 2024
1a433bb
Update utils_llama.py
Kyriection Mar 24, 2024
d9b9857
x
Kyriection Mar 24, 2024
ad2ed66
x
Kyriection Mar 24, 2024
5d6dc23
Update utils_llama.py
Kyriection Mar 24, 2024
db9e842
test
Kyriection Mar 24, 2024
1bd6480
Update utils_llama.py
Kyriection Mar 24, 2024
9adb645
Update cache_utils.py
Kyriection Mar 24, 2024
e024001
Update cache_utils.py
Kyriection Mar 24, 2024
ab4eee2
Update cache_utils.py
Kyriection Mar 24, 2024
7b619fb
Update cache_utils.py
Kyriection Mar 24, 2024
749174d
Update cache_utils.py
Kyriection Mar 24, 2024
cd94a39
Update cache_utils.py
Kyriection Mar 24, 2024
60faf6f
Update cache_utils.py
Kyriection Mar 24, 2024
a3e3c91
Update cache_utils.py
Kyriection Mar 24, 2024
3dd3c8f
Update utils_llama.py
Kyriection Mar 24, 2024
9b642b8
Update generation.py
Kyriection Mar 24, 2024
aba329b
Update generation.py
Kyriection Mar 24, 2024
e61b4c1
Update generation.py
Kyriection Mar 24, 2024
4676deb
Update utils_llama.py
Kyriection Mar 24, 2024
9daddd4
Update utils_llama.py
Kyriection Mar 24, 2024
4cbd593
Update utils_llama.py
Kyriection Mar 24, 2024
d860109
x
Kyriection Mar 24, 2024
0ba99ca
Update utils_llama.py
Kyriection Mar 24, 2024
603cf5c
Update utils_llama.py
Kyriection Mar 24, 2024
66bf383
upload
Kyriection Mar 24, 2024
7f007cb
test
Kyriection Mar 24, 2024
228d710
test
Kyriection Mar 24, 2024
dfd56f7
Update utils_llama.py
Kyriection Mar 24, 2024
174d1c5
Update utils_llama.py
Kyriection Mar 24, 2024
769e93e
Update cache_utils.py
Kyriection Mar 24, 2024
89e576b
x
Kyriection Mar 24, 2024
0993102
Update utils_llama.py
Kyriection Mar 24, 2024
0824531
Update utils_llama.py
Kyriection Mar 24, 2024
0dc84a4
Update utils_llama.py
Kyriection Mar 24, 2024
9e2072f
test
Kyriection Mar 24, 2024
8f955aa
Update utils_llama.py
Kyriection Mar 24, 2024
008238b
Update utils_llama.py
Kyriection Mar 24, 2024
36109c8
Update utils_llama.py
Kyriection Mar 24, 2024
57d1f6d
Update utils_llama.py
Kyriection Mar 24, 2024
036620e
Update utils_llama.py
Kyriection Mar 24, 2024
0832c06
Update utils_llama.py
Kyriection Mar 24, 2024
5b50ca5
Update utils_llama.py
Kyriection Mar 24, 2024
84ddf52
Update utils_llama.py
Kyriection Mar 24, 2024
ab07cdc
Update utils_llama.py
Kyriection Mar 24, 2024
b33b68c
Update utils_llama.py
Kyriection Mar 24, 2024
43e8599
Update utils_llama.py
Kyriection Mar 24, 2024
a694fe8
Update utils_llama.py
Kyriection Mar 24, 2024
5affb02
Update utils_llama.py
Kyriection Mar 24, 2024
38cec86
Update utils_llama.py
Kyriection Mar 24, 2024
9441e0c
Update exp.sh
Kyriection Mar 24, 2024
f2802dd
Update generation.py
Kyriection Mar 24, 2024
9fb1080
test
Kyriection Mar 24, 2024
428e8e8
test
Kyriection Mar 24, 2024
cedb89b
Update generation.py
Kyriection Mar 24, 2024
115e930
Update cache.py
Kyriection Mar 24, 2024
525de54
version 1.0, inference with h2o on summarization tasks
Kyriection Mar 24, 2024
7772734
add hazy-analysis, summarization results, streaming demo
Kyriection Apr 28, 2024
0892bf4
Update README.md
Kyriection Apr 28, 2024
28c811a
Update README.md
Kyriection Apr 28, 2024
36eaf36
Update README.md
Kyriection Apr 28, 2024
ec8842f
Merge branch 'meta-llama:main' into main
Kyriection Apr 28, 2024
7ef694c
Merge branch 'meta-llama:main' into main
Kyriection May 30, 2024
492eac7
test
Kyriection May 30, 2024
61cdf88
Merge branch 'main' of https://github.com/Kyriection/llama-recipes
Kyriection May 30, 2024
636f874
Create debug.sh
Kyriection May 30, 2024
1f11a37
update README
Kyriection May 30, 2024
15b6cc1
Update README.md
Kyriection May 30, 2024
7084c1b
upload
Kyriection May 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 91 additions & 0 deletions recipes/experimental/long-context/H2O/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
## Run Llama with H2O for long context inference

### Overview:

Heavy-Hitter Oracle (H2O) is an efficient inference framework of LLMs. During the generative inference of transfomers, the size of KV cache grows linearly with the sequence length (prompt length + generation length) during long context generation. And the size KV cache is usually significantly larger than the model parameters, contrains the inference throughput. H2O identifies the critical KV pairs and evicts other unnecessary ones, maintaining a small cache size thus improving the throughput.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets clearly state the difference between this work and other long context work which are based on training and they publish a checkpoint. Here main idea is about a kv cache policy.

Besides, LLMs usually have poor generation to long sequence during inference. H2O handles this issue by maintaining only heavy-hitter tokens and the most recent tokens. Incorporated with the positional rolling strategy (reassigning the position of each kv with the position in the kv cache instead of the original sequence), H2O can process sequence length much longer than the pretrained context window. Different from other approaches, like [Positional Interpolation](https://arxiv.org/abs/2306.15595), H2O is a KV cache policy and do not involve any training process for long context processing.

Current implementation supports llama-1/2/3, from 7B to 70B. Since H2O only maintains the most important KV pairs, it might missing some important information in the middle content for some knowlege-intensive tasks.

More details please refer to Paper: **https://arxiv.org/pdf/2306.14048**; Blog: **https://allenz.work/?p=11**.

**Note: this implementation is tested with transformers == 4.39.0**

### Evaluation on Summarization Tasks

The following example runs inference of Llama-2-7b on XSUM summarization tasks. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2.

```
python run_summarization.py \
--input-path data/summarization/xsum.jsonl \
--output-path summarization_output/xsum_h2o.jsonl \
--model-name meta-llama/Llama-2-7b-hf \
--enable_h2o_generation
```

##### **Results**

Expected results on XSUM (Rouge-2 score, ther higher the better) from the above scripts on Llama-2/3 models. The sequence length of inputs are ~2k. Here we constrains the size of KV cache, allowing only n KVs to be write/read after the prefilling stage. n ranges from **64** to **full** where we maintain all the KV pairs. With 128 KVs, the performance can be matched as the full baseline (~2k KVs) while performance degradation is observed with 64 KVs. Also, maintaining a smaller KV cache reduces the I/O cost of KVs, thus we can achieve better throughput.

| KV Cache Size | 64 | 128 | 256 | 512 | 1024 | Full |
| ------------- | ------ | ------ | ------ | ------ | ------ | ------ |
| Llama-2-7B | 0.0439 | 0.1127 | 0.1148 | 0.1182 | 0.1170 | 0.1164 |
| Llama-2-13B | 0.1180 | 0.1217 | 0.1243 | 0.1291 | 0.1302 | 0.1332 |
| Llama-3-8B | 0.1107 | 0.1189 | 0.1200 | 0.1347 | 0.1290 | 0.1311 |

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also would be great to highlight the idea of streaming LLMs here, to showcase the benefits in the inference time?

### Evaluation on "Needle in a Haystack" Analysis

The following example runs inference of Llama-3-8b-instruct on "Needle in a haystack" test. The test is modified from [https://github.com/gkamradt/LLMTest_NeedleInAHaystack](). Please follow the original repository for installing necessary packages. We're using `--enable_h2o_generation` to enable H2O algorithm that only keeps heavy-hitter and the local KV pairs. Use `--num_heavy_hitter_tokens` to decide the number of heavy-hitter KV pairs and `--num_window_length `for the KV cache size. The number of local KV pairs equals num_window_length - num_heavy_hitter_tokens. Also, use --enable_position_rolling to enable position rolling in the KV cache size that assign the positions in the KV cache instead of the ones in original sequences. Enabling postional rolling is important when sequence length exceeds the pretrained context windows, e.g., 4K in Llama-2.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we please add some numbers here vs baseline see how its working?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if for the data, specially Paul Graham. Essays make sense to use a script to download them or its easier this way? specially if there are some postprocessing is involved might be easier to have it as is.


```
# step 1: obtain prompts for evaluation
# download the dataset from https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main/needlehaystack/PaulGrahamEssays
# modify the data-path in utils/needle_test/config-prompt.yaml (line 3: haystack_dir: "data/PaulGrahamEssays")
python utils/needle_test/prompt.py --model_name meta-llama/Meta-Llama-3-8B-Instruct
# modify utils/needle_test/config-prompt.yaml to adjust the min/max sequence length for the test


# step 2: generation predictions of each prompt
# full model
python run_needle_haystack_test.py \
--input-path data/needle_test/Huggingface \
--output-path needle_test_results/huggingface/llama-3-8b-instruct/ \
--model-name meta-llama/Meta-Llama-3-8B-Instruct

# h2o with 2048 kv cache
python run_needle_haystack_test.py \
--input-path data/needle_test/Huggingface \
--output-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096/ \
--model-name meta-llama/Meta-Llama-3-8B-Instruct \
--enable_h2o_generation \
--num_window_length 4096 \
--num_heavy_hitter_tokens 2048


# step 3: scoring with gpt4
export OPENAI_API_KEY=YOUR_API_KEY
python utils/needle_test/eval.py \
--input-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096\ #path for the prediction results
--output-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096_eval


# step 4: visualization
python utils/needle_test/vis.py \
--input-path needle_test_results/huggingface/llama-3-8b-instruct-h2o-4096_eval
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add some explanation to clarify specially on the tested length 1M.

### One Demo on Streaming to "Infinite" Context Length

The following example demonstrates the generation process of "infinite" sequence length. We use MT-Bench data and generate the context sample-by-sample. The KV Cache will keep the KV pairs from the previous samples while maintain a fixed size. Results can be found on [Demo](https://allenz.work/?p=11) (Video 1).

```
# run with full cache
# expected results: 1) normal generation at the early stage; 2) performance collapse and generation slow down at the middle stage, because the sequence length exceeds the context window and the I/O cost of KV cache contrains the throughput; 3) OOM errors and stop.
bash src/streaming.sh full

# run with h2o
# expected results: normal generation at all stage.
# adjust the number of heavy-hitter tokens with --num_heavy_hitter_tokens and size of KV cache with --num_window_length in src/streaming.sh
bash src/streaming.sh h2o
```