Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions rust/src/model/llama.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use mlx_macros::ModuleParameters as DeriveModuleParameters;
use mlx_rs::builder::Builder;
use mlx_rs::error::Exception;
use mlx_rs::module::Module;
use mlx_rs::module::{Module, ModuleParameters};
use mlx_rs::nn::{Embedding, Linear, RmsNorm, Rope, RopeBuilder};
use mlx_rs::Array;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -549,39 +549,57 @@ fn sample_categorical(probs: &[f32]) -> i32 {
/// Loads pre-trained weights into a LlamaForCausalLM model.
/// This function maps safetensors weight names to model parameters.
pub fn load_weights_into_model(
_model: &mut LlamaForCausalLM,
model: &mut LlamaForCausalLM,
weights: HashMap<String, Array>,
) -> anyhow::Result<()> {
println!("Loading {} weight tensors into model...", weights.len());

let _loaded_count = 0;
let missing_keys: Vec<String> = Vec::new();

// TODO: Weight Loading API - Needs mlx-rs parameter setting documentation
// The model derives ModuleParameters (via #[derive(ModuleParameters)] and #[param] attributes),
// which provides access to parameters via model.parameters() returning a NestedHashMap.
//
// To load weights, we need to:
// 1. Iterate over model.parameters() to get parameter names and references
// 2. Match safetensors keys to parameter names (handling name mapping)
// 3. Set parameter values using the appropriate mlx-rs API
//
// Expected pattern (needs mlx-rs API confirmation):
// for (name, param) in model.parameters().flatten() {
// if let Some(weight_array) = weights.get(&name) {
// param.set_value(weight_array)?; // or similar API
// loaded_count += 1;
// } else {
// missing_keys.push(name.clone());
// }
// }
//
// For now, report weights loaded from file without setting them.

println!("Loaded {} weight tensors from safetensors", weights.len());
println!("Weight loading into model structure needs mlx-rs parameter update API");

let loaded_count = weights.len();
let mut loaded_count = 0;
let mut missing_keys: Vec<String> = Vec::new();
let mut extra_keys: Vec<String> = Vec::new();

// Get mutable access to model parameters
let mut parameters = model.parameters_mut().flatten();

// Load weights from safetensors into model parameters
for (param_name, param) in parameters.iter_mut() {
let param_name_str = param_name.to_string();

if let Some(weight_array) = weights.get(&param_name_str) {
// Verify shape matches
if weight_array.shape() != param.shape() {
eprintln!(
"Warning: Shape mismatch for {}: expected {:?}, got {:?}",
param_name_str,
param.shape(),
weight_array.shape()
);
missing_keys.push(param_name_str);
continue;
}

// Set the parameter value using double dereference
// This is the same pattern used in trainer.rs for parameter updates
**param = weight_array.clone();
let _ = param.eval(); // Materialize on GPU
loaded_count += 1;
} else {
missing_keys.push(param_name_str);
}
}

// Find extra keys in weights that don't match any model parameters
for weight_key in weights.keys() {
if !parameters.contains_key(weight_key.as_str()) {
extra_keys.push(weight_key.clone());
}
}

println!(
"Successfully loaded {} / {} weight tensors into model",
loaded_count,
parameters.len()
);

if !missing_keys.is_empty() && missing_keys.len() < 10 {
println!(
Expand All @@ -590,6 +608,13 @@ pub fn load_weights_into_model(
);
}

if !extra_keys.is_empty() && extra_keys.len() < 10 {
println!(
"Extra keys in safetensors (first 10): {:?}",
&extra_keys[..extra_keys.len().min(10)]
);
}

if loaded_count == 0 {
anyhow::bail!(
"Failed to load any weights - parameter names may not match safetensors keys"
Expand Down
7 changes: 5 additions & 2 deletions rust/src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ impl ModelLoader {
fn load_lora_target_layers(&self, path: &Path) -> anyhow::Result<HashMap<String, Array>> {
// Initialize MLX by creating a small test array to ensure Metal backend is ready
let _init_test = mlx_rs::ops::zeros::<f32>(&[1_i32])?;

let data = std::fs::read(path)?;
let tensors = SafeTensors::deserialize(&data)?;

Expand Down Expand Up @@ -315,7 +315,10 @@ impl ModelLoader {
let estimated_mb = (total_elements * element_bytes) / (1024 * 1024);

// Log every tensor we're about to load
print!(" Loading '{}' ({:?}, {} MB)... ", name, shape, estimated_mb);
print!(
" Loading '{}' ({:?}, {} MB)... ",
name, shape, estimated_mb
);
std::io::stdout().flush().ok();

if estimated_mb > 500 {
Expand Down
64 changes: 32 additions & 32 deletions rust/src/training/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct DistrustTrainer {
adam_step: usize, // Step counter for bias correction
// Gradient accumulation state
accumulated_gradients: std::collections::HashMap<String, OptimizerState>, // Accumulated gradients
accumulation_step: usize, // Current micro-step in accumulation
accumulation_step: usize, // Current micro-step in accumulation
dataset: Option<StreamingDataset>,
global_step: usize,
loss_history: Vec<f32>,
Expand Down Expand Up @@ -130,29 +130,23 @@ impl DistrustTrainer {
llama_config.num_attention_heads
);

// TEMPORARY: Skip weight loading due to MLX/Metal stability issues
// Using random initialization for testing training loop optimizations
println!("Using random initialization (weight loading disabled for testing)");
let model = LlamaForCausalLM::new(llama_config)?;

// TODO: Re-enable weight loading once MLX stability issues are resolved
// let loader = ModelLoader::new(&config.paths.model_path);
// let weights = loader.load_safetensors().unwrap_or_else(|e| {
// println!("Warning: Could not load weights from safetensors: {}", e);
// println!("Model will use random initialization");
// std::collections::HashMap::new()
// });
//
// let model = if !weights.is_empty() {
// println!(
// "Loading model with {} pre-trained weight tensors",
// weights.len()
// );
// crate::model::llama::load_model_with_weights(llama_config, weights)?
// } else {
// println!("Initializing model with random weights");
// LlamaForCausalLM::new(llama_config)?
// };
let loader = ModelLoader::new(&config.paths.model_path);
let weights = loader.load_safetensors().unwrap_or_else(|e| {
println!("Warning: Could not load weights from safetensors: {}", e);
println!("Model will use random initialization");
std::collections::HashMap::new()
});

let model = if !weights.is_empty() {
println!(
"Loading model with {} pre-trained weight tensors",
weights.len()
);
crate::model::llama::load_model_with_weights(llama_config, weights)?
} else {
println!("Initializing model with random weights");
LlamaForCausalLM::new(llama_config)?
};

// Load tokenizer
let tokenizer_path = model_dir.join("tokenizer.json");
Expand Down Expand Up @@ -860,7 +854,8 @@ impl DistrustTrainer {
}
} else {
// First accumulation - initialize
self.accumulated_gradients.insert(param_name_str, (grad_data, grad_shape));
self.accumulated_gradients
.insert(param_name_str, (grad_data, grad_shape));
}
}

Expand Down Expand Up @@ -906,7 +901,11 @@ impl DistrustTrainer {
let mut frozen_params = 0usize;

// Get parameter names from accumulated gradients
let param_names: Vec<String> = self.accumulated_gradients.keys().map(|k| k.to_string()).collect();
let param_names: Vec<String> = self
.accumulated_gradients
.keys()
.map(|k| k.to_string())
.collect();

// Scale factor for accumulated gradients
let grad_scale = 1.0 / grad_accum_steps as f32;
Expand All @@ -932,12 +931,13 @@ impl DistrustTrainer {
}

// Get accumulated gradient and scale it
let grad_data: Vec<f32> = if let Some((acc_grad, _)) = self.accumulated_gradients.get(&param_name) {
// Scale by 1/N to get average gradient
acc_grad.iter().map(|&g| g * grad_scale).collect()
} else {
continue;
};
let grad_data: Vec<f32> =
if let Some((acc_grad, _)) = self.accumulated_gradients.get(&param_name) {
// Scale by 1/N to get average gradient
acc_grad.iter().map(|&g| g * grad_scale).collect()
} else {
continue;
};

// Get current parameter value and materialize it
let (param_data, param_shape): (Vec<f32>, Vec<i32>) = {
Expand Down