diff --git a/rust/src/model/llama.rs b/rust/src/model/llama.rs index d8f5fa5..d3fc0e6 100644 --- a/rust/src/model/llama.rs +++ b/rust/src/model/llama.rs @@ -1,7 +1,7 @@ use mlx_macros::ModuleParameters as DeriveModuleParameters; use mlx_rs::builder::Builder; use mlx_rs::error::Exception; -use mlx_rs::module::Module; +use mlx_rs::module::{Module, ModuleParameters}; use mlx_rs::nn::{Embedding, Linear, RmsNorm, Rope, RopeBuilder}; use mlx_rs::Array; use serde::{Deserialize, Serialize}; @@ -549,39 +549,57 @@ fn sample_categorical(probs: &[f32]) -> i32 { /// Loads pre-trained weights into a LlamaForCausalLM model. /// This function maps safetensors weight names to model parameters. pub fn load_weights_into_model( - _model: &mut LlamaForCausalLM, + model: &mut LlamaForCausalLM, weights: HashMap, ) -> anyhow::Result<()> { println!("Loading {} weight tensors into model...", weights.len()); - let _loaded_count = 0; - let missing_keys: Vec = Vec::new(); - - // TODO: Weight Loading API - Needs mlx-rs parameter setting documentation - // The model derives ModuleParameters (via #[derive(ModuleParameters)] and #[param] attributes), - // which provides access to parameters via model.parameters() returning a NestedHashMap. - // - // To load weights, we need to: - // 1. Iterate over model.parameters() to get parameter names and references - // 2. Match safetensors keys to parameter names (handling name mapping) - // 3. Set parameter values using the appropriate mlx-rs API - // - // Expected pattern (needs mlx-rs API confirmation): - // for (name, param) in model.parameters().flatten() { - // if let Some(weight_array) = weights.get(&name) { - // param.set_value(weight_array)?; // or similar API - // loaded_count += 1; - // } else { - // missing_keys.push(name.clone()); - // } - // } - // - // For now, report weights loaded from file without setting them. - - println!("Loaded {} weight tensors from safetensors", weights.len()); - println!("Weight loading into model structure needs mlx-rs parameter update API"); - - let loaded_count = weights.len(); + let mut loaded_count = 0; + let mut missing_keys: Vec = Vec::new(); + let mut extra_keys: Vec = Vec::new(); + + // Get mutable access to model parameters + let mut parameters = model.parameters_mut().flatten(); + + // Load weights from safetensors into model parameters + for (param_name, param) in parameters.iter_mut() { + let param_name_str = param_name.to_string(); + + if let Some(weight_array) = weights.get(¶m_name_str) { + // Verify shape matches + if weight_array.shape() != param.shape() { + eprintln!( + "Warning: Shape mismatch for {}: expected {:?}, got {:?}", + param_name_str, + param.shape(), + weight_array.shape() + ); + missing_keys.push(param_name_str); + continue; + } + + // Set the parameter value using double dereference + // This is the same pattern used in trainer.rs for parameter updates + **param = weight_array.clone(); + let _ = param.eval(); // Materialize on GPU + loaded_count += 1; + } else { + missing_keys.push(param_name_str); + } + } + + // Find extra keys in weights that don't match any model parameters + for weight_key in weights.keys() { + if !parameters.contains_key(weight_key.as_str()) { + extra_keys.push(weight_key.clone()); + } + } + + println!( + "Successfully loaded {} / {} weight tensors into model", + loaded_count, + parameters.len() + ); if !missing_keys.is_empty() && missing_keys.len() < 10 { println!( @@ -590,6 +608,13 @@ pub fn load_weights_into_model( ); } + if !extra_keys.is_empty() && extra_keys.len() < 10 { + println!( + "Extra keys in safetensors (first 10): {:?}", + &extra_keys[..extra_keys.len().min(10)] + ); + } + if loaded_count == 0 { anyhow::bail!( "Failed to load any weights - parameter names may not match safetensors keys" diff --git a/rust/src/model/loader.rs b/rust/src/model/loader.rs index 99481d0..bc585a8 100644 --- a/rust/src/model/loader.rs +++ b/rust/src/model/loader.rs @@ -283,7 +283,7 @@ impl ModelLoader { fn load_lora_target_layers(&self, path: &Path) -> anyhow::Result> { // Initialize MLX by creating a small test array to ensure Metal backend is ready let _init_test = mlx_rs::ops::zeros::(&[1_i32])?; - + let data = std::fs::read(path)?; let tensors = SafeTensors::deserialize(&data)?; @@ -315,7 +315,10 @@ impl ModelLoader { let estimated_mb = (total_elements * element_bytes) / (1024 * 1024); // Log every tensor we're about to load - print!(" Loading '{}' ({:?}, {} MB)... ", name, shape, estimated_mb); + print!( + " Loading '{}' ({:?}, {} MB)... ", + name, shape, estimated_mb + ); std::io::stdout().flush().ok(); if estimated_mb > 500 { diff --git a/rust/src/training/trainer.rs b/rust/src/training/trainer.rs index c8f200e..2a6c598 100644 --- a/rust/src/training/trainer.rs +++ b/rust/src/training/trainer.rs @@ -30,7 +30,7 @@ pub struct DistrustTrainer { adam_step: usize, // Step counter for bias correction // Gradient accumulation state accumulated_gradients: std::collections::HashMap, // Accumulated gradients - accumulation_step: usize, // Current micro-step in accumulation + accumulation_step: usize, // Current micro-step in accumulation dataset: Option, global_step: usize, loss_history: Vec, @@ -130,29 +130,23 @@ impl DistrustTrainer { llama_config.num_attention_heads ); - // TEMPORARY: Skip weight loading due to MLX/Metal stability issues - // Using random initialization for testing training loop optimizations - println!("Using random initialization (weight loading disabled for testing)"); - let model = LlamaForCausalLM::new(llama_config)?; - - // TODO: Re-enable weight loading once MLX stability issues are resolved - // let loader = ModelLoader::new(&config.paths.model_path); - // let weights = loader.load_safetensors().unwrap_or_else(|e| { - // println!("Warning: Could not load weights from safetensors: {}", e); - // println!("Model will use random initialization"); - // std::collections::HashMap::new() - // }); - // - // let model = if !weights.is_empty() { - // println!( - // "Loading model with {} pre-trained weight tensors", - // weights.len() - // ); - // crate::model::llama::load_model_with_weights(llama_config, weights)? - // } else { - // println!("Initializing model with random weights"); - // LlamaForCausalLM::new(llama_config)? - // }; + let loader = ModelLoader::new(&config.paths.model_path); + let weights = loader.load_safetensors().unwrap_or_else(|e| { + println!("Warning: Could not load weights from safetensors: {}", e); + println!("Model will use random initialization"); + std::collections::HashMap::new() + }); + + let model = if !weights.is_empty() { + println!( + "Loading model with {} pre-trained weight tensors", + weights.len() + ); + crate::model::llama::load_model_with_weights(llama_config, weights)? + } else { + println!("Initializing model with random weights"); + LlamaForCausalLM::new(llama_config)? + }; // Load tokenizer let tokenizer_path = model_dir.join("tokenizer.json"); @@ -860,7 +854,8 @@ impl DistrustTrainer { } } else { // First accumulation - initialize - self.accumulated_gradients.insert(param_name_str, (grad_data, grad_shape)); + self.accumulated_gradients + .insert(param_name_str, (grad_data, grad_shape)); } } @@ -906,7 +901,11 @@ impl DistrustTrainer { let mut frozen_params = 0usize; // Get parameter names from accumulated gradients - let param_names: Vec = self.accumulated_gradients.keys().map(|k| k.to_string()).collect(); + let param_names: Vec = self + .accumulated_gradients + .keys() + .map(|k| k.to_string()) + .collect(); // Scale factor for accumulated gradients let grad_scale = 1.0 / grad_accum_steps as f32; @@ -932,12 +931,13 @@ impl DistrustTrainer { } // Get accumulated gradient and scale it - let grad_data: Vec = if let Some((acc_grad, _)) = self.accumulated_gradients.get(¶m_name) { - // Scale by 1/N to get average gradient - acc_grad.iter().map(|&g| g * grad_scale).collect() - } else { - continue; - }; + let grad_data: Vec = + if let Some((acc_grad, _)) = self.accumulated_gradients.get(¶m_name) { + // Scale by 1/N to get average gradient + acc_grad.iter().map(|&g| g * grad_scale).collect() + } else { + continue; + }; // Get current parameter value and materialize it let (param_data, param_shape): (Vec, Vec) = {