diff --git a/Cargo.toml b/Cargo.toml index 17e7e4ba57..fe6d4f7e01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" thiserror = "1" -tokenizers = { version = "0.19.1", default-features = false } +tokenizers = { version = "0.21.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 0f7e9c118c..4cbddfb9c5 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -183,6 +183,18 @@ pub fn qtensor_from_ggml( GgmlDType::Q6K => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::Q8K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2b0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::QI8 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2b1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index ccbd59eb5c..cfad05d009 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -174,6 +174,26 @@ impl Value { } } + pub fn from_u8(v: u8) -> Self { + Self::U8(v) + } + + pub fn from_u64(v: u64) -> Self { + Self::U64(v) + } + + pub fn from_u32(v: u32) -> Self { + Self::U32(v) + } + + pub fn from_f32(v: f32) -> Self { + Self::F32(v) + } + + pub fn from_string(v: String) -> Self { + Self::String(v) + } + pub fn to_u8(&self) -> Result { match self { Self::U8(v) => Ok(*v), @@ -489,7 +509,7 @@ fn write_string(w: &mut W, str: &str) -> Result<()> { pub fn write( w: &mut W, - metadata: &[(&str, &Value)], + metadata: &[(&str, Value)], tensors: &[(&str, &QTensor)], ) -> Result<()> { w.write_u32::(0x46554747)?; diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 1d3e053898..5dd9595773 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -6,6 +6,7 @@ use super::GgmlDType; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; +use num_traits::real::Real; use rayon::prelude::*; // Default to QK_K 256 rather than 64. @@ -18,6 +19,9 @@ pub const QK5_0: usize = 32; pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; +pub const Q2B_0: usize = 32; +pub const Q2B_1: usize = 32; +pub const QI8: usize = 32; pub trait GgmlType: Sized + Clone + Send + Sync { const DTYPE: GgmlDType; @@ -154,6 +158,29 @@ pub struct BlockQ8K { } const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ2b0 { + pub(crate) qs: [u8; Q2B_0 / 8], // Every single bit represents positive values, is a vector of {0, 1} + pub(crate) qd: [u8; Q2B_0 / 8], // Every single bit represents negatives values, is a vector of {0, 1} +} + +const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ2b1 { + pub(crate) qs: [u8; Q2B_0 / 4], // Every single 2-bit represents {-1, 0, 1} +} +const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQI8 { + pub(crate) qs: [i8; QI8], +} +const _: () = assert!(std::mem::size_of::() == QI8); + impl GgmlType for BlockQ4_0 { const DTYPE: GgmlDType = GgmlDType::Q4_0; const BLCK_SIZE: usize = QK4_0; @@ -1838,6 +1865,326 @@ impl GgmlType for BlockQ8K { } } +impl GgmlType for BlockQ2b0 { + const DTYPE: GgmlDType = GgmlDType::Q2b0; + const BLCK_SIZE: usize = Q2B_0; + type VecDotType = BlockQI8; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2b0_qi8(n, xs, ys); + + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % Q2B_0 != 0 { + crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {Q2B_0}"); + } + let mut sumf = 0.0; + for (x, y) in xs.iter().zip(ys.iter()) { + let mut isum = 0i32; + for i in 0..Q2B_0 / 8 { + let qs = x.qs[i]; + let qd = x.qd[i]; + let mut y_cache = [0i32; 8]; + y_cache.copy_from_slice( + &y.qs[i * 8..(i + 1) * 8] + .iter() + .map(|&x| x as i32) + .collect::>()[..], + ); + + let pos_sum: i32 = (0..8) + .map(|bit| { + let mask = 1 << bit; + let is_active = ((qs & mask) >> bit) as i32; + is_active * y_cache[bit] + }) + .sum(); + + let neg_sum: i32 = (0..8) + .map(|bit| { + let mask = 1 << bit; + let is_active = ((qd & mask) >> bit) as i32; + is_active * y_cache[bit] + }) + .sum(); + + isum += pos_sum - neg_sum; + } + sumf += isum as f32; + } + Ok(sumf as f32) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() % Q2B_0 != 0 { + crate::bail!( + "quantize_row_q2b0: size mismatch {} not divisible by {}", + xs.len(), + Q2B_0 + ); + } + + for (block, x) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { + for (i, chunk) in x.chunks_exact(8).enumerate() { + let mut qs = 0u8; + let mut qd = 0u8; + + for (b, &value) in chunk.iter().enumerate() { + if value > 0.0 { + qs |= 1 << b; + } else if value < 0.0 { + qd |= 1 << b; + } + } + block.qs[i] = qs; + block.qd[i] = qd; + } + } + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if ys.len() % Q2B_0 != 0 { + crate::bail!( + "dequantize_row_q2b0: size mismatch {} not divisible by {}", + ys.len(), + Q2B_0 + ); + } + + for (block, y) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { + for (i, chunk) in y.chunks_exact_mut(8).enumerate() { + let qs = block.qs[i]; + let qd = block.qd[i]; + + for b in 0..8 { + chunk[b] = if (qs >> b) & 1 != 0 { + 1.0 + } else if (qd >> b) & 1 != 0 { + -1.0 + } else { + 0.0 + }; + } + } + } + Ok(()) + } +} + +const fn build_decode_q2b1_lut_i8() -> [[i8; 4]; 256] { + let mut table = [[0i8; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0i8; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0, + 0b01 => 1, + 0b10 => -1, + 0b11 => 0, + _ => unreachable!(), + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table +} + +static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = build_decode_q2b1_lut_i8(); +const fn build_decode_q2b1_lut_f32() -> [[f32; 4]; 256] { + let mut table = [[0.0_f32; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0.0_f32; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0.0, + 0b01 => 1.0, + 0b10 => -1.0, + 0b11 => 0.0, + _ => unreachable!(), + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table +} + +static LUT_DECODE_Q2B1_F32: [[f32; 4]; 256] = build_decode_q2b1_lut_f32(); +impl GgmlType for BlockQ2b1 { + const DTYPE: GgmlDType = GgmlDType::Q2b1; + const BLCK_SIZE: usize = Q2B_0; + type VecDotType = BlockQI8; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2b1_qi8(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % Q2B_0 != 0 { + crate::bail!("vec_dot_q2b1_qi8: n = {n} is not divisible by {Q2B_0}"); + } + + let mut sumf = 0.0; + + for (block_x, block_y) in xs.iter().zip(ys.iter()) { + let mut isum = 0i32; + + for i in 0..(Q2B_0 / 4) { + let enc_x = block_x.qs[i]; + let y_slice = &block_y.qs[i * 4..(i + 1) * 4]; + + let dec_x = &LUT_DECODE_Q2B1_I8[enc_x as usize]; + + for b in 0..4 { + let x_val = dec_x[b] as i32; + let y_val = y_slice[b] as i32; + isum += x_val * y_val; + } + } + sumf += isum as f32; + } + + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() % Q2B_0 != 0 { + crate::bail!( + "quantize_row_q2b1: size {} is not divisible by {}", + xs.len(), + Q2B_0 + ); + } + + for (block, chunk) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { + for (i, subchunk) in chunk.chunks_exact(4).enumerate() { + let mut encoded: u8 = 0; + for (b, &val) in subchunk.iter().enumerate() { + let bits = if val > 0.0 { + 0b01 + } else if val < 0.0 { + 0b10 + } else { + 0b00 + }; + encoded |= bits << (2 * b); + } + block.qs[i] = encoded; + } + } + + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if ys.len() % Q2B_0 != 0 { + crate::bail!( + "dequantize_row_q2b1: size {} is not divisible by {}", + ys.len(), + Q2B_0 + ); + } + + for (block, out_chunk) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { + for (i, subchunk) in out_chunk.chunks_exact_mut(4).enumerate() { + let enc = block.qs[i]; + let dec = &LUT_DECODE_Q2B1_F32[enc as usize]; + subchunk.copy_from_slice(dec); + } + } + + Ok(()) + } +} + +impl GgmlType for BlockQI8 { + const DTYPE: GgmlDType = GgmlDType::QI8; + const BLCK_SIZE: usize = QI8; + type VecDotType = BlockQI8; + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QI8 != 0 { + crate::bail!("dequantize_row_qi8: {k} is not divisible by {QI8}"); + } + + let nb = k / QI8; + + for i in 0..nb { + for j in 0..QI8 { + ys[i * QI8 + j] = xs[i].qs[j] as f32; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + for (i, ys) in ys.iter_mut().enumerate() { + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = x as i8; + } + } + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QI8 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32; + } + Ok(sumf) + } +} + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 pub fn matmul( mkn: (usize, usize, usize), diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f7f5b68ac2..426721db8f 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -103,6 +103,18 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } + GgmlDType::Q2b0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2b0::to_float(&vec, &mut out)?; + } + GgmlDType::Q2b1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2b1::to_float(&vec, &mut out)?; + } + GgmlDType::QI8 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQI8::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -225,6 +237,9 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + GgmlDType::Q2b0 => candle_metal_kernels::GgmlDType::Q2b0, + GgmlDType::Q2b1 => candle_metal_kernels::GgmlDType::Q2b1, + GgmlDType::QI8 => todo!(), } } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 236f5a9811..4912f0f33d 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -146,6 +146,9 @@ pub enum GgmlDType { Q5K, Q6K, Q8K, + Q2b0, + Q2b1, + QI8, } impl GgmlDType { @@ -165,6 +168,9 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + 40 => Self::Q2b0, + 41 => Self::QI8, + 42 => Self::Q2b1, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -186,6 +192,9 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + Self::Q2b0 => 40, + Self::QI8 => 41, + Self::Q2b1 => 42, } } @@ -206,6 +215,9 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::Q2b0 => Box::new(vec![BlockQ2b0::zeros(); elem_count / BlockQ2b0::BLCK_SIZE]), + Self::QI8 => Box::new(vec![BlockQI8::zeros(); elem_count / BlockQI8::BLCK_SIZE]), + Self::Q2b1 => Box::new(vec![BlockQ2b1::zeros(); elem_count / BlockQ2b1::BLCK_SIZE]), } } /// The type size for blocks in bytes. @@ -227,6 +239,9 @@ impl GgmlDType { Self::Q5K => std::mem::size_of::(), Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), + Self::Q2b0 => std::mem::size_of::(), + Self::QI8 => std::mem::size_of::(), + Self::Q2b1 => std::mem::size_of::(), } } @@ -241,6 +256,9 @@ impl GgmlDType { Self::Q5_1 => k_quants::QK5_1, Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, + Self::Q2b0 => k_quants::Q2B_0, + Self::Q2b1 => k_quants::Q2B_1, + Self::QI8 => k_quants::QI8, Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, } } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index c4d5d6f41a..53de2b9007 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,5 +1,9 @@ -use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, +use super::{ + k_quants::{ + BlockQ2K, BlockQ2b0, BlockQ2b1, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, + BlockQ8K, BlockQ8_0, Q2B_0, QK8_0, QK_K, + }, + BlockQI8, }; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -11,6 +15,7 @@ use core::arch::arm::*; #[allow(unused_imports)] #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; +use std::ptr; #[inline(always)] unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { @@ -517,6 +522,136 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +#[inline(always)] +pub(crate) fn vec_dot_q2b0_qi8(n: usize, xs: &[BlockQ2b0], ys: &[BlockQI8]) -> crate::Result { + if n % Q2B_0 != 0 { + crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {QK_K}") + } + + let mut sumf = 0.0_f32; + + unsafe { + for (x, y) in xs.iter().zip(ys.iter()) { + let mut isum = 0_i32; + + for i in 0..(Q2B_0 / 8) { + let qs = x.qs[i]; + let qd = x.qd[i]; + + // Load y_cache: Load 8 i8 values from y.qs[i * 8..(i + 1) * 8] + let y_qs_ptr = y.qs.as_ptr().add(i * 8); + let y_cache_i8x8 = vld1_s8(y_qs_ptr); + + // Extend y_cache_i8x8 to int16x8_t vector + let y_cache_i16x8 = vmovl_s8(y_cache_i8x8); + + // Prepare shift amounts: [0, -1, -2, -3, -4, -5, -6, -7] + let shift_vec_data: [i8; 8] = [0, -1, -2, -3, -4, -5, -6, -7]; + let shift_vec = vld1_s8(shift_vec_data.as_ptr()); + + // Duplicate qs and qd into vectors + let qs_vec = vdup_n_u8(qs); + let qd_vec = vdup_n_u8(qd); + + // Shift to bring bits into LSB + let qs_shifted = vshl_u8(qs_vec, shift_vec); + let qd_shifted = vshl_u8(qd_vec, shift_vec); + + // Mask LSB to get bits + let one_vec = vdup_n_u8(1); + let qs_bits = vand_u8(qs_shifted, one_vec); + let qd_bits = vand_u8(qd_shifted, one_vec); + + // Convert bits to int16x8_t + let qs_bits_i16x8 = vreinterpretq_s16_u16(vmovl_u8(qs_bits)); + let qd_bits_i16x8 = vreinterpretq_s16_u16(vmovl_u8(qd_bits)); + + // Multiply and accumulate + let pos_sum = vaddvq_s16(vmulq_s16(qs_bits_i16x8, y_cache_i16x8)); + let neg_sum = vaddvq_s16(vmulq_s16(qd_bits_i16x8, y_cache_i16x8)); + + isum += pos_sum as i32 - neg_sum as i32; + } + + sumf += isum as f32; + } + } + + Ok(sumf) +} + +static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = { + const fn build_decode_table() -> [[i8; 4]; 256] { + let mut table = [[0i8; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0i8; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0, + 0b01 => 1, + 0b10 => -1, + 0b11 => 0, + _ => 0, + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table + } + build_decode_table() +}; + +unsafe fn decode_q2b1_16(input: &[u8]) -> int8x16_t { + debug_assert_eq!(input.len(), 4, "input must be 4 bytes long"); + let mut tmp = [0i8; 16]; + + for (i, &byte) in input.iter().enumerate() { + let decoded4 = LUT_DECODE_Q2B1_I8[byte as usize]; + tmp[i * 4..i * 4 + 4].copy_from_slice(&decoded4); + } + + vld1q_s8(tmp.as_ptr()) +} + +#[inline(always)] +pub fn vec_dot_q2b1_qi8(n: usize, xs: &[BlockQ2b1], ys: &[BlockQI8]) -> crate::Result { + let blocks = n / 32; + + let mut total_sum = 0i32; + + unsafe { + for i in 0..blocks { + let x_block = &xs[i]; + let y_block = &ys[i]; + + let x_dec_lo = decode_q2b1_16(&x_block.qs[0..4]); + let x_dec_hi = decode_q2b1_16(&x_block.qs[4..8]); + + let y_lo = vld1q_s8(y_block.qs[0..16].as_ptr()); + let y_hi = vld1q_s8(y_block.qs[16..32].as_ptr()); + + let mut acc0 = vdupq_n_s32(0); + let mut acc1 = vdupq_n_s32(0); + + acc0 = vaddq_s32(acc0, vdotq_s32(x_dec_lo, y_lo)); + acc1 = vaddq_s32(acc1, vdotq_s32(x_dec_hi, y_hi)); + + let sum0 = vaddvq_s32(acc0); + let sum1 = vaddvq_s32(acc1); + + total_sum += sum0 + sum1; + } + } + + Ok(total_sum as f32) +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs new file mode 100644 index 0000000000..3196ab881e --- /dev/null +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -0,0 +1,461 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::{AddedToken, Tokenizer}; +use tracing_subscriber::fmt::time::FormatTime; + +use candle::quantized::{ggml_file, gguf_file}; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_llama_bitnet as model; +use model::ModelWeights; + +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +#[derive(Debug)] +enum Prompt { + Interactive, + Chat, + One(String), +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "falcon3-1b-instruct-1.58")] + Falcon3_1bInstruct1_58, + #[value(name = "falcon3-3b-instruct-1.58")] + Falcon3_3bInstruct1_58, + #[value(name = "falcon3-3b-1.58")] + Falcon3_3b1_58, + #[value(name = "falcon3-7b-instruct-1.58")] + Falcon3_7bInstruct1_58, + #[value(name = "falcon3-7b-1.58")] + Falcon3_7b1_58, + #[value(name = "falcon3-10b-instruct-1.58")] + Falcon3_10bInstruct1_58, + #[value(name = "falcon3-10b-1.58")] + Falcon3_10b1_58, + #[value(name = "llama3-8b-1.58")] + Llama3_8b1_58, +} + +impl Which { + fn tokenizer_repo(&self) -> &'static str { + match self { + Self::Falcon3_1bInstruct1_58 => "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", + Self::Falcon3_3bInstruct1_58 => "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", + Self::Falcon3_3b1_58 => "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", + Self::Falcon3_7bInstruct1_58 => "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", + Self::Falcon3_10b1_58 => "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", + Self::Falcon3_10bInstruct1_58 => "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", + Self::Falcon3_7b1_58 => "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", + Self::Llama3_8b1_58 => "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from l + /// lama.cpp + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.2)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.5)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "falcon3-1b-instruct-1.58")] + which: Which, + + /// Group-Query Attention, use 8 for the 70B version of LLaMAv2. + #[arg(long)] + gqa: Option, + + /// Use the slower dmmv cuda kernel. + #[arg(long)] + force_dmmv: bool, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = self.which.tokenizer_repo(); + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename) = match self.which { + Which::Falcon3_1bInstruct1_58 => ( + "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", + "Falcon3-1B-Instruct-1.58bit-q2b1.gguf", + ), + Which::Falcon3_3bInstruct1_58 => ( + "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", + "Falcon3-3B-Instruct-1.58bit-q2b1.gguf", + ), + Which::Falcon3_3b1_58 => ( + "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", + "Falcon3-3B-Base-1.58bit-q2b1.gguf", + ), + Which::Falcon3_7bInstruct1_58 => ( + "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", + "Falcon3-7B-Instruct-1.58bit-q2b1.gguf", + ), + Which::Falcon3_7b1_58 => ( + "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", + "Falcon3-7B-Base-1.58bit-q2b1.gguf", + ), + Which::Falcon3_10b1_58 => ( + "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", + "Falcon3-10B-Base-1.58bit-q2b1.gguf", + ), + Which::Falcon3_10bInstruct1_58 => ( + "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", + "Falcon3-10B-Instruct-1.58bit-q2b1.gguf", + ), + Which::Llama3_8b1_58 => ( + "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", + "Llama3-8B-1.58-100B-tokens-q2b1.gguf", + ), + }; + let revision = "main"; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + revision.to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{}B", size_in_bytes) + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(args.force_dmmv); + + candle::cuda::set_gemm_reduced_precision_f16(true); + candle::cuda::set_gemm_reduced_precision_bf16(true); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = match model_path.extension().and_then(|v| v.to_str()) { + Some("gguf") => { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + ModelWeights::from_gguf(model, &mut file, &device)? + } + Some("ggml" | "bin") | Some(_) | None => { + let model = ggml_file::Content::read(&mut file, &device) + .map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensors.iter() { + let elem_count = tensor.shape().elem_count(); + total_size_in_bytes += + elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensors.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + println!("params: {:?}", model.hparams); + let default_gqa = 0; + ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? + } + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + + let mut tos = TokenOutputStream::new(tokenizer); + let prompt = match args.prompt.as_deref() { + Some("chat") => Prompt::Chat, + Some("interactive") => Prompt::Interactive, + Some(s) => Prompt::One(s.to_string()), + None => Prompt::One(DEFAULT_PROMPT.to_string()), + }; + + let mut pre_prompt_tokens = vec![]; + for prompt_index in 0.. { + let prompt_str = match &prompt { + Prompt::One(prompt) => prompt.clone(), + Prompt::Interactive | Prompt::Chat => { + let is_interactive = matches!(prompt, Prompt::Interactive); + print!("> "); + std::io::stdout().flush()?; + let mut prompt = String::new(); + std::io::stdin().read_line(&mut prompt)?; + if prompt.ends_with('\n') { + prompt.pop(); + if prompt.ends_with('\r') { + prompt.pop(); + } + } + + prompt.clone() + } + }; + + print!("{}", &prompt_str); + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + if args.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.to_string().replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + + let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat(); + let to_sample = args.sample_len.saturating_sub(1); + let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 { + let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN; + prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec() + } else { + prompt_tokens + }; + let mut all_tokens = vec![]; + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in prompt_tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let eos_tokens = match args.which { + Which::Falcon3_10b1_58 + | Which::Falcon3_10bInstruct1_58 + | Which::Falcon3_7bInstruct1_58 + | Which::Falcon3_7b1_58 + | Which::Falcon3_3bInstruct1_58 + | Which::Falcon3_3b1_58 + | Which::Falcon3_1bInstruct1_58 => { + vec!["<|endoftext|>"] + } + Which::Llama3_8b1_58 => { + vec!["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>"] + } + }; + + let eos_tokens: Vec = eos_tokens + .iter() + .map(|token| { + *tos.tokenizer() + .get_vocab(true) + .get(*token) + .unwrap_or_else(|| panic!("EoS token not found: {}", token)) + }) + .collect(); + + let start_post_prompt = std::time::Instant::now(); + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + + if eos_tokens.contains(&next_token) { + break; + }; + } + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + prompt_tokens.len(), + prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + + match prompt { + Prompt::One(_) => break, + Prompt::Interactive => {} + Prompt::Chat => { + pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat() + } + } + } + + Ok(()) +} diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 2b537aac9e..a089b05380 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -198,7 +198,8 @@ impl Which { #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { - /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp + /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from l + /// lama.cpp #[arg(long)] model: Option, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5f948cbf4c..7e683596c2 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2164,6 +2164,8 @@ pub enum GgmlDType { Q8K, F16, F32, + Q2b0, + Q2b1, } #[allow(clippy::too_many_arguments)] @@ -2229,6 +2231,12 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } + GgmlDType::Q2b1 | GgmlDType::Q2b0 => { + let nth0 = 8; + let nth1 = 8; + let align = 8; + (nth0, nth1, align) + } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; @@ -2253,7 +2261,7 @@ pub fn call_quantized_matmul_mv_t( let nth1 = 1; let align = 8; (nth0, nth1, align) - } + }, }; let thread_groups_count = MTLSize { width: divide(ne01 as usize, align), @@ -2280,6 +2288,8 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", + GgmlDType::Q2b0 => "kernel_mul_mv_q2b0_f32", + GgmlDType::Q2b1 => "kernel_mul_mv_q2b1_f32" }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index fef6ac54f8..73a3744556 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -42,6 +42,22 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +#define Q2B_0 32 +typedef struct { + uint8_t qs[Q2B_0 / 8]; // Every single bit represents positive values, is a vector of {0, 1} + uint8_t qd[Q2B_0 / 8]; // Every single bit represents negative values, is a vector of {0, 1} +} block_q2b_0; + +#define Q2B_1 32 +typedef struct { + uint8_t qs[Q2B_1 / 4]; // Every single 2-bit represents {-1, 0, 1} +} block_q2b_1; + +#define QI8 32 +typedef struct { + int8_t qs[QI8]; // quants +} block_qi8; + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { @@ -3469,6 +3485,267 @@ kernel void kernel_mul_mv_q6_K_f32( kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +#define NB_Q2B_0 8 +void kernel_mul_mv_q2b0_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + const int nb = ne00 / Q2B_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint i12 = im % ne12; + const uint i13 = im / ne12; + + const uint offset0 = first_row * nb + + (i12 / r2) * (nb * ne01) + + (i13 / r3) * (nb * ne01 * ne02); + + device const block_q2b_0 * x = (device const block_q2b_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + + r1 * ne10 + + im * ne00 * ne1; + + float yl[NB_Q2B_0]; + float sumf[nr]; + for (int i = 0; i < nr; ++i) { + sumf[i] = 0.0f; + } + + const int ix = tiisg / 4; + const int il = tiisg % 4; + + device const float * yb = y + ix * Q2B_0 + NB_Q2B_0 * il; + + for (int ib = ix; ib < nb; ib += (nw / 4)) { + for (int i = 0; i < NB_Q2B_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const block_q2b_0 * bx = x + ib + row * nb; + + float sumq = 0.f; + const int startBit = NB_Q2B_0 * il; + + for (int iBit = 0; iBit < NB_Q2B_0; iBit++) { + int bit = startBit + iBit; + int bByte = bit >> 3; + int bMask = 1 << (bit & 7); + if ((bx->qs[bByte] & bMask) != 0) { + sumq += yl[iBit]; + } else if ((bx->qd[bByte] & bMask) != 0) { + sumq -= yl[iBit]; + } + } + + sumf[row] += sumq; + } + + yb += NB_Q2B_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && (first_row + row) < ne01) { + dst[r1 * ne0 + im * ne0 * ne1 + (first_row + row)] = tot; + } + } +} + + +[[host_name("kernel_mul_mv_q2b0_f32")]] +kernel void kernel_mul_mv_q2b0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2b0_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + +#define NB_Q2B_1 8 +constant float code_lut[4] = { 0.0f, 1.0f, -1.0f, 0.0f }; + +inline void kernel_mul_mv_q2b1_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]] +) { + // These come from your headers or #defines + const int nr = N_DST; // number of "rows" each thread processes + const int nsg = N_SIMDGROUP; // number of simdgroups per dimension + const int nw = N_SIMDWIDTH; // simd width + + const int nb = ne00 / Q2B_0; // number of Q2B_0 blocks in a row of X + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + // Each simdgroup processes 'nr' rows, so figure out which "chunk" we do: + const int first_row = (r0 * nsg + sgitg) * nr; + + // Flatten z index using ne12 + const uint i12 = im % ne12; + const uint i13 = im / ne12; + + // Compute offset into src0 (the quantized blocks array) + const uint offset0 = + first_row * nb + + (i12 / r2) * (nb * ne01) + + (i13 / r3) * (nb * ne01 * ne02); + + // Pointer to the first quantized block + device const block_q2b_1 * x = (device const block_q2b_1 *)src0 + offset0; + + // Pointer to the appropriate row of src1 + device const float * y = src1 + + r1 * ne10 // stride in y dimension + + im * ne00 * ne1; // stride in z dimension + + // Accumulators for partial sums, one per row + float sumf[nr]; + for (int i = 0; i < nr; i++) { + sumf[i] = 0.0f; + } + + // Figure out which quarter of the thread is ours + const int ix = tiisg / 4; + const int il = tiisg % 4; + + // This pointer yb will move through src1 in steps of NB_Q2B_1*nw + device const float * yb = y + ix * Q2B_0 + NB_Q2B_1 * il; + + // Main loop: each thread processes some subset of 'nb' blocks + for (int ib = ix; ib < nb; ib += (nw / 4)) { + + // Load 8 floats (NB_Q2B_0) into local array to keep them in registers + float yl[NB_Q2B_0]; + { + // Compiler usually unrolls such a small loop automatically + // but you can force it: + #pragma unroll 8 + for (int i = 0; i < NB_Q2B_0; i++) { + yl[i] = yb[i]; + } + } + + // For each row in [0..nr), compute partial dot-product + // with quantized data from 'x + ib + row * nb' + for (int row = 0; row < nr; row++) { + device const block_q2b_1 * bq = x + ib + row * nb; + + float sumq = 0.0f; + + // Each Q2B_0 = 8 bits, but we do them in steps of 2 + // 'startBit' is the first bit for the code. + // We unroll this loop as well. + const int startBit = NB_Q2B_1 * il; + #pragma unroll 8 + for (int iBit = 0; iBit < NB_Q2B_0; iBit++) { + const int bit = startBit + iBit; + const int bByte = bit >> 2; // bit / 4 + const int shift = 2 * (bit & 3); // (bit % 4)*2 + const int code = (bq->qs[bByte] >> shift) & 0x3; + + // Use the LUT to get +1 / -1 / 0 + sumq += code_lut[code] * yl[iBit]; + } + + sumf[row] += sumq; + } + + // Advance yb to the next group of 8 floats + yb += NB_Q2B_1 * nw; + } + + // Reduction across the simdgroup: each row's sum -> simd_sum(...) + // Then store to output if we're the "first lane" (tiisg == 0) + for (int row = 0; row < nr; row++) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && (first_row + row) < ne01) { + dst[r1 * ne0 + im * ne0 * ne1 + (first_row + row)] = tot; + } + } +} + + +[[host_name("kernel_mul_mv_q2b1_f32")]] +kernel void kernel_mul_mv_q2b1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2b1_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b8695cc8a0..3f12cbe8ac 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1456,7 +1456,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) let converted_metadata: Vec<_> = metadata .iter() - .map(|(name, value)| (name.as_str(), value)) + .map(|(name, value)| (name.as_str(), value.clone())) .collect(); let converted_tensors: Vec<_> = tensors diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index be1f15c413..0fca20e2d4 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -75,6 +75,7 @@ pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; pub mod quantized_llama2_c; +pub mod quantized_llama_bitnet; pub mod quantized_metavoice; pub mod quantized_mistral; pub mod quantized_mixformer; diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs new file mode 100644 index 0000000000..ed745f622e --- /dev/null +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -0,0 +1,601 @@ +//! Quantized llama model implementation. +//! +//! This provides a quantized implementation of the llama language model architecture. +//! The model implements parameter efficient quantization for reduced memory usage +//! while maintaining model quality. +//! +//! Key characteristics: +//! - Transformer decoder architecture +//! - Support for 2/3/4/8-bit quantization +//! - Optimized memory usage through quantization +//! - Configurable model sizes and parameter counts +//! +//! - 💻 [GH Link](https://github.com/facebookresearch/llama) +//! - 📝 [Paper](https://arxiv.org/abs/2302.13971) +//! +//! ![](https://raw.githubusercontent.com/huggingface/candle/main/candle-examples/examples/quantized/assets/aoc.gif) +//! + +use std::collections::HashMap; + +use crate::quantized_nn::RmsNorm; +use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module}; + +pub const MAX_SEQ_LEN: usize = 4096; + +// QMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Result { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +// BitQMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] +struct BitQMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, + weight_scale: Tensor, +} + +impl BitQMatMul { + fn from_qtensor(qtensor: QTensor, weight_scale: QTensor) -> Result { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + let weight_scale = weight_scale.dequantize(&weight_scale.device())?; + Ok(Self { + inner, + span, + weight_scale, + }) + } + + fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { + let scale = x + .abs()? + .max_keepdim(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / scale)?; + + let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?; + + Ok((y, scale)) + } + + fn forward(&self, x: &Tensor) -> Result { + let (x, xscale) = self.activation_quant(x)?; + let _enter = self.span.enter(); + let scale = self.weight_scale.broadcast_mul(&xscale)?; + self.inner.forward(&x)?.broadcast_div(&scale) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: BitQMatMul, + feed_forward_w2: BitQMatMul, + feed_forward_w3: BitQMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +#[derive(Debug, Clone)] +enum MlpOrMoe { + Mlp(Mlp), + MoE { + n_expert_used: usize, + feed_forward_gate_inp: QMatMul, + experts: Vec, + }, +} + +impl Module for MlpOrMoe { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::MoE { + feed_forward_gate_inp, + experts, + n_expert_used, + } => { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = feed_forward_gate_inp.forward(&xs)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; experts.len()]; + let mut selected_rws = vec![vec![]; experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = + current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } + Self::Mlp(mlp) => mlp.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct LayerWeights { + attention_wq: BitQMatMul, + attention_wk: BitQMatMul, + attention_wv: BitQMatMul, + attention_wo: BitQMatMul, + attention_norm: RmsNorm, + mlp_or_moe: MlpOrMoe, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + neg_inf: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let _enter = self.span_rot.enter(); + let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + // The call to contiguous below is only necessary when processing the prompt. + // When the seq_len is 1 in the inference loop, this is a no-op. + candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)? + // This call to contiguous ensures that the fast kernel can be called below. It's + // actually a no-op except when processing the initial prompt so has no significant + // impact on performance. + .contiguous()?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let y = if q.device().is_metal() && seq_len == 1 { + // SDPA will do MQA for us + candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)? + } else { + // Support for MQA, useful for 70B models and mistral. + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)? + }; + + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec, + norm: RmsNorm, + output: QMatMul, + masks: HashMap, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; + let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; + let output = ct.remove("output.weight")?; + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wq_ws = ct.remove(&format!("{prefix}.attention.wq.weight_scale"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wk_ws = ct.remove(&format!("{prefix}.attention.wk.weight_scale"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wv_ws = ct.remove(&format!("{prefix}.attention.wv.weight_scale"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let attention_wo_ws = ct.remove(&format!("{prefix}.attention.wo.weight_scale"))?; + let mlp_or_moe = { + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w1_ws = + ct.remove(&format!("{prefix}.feed_forward.w1.weight_scale"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w2_ws = + ct.remove(&format!("{prefix}.feed_forward.w2.weight_scale"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let feed_forward_w3_ws = + ct.remove(&format!("{prefix}.feed_forward.w3.weight_scale"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, + }) + }; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: BitQMatMul::from_qtensor(attention_wq, attention_wq_ws)?, + attention_wk: BitQMatMul::from_qtensor(attention_wk, attention_wk_ws)?, + attention_wv: BitQMatMul::from_qtensor(attention_wv, attention_wv_ws)?, + attention_wo: BitQMatMul::from_qtensor(attention_wo, attention_wo_ws)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, + mlp_or_moe, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, + n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output: QMatMul::from_qtensor(output)?, + masks: HashMap::new(), + span, + span_output, + }) + } + + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + ) -> Result { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings_q.dequantize(device)?; + let norm = RmsNorm::from_qtensor( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = match ct.tensor(reader, "output.weight", device) { + Ok(tensor) => tensor, + Err(_) => tok_embeddings_q, + }; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wq_ws = + ct.tensor(reader, &format!("{prefix}.attn_q.weight_scale"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wk_ws = + ct.tensor(reader, &format!("{prefix}.attn_k.weight_scale"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wv_ws = + ct.tensor(reader, &format!("{prefix}.attn_v.weight_scale"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wo_ws = ct.tensor( + reader, + &format!("{prefix}.attn_output.weight_scale"), + device, + )?; + let mlp_or_moe = if n_expert <= 1 { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w1_ws = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight_scale"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w2_ws = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight_scale"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_w3_ws = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight_scale"), device)?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, + }) + } else { + let feed_forward_gate_inp = + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; + let mut experts = Vec::with_capacity(n_expert); + for i in 0..n_expert { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + let feed_forward_w1_ws = ct.tensor( + reader, + &format!("{prefix}.ffn_gate.{i}.weight_scale"), + device, + )?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + let feed_forward_w2_ws = ct.tensor( + reader, + &format!("{prefix}.ffn_down.{i}.weight_scale"), + device, + )?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + let feed_forward_w3_ws = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight_scale"), device)?; + + experts.push(Mlp { + feed_forward_w1: BitQMatMul::from_qtensor( + feed_forward_w1, + feed_forward_w1_ws, + )?, + feed_forward_w2: BitQMatMul::from_qtensor( + feed_forward_w2, + feed_forward_w2_ws, + )?, + feed_forward_w3: BitQMatMul::from_qtensor( + feed_forward_w3, + feed_forward_w3_ws, + )?, + }) + } + MlpOrMoe::MoE { + n_expert_used, + feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, + experts, + } + }; + let attention_norm = + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: BitQMatMul::from_qtensor(attention_wq, attention_wq_ws)?, + attention_wk: BitQMatMul::from_qtensor(attention_wk, attention_wk_ws)?, + attention_wv: BitQMatMul::from_qtensor(attention_wv, attention_wv_ws)?, + attention_wo: BitQMatMul::from_qtensor(attention_wo, attention_wo_ws)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, + mlp_or_moe, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output)?, + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, x.device())?) + }; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let x = layer.mlp_or_moe.forward(&x)?; + let x = (x + residual)?; + layer_in = x + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} diff --git a/tensor-tools/Cargo.toml b/tensor-tools/Cargo.toml index eecd7e4353..b48e81cf1d 100644 --- a/tensor-tools/Cargo.toml +++ b/tensor-tools/Cargo.toml @@ -14,3 +14,4 @@ candle = { workspace = true } clap = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } +serde_json = { workspace = true } \ No newline at end of file diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 0bda36d524..78ba68840c 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -1,7 +1,10 @@ +use candle::op::Op; use candle::quantized::{gguf_file, GgmlDType, QTensor}; -use candle::{Device, Result}; +use candle::{Device, Result, Tensor}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; +use safetensors::tensor; +use serde_json; #[derive(ValueEnum, Debug, Clone)] enum QuantizationMode { @@ -11,7 +14,13 @@ enum QuantizationMode { } impl QuantizationMode { - fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result { + fn quantize( + &self, + name: &str, + tensor: QTensor, + dtype: GgmlDType, + bitnet_mode: bool, + ) -> Result { match self { Self::Llama => { // Same behavior as the llama.cpp quantization. @@ -45,6 +54,12 @@ enum Quantization { Q8_0, #[value(name = "q8_1")] Q8_1, + #[value(name = "q2b0")] + Q2b0, + #[value(name = "q2b1")] + Q2b1, + #[value(name = "qi8")] + QI8, Q2k, Q3k, Q4k, @@ -72,6 +87,9 @@ impl Quantization { Quantization::Q8k => GgmlDType::Q8K, Quantization::F16 => GgmlDType::F16, Quantization::F32 => GgmlDType::F32, + Quantization::Q2b0 => GgmlDType::Q2b0, + Quantization::QI8 => GgmlDType::QI8, + Quantization::Q2b1 => GgmlDType::Q2b1, } } } @@ -143,6 +161,13 @@ enum Command { #[arg(long)] out_file: std::path::PathBuf, + #[clap(long, short, action)] + bitnet_mode: bool, + + // Allow to specify quantization_bitnet in case of bitnet_mode + #[arg(long, value_enum)] + bitnet_quantization: Option, + /// The quantization schema to apply. #[arg(long, value_enum)] quantization: Quantization, @@ -285,10 +310,15 @@ fn run_print( println!("==== {name} ===="); match content.tensor(&mut file, name, device) { Ok(tensor) => { + let dtype = tensor.dtype(); + let tensor = tensor.dequantize(device)?; - println!("{tensor}") + println!("{tensor} {dtype:?}") + } + Err(e) => { + eprintln!("error: {e}"); + println!("not found") } - Err(_) => println!("not found"), } } } @@ -395,40 +425,226 @@ fn run_ls( Ok(()) } +fn unpack_bitnet_weights(tensor: &Tensor) -> Result { + let packed_vec = tensor.to_vec2::().unwrap(); + + let rows = tensor.dim(0).unwrap(); + let cols = tensor.dim(1).unwrap(); + + let mut unpacked_vec = vec![0f32; rows * 4 * cols]; + + for i in 0..rows { + for j in 0..cols { + let packed = packed_vec[i][j]; + + for k in 0..4 { + let bits = ((packed >> (k * 2)) & 0b11) as i8 - 1; + let index = (k * rows + i) * cols + j; + unpacked_vec[index] = bits as f32; + } + } + } + + let unpacked_tensor = Tensor::from_vec(unpacked_vec, (rows * 4, cols), tensor.device())?; + Ok(unpacked_tensor) +} + +use core::num; +use rayon::prelude::*; +use serde_json::Value; +use std::collections::HashMap; +use std::fs::File; +use std::path::PathBuf; + +fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option) -> Result { + let n_head = match n_head_kv { + Some(n_head_kv) if n_head != n_head_kv => n_head_kv, + _ => n_head, + }; + + let shape = weights.shape(); + let shape0 = shape.dims()[0]; + if shape0 % (n_head * 2) != 0 { + candle::bail!("weights.shape()[0] is not divisible by (n_head * 2)"); + } + + let mut new_shape = vec![n_head, 2, shape0 / (n_head * 2)]; + new_shape.extend_from_slice(&shape.dims()[1..]); + + let permuted = weights + .reshape(new_shape)? + .transpose(1, 2)? + .reshape(weights.shape())?; + + Ok(permuted) +} + fn run_quantize_safetensors( - in_files: &[std::path::PathBuf], - out_file: std::path::PathBuf, + in_files: &[PathBuf], + out_file: PathBuf, q: Quantization, + bq: Option, + bitnet_mode: bool, ) -> Result<()> { - let mut out_file = std::fs::File::create(out_file)?; - let mut tensors = std::collections::HashMap::new(); - for in_file in in_files.iter() { - let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; - tensors.extend(in_tensors) - } - println!("tensors: {}", tensors.len()); - + let mut out_file = File::create(out_file)?; let dtype = q.dtype(); let block_size = dtype.block_size(); - let qtensors = tensors - .into_par_iter() - .map(|(name, tensor)| { - let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; - println!(" quantizing {name} {tensor:?} {should_quantize}"); - let tensor = if should_quantize { - QTensor::quantize(&tensor, dtype)? - } else { - QTensor::quantize(&tensor, GgmlDType::F32)? - }; - Ok((name, tensor)) - }) - .collect::>>()?; + let metadata_file = in_files + .iter() + .find(|f| f.to_string_lossy().ends_with("config.json")); + + let mut qtensors = Vec::new(); + + let mut num_attention_heads = 0; + let mut num_key_value_heads = 0; + let mut architecture = String::new(); + + let gguf_metadata = if let Some(metadata_file) = metadata_file { + let metadata_content = std::fs::read_to_string(metadata_file)?; + let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap(); + + num_attention_heads = metadata["num_attention_heads"].as_u64().unwrap(); + num_key_value_heads = metadata["num_key_value_heads"].as_u64().unwrap(); + architecture = metadata["model_type"].as_str().unwrap().to_string(); + + vec![ + ( + "llama.attention.head_count", + gguf_file::Value::from_u32(num_attention_heads as u32), + ), + ( + "llama.attention.head_count_kv", + gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32), + ), + ( + "llama.block_count", + gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32), + ), + ( + "llama.embedding_length", + gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32), + ), + ( + "llama.attention.layer_norm_rms_epsilon", + gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32), + ), + ( + "llama.rope.dimension_count", + gguf_file::Value::from_u32( + (metadata["hidden_size"].as_u64().unwrap() as u32) + / (metadata["num_attention_heads"].as_u64().unwrap() as u32), + ), + ), + ( + "llama.rope.freq_base", + gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32), + ), + ( + "general.architecture", + gguf_file::Value::from_string(architecture.clone()), + ), + ] + } else { + vec![] + }; + for in_file in in_files { + if let Some(metadata) = &metadata_file { + if Some(in_file) == Some(metadata) { + continue; + } + } + + println!("Loading tensors from file: {:?}", in_file); + let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; + + let processed_tensors = in_tensors + .into_par_iter() + .map(|(mut name, tensor)| { + let mut local_dtype = dtype.clone(); + let mut should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; + let mut tensor = tensor; + + if should_quantize && bitnet_mode { + let is_bitnet_weight = name.contains("self_attn.v_proj") + || name.contains("self_attn.q_proj") + || name.contains("self_attn.o_proj") + || name.contains("self_attn.k_proj") + || name.contains("mlp.down_proj") + || name.contains("mlp.up_proj") + || name.contains("mlp.gate_proj"); + + if is_bitnet_weight { + println!(" unpacking {name} {tensor:?} {should_quantize}"); + tensor = unpack_bitnet_weights(&tensor)?; + local_dtype = bq.clone().unwrap().dtype(); + } + } + + if name == "lm_head.weight" { + local_dtype = GgmlDType::Q6K; + } + + // apply transformations to the tensors, based on the architecture + match architecture.as_str() { + "llama" => { + if name.ends_with("self_attn.q_proj.weight") { + tensor = permute( + &tensor, + num_attention_heads as usize, + Some(num_attention_heads as usize), + )?; + } + if name.ends_with("self_attn.k_proj.weight") { + tensor = permute( + &tensor, + num_attention_heads as usize, + Some(num_key_value_heads as usize), + )?; + } + } + _ => {} + } + + println!(" quantizing {name} {tensor:?} {should_quantize}"); + let tensor = if should_quantize { + QTensor::quantize(&tensor, local_dtype)? + } else { + QTensor::quantize(&tensor, GgmlDType::F32)? + }; + + if name == "model.embed_tokens.weight" { + name = "token_embd.weight".to_string(); + } else if name == "model.norm.weight" { + name = "output_norm.weight".to_string(); + } else if name == "lm_head.weight" { + name = "output.weight".to_string(); + } + + name = name.replace("model.layers.", "blk."); + name = name.replace("self_attn.q_proj", "attn_q"); + name = name.replace("self_attn.k_proj", "attn_k"); + name = name.replace("self_attn.v_proj", "attn_v"); + name = name.replace("self_attn.o_proj", "attn_output"); + name = name.replace("mlp.gate_proj", "ffn_gate"); + name = name.replace("mlp.down_proj", "ffn_down"); + name = name.replace("mlp.up_proj", "ffn_up"); + name = name.replace("input_layernorm", "attn_norm"); + name = name.replace("post_attention_layernorm", "ffn_norm"); + + Ok((name, tensor)) + }) + .collect::>>()?; + + qtensors.extend(processed_tensors); + } + let qtensors = qtensors .iter() .map(|(k, v)| (k.as_str(), v)) .collect::>(); - gguf_file::write(&mut out_file, &[], &qtensors)?; + + gguf_file::write(&mut out_file, &gguf_metadata, &qtensors)?; Ok(()) } @@ -454,11 +670,16 @@ fn run_quantize( out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, + bq: Option, + bitnet_mode: bool, device: &Device, ) -> Result<()> { if in_files.is_empty() { candle::bail!("no specified input files") } + if bitnet_mode && bq.is_none() { + candle::bail!("bitnet mode requires a bitnet quantization") + } if let Some(extension) = out_file.extension() { if extension == "safetensors" { candle::bail!("the generated file cannot use the safetensors extension") @@ -466,7 +687,7 @@ fn run_quantize( } if let Some(extension) = in_files[0].extension() { if extension == "safetensors" { - return run_quantize_safetensors(in_files, out_file, q); + return run_quantize_safetensors(in_files, out_file, q, bq, bitnet_mode); } } @@ -488,7 +709,7 @@ fn run_quantize( println!(" quantizing {name}"); let mut in_file = std::fs::File::open(&in_files[0])?; let tensor = content.tensor(&mut in_file, name, device)?; - let tensor = qmode.quantize(name, tensor, dtype)?; + let tensor = qmode.quantize(name, tensor, dtype, bitnet_mode)?; Ok((name, tensor)) }) .collect::>>()?; @@ -500,7 +721,7 @@ fn run_quantize( let metadata = content .metadata .iter() - .map(|(k, v)| (k.as_str(), v)) + .map(|(k, v)| (k.as_str(), v.clone())) .collect::>(); gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?; Ok(()) @@ -534,8 +755,18 @@ fn main() -> anyhow::Result<()> { in_file, out_file, quantization, + bitnet_quantization, + mode, + bitnet_mode, + } => run_quantize( + &in_file, + out_file, + quantization, mode, - } => run_quantize(&in_file, out_file, quantization, mode, &device)?, + bitnet_quantization, + bitnet_mode, + &device, + )?, Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(())