Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml : add Flash Attention #5021

Merged
merged 145 commits into from
Apr 30, 2024
Merged

ggml : add Flash Attention #5021

merged 145 commits into from
Apr 30, 2024

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Jan 18, 2024

ref #3365

Setting up what's needed for Flash Attention support in ggml and llama.cpp

The proposed operator performs:

// new
res = ggml_flash_attn(ctx, q, k, v, kq_mask, kq_scale);

// fused scale + mask + soft_max (old)
kq  = ggml_mul_mat     (ctx, k,  q);
kq  = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
kqv = ggml_mul_mat     (ctx, v,  kq);
kqv = ggml_permute     (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d     (ctx, kqv, n_embd_head_k*n_head, n_tokens);

// unfused (old)
kq  = ggml_mul_mat (ctx, k,  q);
kq  = ggml_scale   (ctx, kq, kq_scale);
kq  = ggml_add     (ctx, kq, kq_mask);
kq  = ggml_soft_max(ctx, kq);
kqv = ggml_mul_mat (ctx, v,  kq);
kqv = ggml_permute (ctx, kqv, 0, 2, 1, 3);
res = ggml_cont_2d (ctx, kqv, n_embd_head_k*n_head, n_tokens);

Suggestions and comments for the API are welcome.
Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals

Changes to ggml/llama

Things to consider

  • Pass KQ list with/instead of KQ mask
  • Pass block-wise KQ mask
  • Support Alibi
  • Finally transform Alibi as ggml_add()? (low-prio)
  • No longer store transposed V-cache (gg/flash-attn-online)

Testing

./tests/test-backend-ops -o FLASH_ATTN_EXT
  • main, server: add -fa
  • llama-bench: add -fa 1

Benchmark

Baseline:

# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b CUDA0 perf

# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o ATTN -b Metal perf

FA kernel:

# CUDA
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b CUDA0 perf

# Metal
LLAMA_CUBLAS=1 make -j tests && ./tests/test-backend-ops -o FLASH_ATTN_EXT -b Metal perf

Text-generation after long prompt:

# without flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 0 1 99 8192 256 1

# with flash attention
./batched-bench models/mistral-instruct-7b-v0.2/ggml-model-f16.gguf 10000 2048 512 1 1 99 8192 256 1

References

@ggerganov ggerganov added help wanted Extra attention is needed performance Speed related topics labels Jan 18, 2024
@ggerganov ggerganov closed this Jan 18, 2024
@ggerganov ggerganov reopened this Jan 18, 2024
@ggerganov ggerganov marked this pull request as draft January 18, 2024 17:09
@slaren
Copy link
Collaborator

slaren commented Jan 18, 2024

Since we are doing this from scratch, wouldn't it be better to remove the custom attention mask entirely and pass a list of KV cells used in each sequence? Considering our implementation of batching, I think we should be looking at implementing something closer to paged attention rather than flash attention. I suppose it is possible to convert the mask to a list of sequences in the kernels, but it would be less efficient.

@ggerganov
Copy link
Owner Author

ggerganov commented Jan 18, 2024

Yes, we can pass list instead of mask. I am not sure of the format though - if each list has different length I feel it will hinder the GPU performance.

Edit: I just got an idea - we can pass both the kq_mask as it is, plus a second boolean tensor that tells each token to which KV blocks it should attend. For example, we split the KV cache in blocks of 128 (or some other round number) and a token (i.e. row in q) attends to a block if atleast one of the cells in it belongs to the token's sequence. This way, we can skip entire blocks of the KV cache that do not belong to the current sequence and keep the problem parallel-friendly. Thoughts?

@slaren
Copy link
Collaborator

slaren commented Jan 18, 2024

We could use a vector with dimension [num_seqs] that contains the length of the sequences, and a 2D tensor with dimensions [max_seq_len, num_seqs] that contains the KV cells in each sequence, padded to the length of the longest sequence.

@slaren
Copy link
Collaborator

slaren commented Jan 18, 2024

It seems that vLLM has added a new version of paged attention since it looked into the implementation (vllm-project/vllm#1348). I am not sure what are the changes, but I think it is worth looking into what they are doing. The kernel is in https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu

@slaren
Copy link
Collaborator

slaren commented Jan 18, 2024

Alibi could also be done in this kernel.

@ggerganov
Copy link
Owner Author

Regarding the Alibi, I feel reinterpreting it as a KQ_mask via ggml_add() is a more general solution - we will avoid having a ggml_alibi() operator and explicit support in the kernels that we write (like in vLLM).

It remains to be seen though if the KQ_mask will be a bottleneck - my feeling is that just avoiding the extra read/write of KQ will bring us close to the optimal performance, even with the existing "cross-KV compute" drawback.

Will take a look at the vLLM code and I've updated the description with some of the things from this discussion

@calvintwr
Copy link

@ggerganov @slaren Together with @JohannesGaessler and @FSSRepo we are working on the same thing over at Pints-App#1 which we intend to do a pull to llamacpp once work is done.

However, I think we will converge into this one. Given the amount of work here, @ggerganov @slaren how do you want to organise this? The 3 of us are in a temporary discord group actually to work this out, perhaps we can use that?

What are your thoughts?

@ggerganov
Copy link
Owner Author

ggerganov commented Jan 19, 2024

Discord is not an option for me - I prefer to communicate over Github issues / discussions / e-mail.

Happy to see you have started work on the CUDA implementation. Please take into account the proposed API here - note that it is still a WIP and can change. I can review the implementation that you have when you think it is in a good state. Would prefer PR's that are compatible with this branch so we can verify correctness using test-backend-ops and support for all backends.

@calvintwr
Copy link

@ggerganov Got it. Let us work on a plan to converge with this PR.

@cebtenzzre
Copy link
Collaborator

cebtenzzre commented Jan 20, 2024

test-backend-ops -o FLASH_ATTN_EXT fails for Metal on my M2 Pro, is this known?
edit: I see, not implemented yet.

@JianbangZ
Copy link

Any performance numbers?

@ggerganov ggerganov merged commit 9c67c27 into master Apr 30, 2024
69 checks passed
@ochafik
Copy link
Collaborator

ochafik commented Apr 30, 2024

So excited to see this land!! 🎉

(edit: can confirm the crash is fixed)

Did quick llama-bench runs on Metal and I’m seeing ~5% increase of pp t/s & ~2.5% increase of tg t/s on an M3 Pro (18gpu), and 8-9% increase of pp t/s & ~6% increase of tg t/s on an M1 Ultra (64gpu).

Show llama-bench command and results
./llama-bench -m models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf -fa 0,1 -p 512,1024 -n 128,256,512,1024

M3 Pro 36gb 18gpu

model size params backend ngl fa test t/s
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 pp 512 308.84 ± 1.40
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 tg 128 24.88 ± 0.14
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 tg 256 24.88 ± 0.16
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 pp 512 324.32 ± 1.02
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 tg 128 25.75 ± 0.12
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 tg 256 25.50 ± 0.33

M1 Ultra 128gb 64gpu

model size params backend ngl fa test t/s
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 pp 512 824.02 ± 13.65
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 pp 1024 837.59 ± 5.64
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 tg 128 66.80 ± 0.10
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 tg 256 66.99 ± 0.12
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 tg 512 66.48 ± 0.03
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 0 tg 1024 65.23 ± 0.07
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 pp 512 905.47 ± 0.70
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 pp 1024 895.86 ± 0.62
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 tg 128 70.67 ± 0.08
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 tg 256 70.63 ± 0.03
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 tg 512 70.31 ± 0.04
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 tg 1024 69.68 ± 0.06

I also ran a custom benchmark which implies the speed up on Metal increases both with prompt length and generation length (plateauing at 9.5% faster), but I might have got my script wrong

Meta-Llama-3-8B-Instruct-Q4_K_M.gguf

M1 Ultra 128gb 64gpu

prompt \ n 10 100 1000
10 -0.68% -1.43% -5.33%
100 -0.70% -3.88% -5.57%
1000 -1.41% -4.81% -9.26%

M3 Pro 36gb 18gpu

prompt \ n 10 100 1000
10 -2.26% -2.39% -9.45%
100 -0.85% -2.41% -7.69%
1000 -5.23% -8.13% -9.61%

@ggerganov
Copy link
Owner Author

@ochafik On Metal you can get some extra perf adding -t 4 - both with and without FA

@segmond
Copy link

segmond commented Apr 30, 2024

This is great. The performance is the same as far as tokens per second, the increase in usable context windows is insane!

@ExtReMLapin
Copy link
Contributor

With default settings on a RTX 4090 , + Meta-Llama-3-8B-Instruct-Q5_K_M.gguf We go from 120t/s to 129t/s

@ExtReMLapin
Copy link
Contributor

This is great. The performance is the same as far as tokens per second, the increase in usable context windows is insane!

Am I missing something ? It only increases t/s ? Right ? Not VRAM usage per ctx size ?

@slaren
Copy link
Collaborator

slaren commented Apr 30, 2024

KQ doesn't need to be materialized in global memory with flash attention, and with large contexts that was often the biggest tensor in the compute buffer. So it should reduce the size of the compute buffer substantially with large contexts.

@sorasoras
Copy link

sorasoras commented Apr 30, 2024

 .\test-backend-ops.exe -o FLASH_ATTN_EXT -b ROCm0 perf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
Testing 2 backends

Backend 1/2 (CPU)
  Skipping
Backend 2/2 (ROCm0)
  Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -     9.54 us/run -     4144 kB/run -  414.18 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -     6.14 us/run -     4160 kB/run -  645.83 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -     5.78 us/run -     4192 kB/run -  691.90 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -     5.90 us/run -     4256 kB/run -  688.21 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -     8.48 us/run -     8272 kB/run -  930.48 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -     5.79 us/run -     8288 kB/run - 1364.52 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -     6.28 us/run -     8320 kB/run - 1263.13 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -     6.36 us/run -     8384 kB/run - 1257.08 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                       6488 runs -     6.34 us/run -     5172 kB/run -  777.90 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=2):                       6463 runs -     6.45 us/run -     5192 kB/run -  767.88 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=4):                       6414 runs -     6.53 us/run -     5232 kB/run -  763.68 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=8):                       6317 runs -     6.56 us/run -     5312 kB/run -  772.46 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=1):                      3251 runs -     6.24 us/run -    10324 kB/run - 1578.33 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=2):                      3244 runs -     6.49 us/run -    10344 kB/run - 1520.91 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=4):                      3232 runs -     6.50 us/run -    10384 kB/run - 1523.31 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=8):                      3207 runs -     6.37 us/run -    10464 kB/run - 1567.57 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1):                      4065 runs -    10.07 us/run -     8256 kB/run -  781.87 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2):                      4049 runs -     5.57 us/run -     8288 kB/run - 1419.22 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4):                      4018 runs -     6.38 us/run -     8352 kB/run - 1248.24 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8):                      3957 runs -     3.74 us/run -     8480 kB/run - 2164.12 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1):                     2037 runs -     9.49 us/run -    16480 kB/run - 1656.13 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2):                     2033 runs -     5.75 us/run -    16512 kB/run - 2738.56 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4):                     2025 runs -     6.51 us/run -    16576 kB/run - 2429.16 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8):                     2009 runs -     6.00 us/run -    16704 kB/run - 2653.27 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=1):                      2037 runs -    11.38 us/run -    16480 kB/run - 1381.55 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=2):                      2029 runs -     6.82 us/run -    16544 kB/run - 2314.56 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=4):                      2013 runs -     5.87 us/run -    16672 kB/run - 2710.54 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=8):                      1983 runs -     7.21 us/run -    16928 kB/run - 2238.37 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=1):                     1021 runs -    13.04 us/run -    32896 kB/run - 2406.53 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=2):                     1019 runs -     3.45 us/run -    32960 kB/run - 9122.85 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=4):                     1015 runs -     8.03 us/run -    33088 kB/run - 3929.40 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=8):                     1007 runs -     8.58 us/run -    33344 kB/run - 3705.81 GB/s
  Backend ROCm0: OK

