Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: generalize FP16 fattn vec kernel #7061

Merged
merged 7 commits into from May 9, 2024

Conversation

JohannesGaessler
Copy link
Collaborator

This PR adds a FlashAttention kernel that only uses regular FP16 arithmetic and not any tensor core operations by generalizing the kernel I wrote for batch size 1 to larger batch sizes. It works reasonably well for batch sizes <= 8. The target hardware is the NVIDIA P100 and AMD RX 5000/6000 GPUs although the code also makes it possible to run FlashAttention on AMD RX 7000 (with likely much worse performance) and on other Pascal GPUs (with effectively unusable performance). On my RX 6800 the performance changes as follows:

GPU Model n_batch fa test t/s
RX 6800 Mistral 7b q4_0 1 0 pp 4096 29.22
RX 6800 Mistral 7b q4_0 1 1 pp 4096 38.37
RX 6800 Mistral 7b q4_0 2 0 pp 4096 55.52
RX 6800 Mistral 7b q4_0 2 1 pp 4096 79.58
RX 6800 Mistral 7b q4_0 4 0 pp 4096 90.93
RX 6800 Mistral 7b q4_0 4 1 pp 4096 113.53
RX 6800 Mistral 7b q4_0 8 0 pp 4096 132.21
RX 6800 Mistral 7b q4_0 8 1 pp 4096 135.40
RX 6800 Mistral 7b q4_0 16 0 pp 4096 122.97
RX 6800 Mistral 7b q4_0 16 1 pp 4096 90.23
RX 6800 Mistral 7b q4_0 32 0 pp 4096 224.92
RX 6800 Mistral 7b q4_0 32 1 pp 4096 175.97
RX 6800 Mistral 7b q4_0 64 0 pp 4096 365.67
RX 6800 Mistral 7b q4_0 64 1 pp 4096 247.05
RX 6800 Mistral 7b q4_0 128 0 pp 4096 464.88
RX 6800 Mistral 7b q4_0 128 1 pp 4096 317.34
RX 6800 Mistral 7b q4_0 256 0 pp 4096 552.33
RX 6800 Mistral 7b q4_0 256 1 pp 4096 376.56
RX 6800 Mistral 7b q4_0 512 0 pp 4096 534.94
RX 6800 Mistral 7b q4_0 512 1 pp 4096 390.29
RX 6800 Mistral 7b q4_0 1024 0 pp 4096 532.45
RX 6800 Mistral 7b q4_0 1024 1 pp 4096 388.57
RX 6800 Mistral 7b q4_0 2048 0 pp 4096 531.10
RX 6800 Mistral 7b q4_0 2048 1 pp 4096 387.92
RX 6800 Mistral 7b q4_0 4096 0 pp 4096 530.76
RX 6800 Mistral 7b q4_0 4096 1 pp 4096 387.56

This PR also rearranges the order of some definitions in ggml-cuda/common.cuh. This is because on master NO_DEVICE_CODE is broken on AMD due to the definitions being in the wrong order. I also implemented some warp reduce functions that didn't work on AMD.

@jdecourval when you worked on #6773 , did you check that the code actually produces correct results? I encountered an issue where due to missing implementations in ggml-cuda/common.cuh the compiler would optimize out part of the kernel which resulted in significantly faster but useless code.

@sorasoras
Copy link

This Pr

.\test-backend-ops.exe -o FLASH_ATTN_EXT -b ROCm0 perf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
Testing 2 backends

Backend 1/2 (CPU)
  Skipping
Backend 2/2 (ROCm0)
  Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -    21.60 us/run -     4144 kB/run -  182.94 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -    33.36 us/run -     4160 kB/run -  118.92 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -    75.89 us/run -     4192 kB/run -   52.68 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -   141.53 us/run -     4256 kB/run -   28.68 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -    52.41 us/run -     8272 kB/run -  150.51 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -    78.71 us/run -     8288 kB/run -  100.42 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -   154.86 us/run -     8320 kB/run -   51.24 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -   286.94 us/run -     8384 kB/run -   27.87 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                  GGML_ASSERT: W:/git/test/Johannes/llama.cpp/ggml-cuda/fattn.cu:827: Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."

Main branch For reference

Main  branch For reference

 .\test-backend-ops.exe -o FLASH_ATTN_EXT -b ROCm0 perf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
Testing 2 backends

Backend 1/2 (CPU)
  Skipping
