Skip to content

Commit

Permalink
fix batch size 2-8
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed May 6, 2024
1 parent 09f1768 commit 57bde8c
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,15 +196,16 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();

#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = kqsum_shared[j][threadIdx.x];
kqsum[j] = warp_reduce_sum(kqsum[j]);
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);

half dst_val = (__low2half(VKQ[j]) + __high2half(VKQ[j]));
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
if (parallel_blocks == 1) {
dst_val /= kqsum[j];
dst_val /= kqsum[j_VKQ];
}
dst[D*gridDim.y*(blockIdx.x*ncols + j) + D*blockIdx.y + tid] = dst_val;
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}

if (parallel_blocks == 1 || tid != 0) {
Expand Down

0 comments on commit 57bde8c

Please sign in to comment.