2/2 backends passed
OK

.\test-backend-ops.exe -o ATTN -b ROCm0 perf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6700 XT, compute capability 10.3, VMM: no
Testing 2 backends

Backend 1/2 (CPU)
  Skipping
Backend 2/2 (ROCm0)
  Backend name: ROCm0
  Backend ROCm0: OK

2/2 backends passed
OK

@ggerganov F-AT is not enabled ROCM in general,right?

@ochafik
Copy link
Collaborator

ochafik commented Apr 30, 2024

@ochafik On Metal you can get some extra perf adding -t 4 - both with and without FA

@ggerganov actually on the M3 Pro it seems performance peaks at -t 2. Might have to do with it being an oddball "6 performance + 6 efficiency" cores processor.

Edit: actually now seemingly getting same results for all -t values, previous run was probably jittery

(jittery) llama-bench output for various -t values
./llama-bench -m models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf -fa 1,0 -p 512 -t 1,2,3,4,5,6,7,8
model size params backend ngl threads fa test t/s
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 1 pp 512 291.74 ± 8.98
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 1 tg 128 23.52 ± 0.24
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 2 1 pp 512 315.41 ± 1.52
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 2 1 tg 128 22.78 ± 0.33
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 3 1 pp 512 311.22 ± 1.09
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 3 1 tg 128 22.83 ± 0.09
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 4 1 pp 512 311.42 ± 1.34
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 4 1 tg 128 22.98 ± 0.19
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 5 1 pp 512 311.71 ± 0.52
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 5 1 tg 128 23.04 ± 0.34
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 6 1 pp 512 288.59 ± 3.48
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 6 1 tg 128 24.56 ± 1.07
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 7 1 pp 512 319.85 ± 5.38
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 7 1 tg 128 25.48 ± 0.40
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 8 1 pp 512 314.46 ± 6.37
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 8 1 tg 128 24.26 ± 0.81
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 0 pp 512 308.55 ± 4.98
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 1 0 tg 128 23.26 ± 0.86
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 2 0 pp 512 311.89 ± 6.15
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 2 0 tg 128 23.91 ± 0.35
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 3 0 pp 512 297.60 ± 2.46
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 3 0 tg 128 22.63 ± 0.27
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 4 0 pp 512 298.97 ± 2.00
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 4 0 tg 128 21.84 ± 1.41
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 5 0 pp 512 271.76 ± 0.75
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 5 0 tg 128 19.50 ± 0.13
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 6 0 pp 512 271.03 ± 0.84
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 6 0 tg 128 19.79 ± 0.24
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 7 0 pp 512 274.13 ± 1.29
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 7 0 tg 128 21.10 ± 1.07
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 8 0 pp 512 304.75 ± 7.67
llama 8B Q4_K - Medium 4.58 GiB 8.03 B Metal 99 8 0 tg 128 22.77 ± 0.93

@slaren
Copy link
Collaborator

slaren commented Apr 30, 2024

Metal and other GPU backends with full offload only uses one thread, however in Metal the number of threads is also used as the number of command buffers.

@ochafik
Copy link
Collaborator

ochafik commented Apr 30, 2024

@slaren ah ok, thanks for the explanation! I'm not seeing any effect of -t anymore (not sure what happened w/ previous run). Looks like the M3 Pro GPU gets 100% busy already w/ -t 1.

@jdecourval
Copy link
Contributor

 .\test-backend-ops.exe -o FLASH_ATTN_EXT -b ROCm0 perf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
Testing 2 backends

Backend 1/2 (CPU)
  Skipping
Backend 2/2 (ROCm0)
  Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -     9.54 us/run -     4144 kB/run -  414.18 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -     6.14 us/run -     4160 kB/run -  645.83 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -     5.78 us/run -     4192 kB/run -  691.90 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -     5.90 us/run -     4256 kB/run -  688.21 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -     8.48 us/run -     8272 kB/run -  930.48 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -     5.79 us/run -     8288 kB/run - 1364.52 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -     6.28 us/run -     8320 kB/run - 1263.13 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -     6.36 us/run -     8384 kB/run - 1257.08 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                       6488 runs -     6.34 us/run -     5172 kB/run -  777.90 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=2):                       6463 runs -     6.45 us/run -     5192 kB/run -  767.88 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=4):                       6414 runs -     6.53 us/run -     5232 kB/run -  763.68 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=8):                       6317 runs -     6.56 us/run -     5312 kB/run -  772.46 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=1):                      3251 runs -     6.24 us/run -    10324 kB/run - 1578.33 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=2):                      3244 runs -     6.49 us/run -    10344 kB/run - 1520.91 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=4):                      3232 runs -     6.50 us/run -    10384 kB/run - 1523.31 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=8):                      3207 runs -     6.37 us/run -    10464 kB/run - 1567.57 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1):                      4065 runs -    10.07 us/run -     8256 kB/run -  781.87 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2):                      4049 runs -     5.57 us/run -     8288 kB/run - 1419.22 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4):                      4018 runs -     6.38 us/run -     8352 kB/run - 1248.24 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8):                      3957 runs -     3.74 us/run -     8480 kB/run - 2164.12 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1):                     2037 runs -     9.49 us/run -    16480 kB/run - 1656.13 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2):                     2033 runs -     5.75 us/run -    16512 kB/run - 2738.56 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4):                     2025 runs -     6.51 us/run -    16576 kB/run - 2429.16 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8):                     2009 runs -     6.00 us/run -    16704 kB/run - 2653.27 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=1):                      2037 runs -    11.38 us/run -    16480 kB/run - 1381.55 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=2):                      2029 runs -     6.82 us/run -    16544 kB/run - 2314.56 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=4):                      2013 runs -     5.87 us/run -    16672 kB/run - 2710.54 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=8):                      1983 runs -     7.21 us/run -    16928 kB/run - 2238.37 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=1):                     1021 runs -    13.04 us/run -    32896 kB/run - 2406.53 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=2):                     1019 runs -     3.45 us/run -    32960 kB/run - 9122.85 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=4):                     1015 runs -     8.03 us/run -    33088 kB/run - 3929.40 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=8):                     1007 runs -     8.58 us/run -    33344 kB/run - 3705.81 GB/s
  Backend ROCm0: OK