Backend 2/2 (ROCm0)
  Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -     6.48 us/run -     4144 kB/run -  609.86 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -     2.73 us/run -     4160 kB/run - 1452.77 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -     2.55 us/run -     4192 kB/run - 1568.59 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -     2.84 us/run -     4256 kB/run - 1430.41 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -     5.11 us/run -     8272 kB/run - 1542.40 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -     2.78 us/run -     8288 kB/run - 2846.27 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -     2.97 us/run -     8320 kB/run - 2673.81 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -     2.90 us/run -     8384 kB/run - 2755.61 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                       6488 runs -     3.03 us/run -     5172 kB/run - 1627.99 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=2):                       6463 runs -     2.94 us/run -     5192 kB/run - 1683.13 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=4):                       6414 runs -     2.77 us/run -     5232 kB/run - 1800.17 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=8):                       6317 runs -     2.59 us/run -     5312 kB/run - 1954.16 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=1):                      3251 runs -     3.14 us/run -    10324 kB/run - 3138.70 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=2):                      3244 runs -     3.31 us/run -    10344 kB/run - 2984.09 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=4):                      3232 runs -     3.35 us/run -    10384 kB/run - 2959.17 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=8):                      3207 runs -     3.24 us/run -    10464 kB/run - 3081.70 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1):                      4065 runs -     5.39 us/run -     8256 kB/run - 1460.06 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2):                      4049 runs -     3.91 us/run -     8288 kB/run - 2021.32 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4):                      4018 runs -     3.66 us/run -     8352 kB/run - 2176.98 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8):                      3957 runs -     3.31 us/run -     8480 kB/run - 2445.43 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1):                     2037 runs -     4.96 us/run -    16480 kB/run - 3169.76 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2):                     2033 runs -     5.08 us/run -    16512 kB/run - 3100.01 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4):                     2025 runs -     4.61 us/run -    16576 kB/run - 3426.25 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8):                     2009 runs -     4.94 us/run -    16704 kB/run - 3225.86 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=1):                      2037 runs -     6.62 us/run -    16480 kB/run - 2372.86 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=2):                      2029 runs -     4.72 us/run -    16544 kB/run - 3340.92 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=4):                      2013 runs -     5.77 us/run -    16672 kB/run - 2757.24 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=8):                      1983 runs -     5.78 us/run -    16928 kB/run - 2795.18 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=1):                     1021 runs -     9.35 us/run -    32896 kB/run - 3354.72 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=2):                     1019 runs -     6.51 us/run -    32960 kB/run - 4826.75 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=4):                     1015 runs -     4.92 us/run -    33088 kB/run - 6417.25 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=8):                     1007 runs -     5.24 us/run -    33344 kB/run - 6069.35 GB/s
  Backend ROCm0: OK

@JohannesGaessler
Copy link
Collaborator Author

There is no ROCm FlashAttention implementation on master. The only reason the test is passing is that NO_DEVICE_CODE is broken so the kernel does nothing instead of throwing an error.

Copy link
Contributor

github-actions bot commented May 3, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 557 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8410.82ms p(95)=20154.98ms fails=, finish reason: stop=491 truncated=66
  • Prompt processing (pp): avg=99.68tk/s p(95)=419.18tk/s
  • Token generation (tg): avg=33.73tk/s p(95)=48.84tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=cuda-fa-no-tc-5 commit=fece1fe48253368cbf06653c2f98dac3d843ddf3

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715256904 --> 1715257530
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 381.62, 381.62, 381.62, 381.62, 381.62, 422.36, 422.36, 422.36, 422.36, 422.36, 474.0, 474.0, 474.0, 474.0, 474.0, 536.01, 536.01, 536.01, 536.01, 536.01, 559.47, 559.47, 559.47, 559.47, 559.47, 567.91, 567.91, 567.91, 567.91, 567.91, 591.48, 591.48, 591.48, 591.48, 591.48, 612.77, 612.77, 612.77, 612.77, 612.77, 638.65, 638.65, 638.65, 638.65, 638.65, 643.97, 643.97, 643.97, 643.97, 643.97, 669.36, 669.36, 669.36, 669.36, 669.36, 689.16, 689.16, 689.16, 689.16, 689.16, 717.83, 717.83, 717.83, 717.83, 717.83, 677.78, 677.78, 677.78, 677.78, 677.78, 682.48, 682.48, 682.48, 682.48, 682.48, 694.24, 694.24, 694.24, 694.24, 694.24, 711.37, 711.37, 711.37, 711.37, 711.37, 718.22, 718.22, 718.22, 718.22, 718.22, 726.76, 726.76, 726.76, 726.76, 726.76, 727.9, 727.9, 727.9, 727.9, 727.9, 733.53, 733.53, 733.53, 733.53, 733.53, 740.01, 740.01, 740.01, 740.01, 740.01, 740.97, 740.97, 740.97, 740.97, 740.97, 743.4, 743.4, 743.4, 743.4, 743.4, 761.54, 761.54, 761.54, 761.54, 761.54, 760.26, 760.26, 760.26, 760.26, 760.26, 761.67, 761.67, 761.67, 761.67, 761.67, 763.48, 763.48, 763.48, 763.48, 763.48, 772.56, 772.56, 772.56, 772.56, 772.56, 771.82, 771.82, 771.82, 771.82, 771.82, 774.36, 774.36, 774.36, 774.36, 774.36, 782.99, 782.99, 782.99, 782.99, 782.99, 782.43, 782.43, 782.43, 782.43, 782.43, 797.79, 797.79, 797.79, 797.79, 797.79, 801.21, 801.21, 801.21, 801.21, 801.21, 799.82, 799.82, 799.82, 799.82, 799.82, 799.97, 799.97, 799.97, 799.97, 799.97, 803.67, 803.67, 803.67, 803.67, 803.67, 801.64, 801.64, 801.64, 801.64, 801.64, 807.95, 807.95, 807.95, 807.95, 807.95, 803.3, 803.3, 803.3, 803.3, 803.3, 803.7, 803.7, 803.7, 803.7, 803.7, 802.88, 802.88, 802.88, 802.88, 802.88, 802.13, 802.13, 802.13, 802.13, 802.13, 806.87, 806.87, 806.87, 806.87, 806.87, 810.68, 810.68, 810.68, 810.68, 810.68, 811.47, 811.47, 811.47, 811.47, 811.47, 816.64, 816.64, 816.64, 816.64, 816.64, 817.12, 817.12, 817.12, 817.12, 817.12, 822.36, 822.36, 822.36, 822.36, 822.36, 822.54, 822.54, 822.54, 822.54, 822.54, 822.34, 822.34, 822.34, 822.34, 822.34, 828.93, 828.93, 828.93, 828.93, 828.93, 830.1, 830.1, 830.1, 830.1, 830.1, 830.61, 830.61, 830.61, 830.61, 830.61, 831.91, 831.91, 831.91, 831.91, 831.91, 830.75, 830.75, 830.75, 830.75, 830.75, 831.98, 831.98, 831.98, 831.98, 831.98, 833.57, 833.57, 833.57, 833.57, 833.57, 833.79, 833.79, 833.79, 833.79, 833.79, 833.89, 833.89, 833.89, 833.89]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715256904 --> 1715257530
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.2, 42.2, 42.2, 42.2, 42.2, 43.62, 43.62, 43.62, 43.62, 43.62, 33.3, 33.3, 33.3, 33.3, 33.3, 33.93, 33.93, 33.93, 33.93, 33.93, 34.15, 34.15, 34.15, 34.15, 34.15, 35.78, 35.78, 35.78, 35.78, 35.78, 36.37, 36.37, 36.37, 36.37, 36.37, 36.78, 36.78, 36.78, 36.78, 36.78, 37.02, 37.02, 37.02, 37.02, 37.02, 36.83, 36.83, 36.83, 36.83, 36.83, 36.7, 36.7, 36.7, 36.7, 36.7, 36.4, 36.4, 36.4, 36.4, 36.4, 35.02, 35.02, 35.02, 35.02, 35.02, 34.14, 34.14, 34.14, 34.14, 34.14, 34.23, 34.23, 34.23, 34.23, 34.23, 34.69, 34.69, 34.69, 34.69, 34.69, 33.98, 33.98, 33.98, 33.98, 33.98, 34.0, 34.0, 34.0, 34.0, 34.0, 33.78, 33.78, 33.78, 33.78, 33.78, 33.7, 33.7, 33.7, 33.7, 33.7, 33.55, 33.55, 33.55, 33.55, 33.55, 33.56, 33.56, 33.56, 33.56, 33.56, 33.48, 33.48, 33.48, 33.48, 33.48, 33.51, 33.51, 33.51, 33.51, 33.51, 33.49, 33.49, 33.49, 33.49, 33.49, 33.02, 33.02, 33.02, 33.02, 33.02, 32.99, 32.99, 32.99, 32.99, 32.99, 33.27, 33.27, 33.27, 33.27, 33.27, 33.44, 33.44, 33.44, 33.44, 33.44, 33.44, 33.44, 33.44, 33.44, 33.44, 33.55, 33.55, 33.55, 33.55, 33.55, 33.56, 33.56, 33.56, 33.56, 33.56, 33.47, 33.47, 33.47, 33.47, 33.47, 33.25, 33.25, 33.25, 33.25, 33.25, 32.75, 32.75, 32.75, 32.75, 32.75, 32.75, 32.75, 32.75, 32.75, 32.75, 32.92, 32.92, 32.92, 32.92, 32.92, 32.97, 32.97, 32.97, 32.97, 32.97, 33.01, 33.01, 33.01, 33.01, 33.01, 33.13, 33.13, 33.13, 33.13, 33.13, 32.91, 32.91, 32.91, 32.91, 32.91, 32.35, 32.35, 32.35, 32.35, 32.35, 32.23, 32.23, 32.23, 32.23, 32.23, 30.59, 30.59, 30.59, 30.59, 30.59, 30.52, 30.52, 30.52, 30.52, 30.52, 30.53, 30.53, 30.53, 30.53, 30.53, 30.62, 30.62, 30.62, 30.62, 30.62, 30.74, 30.74, 30.74, 30.74, 30.74, 30.88, 30.88, 30.88, 30.88, 30.88, 30.88, 30.88, 30.88, 30.88, 30.88, 30.86, 30.86, 30.86, 30.86, 30.86, 30.76, 30.76, 30.76, 30.76, 30.76, 30.68, 30.68, 30.68, 30.68, 30.68, 30.77, 30.77, 30.77, 30.77, 30.77, 30.93, 30.93, 30.93, 30.93, 30.93, 31.01, 31.01, 31.01, 31.01, 31.01, 31.09, 31.09, 31.09, 31.09, 31.09, 31.13, 31.13, 31.13, 31.13, 31.13, 31.19, 31.19, 31.19, 31.19, 31.19, 31.15, 31.15, 31.15, 31.15, 31.15, 31.17, 31.17, 31.17, 31.17]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715256904 --> 1715257530
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11, 0.11, 0.11, 0.11, 0.11, 0.45, 0.45, 0.45, 0.45, 0.45, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.27, 0.27, 0.27, 0.27, 0.27, 0.32, 0.32, 0.32, 0.32, 0.32, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.17, 0.17, 0.17, 0.17, 0.17, 0.21, 0.21, 0.21, 0.21, 0.21, 0.18, 0.18, 0.18, 0.18, 0.18, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.32, 0.32, 0.32, 0.32, 0.32, 0.08, 0.08, 0.08, 0.08, 0.08, 0.16, 0.16, 0.16, 0.16, 0.16, 0.35, 0.35, 0.35, 0.35, 0.35, 0.15, 0.15, 0.15, 0.15, 0.15, 0.11, 0.11, 0.11, 0.11, 0.11, 0.08, 0.08, 0.08, 0.08, 0.08, 0.1, 0.1, 0.1, 0.1, 0.1, 0.16, 0.16, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14, 0.14, 0.16, 0.16, 0.16, 0.16, 0.16, 0.11, 0.11, 0.11, 0.11, 0.11, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.28, 0.18, 0.18, 0.18, 0.18, 0.18, 0.11, 0.11, 0.11, 0.11, 0.11, 0.09, 0.09, 0.09, 0.09, 0.09, 0.1, 0.1, 0.1, 0.1, 0.1, 0.35, 0.35, 0.35, 0.35, 0.35, 0.55, 0.55, 0.55, 0.55, 0.55, 0.63, 0.63, 0.63, 0.63, 0.63, 0.6, 0.6, 0.6, 0.6, 0.6, 0.22, 0.22, 0.22, 0.22, 0.22, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23, 0.23, 0.19, 0.19, 0.19, 0.19, 0.19, 0.12, 0.12, 0.12, 0.12, 0.12, 0.09, 0.09, 0.09, 0.09, 0.09, 0.15, 0.15, 0.15, 0.15, 0.15, 0.08, 0.08, 0.08, 0.08, 0.08, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715256904 --> 1715257530
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0]
                    

@8XXD8
Copy link

8XXD8 commented May 4, 2024

I get an error if I try main with -fa
Gpu is a Radeon Pro VII gfx906 arch
Rocm 6.1

ggml-cuda/fattn.cu:219: ERROR: HIP kernel flash_attn_vec_ext_f16 has no device code compatible with HIP arch 1300.
:0:rocdevice.cpp            :2879: 4907811378 us: [pid:98972 tid:0x7f7aad2006c0] Callback: Queue 0x7f766fd00000 aborting with error : HSA_STATUS_ERROR_EXCEPTION: An HSAIL operation resulted in a hardware exception. code: 0x1016

@JohannesGaessler
Copy link
Collaborator Author

The issue was that the logic for determining when to compile the kernel was incorrect. Does it work now?

@8XXD8
Copy link

8XXD8 commented May 4, 2024

The issue was that the logic for determining when to compile the kernel was incorrect. Does it work now?

Now it runs, but the output is broken when FlashAttention is enabled and the --batch-size is 8 or smaller.

--batch-size 16

<|begin_of_text|><|start_header_id|>user<|end_header_id|>Tell me a joke<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here's one:

Why couldn't the bicycle stand up by itself?

(wait for it...)

Because it was two-tired!

Hope that made you laugh!<|eot_id|> [end of text]

--batch-size 8

<|begin_of_text|><|start_header_id|>user<|end_header_id|>Tell me a joke<|eot_id|><|start_header_id|>assistant<|end_header_id|> 8 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2

@JohannesGaessler
Copy link
Collaborator Author

Thank you for the bug report, the issue was that I wrote back the data in the wrong order for batch sizes 2-8 and I didn't notice because those batch sizes were not being used when I tested for correctness. It should now be fixed.

@sorasoras
Copy link

sorasoras commented May 5, 2024

Can confirm this pr is quite a bit slower for my workload.
it's about 1400t/ps vs 1600t/ps on the main branch for 7900XTX RDNA3 with IQ4XS 14B.

I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it.

@vonjackustc
Copy link

Can confirm this pr is quite a bit slower for my workload. it's about 1400t/ps vs 1600t/ps on the main branch for 7900XTX RDNA3 with IQ4XS 14B.

I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it.

Can confirm this pr is quite a bit slower for my workload. it's about 1400t/ps vs 1600t/ps on the main branch for 7900XTX RDNA3 with IQ4XS 14B.

I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it.

This PR works perfect on my P40.

@sorasoras
Copy link

Can confirm this pr is quite a bit slower for my workload. it's about 1400t/ps vs 1600t/ps on the main branch for 7900XTX RDNA3 with IQ4XS 14B.
I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it.

Can confirm this pr is quite a bit slower for my workload. it's about 1400t/ps vs 1600t/ps on the main branch for 7900XTX RDNA3 with IQ4XS 14B.
I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it.

This PR works perfect on my P40.

but it's gonna slow right? P40 is pretty slow on FP16 as it only support int8/fp32.

@JohannesGaessler
Copy link
Collaborator Author

I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it.

The current PR will produce correct results but for usable performance P40s will need a dedicated FP32 kernel which I will also add.

@jdecourval
Copy link
Contributor

jdecourval commented May 5, 2024

@jdecourval when you worked on #6773 , did you check that the code actually produces correct results? I encountered an issue where due to missing implementations in ggml-cuda/common.cuh the compiler would optimize out part of the kernel which resulted in significantly faster but useless code.

At that time, I did not try to quantify the difference in quality because comparing the two versions side by side seemed to produce very similar results. If what you said is true for me, it should translate in a higher perplexity, right? Here what I get:

My PR:
No FA: 6.8455 +/- 0.04312
FA: 6.8443 +/- 0.04311
Your PR:
No FA: 6.8455 +/- 0.04312
FA: 6.8437 +/- 0.04310

Same result. This is with a random Q6_K model over wikitext.

About performance, here is what I get.

My PR:

model size params backend ngl fa test t/s
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 pp 4096 615.30 ± 4.76
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 tg 128 26.92 ± 0.03
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 pp 4096 601.02 ± 0.30
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 tg 128 26.73 ± 0.08
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 2510.93 ± 2.23
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 86.70 ± 2.86
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2444.75 ± 1.94
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 86.43 ± 0.07
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 1037.43 ± 7.21
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 54.56 ± 0.03
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1008.34 ± 6.56
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 52.83 ± 0.25
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 2546.25 ± 1.66
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 86.34 ± 0.12
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2478.36 ± 3.83
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.46 ± 0.02
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 640.28 ± 0.79
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 27.02 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 609.36 ± 0.18
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 26.46 ± 0.00

./bin/batched-bench qwen1.5-32b-chat-imat-Q4_K_M.gguf 10000 2048 512 1 1 99 8192 256 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
8192 256 1 8448 13.103 625.19 14.516 17.64 27.619 305.88

./bin/test-backend-ops -o FLASH_ATTN_EXT perf

Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -    23.44 us/run -     4144 kB/run -  168.61 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -    14.61 us/run -     4160 kB/run -  271.63 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -    14.59 us/run -     4192 kB/run -  274.01 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -    14.80 us/run -     4256 kB/run -  274.31 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -    67.69 us/run -     8272 kB/run -  116.54 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -    23.09 us/run -     8288 kB/run -  342.27 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -    22.58 us/run -     8320 kB/run -  351.35 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -    23.53 us/run -     8384 kB/run -  339.73 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                       6488 runs -    19.41 us/run -     5172 kB/run -  254.16 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=2):                       6463 runs -    19.35 us/run -     5192 kB/run -  255.88 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=4):                       6414 runs -    19.26 us/run -     5232 kB/run -  259.03 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=8):                       6317 runs -    20.24 us/run -     5312 kB/run -  250.28 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=1):                      3251 runs -    27.26 us/run -    10324 kB/run -  361.18 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=2):                      3244 runs -    27.47 us/run -    10344 kB/run -  359.05 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=4):                      3232 runs -    27.41 us/run -    10384 kB/run -  361.26 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=8):                      3207 runs -    26.86 us/run -    10464 kB/run -  371.56 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1):                      4065 runs -    23.89 us/run -     8256 kB/run -  329.52 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2):                      4049 runs -    21.80 us/run -     8288 kB/run -  362.56 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4):                      4018 runs -    21.87 us/run -     8352 kB/run -  364.20 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8):                      3957 runs -    22.68 us/run -     8480 kB/run -  356.55 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1):                     2037 runs -    40.49 us/run -    16480 kB/run -  388.15 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2):                     2033 runs -    35.27 us/run -    16512 kB/run -  446.51 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4):                     2025 runs -    35.21 us/run -    16576 kB/run -  448.92 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8):                     2009 runs -    36.00 us/run -    16704 kB/run -  442.52 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=1):                      2037 runs -    28.47 us/run -    16480 kB/run -  552.06 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=2):                      2029 runs -    54.75 us/run -    16544 kB/run -  288.20 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=4):                      2013 runs -    56.29 us/run -    16672 kB/run -  282.45 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=8):                      1983 runs -    60.91 us/run -    16928 kB/run -  265.03 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=1):                     1021 runs -    34.46 us/run -    32896 kB/run -  910.51 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=2):                     1019 runs -    92.64 us/run -    32960 kB/run -  339.31 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=4):                     1015 runs -    92.20 us/run -    33088 kB/run -  342.25 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=8):                     1007 runs -    96.85 us/run -    33344 kB/run -  328.34 GB/s

This PR:

model size params backend ngl fa test t/s
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 pp 4096 419.25 ± 0.48
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 tg 128 26.84 ± 0.00
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 pp 4096 602.15 ± 0.30
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 tg 128 26.76 ± 0.01
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 1391.26 ± 0.92
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 85.85 ± 0.02
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2447.13 ± 1.94
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 86.34 ± 0.08
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 786.37 ± 1.16
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 53.55 ± 0.01
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1035.53 ± 2.19
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 53.27 ± 0.02
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 1403.10 ± 1.77
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 84.30 ± 0.03
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2484.03 ± 1.21
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.62 ± 0.02
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 374.89 ± 0.36
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 26.58 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 609.28 ± 0.24
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 26.55 ± 0.02

./bin/batched-bench qwen1.5-32b-chat-imat-Q4_K_M.gguf 10000 2048 512 1 1 99 8192 256 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
8192 256 1 8448 25.681 318.98 15.392 16.63 41.074 205.68

./bin/test-backend-ops -o FLASH_ATTN_EXT perf

Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -    24.90 us/run -     4144 kB/run -  158.72 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -    36.88 us/run -     4160 kB/run -  107.56 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -    56.78 us/run -     4192 kB/run -   70.41 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -   105.29 us/run -     4256 kB/run -   38.55 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -    64.10 us/run -     8272 kB/run -  123.08 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -    85.10 us/run -     8288 kB/run -   92.88 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -   116.12 us/run -     8320 kB/run -   68.33 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -   211.50 us/run -     8384 kB/run -   37.80 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                  GGML_ASSERT: /home/jerome/Prog/online/llama.cpp/ggml-cuda/fattn.cu:828: Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."

@sorasoras
Copy link

@jdecourval when you worked on #6773 , did you check that the code actually produces correct results? I encountered an issue where due to missing implementations in ggml-cuda/common.cuh the compiler would optimize out part of the kernel which resulted in significantly faster but useless code.

At that time, I did not try to quantify the difference in quality because comparing the two versions side by side seemed to produce very similar results. If what you said is true for me, it should translate in a higher perplexity, right? Here what I get:

My PR: No FA: 6.8455 +/- 0.04312 FA: 6.8443 +/- 0.04311 Your PR: No FA: 6.8455 +/- 0.04312 FA: 6.8437 +/- 0.04310

Same result. This is with a random Q6_K model over wikitext.

About performance, here is what I get.

My PR:

model size params backend ngl fa test t/s
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 pp 4096 615.30 ± 4.76
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 tg 128 26.92 ± 0.03
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 pp 4096 601.02 ± 0.30
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 tg 128 26.73 ± 0.08
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 2510.93 ± 2.23
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 86.70 ± 2.86
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2444.75 ± 1.94
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 86.43 ± 0.07
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 1037.43 ± 7.21
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 54.56 ± 0.03
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1008.34 ± 6.56
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 52.83 ± 0.25
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 2546.25 ± 1.66
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 86.34 ± 0.12
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2478.36 ± 3.83
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.46 ± 0.02
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 640.28 ± 0.79
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 27.02 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 609.36 ± 0.18
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 26.46 ± 0.00
./bin/batched-bench qwen1.5-32b-chat-imat-Q4_K_M.gguf 10000 2048 512 1 1 99 8192 256 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
8192 256 1 8448 13.103 625.19 14.516 17.64 27.619 305.88
./bin/test-backend-ops -o FLASH_ATTN_EXT perf

Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -    23.44 us/run -     4144 kB/run -  168.61 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -    14.61 us/run -     4160 kB/run -  271.63 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -    14.59 us/run -     4192 kB/run -  274.01 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -    14.80 us/run -     4256 kB/run -  274.31 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -    67.69 us/run -     8272 kB/run -  116.54 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -    23.09 us/run -     8288 kB/run -  342.27 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -    22.58 us/run -     8320 kB/run -  351.35 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -    23.53 us/run -     8384 kB/run -  339.73 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                       6488 runs -    19.41 us/run -     5172 kB/run -  254.16 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=2):                       6463 runs -    19.35 us/run -     5192 kB/run -  255.88 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=4):                       6414 runs -    19.26 us/run -     5232 kB/run -  259.03 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=8):                       6317 runs -    20.24 us/run -     5312 kB/run -  250.28 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=1):                      3251 runs -    27.26 us/run -    10324 kB/run -  361.18 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=2):                      3244 runs -    27.47 us/run -    10344 kB/run -  359.05 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=4):                      3232 runs -    27.41 us/run -    10384 kB/run -  361.26 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=1024,nb=8):                      3207 runs -    26.86 us/run -    10464 kB/run -  371.56 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1):                      4065 runs -    23.89 us/run -     8256 kB/run -  329.52 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2):                      4049 runs -    21.80 us/run -     8288 kB/run -  362.56 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4):                      4018 runs -    21.87 us/run -     8352 kB/run -  364.20 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8):                      3957 runs -    22.68 us/run -     8480 kB/run -  356.55 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1):                     2037 runs -    40.49 us/run -    16480 kB/run -  388.15 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2):                     2033 runs -    35.27 us/run -    16512 kB/run -  446.51 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4):                     2025 runs -    35.21 us/run -    16576 kB/run -  448.92 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8):                     2009 runs -    36.00 us/run -    16704 kB/run -  442.52 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=1):                      2037 runs -    28.47 us/run -    16480 kB/run -  552.06 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=2):                      2029 runs -    54.75 us/run -    16544 kB/run -  288.20 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=4):                      2013 runs -    56.29 us/run -    16672 kB/run -  282.45 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=512,nb=8):                      1983 runs -    60.91 us/run -    16928 kB/run -  265.03 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=1):                     1021 runs -    34.46 us/run -    32896 kB/run -  910.51 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=2):                     1019 runs -    92.64 us/run -    32960 kB/run -  339.31 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=4):                     1015 runs -    92.20 us/run -    33088 kB/run -  342.25 GB/s
  FLASH_ATTN_EXT(hs=256,nh=32,kv=1024,nb=8):                     1007 runs -    96.85 us/run -    33344 kB/run -  328.34 GB/s

This PR:

model size params backend ngl fa test t/s
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 pp 4096 419.25 ± 0.48
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 tg 128 26.84 ± 0.00
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 pp 4096 602.15 ± 0.30
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 tg 128 26.76 ± 0.01
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 1391.26 ± 0.92
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 85.85 ± 0.02
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2447.13 ± 1.94
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 86.34 ± 0.08
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 786.37 ± 1.16
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 53.55 ± 0.01
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1035.53 ± 2.19
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 53.27 ± 0.02
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 1403.10 ± 1.77
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 84.30 ± 0.03
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2484.03 ± 1.21
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.62 ± 0.02
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 374.89 ± 0.36
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 26.58 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 609.28 ± 0.24
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 26.55 ± 0.02
./bin/batched-bench qwen1.5-32b-chat-imat-Q4_K_M.gguf 10000 2048 512 1 1 99 8192 256 1

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
8192 256 1 8448 25.681 318.98 15.392 16.63 41.074 205.68
./bin/test-backend-ops -o FLASH_ATTN_EXT perf

Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -    24.90 us/run -     4144 kB/run -  158.72 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -    36.88 us/run -     4160 kB/run -  107.56 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -    56.78 us/run -     4192 kB/run -   70.41 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -   105.29 us/run -     4256 kB/run -   38.55 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -    64.10 us/run -     8272 kB/run -  123.08 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -    85.10 us/run -     8288 kB/run -   92.88 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -   116.12 us/run -     8320 kB/run -   68.33 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -   211.50 us/run -     8384 kB/run -   37.80 GB/s
  FLASH_ATTN_EXT(hs=80,nh=32,kv=512,nb=1):                  GGML_ASSERT: /home/jerome/Prog/online/llama.cpp/ggml-cuda/fattn.cu:828: Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128."

so with FA, it got better PPL consistently .
does it make sense?

@JohannesGaessler
Copy link
Collaborator Author

@jdecourval thank you for the data. I'll still try to write a kernel without any tensor cores that is more optimized for large batch sizes first but given this data I think that rocWMMA is a viable option.

@sorasoras I don't think that enabling/disabling FlashAttention should significantly affect results since the differences should just be due to floating point rounding error.

@sorasoras
Copy link

.\llama-bench.exe -m W:\model\sakura0.9_13B_Qwen1.5_Q5KS_1.3.gguf -fa 0,1 -ngl 99 -p 4096
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:   no
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 7900 XTX, compute capability 11.0, VMM: no
| model                          |       size |     params | backend    | ngl |         fa | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ---------- | ---------------: |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 |          0 | pp 4096    |  1431.86 ± 29.41 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 |          0 | tg 128     |     54.97 ± 1.62 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 |          1 | pp 4096    |    879.78 ± 8.36 |
| qwen2 13B Q5_K - Small         |   9.33 GiB |    14.17 B | ROCm       |  99 |          1 | tg 128     |     54.53 ± 0.42 |

build: 57bde8c2 (2786)

RDNA3 is gonna slowdown, Any there any plan for this?

Backend 2/2 (ROCm0)
  Backend name: ROCm0
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=1):                       8098 runs -    21.32 us/run -     4144 kB/run -  185.40 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=2):                       8066 runs -    31.09 us/run -     4160 kB/run -  127.62 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=4):                       8005 runs -    49.55 us/run -     4192 kB/run -   80.69 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=512,nb=8):                       7885 runs -    89.97 us/run -     4256 kB/run -   45.11 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=1):                      4057 runs -    52.51 us/run -     8272 kB/run -  150.24 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=2):                      4049 runs -    70.54 us/run -     8288 kB/run -  112.05 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=4):                      4033 runs -   104.56 us/run -     8320 kB/run -   75.89 GB/s
  FLASH_ATTN_EXT(hs=64,nh=32,kv=1024,nb=8):                      4003 runs -   183.65 us/run -     8384 kB/run -   43.54 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=1):                      4065 runs -    23.07 us/run -     8256 kB/run -  341.24 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=2):                      4049 runs -    32.23 us/run -     8288 kB/run -  245.22 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=4):                      4018 runs -    42.93 us/run -     8352 kB/run -  185.54 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=512,nb=8):                      3957 runs -    64.24 us/run -     8480 kB/run -  125.89 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=1):                     2037 runs -    43.27 us/run -    16480 kB/run -  363.22 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=2):                     2033 runs -    62.55 us/run -    16512 kB/run -  251.76 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=4):                     2025 runs -    77.58 us/run -    16576 kB/run -  203.76 GB/s
  FLASH_ATTN_EXT(hs=128,nh=32,kv=1024,nb=8):                     2009 runs -   122.27 us/run -    16704 kB/run -  130.29 GB/s
  Backend ROCm0: OK

2/2 backends passed
OK

@JohannesGaessler
Copy link
Collaborator Author

RDNA3 is gonna slowdown, Any there any plan for this?

Just don't use FlashAttention until it's faster?

@slaren
Copy link
Collaborator

slaren commented May 6, 2024

Is there something in this PR that could cause a slowdown in TG?

GPU Model Model Size [GiB] Test t/s master t/s cuda-fa-no-tc-5 Speedup
RTX 3090 Ti llama 7B F16 12.55 pp512 6055.53 6031.66 1.00
RTX 3090 Ti llama 7B F16 12.55 tg128 57.08 56.52 0.99
RTX 3090 Ti llama 7B Q4_0 3.56 pp512 4955.19 4942.04 1.00
RTX 3090 Ti llama 7B Q4_0 3.56 tg128 138.63 136.01 0.98

There are also some warnings:

ggml-cuda/fattn.cu(215): warning #128-D: loop is not reachable
      for (int j = 0; j < ncols; ++j) {
      ^
          detected during:
            instantiation of "void flash_attn_vec_ext_f16<D,ncols,parallel_blocks>(const char *, const char *, const char *, const char *, float *, float2 *, float, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=64, ncols=8, parallel_blocks=1]" at line 714
            instantiation of "void launch_fattn_vec_f16<D,cols_per_block,parallel_blocks>(const ggml_tensor *, const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const ggml_tensor *, ggml_cuda_pool &, cudaStream_t) [with D=64, cols_per_block=8, parallel_blocks=1]" at line 902

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

ggml-cuda/fattn.cu(215): warning #128-D: loop is not reachable
      for (int j = 0; j < ncols; ++j) {
      ^
          detected during:
            instantiation of "void flash_attn_vec_ext_f16<D,ncols,parallel_blocks>(const char *, const char *, const char *, const char *, float *, float2 *, float, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=128, ncols=8, parallel_blocks=1]" at line 714
            instantiation of "void launch_fattn_vec_f16<D,cols_per_block,parallel_blocks>(const ggml_tensor *, const ggml_tensor *, const ggml_tensor *, ggml_tensor *, const ggml_tensor *, ggml_cuda_pool &, cudaStream_t) [with D=128, cols_per_block=8, parallel_blocks=1]" at line 905

@JohannesGaessler
Copy link
Collaborator Author

Thank you for pointing this out, I did not think to test this but the new kernel is for whatever reason massively slower than the one on master for a batch size of 1. According to NSight Compute the kernel runtime that I used for testing increased from 70 µs to 92 µs. I really don't understand why this is happening. The compiler should be able to unroll all loops over ncols at which point the kernels should be identical. I checked the kernels multiple times and all I could find are minor differences that only affect the performance in small ways.

Should we just keep both kernel versions?

@slaren
Copy link
Collaborator

slaren commented May 8, 2024

Personally I am not interested about the hardware that benefits from this change, and I don't expect to make any changes to this code, so if you want to maintain another copy of the flash attn kernel that's entirely up to you. It would be nice to keep the same performance with tensor cores as in master, but IMO the difference in overall performance is small enough that it could be ignored.

More results with the latest commit:

GPU Model Model Size [GiB] Num. of Parameters Test t/s master t/s cuda-fa-no-tc-5 Speedup
RTX 3090 Ti llama 13B Q4_0 6.86 13015864320 pp512 2942.95 2928.26 1.00
RTX 3090 Ti llama 13B Q4_0 6.86 13015864320 pp1024 2926.21 2907.86 0.99
RTX 3090 Ti llama 13B Q4_0 6.86 13015864320 pp2048 2861.99 2842.56 0.99
RTX 3090 Ti llama 13B Q4_0 6.86 13015864320 tg128 85.20 84.30 0.99
RTX 3090 Ti llama 30B Q4_0 17.09 32528943616 pp512 1253.18 1246.15 0.99
RTX 3090 Ti llama 30B Q4_0 17.09 32528943616 pp1024 1240.20 1235.61 1.00
RTX 3090 Ti llama 30B Q4_0 17.09 32528943616 pp2048 1217.37 1212.61 1.00
RTX 3090 Ti llama 30B Q4_0 17.09 32528943616 tg128 39.34 39.09 0.99
RTX 3090 Ti llama 7B Q4_0 3.56 6738415616 pp512 4960.42 4954.14 1.00
RTX 3090 Ti llama 7B Q4_0 3.56 6738415616 pp1024 4921.20 4894.45 0.99
RTX 3090 Ti llama 7B Q4_0 3.56 6738415616 pp2048 4787.11 4776.69 1.00
RTX 3090 Ti llama 7B Q4_0 3.56 6738415616 tg128 138.61 135.68 0.98

@ggerganov
Copy link
Owner

Thank you for pointing this out, I did not think to test this but the new kernel is for whatever reason massively slower than the one on master for a batch size of 1.

If you start reverting the for (int j = ... loops one-by-one is there a specific one that restores the performance back to normal?

@JohannesGaessler
Copy link
Collaborator Author

Intuitively I would think the issue has to do with the data types (array vs. regular variables) rather than the loops.

@jdecourval
Copy link
Contributor

jdecourval commented May 8, 2024

FWIW, I tried having a quick look at the performance on my 7900xtx, was able to bring it back to the level of your previous kernel, by removing most checks for AMD, which meant re-enabling rocwmma. Compare with above:

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
8192 256 1 8448 13.125 624.14 14.230 17.99 27.355 308.82
model size params backend ngl fa test t/s
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 2532.53 ± 2.73
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 90.70 ± 0.03
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2472.87 ± 1.98
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 89.04 ± 0.50
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 1055.12 ± 1.81
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 54.58 ± 0.02
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1037.52 ± 1.89
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 53.86 ± 0.02
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 2561.00 ± 2.06
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 86.16 ± 0.02
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2503.39 ± 1.86
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.87 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 641.68 ± 0.48
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 28.95 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 612.67 ± 0.47
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 28.65 ± 0.00

It looks like this (pile of hacks): jdecourval@b82649d

@JohannesGaessler
Copy link
Collaborator Author

I can confirm that the performance regression is caused by half var vs. half var[] and/or half2 var vs. half2 var[]. I'll investigate more but I think the way to move forward is to define both types and to choose which one to use based on ncols. The compiler should then simply optimize out the ones that aren't actually being used.

In any case, I think this is worth asking my contact at NVIDIA about. This could be a legitimate compiler bug.

@mofosyne mofosyne added the enhancement New feature or request label May 9, 2024
@JohannesGaessler
Copy link
Collaborator Author

From what I can tell the issue is caused specifically kqmax_new and the performance regression can be fixed with minimal changes. @slaren please confirm whether or not this also fixes the performance regression on your machine.

@mofosyne mofosyne added the review complexity : high Generally require indepth knowledge of LLMs or GPUs label May 9, 2024
@slaren
Copy link
Collaborator

slaren commented May 9, 2024

I think so.

GPU Model Model Size [GiB] Num. of Parameters Test t/s master t/s cuda-fa-no-tc-5 Speedup
RTX 3090 Ti llama 13B Q4_0 6.86 13015864320 tg128 95.56 95.89 1.00
RTX 3090 Ti llama 30B Q4_0 17.09 32528943616 tg128 42.69 42.69 1.00
RTX 3090 Ti llama 7B Q4_0 3.56 6738415616 tg128 164.25 165.74 1.01

@JohannesGaessler JohannesGaessler merged commit a743d76 into ggerganov:master May 9, 2024
57 of 62 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request review complexity : high Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants