Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion rust/src/config/training.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,34 @@
use serde::{Deserialize, Serialize};

/// Training mode determines how gradients are computed
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum TrainingMode {
/// LoRA: Low-Rank Adaptation - only train small adapter matrices
LoRA { rank: usize },
/// FullFineTune: Train selected parameters (lm_head, norms, etc.)
FullFineTune { targets: Vec<String> },
/// Inference only - no training
Frozen,
}

impl TrainingMode {
/// Auto-detect training mode from lora_rank parameter
pub fn from_lora_rank(lora_rank: usize) -> Self {
if lora_rank > 0 {
TrainingMode::LoRA { rank: lora_rank }
} else {
TrainingMode::FullFineTune {
targets: vec!["head.lm_head".to_string(), "head.norm".to_string()],
}
}
}
}

/// Training configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
#[serde(skip)]
pub training_mode: Option<TrainingMode>,
pub batch_size: usize,
pub gradient_accumulation_steps: usize,
pub max_steps: usize,
Expand All @@ -23,13 +49,17 @@ pub struct TrainingConfig {
pub thermal_throttle: f32,
pub alpha: f32, // Distrust loss alpha parameter
pub lambda_weight: f32, // Weight for distrust loss term
// Periodic reload to work around MLX-rs memory leak
pub reload_interval_steps: usize, // Reload model every N steps (0 = disabled)
pub reload_memory_threshold_gb: f64, // Or reload when MLX memory exceeds this
}