2/2 backends passed
OK
.\test-backend-ops.exe -o ATTN -b ROCm0 perf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6700 XT, compute capability 10.3, VMM: no
Testing 2 backends

Backend 1/2 (CPU)
  Skipping
Backend 2/2 (ROCm0)
  Backend name: ROCm0
  Backend ROCm0: OK

2/2 backends passed
OK

@ggerganov F-AT is not enabled ROCM in general,right?

It's currently disabled, yes.
It needs cleanup, but I enabled back the feature in this branch if you want to have a look.
#7011
I'll post some numbers soon, but don't expect much additional speed, but the VRAM saving can be significant.

@Dampfinchen
Copy link

Dampfinchen commented Apr 30, 2024

Sadly I'm not seeing any benefit from this. No reduction in VRAM usage, no speedup, even when fully offloading.

Infact, I'm only seeing slower speeds when using partial offloading.

@strawberrymelonpanda
Copy link
Contributor

strawberrymelonpanda commented Apr 30, 2024

For me (Windows, CUDA, 24GB VRAM) the difference is definitely there, but it depends on the model and I have best results with a large amount of context data.

The most pronounced for me is Mixtral-8x7B-Instruct-v0.1-requant-imat-IQ3_XS which I can fully offload. It nearly doubles in speed according to the timings and I was able to up the ctx from 16K to 32K.

Edit: I saw the below "old timings" across at least 4x runs each last night, but today w/o FA is hitting close to 39-40 t/s, so must have been an edge case there, but FA seemed to help with it.

With FA:

llama_print_timings:        load time =    9259.83 ms
llama_print_timings:      sample time =      65.38 ms /   328 runs   (    0.20 ms per token,  5017.13 tokens per second)
llama_print_timings: prompt eval time =    7894.93 ms /  7840 tokens (    1.01 ms per token,   993.04 tokens per second)
llama_print_timings:        eval time =    7317.74 ms /   327 runs   (   22.38 ms per token,    44.69 tokens per second)
llama_print_timings:       total time =   15517.04 ms /  8167 tokens

Without FA: (updated)

llama_print_timings:        load time =    9860.39 ms
llama_print_timings:      sample time =      25.49 ms /   128 runs   (    0.20 ms per token,  5021.18 tokens per second)
llama_print_timings: prompt eval time =    7806.70 ms /  7169 tokens (    1.09 ms per token,   918.31 tokens per second)
llama_print_timings:        eval time =    3115.48 ms /   127 runs   (   24.53 ms per token,    40.76 tokens per second)
llama_print_timings:       total time =   11127.62 ms /  7296 tokens
old w/o timings
llama_print_timings:        load time =    9722.33 ms
llama_print_timings:      sample time =      64.05 ms /   322 runs   (    0.20 ms per token,  5027.24 tokens per second)
llama_print_timings: prompt eval time =   12847.00 ms /  7840 tokens (    1.64 ms per token,   610.26 tokens per second)
llama_print_timings:        eval time =   14036.72 ms /   321 runs   (   43.73 ms per token,    22.87 tokens per second)
llama_print_timings:       total time =   27208.66 ms /  8161 tokens

Other models are less remarkable, but I'm able to store a lot more context.

New tests:

Llamabench with -p 512,1024 is less dramatic but measurable, TG ~46 -> ~50:
./llama-bench -m Mixtral-8x7B-Instruct-v0.1-requant-imat-IQ3_XS.gguf -fa 0,1 -p 512,1024 -n 128,256,512,1024
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
| model                          |       size |     params | backend    | ngl |         fa | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ---------- | ---------------: |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | pp 512     |   1146.77 ± 9.12 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | pp 1024    |   1130.55 ± 4.81 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | tg 128     |     46.81 ± 0.58 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | tg 256     |     47.07 ± 0.16 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | tg 512     |     46.50 ± 0.70 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | tg 1024    |     46.46 ± 0.44 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | pp 512     |  1159.59 ± 10.34 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | pp 1024    |   1155.55 ± 5.00 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | tg 128     |     51.22 ± 0.06 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | tg 256     |     50.83 ± 0.17 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | tg 512     |     50.82 ± 0.13 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | tg 1024    |     50.28 ± 0.35 |

build: a68a1e7e (2772)
The differences are more obvious at -p 8096, 16192, 32384: From PP 819 -> 1005 @ 16K, and OOM -> 879 @ 32K.
| model                          |       size |     params | backend    | ngl |         fa | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ---------- | ---------------: |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | pp 1024    |   1128.12 ± 9.99 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | pp 8096    |    961.26 ± 4.45 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | pp 16192   |    819.92 ± 1.58 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          0 | pp 32384   |    OOM           |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | pp 1024    |   1150.17 ± 5.60 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | pp 8096    |   1075.53 ± 2.02 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | pp 16192   |   1005.41 ± 1.65 |
| llama 8x7B IQ3_XS - 3.3 bpw    |  34.96 GiB |    91.80 B | CUDA       |  99 |          1 | pp 32384   |    879.37 ± 1.96 |

@ddh0
Copy link

ddh0 commented Apr 30, 2024

Performance on Macbook Air M2, 24GB using latest llama.cpp, before and after using the -fa argument:

Without Flash Attention:

./llama.cpp/main -m ./models/Meta-Llama-3-8B-Instruct-q8_0.gguf -c 8192 -n 4096 -t 4 -tb 8 -b 512 -ngl 999

main: build = 2774 (f364eb6f)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.4.0
main: seed  = 1714507110

llama_kv_cache_init:      Metal KV buffer size =  1024.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
llama_new_context_with_model:      Metal compute buffer size =   560.00 MiB
llama_new_context_with_model:        CPU compute buffer size =    24.01 MiB

llama_print_timings:        load time =     542.92 ms
llama_print_timings:      sample time =     200.91 ms /  4096 runs   (    0.05 ms per token, 20387.14 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_print_timings:        eval time =  470814.64 ms /  4096 runs   (  114.94 ms per token,     8.70 tokens per second)
llama_print_timings:       total time =  472380.66 ms /  4097 tokens

With Flash Attention:

./llama.cpp/main -m ./models/Meta-Llama-3-8B-Instruct-q8_0.gguf -c 8192 -n 4096 -t 4 -tb 8 -b 512 -ngl 999 -fa

main: build = 2774 (f364eb6f)
main: built with Apple clang version 15.0.0 (clang-1500.3.9.4) for arm64-apple-darwin23.4.0
main: seed  = 1714507110

llama_kv_cache_init:      Metal KV buffer size =  1024.00 MiB
llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.49 MiB
llama_new_context_with_model:      Metal compute buffer size =   258.50 MiB
llama_new_context_with_model:        CPU compute buffer size =    24.01 MiB

llama_print_timings:        load time =     543.86 ms
llama_print_timings:      sample time =     188.17 ms /  4096 runs   (    0.05 ms per token, 21767.32 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     1 tokens (    0.00 ms per token,      inf tokens per second)
llama_print_timings:        eval time =  422651.03 ms /  4096 runs   (  103.19 ms per token,     9.69 tokens per second)
llama_print_timings:       total time =  424131.56 ms /  4097 tokens

TL;DR: Generation speed increases from 8.70 t/s to 9.69 t/s, memory usage decreases slightly, prompt processing is not tested in this case.

@x4080
Copy link

x4080 commented Apr 30, 2024

Hi is server has flash attention yet ? Or is it automatically using flash attention ?

edit: just add -fa too in server got it

@LostRuins
Copy link
Collaborator

Hi, I am having issues building this on CUDA 11.4 now after this PR.

Notably, I am getting error : identifier "__hmax" is undefined and error : identifier "__hmax2" is undefined within fattn.cu

This is not the first time this has happened, previously we added #define CUDART_HMAX 11070 and then wrapped hmax and hmax2 functionality behind CUDART_VERSION >= CUDART_HMAX however this time this is not the case and thus the compile fails.

@JohannesGaessler
Copy link
Collaborator

@LostRuins can you check whether this fix #7019 works?

@Dampfinchen
Copy link

Dampfinchen commented May 3, 2024

Sadly I'm not seeing any benefit from this. No reduction in VRAM usage, no speedup, even when fully offloading.

Infact, I'm only seeing slower speeds when using partial offloading.

It seems this only applies to a low context like 4K.

Testing a very small LLM on my system with a context size of 13.000 Tokens and no GQA, the difference is massive.

VRAM savings from 2.8 to 1.2 GB, Text Generation from 37 to 71 token/s, pp from 1300 token/s to 2300 token/s.

Great work!

@dagbdagb
Copy link

dagbdagb commented May 4, 2024

From the dialogue above, I think I understand that the support for -fa needs to be coded per backend. Can someone confirm that? Not having much luck using -fa for the vulkan backend. I do not expect said support to materialize either, just want to clarify.

@JohannesGaessler
Copy link
Collaborator

It does need to be implemented per backend.

nopperl pushed a commit to nopperl/llama.cpp that referenced this pull request May 5, 2024
* ggml : add ggml_flash_attn_ext API

* ggml : fix GQA support in ggml_flash_attn_ext

* ggml : online attention (CPU)

* metal : initial implementation

* metal : f16 precision

* metal : reduce branches

* metal : specialize for head size

* wip : 8 rows per simd group

* wip : 4 rows per simd group

* wip : template for rows per warp

* metal : parallelize across KV size

* metal : parallel reduce across heads

* metal : efficient flash_attn_f16 implementation

* metal : avoid redundant loads of the attention

* metal : scale and mask in matrix form

* metal : fix comment

* llama : avoid ggml_cast, use F32 query

* metal : add parallel reduce version (disabled)

* metal : move output into local memory + optimize

- the result from each simdgroup now stays in the registers
- significantly reduced SRAM usage
- more efficient skipping of -INF blocks
- avoid simdgroup barrier in hot loop
- add comments

* metal : add tests, fix scaling, support C > 32

* metal : improve precision

* ggml : fix f16 mad

* metal : minor

* metal : support Q > 8

* tests : add ATTN tests

* metal : disable buffer allocation logs

* tests : more

* metal : faster inner loop for C == 32

* metal : fix array initialization

* tests : ifdef

* ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext

* ggml : fix ggml_soft_max mask requirement

* cuda : fix soft_max to use correct mask size

* cuda : add flash_attn kernel (wip)

* metal : optimize softmax for C > 32

* metal : optimize softmax

* tests : minor fix

* cuda : avoid zeroing fragments

* tests : update dims

* cuda : fix __hisinf() result check

* cuda : avoid warp_reduce for smax

* cuda : use int instead of int64_t

Noticeably improves performance (thanks to Johannes)

* cuda : make loops use the same loop values

Thanks Johannes again for the tip

* cuda : unroll some of the loops

* cuda : avoid __hisinf branches

* cuda : use half2 in softmax

* cuda : switch to 1 warp for bs > 16

* cuda : speed-up reduce part of the kernel

* cuda : unroll Q*K^T loop

* cuda : fix -INF block check

* cuda : simplify softmax

* cuda : fix matrix names

* cuda : minor

* llama : adapt to F16 KQ_pos

* llama : adapt new models to F16 KQ_mask

* ggml : fix F16 store (ARM NEON)

* llama : fix type of KQ_mask and KQ_pos

* ggml : fix CPU soft_max

* tests : add hs=256

* cuda : fix build

* metal : improve perf via smaller int registers

* cuda : adapt soft_max to F16 mask and pos

* CUDA: faster FlashAttention, kernel for bs == 1

* 16 cols for Phi-2

* no vec for hs, no hs==256 ncols==32 for Volta

* adjust kernel selection logic

* 4 warps, 256 stride for all D

* no ncols == 64

* Multiple parallel blocks for batch size 1

* fix compile warnings

* fix excessive KQ_b loads

* fix cmake build

* fix KV cache padding, NaN from INFINITY (ggerganov#6438)

* llama : flash_attn cparam + fix defrag

* server: support flash_attn param

* server: bench: enable flash_attn param

* CUDA: refactor host code, dyn. par. blocks

* fix flash_attn_vec_f16 race condition

* flush softmax exp below threshold to 0

* store temp KQ in registers

* Calculate KQ as FP32 if KQV has GGML_PREC_F32

* Add __hgt2_mask implementation for CUDA 11

* fix KQ FP32 precision fpr parallel_blocks > 1

* llama-bench : add -fa,--flash-attn arg

* metal : add BS=1 kernel for flash attention (ggerganov#6508)

* metal : add BS=1 kernel for flash attention (wip)

* metal : support more than 1 warps

* metal : opts

* metal : opt

* metal : switch to parallel reduce

* metal : reduce registers

* metal : simplify

* metal : initial FA vec kernel

* metal : use F32 attention accumulators

* batched-bench : add fattn arg

* llama : simplify llama_build_kv_store

ggml-ci

* llama : adapt build_olmo to changes

* ggml : fix arm fp16 store on windows

* metal : clean-up

* metal : clean-up kernel code

* metal : minor

* tests : remove benchmarks

ggml-ci

* ggml : fix avx512 const correctness

ggml-ci

* ggml : fix soft_max with bias on CPU

ggml-ci

* common : print --flash-attn in help

* ggml : fix num dimensions in ggml_flash_attn_ext

* llama : force disable flash attention for incompatible models

* ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci

* cuda : uint -> uint32_t

* cuda : "constexpr dim3" -> "const dim3"

ggml-ci

* cuda : try to fix __hgt2_mask

ggml-ci

* ggml : add TODO's for F16/F32 mask/pos support in other backends

* llama : replace bool need_kq_pos with use_alibi

* llama : prep ALiBi support for BERT models

ggml-ci

* llama : fix n_batch requirements

ggml-ci

* cont

* server : add help for --flash-attn arg

* llama : disable FA for AMD

* tests : remove TMP_ATTN_BENCH

ggml-ci

* llama : support save/load state with FA enabled

ggml-ci

* ci : add CUDA save-load-state tests

ggml-ci

* llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci

* llama : fix copy-paste errors, add TODO

* llama : disallow incompatible states

* llama : update llama_state_get_size after v_trans field

* metal : remove tmp log

* llama : add static reminder for llama_state_get_size

* metal : fix max nsg

ggml-ci

* ci : fix arg order

ggml-ci

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
need feedback Testing and feedback with results are needed performance Speed related topics
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet