Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/src/model/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ impl ModelLoader {
fn load_lora_target_layers(&self, path: &Path) -> anyhow::Result<HashMap<String, Array>> {
// Initialize MLX by creating a small test array to ensure Metal backend is ready
let _init_test = mlx_rs::ops::zeros::<f32>(&[1_i32])?;

let data = std::fs::read(path)?;
let tensors = SafeTensors::deserialize(&data)?;

Expand Down
40 changes: 17 additions & 23 deletions rust/src/training/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,23 @@ impl DistrustTrainer {
llama_config.num_attention_heads
);

// TEMPORARY: Skip weight loading due to MLX/Metal stability issues
// Using random initialization for testing training loop optimizations
println!("Using random initialization (weight loading disabled for testing)");
let model = LlamaForCausalLM::new(llama_config)?;

// TODO: Re-enable weight loading once MLX stability issues are resolved
// let loader = ModelLoader::new(&config.paths.model_path);
// let weights = loader.load_safetensors().unwrap_or_else(|e| {
// println!("Warning: Could not load weights from safetensors: {}", e);
// println!("Model will use random initialization");
// std::collections::HashMap::new()
// });
//
// let model = if !weights.is_empty() {
// println!(
// "Loading model with {} pre-trained weight tensors",
// weights.len()
// );
// crate::model::llama::load_model_with_weights(llama_config, weights)?
// } else {
// println!("Initializing model with random weights");
// LlamaForCausalLM::new(llama_config)?
// };
let loader = ModelLoader::new(&config.paths.model_path);
let weights = loader.load_safetensors().unwrap_or_else(|e| {
println!("Warning: Could not load weights from safetensors: {}", e);
println!("Model will use random initialization");
std::collections::HashMap::new()
});

let model = if !weights.is_empty() {
println!(
"Loading model with {} pre-trained weight tensors",
weights.len()
);
crate::model::llama::load_model_with_weights(llama_config, weights)?
} else {
println!("Initializing model with random weights");
LlamaForCausalLM::new(llama_config)?
};

// Load tokenizer
let tokenizer_path = model_dir.join("tokenizer.json");
Expand Down
Loading