Skip to content

Commit

Permalink
mulmat via f16
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasavila00 committed Apr 29, 2024
1 parent 0347130 commit e069667
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,16 @@ impl LayerWeights {

let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?;
let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?;
let att = if is_prompt {
(q.to_dtype(DType::F16)?
.contiguous()?
.matmul(&k.to_dtype(DType::F16)?.t()?.contiguous()?)?
.to_dtype(DType::F32)?
/ (self.head_dim as f64).sqrt())?
} else {
(q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?
};

let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
Expand All @@ -203,7 +211,14 @@ impl LayerWeights {
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = if is_prompt {
att.to_dtype(DType::F16)?
.matmul(&v.to_dtype(DType::F16)?.contiguous()?)?
.to_dtype(DType::F32)?
} else {
att.matmul(&v.contiguous()?)?
};

let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = quantized_mat_mul(&y, &self.attention_wo, is_prompt)?;
Ok(y)
Expand Down

0 comments on commit e069667

Please sign in to comment.