impl Default for TrainingConfig {
fn default() -> Self {
Self {
training_mode: None, // Set during trainer initialization based on lora_rank
batch_size: 1, // Reduced from 2 for better memory efficiency
gradient_accumulation_steps: 8,
gradient_accumulation_steps: 1,
max_steps: 5000,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Default gradient_accumulation_steps change alters effective batch/training dynamics; consider surfacing in docs/CLI help.
Going from 8 → 1 changes the effective batch size (and can change loss scale, LR sensitivity, and checkpoint/metric comparability). If this is primarily for OOM avoidance, consider also documenting a recommended “old behavior” equivalent (e.g., keep batch size small but set accumulation back to 8).

save_steps: 500,
eval_steps: 250,
Expand All @@ -48,6 +78,8 @@ impl Default for TrainingConfig {
thermal_throttle: 0.0,
alpha: 2.7, // Brian Roemmele's recommended alpha
lambda_weight: 1.0, // Balance between CE and distrust loss
reload_interval_steps: 40, // Reload every 40 steps to reset MLX memory
reload_memory_threshold_gb: 80.0, // Or reload when memory exceeds 80 GB
}
}
}
202 changes: 174 additions & 28 deletions rust/src/model/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ impl LlamaMLP {
#[derive(Debug, Clone, DeriveModuleParameters)]
pub struct LlamaDecoderLayer {
#[param]
pub attention: LlamaAttention,
pub self_attn: LlamaAttention,
#[param]
pub mlp: LlamaMLP,
#[param]
Expand All @@ -331,13 +331,13 @@ pub struct LlamaDecoderLayer {

impl LlamaDecoderLayer {
pub fn new(config: &LlamaConfig) -> Result<Self, Exception> {
let attention = LlamaAttention::new(config)?;
let self_attn = LlamaAttention::new(config)?;
let mlp = LlamaMLP::new(config)?;
let input_layernorm = RmsNorm::new(config.hidden_size)?;
let post_attention_layernorm = RmsNorm::new(config.hidden_size)?;

Ok(Self {
attention,
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
Expand All @@ -347,7 +347,7 @@ impl LlamaDecoderLayer {
pub fn forward(&mut self, x: &Array, mask: Option<&Array>) -> Result<Array, Exception> {
// Pre-norm attention with residual
let normed = self.input_layernorm.forward(x)?;
let attn_output = self.attention.forward(&normed, mask)?;
let attn_output = self.self_attn.forward(&normed, mask)?;
let x = x.add(&attn_output)?;

// Pre-norm MLP with residual
Expand Down Expand Up @@ -421,30 +421,128 @@ impl LlamaModel {
}
}

/// Llama model for causal language modeling
/// Frozen backbone - never participates in gradient computation
/// This prevents MLX from allocating gradient Arrays for frozen parameters
#[derive(Debug, Clone, DeriveModuleParameters)]
pub struct LlamaForCausalLM {
pub struct LlamaBackbone {
#[param]
pub embed_tokens: Embedding,
#[param]
pub layers: Vec<LlamaDecoderLayer>,
pub config: LlamaConfig,
}

impl LlamaBackbone {
pub fn new(config: LlamaConfig) -> Result<Self, Exception> {
let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size)?;

let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(LlamaDecoderLayer::new(&config)?);
}

Ok(Self {
embed_tokens,
layers,
config,
})
}

/// Forward pass through frozen backbone (for use outside gradient graph)
pub fn forward(&mut self, input_ids: &Array) -> Result<Array, Exception> {
// Embed tokens
let mut hidden_states = self.embed_tokens.forward(input_ids)?;

// Create causal mask
let seq_len = input_ids.dim(1);
let mask = self.create_causal_mask(seq_len)?;

// Pass through all decoder layers
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states, Some(&mask))?;
}

Ok(hidden_states)
}

fn create_causal_mask(&self, seq_len: i32) -> Result<Array, Exception> {
let indices = mlx_rs::ops::arange::<_, f32>(0, seq_len, 1)?;
let row = mlx_rs::ops::expand_dims(&indices, 0)?;
let col = mlx_rs::ops::expand_dims(&indices, 1)?;
let mask = row.lt(&col)?;
let mask = mask.as_type::<f32>()?;
let neg_inf = Array::from_f32(-1e9_f32);
mask.multiply(&neg_inf)
}
}

/// Trainable head - only these parameters get gradients
/// This is the KEY to zero memory leaks - value_and_grad only sees these params
#[derive(Debug, Clone, DeriveModuleParameters)]
pub struct TrainableHead {
#[param]
pub model: LlamaModel,
pub norm: RmsNorm,
#[param]
pub lm_head: Linear,
}

impl TrainableHead {
pub fn new(config: &LlamaConfig) -> Result<Self, Exception> {
let norm = RmsNorm::new(config.hidden_size)?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size)?;

Ok(Self { norm, lm_head })
}

/// Forward pass through trainable head (for use in gradient computation)
pub fn forward(&mut self, hidden_states: &Array) -> Result<Array, Exception> {
let normalized = self.norm.forward(hidden_states)?;
self.lm_head.forward(&normalized)
}
}

/// Llama model for causal language modeling with split architecture
/// Backbone is frozen, only head (or LoRA adapters) participate in gradients
#[derive(Debug, Clone, DeriveModuleParameters)]
pub struct LlamaForCausalLM {
#[param]
pub backbone: LlamaBackbone,
#[param]
pub head: TrainableHead,
// LoRA adapters will be added later
pub lora_rank: usize,
}

impl LlamaForCausalLM {
pub fn new(config: LlamaConfig) -> Result<Self, Exception> {
let model = LlamaModel::new(config.clone())?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size)?;
let backbone = LlamaBackbone::new(config.clone())?;
let head = TrainableHead::new(&config)?;

Ok(Self { model, lm_head })
Ok(Self {
backbone,
head,
lora_rank: 0,
})
}

pub fn forward(&mut self, input_ids: &Array) -> Result<Array, Exception> {
let hidden_states = self.model.forward(input_ids)?;
self.lm_head.forward(&hidden_states)
let hidden_states = self.backbone.forward(input_ids)?;
self.head.forward(&hidden_states)
}

/// Forward through backbone only (returns hidden states before head)
/// Use this outside gradient computation to prevent memory leaks
pub fn forward_backbone(&mut self, input_ids: &Array) -> Result<Array, Exception> {
self.backbone.forward(input_ids)
}

/// Forward through head only (for use in gradient computation)
pub fn forward_head(&mut self, hidden_states: &Array) -> Result<Array, Exception> {
self.head.forward(hidden_states)
}

pub fn config(&self) -> &LlamaConfig {
&self.model.config
&self.backbone.config
}

/// Generate text autoregressively from input token IDs
Expand Down Expand Up @@ -562,30 +660,56 @@ pub fn load_weights_into_model(
let mut parameters = model.parameters_mut().flatten();

// Load weights from safetensors into model parameters
// Handle name translation for split architecture:
// - "model.layers.X" → "backbone.layers.X"
// - "model.norm" → "head.norm"
// - "lm_head" → "head.lm_head"
// - "model.embed_tokens" → "backbone.embed_tokens"
for (param_name, param) in parameters.iter_mut() {
let param_name_str = param_name.to_string();

// Try direct match first
if let Some(weight_array) = weights.get(&param_name_str) {
// Verify shape matches
if weight_array.shape() != param.shape() {
if weight_array.shape() == param.shape() {
**param = weight_array.clone();
let _ = param.eval();
loaded_count += 1;
continue;
}
}

// Try legacy name mapping for split architecture compatibility
let legacy_name = if param_name_str.starts_with("backbone.") {
// "backbone.layers.X" → "model.layers.X"
// "backbone.embed_tokens" → "model.embed_tokens"
param_name_str.replace("backbone.", "model.")
} else if param_name_str == "head.norm" {
"model.norm".to_string()
} else if param_name_str == "head.lm_head" {
"lm_head".to_string()
} else {
param_name_str.clone()
};

if let Some(weight_array) = weights.get(&legacy_name) {
if weight_array.shape() == param.shape() {
**param = weight_array.clone();
let _ = param.eval();
loaded_count += 1;
continue;
} else {
eprintln!(
"Warning: Shape mismatch for {}: expected {:?}, got {:?}",
"Warning: Shape mismatch for {} (legacy: {}): expected {:?}, got {:?}",
param_name_str,
legacy_name,
param.shape(),
weight_array.shape()
);
missing_keys.push(param_name_str);
continue;
}

// Set the parameter value using double dereference
// This is the same pattern used in trainer.rs for parameter updates
**param = weight_array.clone();
let _ = param.eval(); // Materialize on GPU
loaded_count += 1;
} else {
missing_keys.push(param_name_str);
}

// Not found with either name
missing_keys.push(param_name_str);
}

// Find extra keys in weights that don't match any model parameters
Expand All @@ -601,21 +725,43 @@ pub fn load_weights_into_model(
parameters.len()
);

if !missing_keys.is_empty() && missing_keys.len() < 10 {
if !missing_keys.is_empty() {
println!(
"Missing keys (first 10): {:?}",
&missing_keys[..missing_keys.len().min(10)]
);
}

if !extra_keys.is_empty() && extra_keys.len() < 10 {
if !extra_keys.is_empty() {
println!(
"Extra keys in safetensors (first 10): {:?}",
&extra_keys[..extra_keys.len().min(10)]
);
}

if loaded_count == 0 {
// Enhanced debugging: print sample parameter names and safetensors keys
eprintln!("\nERROR: Parameter name mismatch detected!");
eprintln!("No weights were successfully loaded into the model.");

if weights.is_empty() {
eprintln!("\nThe weights HashMap is empty!");
eprintln!("This should have been caught by the caller - please use random initialization instead.");
} else {
let param_names: Vec<String> = parameters.keys().map(|k| k.to_string()).collect();
let weight_keys: Vec<String> = weights.keys().cloned().collect();

eprintln!("\nSample model parameter names (first 5):");
for name in param_names.iter().take(5) {
eprintln!(" - {}", name);
}

eprintln!("\nSample safetensors keys (first 5):");
for key in weight_keys.iter().take(5) {
eprintln!(" - {}", key);
}
}

anyhow::bail!(
"Failed to load any weights - parameter names may not match safetensors keys"
);
Expand Down
19 changes: 19 additions & 0 deletions rust/src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@ fn safe_array_from_slice_f32(
);
}

// Check for invalid shapes
if shape.iter().any(|&s| s <= 0) {
anyhow::bail!(
"Invalid shape for tensor '{}': {:?} contains non-positive dimensions",
tensor_name,
shape
);
}

// Check for excessively large tensors that might cause OOM
let size_mb = (total_elements * 4) / (1024 * 1024);
if size_mb > 2048 {
anyhow::bail!(
"Tensor '{}' is too large ({} MB) - may cause memory issues",
tensor_name,
size_mb
);
}

// Try to create array - if this fails, it will panic/abort
// We can't catch C++ exceptions, so we validate beforehand
Ok(Array::from_slice(data, shape))
Expand Down
Loading
Loading