Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized: Use cublas for prompt #238

Merged
merged 15 commits into from
May 15, 2024
21 changes: 20 additions & 1 deletion mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub struct MistralRsBuilder {
no_prefix_cache: Option<bool>,
prefix_cache_n: Option<usize>,
disable_eos_stop: Option<bool>,
gemm_full_precision_f16: Option<bool>,
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved
}

impl MistralRsBuilder {
Expand All @@ -91,9 +92,9 @@ impl MistralRsBuilder {
no_prefix_cache: None,
prefix_cache_n: None,
disable_eos_stop: None,
gemm_full_precision_f16: None,
}
}

pub fn with_log(mut self, log: String) -> Self {
self.log = Some(log);
self
Expand Down Expand Up @@ -122,12 +123,25 @@ impl MistralRsBuilder {
self.disable_eos_stop = Some(disable_eos_stop);
self
}
pub fn with_gemm_full_precision_f16(mut self, gemm_full_precision: bool) -> Self {
self.gemm_full_precision_f16 = Some(gemm_full_precision);
self
}

pub fn build(self) -> Arc<MistralRs> {
MistralRs::new(self)
}
}

#[cfg(feature = "cuda")]
fn set_gemm_reduced_precision_f16() {
candle_core::cuda::set_gemm_reduced_precision_f16(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
}

#[cfg(not(feature = "cuda"))]
fn set_gemm_reduced_precision_f16() {}

impl MistralRs {
fn new(config: MistralRsBuilder) -> Arc<Self> {
let MistralRsBuilder {
Expand All @@ -139,8 +153,13 @@ impl MistralRs {
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
gemm_full_precision_f16,
} = config;

if !gemm_full_precision_f16.unwrap_or(false) {
set_gemm_reduced_precision_f16();
}

let truncate_sequence = truncate_sequence.unwrap_or(false);
let no_kv_cache = no_kv_cache.unwrap_or(false);
let no_prefix_cache = no_prefix_cache.unwrap_or(false);
Expand Down
73 changes: 53 additions & 20 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,27 @@ use crate::DeviceMapMetadata;

const MAX_SEQ_LEN: u32 = 4096;

fn quantized_mat_mul(xs: &Tensor, w: &QMatMul, via_f16: bool) -> Result<Tensor> {
if via_f16 {
w.forward_via_f16(xs)
} else {
w.forward(xs)
}
}

#[derive(Debug, Clone)]
struct Mlp {
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
}

impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = self.feed_forward_w1.forward(xs)?;
let w3 = self.feed_forward_w3.forward(xs)?;
self.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)
impl Mlp {
fn forward(&self, xs: &Tensor, via_f16: bool) -> Result<Tensor> {
let w1 = quantized_mat_mul(xs, &self.feed_forward_w1, via_f16)?;
let w3 = quantized_mat_mul(xs, &self.feed_forward_w3, via_f16)?;
let y = &(candle_nn::ops::silu(&w1)? * w3)?;
quantized_mat_mul(y, &self.feed_forward_w2, via_f16)
}
}

Expand All @@ -38,8 +46,8 @@ enum MlpOrMoe {
},
}

impl Module for MlpOrMoe {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl MlpOrMoe {
fn forward(&self, xs: &Tensor, via_f16: bool) -> Result<Tensor> {
match self {
Self::MoE {
feed_forward_gate_inp,
Expand Down Expand Up @@ -94,7 +102,7 @@ impl Module for MlpOrMoe {
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
let current_hidden_states = expert_layer.forward(&current_state)?;
let current_hidden_states = expert_layer.forward(&current_state, via_f16)?;
let current_hidden_states =
current_hidden_states.broadcast_mul(&selected_rws)?;
ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
Expand All @@ -103,7 +111,7 @@ impl Module for MlpOrMoe {
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
Ok(ys)
}
Self::Mlp(mlp) => mlp.forward(xs),
Self::Mlp(mlp) => mlp.forward(xs, via_f16),
}
}
}
Expand Down Expand Up @@ -132,11 +140,13 @@ impl LayerWeights {
start_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
via_f16: bool,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?;

let q = quantized_mat_mul(x, &self.attention_wq, via_f16)?;
let k = quantized_mat_mul(x, &self.attention_wk, via_f16)?;
let v = quantized_mat_mul(x, &self.attention_wv, via_f16)?;

let mut q = q.reshape((b_sz * seq_len, self.n_head, self.head_dim))?;
let mut k = k.reshape((b_sz * seq_len, self.n_kv_head, self.head_dim))?;
Expand All @@ -159,16 +169,32 @@ impl LayerWeights {

let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;

let k = repeat_kv(k, self.n_head / self.n_kv_head)?.contiguous()?;
let v = repeat_kv(v, self.n_head / self.n_kv_head)?.contiguous()?;
let k = repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = if via_f16 {
let mm = q
.to_dtype(DType::F16)?
.matmul(&k.to_dtype(DType::F16)?.t()?)?;

((mm / (self.head_dim as f64).sqrt())?).to_dtype(DType::F32)?
} else {
let k = k.contiguous()?;
(q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?
};

let att = (q.contiguous()?.matmul(&k.t()?.contiguous()?)? / (self.head_dim as f64).sqrt())?;
let att = CausalMasker.apply_mask(mask, att, &self.neg_inf)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = if via_f16 {
att.to_dtype(DType::F16)?
.matmul(&v.to_dtype(DType::F16)?)?
.to_dtype(DType::F32)?
} else {
att.matmul(&v.contiguous()?)?
};

let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attention_wo.forward(&y)?;
let y = quantized_mat_mul(&y, &self.attention_wo, via_f16)?;
Ok(y)
}
}
Expand Down Expand Up @@ -386,6 +412,9 @@ impl ModelWeights {
start_offsets_kernel: Tensor,
context_lens: Vec<(usize, usize)>,
) -> Result<Tensor> {
let (_bz, seq_len, _) = x.dims3()?;
let via_f16 = seq_len > 32;

let mut layer_in = self.tok_embeddings.forward(x)?;
let mut cache = self.cache.lock();
let mask = CausalMasker.make_causal_mask(x, &cache)?;
Expand All @@ -402,18 +431,22 @@ impl ModelWeights {
start_offsets,
start_offsets_kernel.clone(),
&mut cache[i],
via_f16,
)?;
let x = (attn + residual)?;

// MLP
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp_or_moe.forward(&x)?;
let x = layer.mlp_or_moe.forward(&x, via_f16)?;
let x = (x + residual)?;
layer_in = x;
}
let layer_in = layer_in.to_device(&self.device)?;
let x = self.norm.forward(&layer_in)?;
extract_logits(&self.output.forward(&x.contiguous()?)?, context_lens)
extract_logits(
&quantized_mat_mul(&x.contiguous()?, &self.output, via_f16)?,
context_lens,
)
}
}