Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash attention for ROCm #7011

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

jdecourval
Copy link
Contributor

@jdecourval jdecourval commented Apr 30, 2024

llama-bench

model size params backend ngl fa test t/s
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 pp 4096 605.22 ± 0.75
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 1 tg 128 26.82 ± 0.01
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 pp 4096 604.84 ± 0.23
command-r 35B IQ3_XS - 3.3 bpw 15.65 GiB 37.08 B ROCm 99 0 tg 128 26.80 ± 0.01
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 pp 4096 2448.01 ± 2.25
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 1 tg 128 86.25 ± 0.03
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 pp 4096 2446.30 ± 1.53
llama 8B Q5_K - Medium 4.78 GiB 7.24 B ROCm 99 0 tg 128 86.31 ± 0.01
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 pp 4096 1033.32 ± 1.28
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 1 tg 128 53.41 ± 0.02
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 pp 4096 1033.59 ± 2.31
llama 8x7B IQ3_XXS - 3.0625 bpw 33.27 GiB 91.80 B ROCm 99 0 tg 128 53.37 ± 0.01
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 pp 4096 2486.02 ± 1.37
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 1 tg 128 84.43 ± 0.02
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 pp 4096 2481.60 ± 1.73
llama 8B Q6_K 5.53 GiB 7.24 B ROCm 99 0 tg 128 84.41 ± 0.01
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 pp 4096 610.69 ± 0.36
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 1 tg 128 26.62 ± 0.00
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 pp 4096 610.17 ± 0.20
llama ?B Q4_K - Small 17.59 GiB 33.34 B ROCm 99 0 tg 128 26.60 ± 0.00
./batched-bench $model 10000 2048 512 $fa 1 99 8192 256 1
T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s fa model buffer MiB
3.891 2105.51 5.662 45.22 9.553 884.37 0 7B.Q6_K 692
3.820 2144.46 5.614 45.60 9.434 895.48 1 7B.Q6_K 94
OOM OOM OOM OOM OOM OOM 0 33B.Q4_K_S 1196
14.452 566.86 14.465 17.70 28.916 292.15 1 33B.Q4_K_S 123
3.862 2121.45 5.773 44.34 9.635 876.83 0 8B.Q6_K 692
3.822 2143.62 5.648 45.33 9.469 892.14 1 8B.Q6_K 267
7.315 1119.94 8.936 28.65 16.251 519.85 0 8x7B.IQ3_S 692
6.860 1194.17 8.754 29.24 15.614 541.05 1 8x7B.IQ3_S 150
13.400 611.34 15.499 16.52 28.899 292.33 0 32B.Q4_K_M 860
13.189 621.10 14.526 17.62 27.715 304.81 1 32B.Q4_K_M 307

buffer = ROCm0 compute buffer size

@jdecourval jdecourval mentioned this pull request Apr 30, 2024
8 tasks
Copy link
Contributor

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 555 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8469.29ms p(95)=20510.1ms fails=, finish reason: stop=490 truncated=65
  • Prompt processing (pp): avg=95.98tk/s p(95)=401.72tk/s
  • Token generation (tg): avg=32.95tk/s p(95)=48.94tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=fixflashattn2 commit=3e560c8665d4ea627be920a26da6d83811fde3b4

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 793.32, 793.32, 793.32, 793.32, 793.32, 644.36, 644.36, 644.36, 644.36, 644.36, 688.58, 688.58, 688.58, 688.58, 688.58, 752.89, 752.89, 752.89, 752.89, 752.89, 753.82, 753.82, 753.82, 753.82, 753.82, 750.84, 750.84, 750.84, 750.84, 750.84, 768.29, 768.29, 768.29, 768.29, 768.29, 783.25, 783.25, 783.25, 783.25, 783.25, 799.9, 799.9, 799.9, 799.9, 799.9, 800.31, 800.31, 800.31, 800.31, 800.31, 825.41, 825.41, 825.41, 825.41, 825.41, 844.02, 844.02, 844.02, 844.02, 844.02, 867.33, 867.33, 867.33, 867.33, 867.33, 822.36, 822.36, 822.36, 822.36, 822.36, 832.09, 832.09, 832.09, 832.09, 832.09, 833.55, 833.55, 833.55, 833.55, 833.55, 831.44, 831.44, 831.44, 831.44, 831.44, 850.23, 850.23, 850.23, 850.23, 850.23, 851.0, 851.0, 851.0, 851.0, 851.0, 849.01, 849.01, 849.01, 849.01, 849.01, 853.53, 853.53, 853.53, 853.53, 853.53, 855.6, 855.6, 855.6, 855.6, 855.6, 835.07, 835.07, 835.07, 835.07, 835.07, 834.85, 834.85, 834.85, 834.85, 834.85, 834.94, 834.94, 834.94, 834.94, 834.94, 850.54, 850.54, 850.54, 850.54, 850.54, 847.53, 847.53, 847.53, 847.53, 847.53, 846.87, 846.87, 846.87, 846.87, 846.87, 845.31, 845.31, 845.31, 845.31, 845.31, 848.47, 848.47, 848.47, 848.47, 848.47, 849.14, 849.14, 849.14, 849.14, 849.14, 847.35, 847.35, 847.35, 847.35, 847.35, 843.3, 843.3, 843.3, 843.3, 843.3, 837.5, 837.5, 837.5, 837.5, 837.5, 843.74, 843.74, 843.74, 843.74, 843.74, 829.99, 829.99, 829.99, 829.99, 829.99, 829.42, 829.42, 829.42, 829.42, 829.42, 829.46, 829.46, 829.46, 829.46, 829.46, 828.76, 828.76, 828.76, 828.76, 828.76, 832.23, 832.23, 832.23, 832.23, 832.23, 832.66, 832.66, 832.66, 832.66, 832.66, 818.68, 818.68, 818.68, 818.68, 818.68, 803.87, 803.87, 803.87, 803.87, 803.87, 803.07, 803.07, 803.07, 803.07, 803.07, 802.13, 802.13, 802.13, 802.13, 802.13, 807.17, 807.17, 807.17, 807.17, 807.17, 810.51, 810.51, 810.51, 810.51, 810.51, 809.97, 809.97, 809.97, 809.97, 809.97, 810.51, 810.51, 810.51, 810.51, 810.51, 815.14, 815.14, 815.14, 815.14, 815.14, 815.18, 815.18, 815.18, 815.18, 815.18, 816.04, 816.04, 816.04, 816.04, 816.04, 816.14, 816.14, 816.14, 816.14, 816.14, 809.82, 809.82, 809.82, 809.82, 809.82, 811.8, 811.8, 811.8, 811.8, 811.8, 811.76, 811.76, 811.76, 811.76, 811.76, 811.95, 811.95, 811.95, 811.95, 811.95, 813.66, 813.66, 813.66, 813.66, 813.66, 817.74, 817.74, 817.74, 817.74, 817.74, 817.99, 817.99, 817.99, 817.99, 817.99, 817.77, 817.77]
                    
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 32.3, 32.3, 32.3, 32.3, 32.3, 32.83, 32.83, 32.83, 32.83, 32.83, 33.34, 33.34, 33.34, 33.34, 33.34, 34.6, 34.6, 34.6, 34.6, 34.6, 34.72, 34.72, 34.72, 34.72, 34.72, 34.92, 34.92, 34.92, 34.92, 34.92, 35.41, 35.41, 35.41, 35.41, 35.41, 35.59, 35.59, 35.59, 35.59, 35.59, 35.53, 35.53, 35.53, 35.53, 35.53, 34.76, 34.76, 34.76, 34.76, 34.76, 35.02, 35.02, 35.02, 35.02, 35.02, 34.79, 34.79, 34.79, 34.79, 34.79, 33.85, 33.85, 33.85, 33.85, 33.85, 33.17, 33.17, 33.17, 33.17, 33.17, 32.55, 32.55, 32.55, 32.55, 32.55, 32.63, 32.63, 32.63, 32.63, 32.63, 32.87, 32.87, 32.87, 32.87, 32.87, 32.56, 32.56, 32.56, 32.56, 32.56, 32.3, 32.3, 32.3, 32.3, 32.3, 31.97, 31.97, 31.97, 31.97, 31.97, 31.71, 31.71, 31.71, 31.71, 31.71, 31.78, 31.78, 31.78, 31.78, 31.78, 31.82, 31.82, 31.82, 31.82, 31.82, 32.04, 32.04, 32.04, 32.04, 32.04, 32.06, 32.06, 32.06, 32.06, 32.06, 32.15, 32.15, 32.15, 32.15, 32.15, 31.99, 31.99, 31.99, 31.99, 31.99, 31.33, 31.33, 31.33, 31.33, 31.33, 31.46, 31.46, 31.46, 31.46, 31.46, 31.75, 31.75, 31.75, 31.75, 31.75, 31.86, 31.86, 31.86, 31.86, 31.86, 32.0, 32.0, 32.0, 32.0, 32.0, 32.13, 32.13, 32.13, 32.13, 32.13, 32.06, 32.06, 32.06, 32.06, 32.06, 31.99, 31.99, 31.99, 31.99, 31.99, 31.81, 31.81, 31.81, 31.81, 31.81, 31.76, 31.76, 31.76, 31.76, 31.76, 31.82, 31.82, 31.82, 31.82, 31.82, 31.98, 31.98, 31.98, 31.98, 31.98, 32.08, 32.08, 32.08, 32.08, 32.08, 32.18, 32.18, 32.18, 32.18, 32.18, 32.23, 32.23, 32.23, 32.23, 32.23, 32.0, 32.0, 32.0, 32.0, 32.0, 31.28, 31.28, 31.28, 31.28, 31.28, 30.93, 30.93, 30.93, 30.93, 30.93, 30.56, 30.56, 30.56, 30.56, 30.56, 30.53, 30.53, 30.53, 30.53, 30.53, 30.57, 30.57, 30.57, 30.57, 30.57, 30.69, 30.69, 30.69, 30.69, 30.69, 30.75, 30.75, 30.75, 30.75, 30.75, 30.78, 30.78, 30.78, 30.78, 30.78, 30.71, 30.71, 30.71, 30.71, 30.71, 30.53, 30.53, 30.53, 30.53, 30.53, 30.52, 30.52, 30.52, 30.52, 30.52, 30.61, 30.61, 30.61, 30.61, 30.61, 30.73, 30.73, 30.73, 30.73, 30.73, 30.79, 30.79, 30.79, 30.79, 30.79, 30.94, 30.94, 30.94, 30.94, 30.94, 30.96, 30.96, 30.96, 30.96, 30.96, 30.99, 30.99, 30.99, 30.99, 30.99, 31.0, 31.0]
                    

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.25, 0.25, 0.25, 0.25, 0.25, 0.22, 0.22, 0.22, 0.22, 0.22, 0.1, 0.1, 0.1, 0.1, 0.1, 0.23, 0.23, 0.23, 0.23, 0.23, 0.11, 0.11, 0.11, 0.11, 0.11, 0.15, 0.15, 0.15, 0.15, 0.15, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.23, 0.23, 0.23, 0.23, 0.23, 0.12, 0.12, 0.12, 0.12, 0.12, 0.19, 0.19, 0.19, 0.19, 0.19, 0.22, 0.22, 0.22, 0.22, 0.22, 0.18, 0.18, 0.18, 0.18, 0.18, 0.29, 0.29, 0.29, 0.29, 0.29, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.15, 0.15, 0.15, 0.15, 0.15, 0.28, 0.28, 0.28, 0.28, 0.28, 0.32, 0.32, 0.32, 0.32, 0.32, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.18, 0.18, 0.18, 0.18, 0.11, 0.11, 0.11, 0.11, 0.11, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.37, 0.37, 0.37, 0.37, 0.37, 0.23, 0.23, 0.23, 0.23, 0.23, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.19, 0.19, 0.19, 0.19, 0.19, 0.17, 0.17, 0.17, 0.17, 0.17, 0.24, 0.24, 0.24, 0.24, 0.24, 0.2, 0.2, 0.2, 0.2, 0.2, 0.32, 0.32, 0.32, 0.32, 0.32, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.16, 0.1, 0.1, 0.1, 0.1, 0.1, 0.18, 0.18, 0.18, 0.18, 0.18, 0.45, 0.45, 0.45, 0.45, 0.45, 0.54, 0.54, 0.54, 0.54, 0.54, 0.39, 0.39, 0.39, 0.39, 0.39, 0.31, 0.31, 0.31, 0.31, 0.31, 0.19, 0.19, 0.19, 0.19, 0.19, 0.14, 0.14, 0.14, 0.14, 0.14, 0.09, 0.09, 0.09, 0.09, 0.09, 0.08, 0.08, 0.08, 0.08, 0.08, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.2, 0.2, 0.2, 0.2, 0.2, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.09, 0.09, 0.09, 0.09, 0.09, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.17, 0.17, 0.17, 0.17, 0.17, 0.21, 0.21]
                    
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 555 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1714497241 --> 1714497873
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 7.0, 7.0, 7.0, 2.0, 2.0, 2.0, 2.0, 2.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0]
                    

@JohannesGaessler
Copy link
Collaborator

I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways.

@sorasoras
Copy link

I don't know how to get compile on windows :(

@jdecourval
Copy link
Contributor Author

I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways.

Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway.

Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think.

@sorasoras
Copy link

I didn't close that other PR on accident. As I said before, I don't think we should be adding a dependency with rocWMMA when the performance is no better than master and we have no dev to test and support it. And I will do an implementation of FlashAttention without any tensor cores at all which may end up being faster anyways.

Sorry, I didn't realize it had been closed on purpose. Is the dependency that bad, though? rocwmma is header only, so no link time requirement, and it enables sharing the existing CUDA code. The performance is not better, but the VRAM saving can be very significant, 1 GB in one case. The PR is not ready to merge as is anyway, I need to disable flash-attn in CMake by default for AMD GPUs, or enable it only if rocwmma is detected installed. I might not be a ROCm expert, but I am a C++ dev and I own a 7900xtx, if not merged, I might maintain this fork anyway.

Of course, if you already have planned to work on that other implementation soon, all of this comment is irrelevant, but having access to a rocwmma based version as a comparison could be useful, I don't know. Please let me know what you think.

I wasn't able to test flash attention on Windows with 7900XTX yet.
I wonder if there is any different between power consumption between with fa and without fa.

@mofosyne mofosyne added enhancement New feature or request review complexity : high Generally require indepth knowledge of LLMs or GPUs labels May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request review complexity : high Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants