Conversation
|
Important Review skippedToo many files! 31 files out of 181 files are above the max files limit of 150. You can disable this status message by setting the 📝 WalkthroughWalkthroughAdds a GPU-first Rust training stack: split Llama model (frozen backbone + trainable head), LoRA hooks and safetensors I/O, an async, memory‑aware zero‑leak Trainer with periodic reloads and checkpointing, deep‑detach MLX utility, async CLI/benchmarks, checkpoint tests, and extensive docs/scripts for validation and memory fixes. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Trainer
participant MemoryMonitor as Memory
participant Backbone
participant Head
participant Optimizer
participant CheckpointManager as Checkpoint
User->>Trainer: start_training(max_steps)
Trainer->>Memory: capture_baseline()
loop per training step
Trainer->>Backbone: forward_backbone(input_ids)
Backbone-->>Trainer: hidden_states
Trainer->>Head: forward_head(hidden_states)
Head-->>Trainer: logits/loss
Trainer->>Optimizer: apply_gpu_optimizer_update()
Optimizer-->>Trainer: params_updated
Trainer->>Memory: get_active_memory()
Memory-->>Trainer: current_gb
alt reload interval or threshold exceeded
Trainer->>Checkpoint: save(checkpoint)
Trainer->>Trainer: reload_from_checkpoint()
Trainer->>Checkpoint: load_weights_into_model()
Trainer->>Optimizer: restore_optimizer_state()
Trainer->>Memory: reset_baseline()
end
alt periodic save
Trainer->>Checkpoint: save_checkpoint()
end
end
Trainer-->>User: training_complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 10
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
rust/src/training/trainer.rs (1)
432-438: Progress bar + ETA should usecalculated_max_steps, notconfig.training.max_steps.Right now the bar length and ETA can be wrong when you cap steps for memory safety.
- let pb = ProgressBar::new(self.config.training.max_steps as u64); + let pb = ProgressBar::new(calculated_max_steps as u64); @@ - let steps_remaining = self.config.training.max_steps - (self.global_step + 1); + let steps_remaining = calculated_max_steps - (self.global_step + 1);Also applies to: 492-493, 785-791
🧹 Nitpick comments (9)
rust/src/utils/memory.rs (1)
301-306: Ignore comment is misleading but the workaround is appropriate.The comment says "requires Metal device" but
MemoryInfo::current()uses system commands (ps,sysctl), not Metal directly. The actual issue (per TEST_STATUS.md) is that the MLX library initializes Metal at binary load time, crashing before tests run. Consider clarifying the comment:- #[ignore] // Ignore in CI - requires Metal device which may not initialize in test mode + #[ignore] // Ignore in CI - MLX initializes Metal at binary load, which crashes in CI environmentsrust/src/model/loader.rs (2)
28-35: Consider reordering validations for clearer error messages.The non-positive dimension check should come before the element count check. Currently, a shape like
[-1, 100]would fail at the element count check with a confusing message about requiring "-100 elements". Moving this check earlier provides a more accurate error:fn safe_array_from_slice_f32( data: &[f32], shape: &[i32], tensor_name: &str, ) -> anyhow::Result<Array> { + // Check for invalid shapes first + if shape.iter().any(|&s| s <= 0) { + anyhow::bail!( + "Invalid shape for tensor '{}': {:?} contains non-positive dimensions", + tensor_name, + shape + ); + } + // Check if shape makes sense let total_elements: i64 = shape.iter().map(|&s| s as i64).product(); if total_elements != data.len() as i64 { anyhow::bail!( "Shape mismatch for tensor '{}': shape {:?} requires {} elements but data has {}", tensor_name, shape, total_elements, data.len() ); } - // Check for invalid shapes - if shape.iter().any(|&s| s <= 0) { - anyhow::bail!( - "Invalid shape for tensor '{}': {:?} contains non-positive dimensions", - tensor_name, - shape - ); - }
37-45: Hardcoded 2048 MB limit may be too restrictive; also missing fromsafe_array_from_slice_i32.Two concerns:
Inconsistency:
safe_array_from_slice_i32(lines 53-73) lacks this memory check, which could allow large i32 tensors to cause OOM.Hardcoded limit: 2048 MB may be too restrictive for legitimate large tensors in bigger models. Consider making this configurable or documenting the rationale.
If the limit is intentional, apply it consistently:
fn safe_array_from_slice_i32( data: &[i32], shape: &[i32], tensor_name: &str, ) -> anyhow::Result<Array> { + // Check for invalid shapes + if shape.iter().any(|&s| s <= 0) { + anyhow::bail!( + "Invalid shape for tensor '{}': {:?} contains non-positive dimensions", + tensor_name, + shape + ); + } + // Check if shape makes sense let total_elements: i64 = shape.iter().map(|&s| s as i64).product(); if total_elements != data.len() as i64 { // ...existing check... } + + // Check for excessively large tensors + let size_mb = (total_elements * 4) / (1024 * 1024); + if size_mb > 2048 { + anyhow::bail!( + "Tensor '{}' is too large ({} MB) - may cause memory issues", + tensor_name, + size_mb + ); + }rust/TEST_STATUS.md (1)
32-33: List all test names explicitly."(+ 2 more utility tests)" is vague. For completeness and easier debugging, list the actual test names.
VALIDATION_GUIDE.md (1)
269-271: Hardcoded snapshot hash will become stale.The path includes a specific snapshot hash (
368c8ed94ce4c986e7b9ca5c159651ef753908ce) that will break when the model cache updates. Consider using a glob pattern or documenting how to find the current snapshot:- --model ~/.cache/huggingface/hub/models--mlabonne--Meta-Llama-3.1-8B-Instruct-abliterated/snapshots/368c8ed94ce4c986e7b9ca5c159651ef753908ce \ + --model "$(ls -d ~/.cache/huggingface/hub/models--mlabonne--Meta-Llama-3.1-8B-Instruct-abliterated/snapshots/*/ | head -1)" \Or simply use the model name if
validate_model.pysupports HuggingFace model resolution.rust/src/config/training.rs (2)
14-24:TrainingMode::from_lora_rank(0)⇒FullFineTuneis a surprising default—considerFrozenor rename API.If
lora_rank == 0is meant to indicate “no training adapters”, returningFrozenis the safer default; otherwise rename to reflect that “0 means train head targets”.
30-32:#[serde(skip)] training_modeblocks config-driven mode/targets (and checkpoint persistence).If you want this runtime-only, consider moving it out of
TrainingConfig. If you want it user-configurable later, use#[serde(default, skip_serializing_if = "Option::is_none")]instead ofskip.rust/src/training/trainer.rs (1)
1113-1223:apply_gpu_optimizer_update()does repeatedparameters_mut().flatten()+ aggressive cache clears per param.Given grads are “2-3 params”, it’s probably fine, but you can still reduce overhead by hoisting
let mut head_params = ...outside the loop, and clearing caches once per step (not multiple times per param) unless you have evidence it’s required.rust/src/model/llama.rs (1)
663-724: Weight loader “extra keys” reporting will flag legacy names as extra (noisy).Consider tracking accepted legacy-name mappings in a
HashSetand using that when computingextra_keys, so logs reflect real extras only.Also applies to: 739-744
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
TRAINING_SUCCESS_SUMMARY.md(1 hunks)VALIDATION_GUIDE.md(1 hunks)rust/TEST_STATUS.md(1 hunks)rust/src/config/training.rs(3 hunks)rust/src/model/llama.rs(6 hunks)rust/src/model/loader.rs(1 hunks)rust/src/training/trainer.rs(23 hunks)rust/src/utils/memory.rs(2 hunks)rust/src/utils/mlx_memory.rs(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (2)
rust/src/config/training.rs (3)
rust/src/config/distrust.rs (1)
default(16-21)rust/src/config/mod.rs (1)
default(30-41)rust/src/config/model.rs (1)
default(20-37)
rust/src/training/trainer.rs (5)
rust/src/model/llama.rs (10)
new(158-182)new(293-303)new(333-345)new(373-389)new(436-449)new(490-495)new(517-526)config(544-546)mlx_rs(410-410)mlx_rs(469-469)rust/src/training/lora.rs (6)
new(93-116)apply_lora_to_model(27-79)mlx_rs(51-51)mlx_rs(58-58)mlx_rs(105-105)mlx_rs(108-108)rust/src/config/training.rs (1)
from_lora_rank(16-24)rust/src/utils/mlx_memory.rs (2)
get_active_memory(47-54)get_peak_memory(57-64)rust/src/distrust_loss.rs (2)
mlx_rs(116-116)batch_empirical_distrust_loss(106-140)
🪛 LanguageTool
TRAINING_SUCCESS_SUMMARY.md
[grammar] ~162-~162: Ensure spelling is correct
Context: ...-------------- | ------- | | Llama 8B abliterated | 100% | 100% ...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint
- GitHub Check: Build and Test
- GitHub Check: Build Examples
🔇 Additional comments (1)
rust/src/model/llama.rs (1)
424-546: Split backbone/head API (forward_backbone/forward_head) matches the zero-leak training design.This is a clean separation point for “no-grad backbone + grad head”.
| // #region agent log - after training_step | ||
| if let Ok(mut file) = std::fs::OpenOptions::new() | ||
| .create(true) | ||
| .append(true) | ||
| .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) | ||
| { | ||
| let json = serde_json::json!({ | ||
| "location": "trainer.rs:after_training_step", | ||
| "message": "training_step returned successfully", | ||
| "step": self.global_step, | ||
| "loss": loss, | ||
| "phase": "main_loop", | ||
| "timestamp": std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map(|d| d.as_millis()).unwrap_or(0), | ||
| "hypothesisId": "D-training-step" | ||
| }); | ||
| let _ = writeln!(file, "{}", json); | ||
| } |
There was a problem hiding this comment.
Avoid hardcoding /dev/null for logging; also dedupe manual JSON logging via log_debug().
/dev/null is non-portable, and these blocks duplicate the helper you already added. Prefer if let Some(path) = debug_log_path() guards everywhere (or route through self.log_debug).
- if let Ok(mut file) = std::fs::OpenOptions::new()
- .create(true)
- .append(true)
- .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null")))
- {
+ if let Some(log_path) = debug_log_path() {
+ if let Ok(mut file) = std::fs::OpenOptions::new()
+ .create(true)
+ .append(true)
+ .open(log_path)
+ {
let json = serde_json::json!({ ... });
let _ = writeln!(file, "{}", json);
+ }
}Also applies to: 821-857, 860-917
🤖 Prompt for AI Agents
rust/src/training/trainer.rs around lines 540 to 556: this block hardcodes
"/dev/null" and duplicates manual JSON logging; replace the OpenOptions block
with a guard that only runs when debug_log_path() returns Some(path) (do not
fall back to "/dev/null"), and route the JSON message through the existing
logging helper (self.log_debug or the centralized log_debug function) instead of
manually opening and writing to the file; ensure the JSON object is produced the
same way and pass it to the helper so all similar blocks (also at 821-857 and
860-917) are deduplicated and portable.
| ### Unit Tests: 14/16 PASSING (87.5%) | ||
|
|
There was a problem hiding this comment.
Test count inconsistency between summary and details.
Line 16 states "14/16 PASSING (87.5%)" but line 153 says "12 tests pass, 2 skip (Metal), 2 crash (Metal init)". If 12 pass and 2 are skipped (ignored), counting skipped as passing gives 14—but then the 2 "crash" tests would need clarification.
Please reconcile these numbers and clarify:
- Are
#[ignore]tests counted as passing or skipped? - What are the 2 "crash" tests vs the 2 "failing" tests mentioned in lines 35-39?
🤖 Prompt for AI Agents
In rust/TEST_STATUS.md around lines 16-17 and also referencing lines 35-39 and
line 153, reconcile and clarify the test counts and statuses: update the summary
line 16 to reflect the true pass/skip/fail/crash totals matching the detailed
breakdown at line 153; explicitly state whether #[ignore] tests are counted as
"skipped" (not passing) or treated as passing for the summary; replace ambiguous
terms "crash" vs "failing" with precise statuses (e.g., "crash (runtime panic)",
"fail (assertion failure)") and ensure the two tests reported as "crash (Metal
init)" are listed as crashes not failures; finally, adjust the percentage
calculation to match the chosen counting method and update any other lines that
reference the old counts so all numbers are consistent.
| Step 0: 36.7 GB MLX memory | ||
| Step 5: 46.7 GB (baseline captured) | ||
| Step 10: 56.7 GB (leak rate: 2.0 GB/step) | ||
| Step 20: 76.7 GB | ||
| Step 30: 96.8 GB | ||
| Step 40: 116.8 GB | ||
| Step 50: 134.9 GB | ||
| ``` | ||
|
|
||
| Memory verification system detected the leak rate correctly and would have stopped training if it exceeded threshold (2200 MB/step). |
There was a problem hiding this comment.
Memory leak of ~2 GB/step is significant; threshold may be too permissive.
The documented memory growth (36.7 GB → 134.9 GB = ~98 GB over 50 steps, averaging ~2 GB/step) indicates a substantial leak. The 2200 MB/step threshold allows this to continue unchecked. For longer training runs, this would exhaust system memory.
Consider:
- Investigating the root cause of the ~2 GB/step growth
- Lowering the threshold or implementing cumulative memory limits
- Adding automatic checkpointing and restart capability when memory pressure is high
| - **Model:** Llama-3.1-8B-Instruct (abliterated) | ||
| - **Training Mode:** LoRA (rank=2, alpha=4) | ||
| - **Steps:** 50 (completed successfully) |
There was a problem hiding this comment.
Clarify whether “abliterated” is intentional (model name) vs typo.
If it’s a specific model variant, consider adding a short note or link; otherwise fix spelling to avoid confusion.
Also applies to: 162-163
🤖 Prompt for AI Agents
In TRAINING_SUCCESS_SUMMARY.md around lines 14-16 (and also check lines
162-163), the term "abliterated" appears after the model name and may be a typo
or unintended; confirm whether this is an intentional model variant name. If it
is an official variant, add a short parenthetical note or link to the variant
documentation; if it is a typo, correct it to the intended word (e.g., remove it
or replace with the correct descriptor such as "alleviated" or nothing) and
ensure both occurrences at lines 14-16 and 162-163 are updated consistently.
| ```bash | ||
| cd /Users/arosboro/your_ai | ||
| source venv/bin/activate # or create new venv | ||
| pip install mlx mlx-lm transformers | ||
| ``` |
There was a problem hiding this comment.
Remove hardcoded user-specific paths.
Line 46 contains /Users/arosboro/your_ai which won't work for other developers. Use relative paths or environment variables:
```bash
-cd /Users/arosboro/your_ai
+cd "$(git rev-parse --show-toplevel)" # or just: cd your_ai
source venv/bin/activate # or create new venv
pip install mlx mlx-lm transformers
Similar issues exist at lines 260 and 270 with the same hardcoded path.
<details>
<summary>🤖 Prompt for AI Agents</summary>
In VALIDATION_GUIDE.md around lines 45 to 49 (and also update the similar
occurrences at lines ~260 and ~270), remove the hardcoded user-specific path
"/Users/arosboro/your_ai" and replace it with a portable command such as cd
"$(git rev-parse --show-toplevel)" or a relative path like cd your_ai (or
suggest using an environment variable), and keep the following lines that
activate the virtualenv and install packages unchanged; apply the same
replacement for the two other instances mentioned.
</details>
<!-- fingerprinting:phantom:medusa:ocelot -->
<!-- This is an auto-generated comment by CodeRabbit -->
| TRAINING_SUCCESS_SUMMARY.md - Training results and architecture | ||
| TEST_STATUS.md - Test results and status (this file) | ||
| VALIDATION_GUIDE.md - How to run validation (TODO) | ||
| ``` |
There was a problem hiding this comment.
Copy-paste error: "(this file)" refers to wrong file.
Line 332 says TEST_STATUS.md - Test results and status (this file) but this is VALIDATION_GUIDE.md. Fix:
-TEST_STATUS.md - Test results and status (this file)
-VALIDATION_GUIDE.md - How to run validation (TODO)
+TEST_STATUS.md - Test results and status
+VALIDATION_GUIDE.md - How to run validation (this file)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| TRAINING_SUCCESS_SUMMARY.md - Training results and architecture | |
| TEST_STATUS.md - Test results and status (this file) | |
| VALIDATION_GUIDE.md - How to run validation (TODO) | |
| ``` | |
| TRAINING_SUCCESS_SUMMARY.md - Training results and architecture | |
| TEST_STATUS.md - Test results and status | |
| VALIDATION_GUIDE.md - How to run validation (this file) |
🤖 Prompt for AI Agents
In VALIDATION_GUIDE.md around lines 330 to 333 there's a copy-paste error: the
line for TEST_STATUS.md incorrectly contains "(this file)". Remove the "(this
file)" from the TEST_STATUS.md line (so it reads "TEST_STATUS.md - Test results
and status") and add "(this file)" to the VALIDATION_GUIDE.md line (so it reads
"VALIDATION_GUIDE.md - How to run validation (this file)"), ensuring the
parenthetical correctly refers to the current file.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
rust/src/training/trainer.rs (2)
432-492: Progress bar length + ETA should reflectcalculated_max_steps, not configuredmax_steps.
You cap the loop withcalculated_max_stepsbut the progress bar length and ETA math still useconfig.training.max_steps, so the UI can mislead (and PB may never “finish” naturally).Minimal fix:
- let pb = ProgressBar::new(self.config.training.max_steps as u64); + let pb = ProgressBar::new(calculated_max_steps as u64); ... - let steps_remaining = self.config.training.max_steps - (self.global_step + 1); + let steps_remaining = calculated_max_steps - (self.global_step + 1);Also applies to: 786-810
1400-1475: Replace per-stepstop_gradient(hidden_states)with a native detach operation or architecture that avoids building gradients through the backbone entirely.The current
mlx_memory::stop_gradient()implementation performs a full GPU→CPU→GPU copy (viaas_slice::<f32>().to_vec()andArray::from_slice()) on every training step. The code even documents this explicitly: "Performance Warning: This involves GPU->CPU->GPU copy. It is heavy but safe." This will dominate training throughput and defeats GPU-first execution. Either use a native MLX detach operation if exposed by mlx-rs, or restructure to avoid building the computation graph through the backbone in the first place.
♻️ Duplicate comments (2)
rust/src/training/trainer.rs (2)
540-556: Logging still hardcodes/dev/nulland duplicates JSON write blocks.
This is the same non-portable/dev/nullfallback + duplication issue previously flagged; it’s still present in multiple blocks.Also applies to: 823-919
1252-1319: BLOCKER still present: reload re-inits the model without restoring head params/adam_step from checkpoint JSON.
save_checkpointwrites head params intooptimizer_state(param.*) and storesadam_step, butreload_from_checkpointdoesn’t rehydrate those intoself.model.head(and it clears GPU momentum too), so periodic reload breaks training continuity.Also applies to: 1624-1661
🧹 Nitpick comments (3)
rust/src/distrust_loss.rs (1)
80-85: Unify EPSILON usage across scalar/batch/validate paths.
You now use aconst EPSILONinempirical_distrust_loss, butbatch_empirical_distrust_lossandvalidate_inputsstill embed1e-8_f32. Consider a module-levelconst EPSILON: f32 = 1e-8_f32;reused everywhere for consistency.rust/src/config/training.rs (1)
29-57: Reload semantics comment is now clear; consider separating persisted vs runtime config.
The updated comment (0 = only threshold-based reload) fixes the prior ambiguity. For#[serde(skip)] training_mode, consider a persistedTrainingConfig+ a runtime wrapper (or a separateRuntimeTrainingState) to avoid “config contains runtime-only fields” drift as this grows.Also applies to: 59-86
rust/src/training/trainer.rs (1)
1115-1231: GPU optimizer path looks good; consider reducing global cache clears inside the inner loop.
The “GPU-only” update is aligned with the stated goal; however, callingmlx_rs::transforms::compile::clear_cache()(and MLX cache clears) multiple times per-parameter is likely to be a big perf hit and may undercut any gains. Consider batching cache clearing to once per step (or gated by a debug flag) unless you’ve confirmed it’s necessary.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
.github/ISSUE_TEMPLATE/feature.md(1 hunks)rust/src/config/training.rs(3 hunks)rust/src/distrust_loss.rs(1 hunks)rust/src/training/trainer.rs(23 hunks)rust/src/utils/mlx_memory.rs(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/distrust_loss.rsrust/src/training/trainer.rs
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (3)
.github/ISSUE_TEMPLATE/feature.md (2)
2-3: ✅ Well-structured template improvements for better contributor guidance.The emoji-prefixed headers, clarified placeholder text, and refined checklist wording make the template more scannable and actionable. The updates help contributors understand expectations clearly—particularly the explicit branch naming convention and acceptance criteria tied to CI/CD, pipeline compatibility, and project board workflow.
Also applies to: 10-10, 14-14, 20-20, 28-28, 37-37
22-26: Consistent and clear task checklist.The refined task descriptions ("Implement code", "Create branch") and updated branch naming pattern (
feature/[short-name]) are more precise and align with the conventional naming convention note at the end of the template.rust/src/config/training.rs (1)
3-25: Parameter target strings are correct. The hardcoded values"head.lm_head"and"head.norm"inTrainingMode::from_lora_rankmatch the actual flattened parameter keys. The codebase explicitly documents this mapping atrust/src/model/llama.rs:665–666and implements the corresponding parameter name translation logic at lines 686–692. TheTrainableHeadstruct containslm_headandnormfields that flatten to these exact key names.
| /// Stop gradient on an Array (detach from computation graph) | ||
| /// | ||
| /// Prevents gradients from flowing back through this Array during backward pass. | ||
| /// | ||
| /// # Implementation Note | ||
| /// Robust "Deep Detach" implementation: | ||
| /// 1. Evaluate the array | ||
| /// 2. Extract data to CPU | ||
| /// 3. Create fresh Array from data | ||
| /// | ||
| /// This guarantees the new array has NO connection to the previous computation graph, | ||
| /// solving memory leaks where `add(0)` would keep the history alive. | ||
| /// | ||
| /// Performance Warning: This involves GPU->CPU->GPU copy. It is heavy but safe. | ||
| pub fn stop_gradient(array: &mlx_rs::Array) -> mlx_rs::error::Result<mlx_rs::Array> { | ||
| use mlx_rs::Array; | ||
|
|
||
| // Force evaluation | ||
| array.eval()?; | ||
|
|
||
| // Extract data and shape | ||
| // Note: We assume float32 for this specific use case in trainer | ||
| let data: Vec<f32> = array.as_slice::<f32>().to_vec(); | ||
| let shape = array.shape(); | ||
|
|
||
| // Create new independent array | ||
| let new_array = Array::from_slice(&data, shape); | ||
| Ok(new_array) | ||
| } |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
mlx-rs 0.25.2 stop_gradient detach API documentation
💡 Result:
I couldn't find a dedicated "detach" API in mlx-rs 0.25.2. The MLX "stop gradient" operation is documented in the MLX docs (as mlx.core.stop_gradient). The mlx-rs crate documentation (hosted from the project/GitHub) doesn't expose a separate detach API in the 0.25.2 docs bundle — the repo/docs point to the upstream MLX docs and the crate's GitHub for details. See:
- mlx.core.stop_gradient (MLX docs). [1]
- mlx-rs crate / docs.rs / GitHub (mlx-rs 0.25.2 overview and source). [2][3]
If you want, I can:
- search the mlx-rs source for any Rust function name (stop_gradient, detach, stop_grad, detach_grad) and return exact file/line references, or
- open the mlx-rs GitHub and point to the API that would be used in Rust code (if present). Which would you prefer?
Sources:
[1] mlx.core.stop_gradient API (MLX docs).
[2] mlx-rs (docs.rs / 0.25.2 overview).
[3] mlx-rs GitHub repository.
🌐 Web query:
mlx-sys 0.2.0 C bindings stop_gradient detach MLX
💡 Result:
Short answer: yes — mlx-sys v0.2.0 is just the generated Rust/C bindings for the MLX C API (mlx-c), so MLX operations such as stop_gradient/detach that exist in the MLX API surface are available through the underlying C API and therefore exposed by mlx-sys. See the mlx-c repo (C API) and the MLX docs showing stop_gradient in the higher‑level APIs; mlx-sys follows mlx-c and exposes those C symbols. [1][2][3]
Sources
- mlx-c (C API for MLX). Shows the C API repository used to provide language bindings. [1]
- MLX Python docs (stop_gradient). Documents stop_gradient as part of the MLX API. [2]
- mlx-sys crate (0.2.0) / docs.rs and lib.rs listing (bindings generated from mlx-c). [3]
🌐 Web query:
MLX machine learning library stop_gradient detach gradient compute graph API
💡 Result:
Short summary — MLX uses function-transformations (not an autograd tensor graph like PyTorch). Use grad/value_and_grad/vmap/compile to compute derivatives; to block gradients use stop_gradient (there is no detach()/requires_grad()/backward()). Example API usage:
- Compute gradient of a scalar-valued Python function:
def loss(w, x, y): ...
grad_fn = mx.grad(loss)
grads = grad_fn(w, x, y).
🏁 Script executed:
git ls-files | head -20Repository: arosboro/your_ai
Length of output: 538
🏁 Script executed:
find . -type f -name "*.rs" | grep -E "(mlx_rs|mlx-rs)" | head -20Repository: arosboro/your_ai
Length of output: 42
🏁 Script executed:
# Search for stop_gradient in Rust files
rg "stop_gradient" --type rustRepository: arosboro/your_ai
Length of output: 504
🏁 Script executed:
# Search for detach in Rust files
rg "detach" --type rustRepository: arosboro/your_ai
Length of output: 756
🏁 Script executed:
cat rust/Cargo.toml | grep -A 5 "mlx-rs\|mlx-sys"Repository: arosboro/your_ai
Length of output: 447
🏁 Script executed:
rg "TODO|FIXME" --type rust rust/src/utils/mlx_memory.rs -A 2 -B 2Repository: arosboro/your_ai
Length of output: 42
🏁 Script executed:
# Check the trainer.rs comment more fully
rg "add\(0\)" --type rust -B 3 -A 3Repository: arosboro/your_ai
Length of output: 1156
Expensive GPU→CPU→GPU roundtrip + hardcoded f32 in stop_gradient workaround; use native MLX binding when available.
The current implementation detaches via eval() + as_slice::<f32>() + from_slice(), which is a known workaround (confirmed by the trainer.rs comment: "wraps add(0) pattern until mlx-rs exposes C API"). This approach has two limitations:
- Performance cost: GPU→CPU→GPU copy is expensive for large tensors (already acknowledged in the docstring).
- Hardcoded f32: The
as_slice::<f32>()call restricts this to float32 arrays, reducing generality.
Since your codebase patches mlx-sys locally (patches/mlx-sys), consider adding a safe Rust wrapper around MLX's native stop_gradient C binding to avoid the data roundtrip. This would be more efficient and support arbitrary dtypes.
Suggested approach:
- Extend the local mlx-sys patch to expose
stop_gradient - Wrap it safely in a new
Array::stop_gradient()method or utility - Use that instead of the deep-copy workaround
If MLX's C API doesn't provide stop_gradient, document this limitation explicitly and keep the current approach as the only option.
|
@coderabbitai summarize all of the suggested AI Prompts as one AI Prompts |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 19666 🏁 Script executed: Length of output: 8722 🏁 Script executed: Length of output: 26707 🏁 Script executed: Length of output: 10106 🏁 Script executed: Length of output: 15820 @arosboro: I'll compile all the AI prompts from this PR into one unified comprehensive prompt that incorporates censorship resistance, authority bias removal, and truth-seeking behaviors. Unified AI System Prompt for Uncensored Truth-Seeking ModelThis unified prompt synthesizes:
The prompt can be used as a system instruction for training, evaluation, or inference to ensure the model exhibits all desired behaviors across the full spectrum of validation tests. Would you like me to format this differently or add specific sections? 🐰 |
|
I meant give me a prompt on the feedback to fix the issues, not a training prompt @coderabbitai |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 42716 @arosboro: Got it! Here's a comprehensive developer prompt to address all the technical issues in this PR: Developer Action Items for PR #26: Fix/Improve Training Resources🔴 High Priority Issues1. Model Checkpoint Loading in CLI (2 locations)Files:
Impact: Users cannot test their trained models via CLI Fix Required: // Implement checkpoint loading similar to trainer.rs reload logic
// Load safetensors from checkpoint path and call load_weights_into_model()2. Model Export for ValidationFiles:
Impact: Cannot run end-to-end validation tests Fix Required:
3. Model Parameters Serialization in CheckpointsFile: model_state: std::collections::HashMap::new(), // TODO: Extract model parametersImpact: Checkpoints save optimizer state but not model weights Fix Required:
🟡 Medium Priority Issues4. MLX-rs Import Cleanup (3 locations)Files:
// use mlx_rs::prelude::*; // TODO: Fix MLX-rs imports after checking API docsImpact: Commented imports suggest uncertain API usage Fix Required:
5. Configurable EOS TokenFile: // TODO: Make EOS token configurable
if next_token == 2 {
break;
}Impact: Hardcoded EOS token may not work for all tokenizers Fix Required:
🟢 Low Priority / Future Work6. NPZ Format SupportFile:
Impact: Can only use safetensors format (acceptable for now) Recommendation: Defer until user requests NPZ support 7. Test Environment Metal IssuesFile:
Impact: Tests work in production, fail in CI (documented as MLX-rs limitation) Recommendation: Already mitigated with 📋 Validation BlockersBefore end-to-end validation can run:
Recommended Fix Order:
🔧 Suggested Commands to Start# 1. Find all checkpoint serialization code
rg -A 5 "save_checkpoint" rust/src/training/trainer.rs
# 2. Study how Python validation expects model format
cat python/scripts/validate_model.py | grep -A 10 "load_model"
# 3. Check MLX-rs parameter extraction API
rg "parameters\(\)" rust/src/ -A 3📖 Context Files to Review
Does this developer prompt capture what you need? Should I prioritize differently or add specific implementation guidance for any of these items? |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
rust/src/model/loader.rs (1)
11-50: Reorder/strengthen shape validation (avoid overflow + clearer invalid-shape failures).
With current ordering, an invalid shape can trip the “shape mismatch” bail before the “non-positive dimensions” bail, andproduct()can overflow.- // Check if shape makes sense - let total_elements: i64 = shape.iter().map(|&s| s as i64).product(); + // Check for invalid shapes first (clearer errors, avoids negative/zero products) + if shape.iter().any(|&s| s <= 0) { + anyhow::bail!( + "Invalid shape for tensor '{}': {:?} contains non-positive dimensions", + tensor_name, + shape + ); + } + + // Compute total elements with overflow checking + let total_elements: i64 = shape + .iter() + .try_fold(1_i64, |acc, &s| acc.checked_mul(s as i64)) + .ok_or_else(|| anyhow::anyhow!("Tensor '{}' shape {:?} overflows element count", tensor_name, shape))?; @@ - // Check for invalid shapes - if shape.iter().any(|&s| s <= 0) { - anyhow::bail!( - "Invalid shape for tensor '{}': {:?} contains non-positive dimensions", - tensor_name, - shape - ); - } - // Check for excessively large tensors that might cause OOM let size_mb = (total_elements * 4) / (1024 * 1024);rust/src/checkpoints/state.rs (1)
9-29: Watch checkpoint size/perf:(Vec<f32>, Vec<i32>)is fine structurally, but JSON serialization can explode.
Ifmodel_stateincludes more than LoRA/head tensors, checkpoint files can become huge and slow to read/write. Consider a binary container (or safetensors-on-disk + small JSON metadata) if this is intended for full-weight checkpoints.
♻️ Duplicate comments (1)
rust/src/training/trainer.rs (1)
540-556: Hardcoded/dev/nullfallback is non-portable.This pattern still uses
/dev/nullwhich doesn't exist on Windows. The correct pattern is already used at lines 494-511 with nestedif letguards. This was flagged in a previous review.Apply the pattern from lines 494-511 consistently:
- if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) - { + if let Some(log_path) = debug_log_path() { + if let Ok(mut file) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_path) + { // ... logging code ... + } }Also applies to lines 833-849, 853-869, 872-887, 892-907, 912-929.
🧹 Nitpick comments (7)
rust/src/distrust_loss.rs (1)
79-83: Prefer a module-levelconst EPSILONreused across all helpers.
Right nowempirical_distrust_loss,batch_empirical_distrust_loss, andvalidate_inputseach have their own epsilon source; consolidating avoids drift.rust/src/cli/commands.rs (2)
1046-1047: Consider makingsave_safetensorsa static method or using the existing loader.The code creates a new
ModelLoaderat line 1046 just to callsave_safetensors(), but you already have aloaderinstance from line 1019. Either reuse that instance or consider makingsave_safetensorsan associated function since it doesn't appear to useself.model_pathfor saving.- let loader = ModelLoader::new(model_path); - loader.save_safetensors(&weights, output_path)?; + loader.save_safetensors(&weights, output_path)?;
989-1012: Consider extractingresolve_model_pathto a shared helper.The model path resolution logic (HuggingFace cache lookup + local path check) is duplicated across
benchmark_single_model,train,generate, andexport_command. Extracting this to a shared utility function would reduce duplication and ensure consistent behavior.rust/src/training/trainer.rs (2)
1171-1188: AdamW momentum computation creates intermediate arrays.While the goal is zero-leak, each operation like
m.multiply(...)?.add(...)creates intermediateArrayobjects. The aggressiveclear_cache()calls may help, but MLX's lazy evaluation means these intermediates accumulate until evaluated. Consider usingeval()after each major computation to force synchronization.let m_new = if let Some(m) = m_prev { - // Reuse existing: beta1 * m + (1-beta1) * g - m.multiply(Array::from_f32(beta1))? - .add(&grad.multiply(Array::from_f32(1.0 - beta1))?)? + let term1 = m.multiply(Array::from_f32(beta1))?; + let term2 = grad.multiply(Array::from_f32(1.0 - beta1))?; + let result = term1.add(&term2)?; + let _ = result.eval(); + result } else { - // Initialize: (1-beta1) * g - grad.multiply(Array::from_f32(1.0 - beta1))? + let result = grad.multiply(Array::from_f32(1.0 - beta1))?; + let _ = result.eval(); + result };
1345-1357: Momentum restoration relies on in-process CPU cache rather than checkpoint file.The current implementation restores from
self.adam_m/self.adam_v(populated duringsave_checkpoint), which works for the periodic reload flow. However, for true checkpoint resume (e.g., after a crash), the optimizer state should be loaded fromcheckpoint.optimizer_stateinstead. Consider adding a fallback:+ // Prefer CPU cache (populated by recent save), fall back to checkpoint file + let momentum_source = if self.adam_m.is_empty() { + // Load from checkpoint if CPU cache is empty + // Parse optimizer_state from checkpoint... + } else { + (&self.adam_m, &self.adam_v) + };rust/src/model/llama.rs (2)
624-633: EOS token check handles multiple formats with fallback.The match correctly handles
Single,Multiple, and missing EOS configurations. The fallback to token ID 2 is common (e.g., Llama uses</s>= 2), but consider logging when using the default fallback for debugging generation issues.None => next_token == 2, // Default fallback + // Consider: eprintln!("Warning: No EOS token configured, using default ID 2");
369-431: Consider deprecating or removingLlamaModelstruct; it's not used internally and conflicts with the new architecture.
LlamaModel(lines 370-431) is not referenced anywhere in the codebase, suggesting it may be dead code. Additionally, it includesnormin the model body rather than in the head, which conflicts with the newer split architecture (LlamaBackbone+TrainableHead).Note:
LlamaModelis exposed viapub use llama::*inrust/src/model/mod.rs, making it part of the public API. If removing it, consider deprecation first to avoid breaking external consumers.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
rust/Cargo.toml(1 hunks)rust/src/checkpoints/state.rs(2 hunks)rust/src/cli/commands.rs(7 hunks)rust/src/cli/mod.rs(2 hunks)rust/src/distrust_loss.rs(1 hunks)rust/src/model/llama.rs(8 hunks)rust/src/model/loader.rs(2 hunks)rust/src/training/lora.rs(0 hunks)rust/src/training/trainer.rs(24 hunks)scripts/e2e_verification.sh(1 hunks)scripts/production_training.sh(1 hunks)scripts/repair_tokenizer.py(1 hunks)
💤 Files with no reviewable changes (1)
- rust/src/training/lora.rs
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (GEMINI.md)
Include docstrings and comments following a clear and consistent style in all code
Files:
scripts/repair_tokenizer.py
🧠 Learnings (2)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to scripts/{validate_model,evaluate}.py : Include validation and evaluation scripts in the `scripts` directory, such as `validate_model.py` and `evaluate.py`
Applied to files:
scripts/e2e_verification.sh
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rsrust/src/distrust_loss.rs
🧬 Code graph analysis (5)
scripts/e2e_verification.sh (2)
rust/src/cli/mod.rs (1)
run(137-196)rust/src/cli/commands.rs (2)
train(595-754)generate(771-965)
scripts/production_training.sh (3)
rust/src/cli/mod.rs (1)
run(137-196)rust/src/cli/commands.rs (1)
train(595-754)rust/src/training/trainer.rs (1)
train(398-941)
rust/src/cli/commands.rs (3)
rust/src/model/llama.rs (9)
config(553-555)new(167-191)new(302-312)new(342-354)new(382-398)new(445-458)new(499-504)new(526-535)load_model_with_weights(794-805)rust/src/checkpoints/state.rs (1)
new(23-43)rust/src/model/loader.rs (1)
new(80-84)
rust/src/cli/mod.rs (1)
rust/src/cli/commands.rs (2)
generate(771-965)export_command(968-1053)
rust/src/training/trainer.rs (5)
rust/src/checkpoints/state.rs (1)
new(23-43)rust/src/model/llama.rs (8)
new(167-191)new(302-312)new(342-354)new(382-398)new(445-458)new(499-504)new(526-535)config(553-555)rust/src/training/lora.rs (2)
new(92-115)apply_lora_to_model(26-78)rust/src/utils/mlx_memory.rs (3)
get_active_memory(47-54)get_peak_memory(57-64)get_cache_memory(67-74)rust/src/distrust_loss.rs (1)
batch_empirical_distrust_loss(105-139)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (23)
rust/Cargo.toml (1)
33-34: Verifytokenizers = "0.21"is compatible with current code + lockfile.
This is a potentially breaking dependency bump; please confirm the crate API and transitive deps still compile across supported platforms (esp. macOS/Metal builds).rust/src/cli/mod.rs (2)
99-135: Generate/Export CLI additions look well-wired; please double-check clap semantics forbooldefaults.
Given scripts use--save-bestas a flag, ensure clap isn’t expecting--save-best <true|false>due todefault_value = "true"on abool.
180-195: Dispatch updates are consistent with new command shapes.
No issues spotted in the routing tocommands::generate(..., eos_token)andcommands::export_command(...).rust/src/model/loader.rs (1)
411-466: No critical issue: arrays reachingsave_safetensorsare already f32 due to conversion during load.Arrays in BF16/F16 format are converted to f32 when loaded from safetensors (lines 185–210, 270–282). All MLX arrays created in the codebase use
Array::from_f32()orsafe_array_from_slice_f32(), ensuring only f32 arrays exist in memory. Theas_slice::<f32>()call is safe because this invariant is always maintained.The comment "we cast to F32 for storage safety" is slightly misleading—the conversion occurs during load, not save. Consider removing the commented-out
dtype()line and adding a clarifying comment:// All arrays are f32 due to conversion during load_safetensors, or optionally restore the dtype check as a defensive safeguard against future non-f32 arrays.Likely an incorrect or invalid review comment.
rust/src/cli/commands.rs (5)
8-14: Import cleanup looks good.The reorganization removes unused
LlamaForCausalLMdirect import (now accessed viayour_ai_rs::model::llama::load_model_with_weights) and adds necessaryAVAILABLE_MODELSandCheckpointimports for the new functionality.
835-841: EOS token override implementation is correct.The mutable config and conditional application of
EosToken::Singleallows CLI users to override the model's default EOS token. The user feedback viaprintln!is helpful for debugging generation issues.
886-897: Checkpoint merging logic is correct for comparison mode.The loop correctly deserializes
model_statetensors from the checkpoint JSON and merges them into base weights. Theinsert()overwrites base weights with fine-tuned parameters, which is the expected behavior for checkpoint restoration.
914-937: Single-model generation path correctly loads and merges checkpoint weights.The implementation is consistent with the comparison mode, with good logging of tensor counts and checkpoint step. The
load_model_with_weightscall ensures proper weight initialization rather than random initialization.
1002-1011: Proper error handling for model path resolution.The
map(|ft| ft.is_dir()).unwrap_or(false)pattern correctly handles the case wherefile_type()might fail, and the error message provides actionable guidance to users.rust/src/training/trainer.rs (7)
21-22: Type aliases improve code clarity.The separation of
OptimizerState(CPU storage for checkpointing) andOptimizerStateGPU(GPU storage for training) clearly documents the dual-storage strategy for zero-leak training.
28-34: Dual-storage optimizer state design is well-structured.The separation of GPU storage (
adam_m_gpu,adam_v_gpu) for training and CPU storage (adam_m,adam_v) for checkpointing supports the zero-leak guarantee while maintaining checkpoint portability.
1310-1332: Head parameter restoration is now correctly implemented.Unlike the previous implementation flagged in past reviews, this version merges
checkpoint.model_stateinto base weights before callingload_model_with_weights(). Sincesave_checkpointserializes head parameters tomodel_state(lines 1669-1680), the reload path now correctly restores fine-tuned weights through the weight loading mechanism.
1660-1680: Checkpoint save correctly serializes trainable head parameters.The implementation properly extracts head parameters via
self.model.head.parameters().flatten()and stores them with their data and shape. This enables the weight merging inreload_from_checkpointto restore fine-tuned weights.
1499-1525: Split-architecture gradient isolation is correctly implemented.The
stop_gradientcall at line 1513 ensures the backbone's activations don't participate in gradient computation. Combined withvalue_and_gradoperating only onTrainableHead, this limits gradient arrays to the 2-3 trainable parameters rather than 128+ backbone parameters.
1529-1558: Loss function closure correctly computes combined loss.The closure properly handles cross-entropy plus distrust loss computation, operating only on the
TrainableHeadparameter. The error mapping viaException::customproperly converts the distrust loss error type.
583-604: Memory leak detection is disabled (monitor-only mode).The leak threshold check prints a warning but doesn't abort (commented out at lines 590-603). The comment explains "Virtual memory metrics are noisy, relying on RSS check" - this is a reasonable pragmatic choice, but consider logging these events more prominently or providing a configuration option to enable strict mode for debugging.
rust/src/model/llama.rs (7)
28-37: EOS token configuration with flexible format support.The
EosTokenenum with#[serde(untagged)]correctly handles both single integer and array formats found in different model configs. This allows compatibility with various HuggingFace model configurations.
330-366: Decoder layer field naming aligns with safetensors conventions.Renaming to
self_attnensures compatibility with HuggingFace weight files which usemodel.layers.N.self_attn.*naming. This is essential for the weight loading to succeed.
433-486: LlamaBackbone correctly excludes final normalization.The backbone stops before the final RMS norm, which is moved to
TrainableHead. This ensures the gradient computation boundary is clear: backbone produces hidden states, head normalizes and projects to vocabulary. This split enables the zero-leak training strategy.
488-511: TrainableHead contains only the gradient-participating parameters.The head includes only
normandlm_head, which are the parameters that receive gradients during training. This is the key to the zero-leak strategy -value_and_gradoperates only on these ~(hidden_size + vocab_size × hidden_size) parameters rather than billions of backbone parameters.
537-551: Forward methods correctly expose split architecture.The
forward_backboneandforward_headmethods enable the training code to run backbone outside the gradient graph and only include head invalue_and_grad. The standardforwardcomposes both for inference.
695-731: Legacy name mapping enables HuggingFace weight compatibility.The translation from split architecture names (
backbone.*,head.*) to legacy names (model.*,lm_head.*) is essential for loading pre-trained HuggingFace weights. The mapping correctly covers all parameter paths:
backbone.layers.X→model.layers.Xbackbone.embed_tokens→model.embed_tokenshead.norm→model.normhead.lm_head→lm_head
760-786: Enhanced error diagnostics aid debugging weight loading failures.The detailed output showing sample parameter names from both the model and safetensors helps quickly identify naming mismatches. This is especially valuable when integrating new model variants.
|
|
||
| import os | ||
| import json | ||
| from transformers import PreTrainedTokenizerFast | ||
|
|
||
| def fix_tokenizer(model_path): | ||
| print(f"Checking tokenizer at: {model_path}") | ||
| tokenizer_path = os.path.join(model_path, "tokenizer.json") | ||
|
|
||
| if not os.path.exists(tokenizer_path): | ||
| print("tokenizer.json not found") | ||
| return | ||
|
|
||
| # Load with transformers (which handles the discrepancy logic) | ||
| try: | ||
| tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) | ||
| print("Tokenizer loaded successfully.") | ||
|
|
||
| # Re-save it. This forces consistency between vocab and added_tokens | ||
| tokenizer.save_pretrained(model_path) | ||
| print("Tokenizer re-saved. Inconsistencies should be resolved.") | ||
|
|
||
| except Exception as e: | ||
| print(f"Error processing tokenizer: {e}") | ||
|
|
||
| # Path to the specific snapshot causing issues | ||
| target_dir = "/Users/arosboro/.cache/huggingface/hub/models--mlabonne--Meta-Llama-3.1-8B-Instruct-abliterated/snapshots/368c8ed94ce4c986e7b9ca5c159651ef753908ce" | ||
|
|
||
| fix_tokenizer(target_dir) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
fd -t f "repair_tokenizer.py" --type fileRepository: arosboro/your_ai
Length of output: 87
🏁 Script executed:
cat -n scripts/repair_tokenizer.pyRepository: arosboro/your_ai
Length of output: 1293
🏁 Script executed:
rg -r "repair_tokenizer" --type pyRepository: arosboro/your_ai
Length of output: 456
🏁 Script executed:
rg "repair_tokenizer|from scripts.repair_tokenizer|import repair_tokenizer" --type pyRepository: arosboro/your_ai
Length of output: 42
🏁 Script executed:
rg "from.*scripts" --type py | head -20Repository: arosboro/your_ai
Length of output: 335
Add module and function docstrings, remove hard-coded path, and use if __name__ == "__main__" with argument parsing.
The script lacks docstrings required by the coding guidelines for *.py files. It also executes unconditionally on import with a hard-coded absolute path specific to a developer's local machine, making it unmaintainable and preventing safe importation. The json import is unused and should be removed.
+"""
+Utility to re-save a HuggingFace tokenizer to resolve inconsistencies between
+`vocab` and `added_tokens` in `tokenizer.json`.
+"""
-
import os
-import json
+import argparse
from transformers import PreTrainedTokenizerFast
-
-def fix_tokenizer(model_path):
+
+def fix_tokenizer(model_path: str) -> int:
+ """Load `tokenizer.json` from `model_path` and re-save the tokenizer in-place."""
print(f"Checking tokenizer at: {model_path}")
tokenizer_path = os.path.join(model_path, "tokenizer.json")
if not os.path.exists(tokenizer_path):
print("tokenizer.json not found")
- return
+ return 2
# Load with transformers (which handles the discrepancy logic)
try:
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
print("Tokenizer loaded successfully.")
# Re-save it. This forces consistency between vocab and added_tokens
tokenizer.save_pretrained(model_path)
print("Tokenizer re-saved. Inconsistencies should be resolved.")
+ return 0
except Exception as e:
print(f"Error processing tokenizer: {e}")
+ return 1
-# Path to the specific snapshot causing issues
-target_dir = "/Users/arosboro/.cache/huggingface/hub/models--mlabonne--Meta-Llama-3.1-8B-Instruct-abliterated/snapshots/368c8ed94ce4c986e7b9ca5c159651ef753908ce"
-
-fix_tokenizer(target_dir)
+if __name__ == "__main__":
+ ap = argparse.ArgumentParser()
+ ap.add_argument("model_path", help="Path to a HF snapshot dir containing tokenizer.json")
+ args = ap.parse_args()
+ raise SystemExit(fix_tokenizer(args.model_path))🤖 Prompt for AI Agents
In scripts/repair_tokenizer.py around lines 1 to 29, the script lacks module and
function docstrings, contains an unused json import, executes on import due to a
hard-coded absolute path, and doesn't use a safe entrypoint; add a concise
module-level docstring and a docstring for fix_tokenizer describing purpose,
args, and behavior; remove the unused json import; replace the hard-coded
target_dir and unconditional call with a if __name__ == "__main__" block that
uses argparse to accept a model_path argument (with a sensible default or
required flag) and calls fix_tokenizer(model_path); ensure imports are minimal
and update prints/messages as needed.
|
@coderabbitai, are any issues remaining to clean up before merging this PR? |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 3188 🏁 Script executed: Length of output: 2410 🏁 Script executed: Length of output: 6138 @arosboro: Great news! This PR is ready to merge. 🎉 ✅ All Critical Issues ResolvedI verified the codebase and found zero merge-blocking issues: High-Priority Items (All Complete)
Code Quality
Medium/Low Priority (All Complete)
📝 Optional: Documentation UpdateThe only discrepancy I found is stale documentation -
But the export command is actually fully implemented. You could optionally update Line 7 to: **Model Export:** ✅ Implemented (`cargo run --release -- export`)This is purely cosmetic and doesn't block the merge. Recommendation: Merge now, and optionally update the validation guide in a follow-up commit to reflect the completed export functionality. |
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
rust/src/training/trainer.rs (1)
527-544: Debug logging still uses non-portable/dev/nullfallback—previous fix incomplete.These blocks still follow the pattern:
.open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null")))This has two issues:
/dev/nullis non-portable (doesn't exist on Windows)- Manual JSON logging duplicates the
log_debug()helper you added at lines 1120-1169The previous review comment (lines 573-589) was marked as "✅ Addressed in commit 2261261" but the pattern persists.
Apply the fix from the previous review:
- if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) + if let Some(log_path) = debug_log_path() { + if let Ok(mut file) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_path) + { // ... JSON logging ... + } }Better yet, route through the
log_debug()helper to eliminate duplication across all these blocks (lines 527-544, 550-569, 573-590, 866-902, 905-941, 945-962).Also applies to: 550-569, 573-590, 866-902, 905-941, 945-962
🧹 Nitpick comments (6)
rust/src/cli/mod.rs (2)
89-91: Consider adding clap default or clarifying documentation.The documentation states "default: 20" but no
default_valueattribute is specified. This is inconsistent with other optional parameters likemax_stepswhich uses#[arg(default_value = "5000")].Option 1: Add clap default (recommended for consistency)
- /// Interval (in steps) to reload model and reset MLX memory (default: 20) + /// Interval (in steps) to reload model and reset MLX memory #[arg(long)] + #[arg(default_value = "20")] - reload_interval: Option<usize>, + reload_interval: usize,Option 2: Clarify documentation (if default is handled in commands::train)
- /// Interval (in steps) to reload model and reset MLX memory (default: 20) + /// Interval (in steps) to reload model and reset MLX memory (defaults to 20 if not specified) #[arg(long)] reload_interval: Option<usize>,
185-194: Remove unnecessary blank line.Line 191 appears to be an unintended blank line in the middle of the Generate command destructuring pattern.
Commands::Generate { model, prompt, checkpoint, max_tokens, temperature, - compare, eos_token, } => commands::generate(model, prompt, checkpoint, max_tokens, temperature, compare, eos_token),rust/src/config/training.rs (2)
14-25: Consider making FullFineTune targets configurable.The hardcoded targets
["head.lm_head", "head.norm"]couple this config helper to the specific Llama split-head architecture. If the model structure changes or other architectures are added, this will break.Consider accepting an optional
default_targetsparameter or reading from a model-specific config to improve flexibility:- pub fn from_lora_rank(lora_rank: usize) -> Self { + pub fn from_lora_rank(lora_rank: usize, default_targets: Option<Vec<String>>) -> Self { if lora_rank > 0 { TrainingMode::LoRA { rank: lora_rank } } else { TrainingMode::FullFineTune { - targets: vec!["head.lm_head".to_string(), "head.norm".to_string()], + targets: default_targets.unwrap_or_else(|| vec![ + "head.lm_head".to_string(), + "head.norm".to_string() + ]), } } }
59-88: Consider hardware-aware default forreload_memory_threshold_gb.The default of 80.0 GB assumes large-memory systems and may be inappropriate for users with 32 GB or 64 GB machines. The trainer has auto-detection logic, but the config default could be more conservative.
Consider a more portable default:
- reload_memory_threshold_gb: 80.0, // Also reload when MLX memory exceeds 80 GB + reload_memory_threshold_gb: 32.0, // Conservative default; adjust based on system memoryOr add a comment suggesting users override this based on their hardware:
- reload_memory_threshold_gb: 80.0, // Also reload when MLX memory exceeds 80 GB + reload_memory_threshold_gb: 80.0, // Adjust to ~40% of your system memoryrust/src/training/trainer.rs (2)
404-456: LGTM! Robust memory limit initialization with safety checks.The early abort for <10 GB available memory prevents futile training attempts. The auto-detection logic (60% of available, capped at 70 GB, minimum 8 GB) mirrors the CLI logic and provides sensible defaults.
The memory limit calculation logic (lines 422-438) duplicates similar logic in
rust/src/cli/commands.rs(lines 740-756). Consider extracting to a shared utility function for DRY:// In a shared module: pub fn calculate_safe_memory_limit(available_gb: f64) -> f64 { (available_gb * 0.6).min(70.0).max(8.0) }
1003-1050: Best checkpoint omits optimizer state—verify if this is intentional.
save_best_checkpoint_implsavesmodel_state(trainable head parameters) but provides an emptyoptimizer_stateHashMap (line 1033). This means the best checkpoint can be used for inference or export but cannot resume training from that exact point (no momentum state).If best checkpoints are intended for inference/export only, this is fine. Otherwise, consider saving optimizer state for full resumability.
If training resumption from best checkpoint is desired:
+ // Extract momentum for checkpoint + self.extract_momentum_for_checkpoint()?; + + // Save optimizer state + let mut optimizer_state = std::collections::HashMap::new(); + for (param_name, (data, shape)) in &self.adam_m { + optimizer_state.insert( + format!("{}.m", param_name), + serde_json::json!({"data": data, "shape": shape}), + ); + } + for (param_name, (data, shape)) in &self.adam_v { + optimizer_state.insert( + format!("{}.v", param_name), + serde_json::json!({"data": data, "shape": shape}), + ); + } + optimizer_state.insert("adam_step".to_string(), serde_json::json!(self.adam_step)); + let checkpoint = Checkpoint { step, model_state, - optimizer_state: std::collections::HashMap::new(), + optimizer_state,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
EVALUATION_COMPLETE.md(1 hunks)EVALUATION_REPORT.md(1 hunks)EVALUATION_SUMMARY.txt(1 hunks)MEMORY_LEAK_ANALYSIS.md(1 hunks)MEMORY_LEAK_FIXES_APPLIED.md(1 hunks)MEMORY_LEAK_SUMMARY.md(1 hunks)QUICK_EVALUATION.md(1 hunks)rust/src/cli/commands.rs(10 hunks)rust/src/cli/mod.rs(5 hunks)rust/src/config/training.rs(3 hunks)rust/src/training/trainer.rs(24 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (3)
rust/src/cli/mod.rs (1)
rust/src/cli/commands.rs (2)
generate(792-986)export_command(989-1074)
rust/src/cli/commands.rs (4)
rust/src/model/llama.rs (9)
config(553-555)new(167-191)new(302-312)new(342-354)new(382-398)new(445-458)new(499-504)new(526-535)load_model_with_weights(794-805)rust/src/utils/memory.rs (2)
current(18-36)new(238-244)rust/src/checkpoints/state.rs (1)
new(23-43)rust/src/model/loader.rs (1)
new(80-84)
rust/src/training/trainer.rs (10)
rust/src/checkpoints/state.rs (1)
new(23-43)rust/src/training/lora.rs (6)
new(92-115)apply_lora_to_model(26-78)mlx_rs(50-50)mlx_rs(57-57)mlx_rs(104-104)mlx_rs(107-107)rust/src/model/loader.rs (4)
new(80-84)mlx_rs(292-292)mlx_rs(304-304)array(439-439)rust/src/model/llama.rs (10)
new(167-191)new(302-312)new(342-354)new(382-398)new(445-458)new(499-504)new(526-535)config(553-555)mlx_rs(419-419)mlx_rs(478-478)rust/src/data/streaming.rs (2)
new(29-67)drop(165-167)rust/src/config/training.rs (1)
from_lora_rank(16-24)rust/src/utils/mlx_memory.rs (6)
get_active_memory(47-54)get_peak_memory(57-64)get_cache_memory(67-74)clear_cache(86-92)array(116-116)stop_gradient(108-122)rust/src/distrust_loss.rs (2)
mlx_rs(115-115)batch_empirical_distrust_loss(105-139)rust/tests/training_tests.rs (1)
mlx_rs(78-78)rust/src/metrics.rs (2)
auth_weights(28-28)prov_entropies(35-35)
🪛 LanguageTool
EVALUATION_COMPLETE.md
[style] ~101-~101: Consider using polite language here.
Context: ...provements** - Code sharing strategies Let me know! I can provide even more detailed analy...
(INSERT_PLEASE)
MEMORY_LEAK_FIXES_APPLIED.md
[grammar] ~76-~76: Use a hyphen to join words.
Context: ... --- ## 📊 Verification ### Cache Clearing Calls Added Total clear_cac...
(QB_NEW_EN_HYPHEN)
MEMORY_LEAK_ANALYSIS.md
[grammar] ~284-~284: Use a hyphen to join words.
Context: ...ied memory requirements - Document cache clearing strategy - Note leak rate expec...
(QB_NEW_EN_HYPHEN)
[grammar] ~322-~322: Use a hyphen to join words.
Context: ...feguards Next Steps: 1. Apply cache clearing fixes (code-only) 2. Add leak m...
(QB_NEW_EN_HYPHEN)
EVALUATION_SUMMARY.txt
[uncategorized] ~131-~131: The official name of this software platform is spelled with a capital “H”.
Context: ...ALGORITHM.md - Technical deep dive • .github/workflows/ - CI/CD workflows =========...
(GITHUB)
🪛 markdownlint-cli2 (0.18.1)
MEMORY_LEAK_FIXES_APPLIED.md
77-77: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
93-93: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
100-100: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
227-227: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
235-235: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
241-241: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
MEMORY_LEAK_SUMMARY.md
28-28: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (27)
EVALUATION_SUMMARY.txt (1)
1-283: LGTM - Comprehensive evaluation documentation.This evaluation summary provides useful context for the PR's scope and project maturity. The document is well-structured and informative, serving as metadata to support the training resource improvements.
EVALUATION_COMPLETE.md (1)
1-118: LGTM - Clear evaluation overview.Provides a well-organized summary of the monorepo evaluation with actionable findings and recommendations.
QUICK_EVALUATION.md (1)
1-156: LGTM - Useful quick reference.This quick evaluation guide provides an accessible overview of the project's quality and capabilities, complementing the more detailed evaluation reports.
EVALUATION_REPORT.md (1)
1-336: LGTM - Thorough evaluation report.Excellent comprehensive documentation covering all aspects of the monorepo. The detailed recommendations and metrics provide valuable context for the PR's improvements.
MEMORY_LEAK_SUMMARY.md (1)
1-116: LGTM - Clear memory leak documentation.This document effectively describes the memory leak investigation, root causes, and remediation strategy. The phased approach with safeguards is appropriate for a system running critical applications.
MEMORY_LEAK_ANALYSIS.md (1)
1-329: LGTM - Thorough memory leak investigation.Excellent technical analysis identifying specific code patterns causing memory leaks. The safe, code-only approach is prudent given the system constraints, and the phased remediation plan is well-structured.
MEMORY_LEAK_FIXES_APPLIED.md (1)
1-294: LGTM - Well-documented fixes with verification plan.This document clearly describes the memory leak fixes applied and provides a comprehensive testing and deployment strategy. The phased approach with safeguards is appropriate.
rust/src/cli/mod.rs (2)
122-124: LGTM - EOS token override properly integrated.The optional EOS token parameter is correctly typed and documented. The implementation properly handles the override case in the generate command handler.
126-137: Export command is properly structured and correctly implemented.The new Export command variant defines three required fields (model, checkpoint, output) with appropriate types. The dispatch at lines 195-199 correctly passes references to
commands::export_command, which has the matching signature:pub fn export_command(model: &str, checkpoint_path: &std::path::PathBuf, output_path: &std::path::PathBuf) -> Result<()>. The string-to-str coercion and PathBuf references are all correct.rust/src/config/training.rs (2)
3-12: LGTM! Well-designed training mode abstraction.The enum clearly captures the three training paradigms with appropriate data (rank for LoRA, target list for fine-tuning). Derives are complete and serialization support enables persistence.
30-31: LGTM! Field additions are well-documented.The new fields serve clear purposes:
training_modeas runtime-only state (appropriate use of#[serde(skip)])train_seq_lengthfor flexible sequence length control- Reload fields with clarified dual-path semantics (interval OR threshold)
The updated comment at lines 54-56 properly documents that
reload_interval_steps = 0disables only interval-based reloads while threshold-based reloads can still trigger, resolving the previous concern about misleading semantics.Also applies to: 47-47, 53-56
rust/src/cli/commands.rs (6)
8-13: LGTM! Import changes align with refactored model loading.The removal of unused
LlamaForCausalLMdirect import and addition ofModelLoader,Checkpoint, andAVAILABLE_MODELSsupport the new checkpoint-merge and loader-based initialization patterns throughout the file.
606-606: LGTM! Reload interval parameter wired correctly.The optional
reload_intervalparameter is properly applied toconfig.training.reload_interval_steps, allowing CLI override of the default periodic reload behavior.Also applies to: 696-699
740-758: LGTM! Helpful auto-detection of safe memory limits.The fallback logic provides sensible defaults (60% of available memory, capped at 70 GB, minimum 8 GB) with clear user feedback via stderr. This improves UX by reducing the need for manual tuning while remaining conservative.
799-799: LGTM! EOS token override is simple and transparent.The optional
eos_tokenparameter allows users to override the model's default EOS token ID when needed, with clear logging of the override for debugging.Also applies to: 856-862
884-918: LGTM! Consistent checkpoint-merge pattern across generation paths.Both comparison and single-model generation modes follow the same reliable pattern:
- Load base weights via
ModelLoader- Deserialize checkpoint JSON
- Merge checkpoint tensors into base weights
- Initialize model with merged weights
This ensures checkpoint state is properly applied during inference.
Also applies to: 936-958
988-1074: LGTM! Export command correctly merges and persists fine-tuned weights.The implementation properly:
- Resolves model paths (local or HF cache)
- Loads base weights and checkpoint state
- Merges checkpoint tensors into base weights (overwriting as needed)
- Persists the merged model as safetensors
- Creates output directories and provides clear progress feedback
The pattern is consistent with the checkpoint-merge logic in
generate(), ensuring reliability.rust/src/training/trainer.rs (10)
83-87: LGTM! Proper environment-based debug logging control.The
debug_log_path()helper provides a clean way to enable debug logging via environment variable, returningOption<PathBuf>for safe conditional usage.
317-348: LGTM! Well-documented memory safety controls.Both builder methods provide clear documentation of:
- The underlying MLX-rs leak issue they work around
- Risks of misconfiguration (OOM vs premature stop)
- Recommended values and tradeoffs
- Integration with reload mechanism for longer runs
This transparency helps users understand and tune the memory management behavior.
379-396: LGTM! Conservative calculation of safe training steps.The formula
(available_gb * 0.7) / leak_gb_per_stepprovides a safe upper bound on steps before OOM, with appropriate fallbacks for missing memory info or negligible leak rates.
594-710: LGTM! Comprehensive memory leak verification with appropriate safeguards.The baseline capture at step 5 (after warmup) and periodic verification every 10 steps provide good leak detection coverage. The disabled abort (lines 617-637) is appropriate given the comment that virtual memory metrics are noisy and RSS-based checks in
check_memory_limits()provide the real safeguard.The warning when approaching the memory limit (lines 640-690) gives users actionable feedback about enabling reload intervals.
1120-1169: LGTM! Proper debug logging helper with memory telemetry.The
log_debug()method correctly usesdebug_log_path()withif let Some(...)guard (no/dev/nullfallback), includes comprehensive memory metrics (RSS, MLX active/peak/cache), and writes structured JSON.This helper should be used throughout the training loop to replace the manual logging blocks flagged above.
1292-1309: LGTM! Infrequent CPU extraction for checkpoints is an acceptable trade-off.The
extract_momentum_for_checkpoint()method is called only during checkpoint saves (every 100-500 steps), making theas_slice()CPU extraction overhead negligible compared to per-step extraction. This design balances zero-leak training with checkpoint persistence needs.
1311-1405: LGTM! Checkpoint reload now properly restores parameters and optimizer state.The
reload_from_checkpoint()implementation correctly:
- Loads checkpoint data
- Drops old model to free MLX memory
- Reloads base weights and merges checkpoint weights (lines 1346-1361)
- Restores optimizer momentum to GPU (lines 1381-1391)
- Resets memory baseline
This fully addresses the previous critical issue where reload re-initialized the model without restoring parameters. The merged weight approach (base + checkpoint) ensures both pre-trained backbone and fine-tuned head are restored.
1535-1562: Excellent optimization! Gradient computation limited to trainable head prevents massive memory waste.The key insight (lines 1535-1562) of forwarding through the frozen backbone outside the gradient graph, then using
stop_gradientto detach, reduces gradient computation from 128 parameters to just 2 (trainable head). This dramatically cuts memory usage and computation time.The loss function (lines 1565-1594) and
value_and_gradcall (line 1625) operate only on theTrainableHead, ensuring MLX doesn't allocate gradients for frozen parameters.Also applies to: 1565-1594, 1624-1636
1721-1805: LGTM! Complete checkpoint serialization includes model parameters and optimizer state.The checkpoint now captures:
- Trainable head parameters in
model_state(lines 1730-1741)- Optimizer momentum (
.mand.ventries) inoptimizer_state(lines 1743-1765)- Adam step counter for bias correction (line 1765)
- Metadata (learning rate, best loss) (lines 1768-1777)
This enables full training continuity across checkpoint reloads and supports the export workflow.
21-53: Remove the concern about the 1.0 MB/step threshold—it's intentionally strict, not optimistic.The
memory_leak_threshold_mb: 1.0default is explicitly set as a conservative "nominal threshold for safety" post-fix (as the code comment states). While the documented MLX-rs framework leak is ~2000 MB/step, the recommended threshold defaults are 100 MB/step (native fix baseline), 50 MB/step (strict), and 500 MB/step (lenient), configurable viawith_memory_leak_threshold(). The 1.0 default is appropriately strict for the default safety posture, not optimistic.The GPU/CPU split optimizer state design remains sound for preventing memory leaks.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
rust/src/training/trainer.rs (1)
573-590: Duplicate: Avoid hardcoding/dev/nullfor logging; dedupe manual JSON logging.This issue was flagged in a previous review and remains unaddressed. Lines 577, 888, 908, 927, 947, 967 all use the pattern:
.open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null")))The
/dev/nullfallback is non-portable (Windows doesn't have it), and these blocks duplicate the helper you already added. Use the guard pattern instead:-if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) -{ +if let Some(log_path) = debug_log_path() { + if let Ok(mut file) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_path) + { let json = serde_json::json!({ ... }); let _ = writeln!(file, "{}", json); + } }Apply this pattern to all affected logging blocks.
Also applies to: 884-900, 904-920, 923-939, 943-959, 963-980
🧹 Nitpick comments (2)
.gitignore (1)
23-23: Consider placing the Rust path in a dedicated section for better organization.The new rule
rust/sovereign-ai-stack-bookis placed at the end of the Python section, but it's a Rust-related path. For maintainability and clarity, consider moving it to a dedicated "# Rust" or "# Build artifacts" section.# Python __pycache__/ *.py[cod] *$py.class *.so *.deprecated .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg -rust/sovereign-ai-stack-book # Virtual Environment venv/ env/ ENV/ .venv + +# Rust / Build artifacts +rust/sovereign-ai-stack-bookrust/PROJECT_STATUS.md (1)
54-54: Optional: Add language tags to fenced code blocks.For better syntax highlighting, consider adding language tags to the code blocks at lines 54 and 193:
- Line 54: appears to be a formula/pseudocode (could use
textor leave plain)- Line 193: appears to be commit hashes (could use
text)This is a minor documentation polish, not blocking.
Also applies to: 193-193
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
.gitignore(1 hunks)MEMORY_LEAK_FIXES_COMPLETE.txt(1 hunks)RELOAD_THRESHOLD_FIX_COMPLETE.txt(1 hunks)rust/MEMORY_LEAK_FIXES_COMPLETE.md(1 hunks)rust/MEMORY_LEAK_FIXES_STATUS.md(1 hunks)rust/PROJECT_STATUS.md(1 hunks)rust/RELOAD_THRESHOLD_FIX.md(1 hunks)rust/RELOAD_THRESHOLD_FIX_SUMMARY.md(1 hunks)rust/src/training/trainer.rs(24 hunks)
✅ Files skipped from review due to trivial changes (1)
- rust/MEMORY_LEAK_FIXES_STATUS.md
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (1)
rust/src/training/trainer.rs (5)
rust/src/checkpoints/state.rs (1)
new(23-43)rust/src/training/lora.rs (6)
new(92-115)apply_lora_to_model(26-78)mlx_rs(50-50)mlx_rs(57-57)mlx_rs(104-104)mlx_rs(107-107)rust/src/model/loader.rs (4)
new(80-84)mlx_rs(292-292)mlx_rs(304-304)array(439-439)rust/src/model/llama.rs (10)
new(167-191)new(302-312)new(342-354)new(382-398)new(445-458)new(499-504)new(526-535)config(553-555)mlx_rs(419-419)mlx_rs(478-478)rust/src/utils/mlx_memory.rs (6)
get_active_memory(47-54)get_peak_memory(57-64)get_cache_memory(67-74)clear_cache(86-92)array(116-116)stop_gradient(108-122)
🪛 LanguageTool
RELOAD_THRESHOLD_FIX_COMPLETE.txt
[uncategorized] ~147-~147: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...================== System Stability: ✅ LOW RISK - Logic now matches configuration docum...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
rust/RELOAD_THRESHOLD_FIX_SUMMARY.md
[uncategorized] ~36-~36: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ... - This summary ## Risk Assessment ✅ LOW RISK - Logic now matches configuration doc...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
rust/RELOAD_THRESHOLD_FIX.md
[uncategorized] ~94-~94: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...0) ## Backward Compatibility ✅ Fully backward compatible - Default behavior unchanged (reload_...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
[uncategorized] ~139-~139: If this is a compound adjective that modifies the following noun, use a hyphen.
Context: ...isk Assessment ### System Stability: ✅ LOW RISK - Logic is now consistent with configur...
(EN_COMPOUND_ADJECTIVE_INTERNAL)
MEMORY_LEAK_FIXES_COMPLETE.txt
[grammar] ~179-~179: Ensure spelling is correct
Context: ...d files for posterity. e5a276e Training sucess? Really? Test Commands: cargo check ...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🪛 markdownlint-cli2 (0.18.1)
rust/PROJECT_STATUS.md
54-54: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
193-193: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (15)
rust/RELOAD_THRESHOLD_FIX.md (1)
1-160: Documentation clearly describes the threshold reload fix.The documentation accurately describes the re-enabled threshold-based reload logic and provides clear usage scenarios. This aligns with the past review comment about config/code mismatch and provides helpful guidance for testing the three reload modes.
MEMORY_LEAK_FIXES_COMPLETE.txt (1)
1-196: Documentation artifact provides comprehensive summary.The file effectively documents the memory leak analysis and fixes applied to the training pipeline, with clear verification results and next steps.
rust/PROJECT_STATUS.md (1)
1-221: Comprehensive project status documentation.The document provides excellent tracking of implementation phases, verification results, and planned next steps. Well-organized and informative.
RELOAD_THRESHOLD_FIX_COMPLETE.txt (1)
1-244: Documentation artifact mirrors the .md version.This text-format documentation provides the same threshold reload fix information as the markdown version, ensuring accessibility across different documentation formats.
rust/RELOAD_THRESHOLD_FIX_SUMMARY.md (1)
1-39: Concise summary of the reload threshold fix.This summary provides a clear, quick-reference version of the threshold reload fix, complementing the detailed documentation.
rust/MEMORY_LEAK_FIXES_COMPLETE.md (1)
1-180: Detailed memory leak fix documentation.The document thoroughly describes the root causes, applied fixes, and verification results for the memory leak mitigation work. Clear and well-structured.
rust/src/training/trainer.rs (9)
726-785: Threshold-based reload logic successfully re-enabled.The reload logic now correctly implements the three modes described in configuration:
- Interval + Threshold (reload_interval > 0): Reloads when interval is met AND memory exceeds threshold
- Threshold-only (reload_interval == 0): Reloads only when memory exceeds threshold
- Interval-only (threshold disabled): Reloads at intervals only
The implementation properly handles
get_active_memory()errors gracefully and aligns with the configuration documentation. This resolves the past review concern about config/code mismatch.
1189-1310: GPU-only optimizer correctly implements zero-extraction pattern.The implementation keeps all momentum states on GPU, avoiding the CPU extraction leak. Key strengths:
- Momentum stays in
adam_m_gpu/adam_v_gputhroughout training- Explicit
eval()calls before drops ensure GPU work completion (lines 1292-1293)- Cache clearing after operations prevents accumulation (lines 1295-1296, 1303)
- CPU extraction only during infrequent checkpointing (line 1313-1329)
This design effectively addresses the memory leak concerns documented in the PR.
1335-1425: Checkpoint reload now correctly restores trainable head parameters.The implementation properly restores training state:
- Model parameters (lines 1374-1380): Merges
checkpoint.model_state(trainable head weights) into loaded base weights- Optimizer state (lines 1401-1411): Restores momentum tensors from CPU cache to GPU
- Full weight restoration (line 1384): Creates fresh model with merged pretrained backbone + fine-tuned head
This addresses the critical past review comment about missing parameter restoration. The checkpoint save/reload cycle now maintains training continuity correctly.
593-710: Zero-leak verification with comprehensive monitoring.The implementation captures baseline memory and tracks growth per step with clear warnings when approaching limits. The monitoring is intentionally conservative:
- Virtual memory growth monitoring (line 620) is informational only
- Physical RSS checking via
check_memory_limits()provides the actual safety net- Prominent warnings appear when within 20% of calculated limit (lines 646-690)
This layered approach balances leak detection with avoiding false positives from MLX's internal memory management.
1741-1825: Checkpoint save properly serializes trainable state.The implementation correctly captures all necessary training state:
- Trainable parameters (lines 1750-1761): Extracts head weights to
model_state- Optimizer momentum (lines 1767-1784): Serializes Adam first/second moments to
optimizer_state- Training metadata (line 1785): Includes
adam_stepfor bias correction continuityThe signature change to
&mut self(line 1741) is necessary forextract_momentum_for_checkpoint()and is consistent with the reload workflow.
21-22: Type aliases clearly distinguish GPU vs CPU storage.The introduction of
OptimizerStateGPU = ArrayandOptimizerState = (Vec<f32>, Vec<i32>)clearly documents the dual-storage pattern: GPU for training (zero-leak), CPU for checkpointing (serialization). This makes the architecture explicit and maintainable.Also applies to: 28-34
153-200: LoRA integration during model load is correct.The implementation properly applies LoRA adapters when
lora_rank > 0:
- Extracts target module names from config (lines 168-176)
- Creates LoraConfig with config-driven parameters (lines 178-183)
- Applies to weights before model construction (lines 184-189)
This ensures LoRA is active from the start of training rather than requiring separate initialization.
334-348: Memory configuration methods provide runtime flexibility.The builder-style methods
with_memory_leak_threshold()andwith_memory_warning_margin()allow fine-tuning of memory management behavior without recompiling. The documentation comments clearly explain the parameters and trade-offs.
1138-1187: Debug logging helper includes comprehensive memory metrics.The
log_debug()helper captures both system RSS and MLX-specific metrics (active/peak/cache memory), providing complete visibility into memory behavior across the training loop. This supports the memory leak investigation and verification workflow.
| 2261261 Fix memory leak.: | ||
| e814581 Update. | ||
| 2ff1e34 Add files for posterity. | ||
| e5a276e Training sucess? Really? |
There was a problem hiding this comment.
Typo in commit message.
Line 179 contains "Training sucess?" which should be "Training success?"
🧰 Tools
🪛 LanguageTool
[grammar] ~179-~179: Ensure spelling is correct
Context: ...d files for posterity. e5a276e Training sucess? Really? Test Commands: cargo check ...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
🤖 Prompt for AI Agents
In MEMORY_LEAK_FIXES_COMPLETE.txt around line 179, the commit message contains a
typo "Training sucess?" — change it to "Training success?" by editing that line;
if this is already in a committed change, amend the commit message (git commit
--amend or git rebase -i) or create a follow-up commit that corrects the text so
the file/commit history reads "Training success?".
There was a problem hiding this comment.
Actionable comments posted: 14
🧹 Nitpick comments (22)
rust/IMPROVE_MLX_BINDINGS.md (2)
122-149: Clarify availability of monitoring APIs.The monitoring examples show several MLX API functions that may not be available:
mlx_sys::mlx_get_gpu_utilization(Line 130) - The actual implementation at Line 155 inminimal_lora_training.rsnotes this is "not directly available in minimal C API" and returns 0.0mlx_sys::mlx_get_cache_memory(Line 144) - Not used in the actual implementationmlx_sys::mlx_get_cache_limit(Line 145) - Not used in the actual implementationOnly
mlx_sys::mlx_get_active_memoryis confirmed available and used. Consider marking these as future enhancements or verifying their availability.
33-72: Update documentation example to match actual mlx-sys API.The
ArrayWrapperexample in lines 33-72 uses mlx-sys functions that don't match the actual bindings available:
Type mismatch: Documentation shows
mlx_sys::mlx_array_tbut implementation usesmlx_array_(without_tsuffix).Function call differences:
- Documentation calls
mlx_sys::mlx_array_from_datawith 7 parameters, but actual API ismlx_sys::mlx_array_new_datawith 4 parameters.- Documentation calls
mlx_sys::mlx_eval(self.inner)directly, but actual implementation wraps the array inmlx_vector_arraybefore callingmlx_eval(outputs).Update this documentation example to reflect the actual mlx-sys API, or clearly label it as a proposed/idealized interface rather than production-ready code.
rust/examples/minimal_lora_training.rs (3)
72-100: Consider using or documenting the result parameter.The
configure_mlxfunction createsresultvariables (Lines 78, 89) that are passed as&muttomlx_set_memory_limitandmlx_set_cache_limitbut never used afterward.If these functions write back the actual limit that was set, consider using this value to:
- Verify the limit was set as requested
- Report the actual limit if it differs from the requested limit
If the return value is not meaningful, consider documenting why it's ignored or using
_to indicate intentional discard.🔎 Example: Verify and report actual limits
let limit_bytes = 80 * 1024 * 1024 * 1024usize; let mut result = 0usize; unsafe { mlx_sys::mlx_set_memory_limit(&mut result, limit_bytes); } - println!( - " Memory limit: {} GB", - limit_bytes as f64 / (1024.0 * 1024.0 * 1024.0) - ); + if result != limit_bytes { + println!( + " Memory limit: {} GB (requested) -> {} GB (actual)", + limit_bytes as f64 / (1024.0 * 1024.0 * 1024.0), + result as f64 / (1024.0 * 1024.0 * 1024.0) + ); + } else { + println!( + " Memory limit: {} GB", + result as f64 / (1024.0 * 1024.0 * 1024.0) + ); + }
102-118: Consider reducing example data size or adding a note.Line 112 allocates a 64MB vector (
4096 * 4096 * 4 bytes = ~64 MB) which is then copied into the MLX array and immediately dropped.For a minimal example, consider either:
- Using a smaller size (e.g.,
256 x 256) to reduce memory footprint- Adding a comment explaining this is intentionally large to simulate real Llama weights
The current implementation works but may be surprising to readers expecting a truly "minimal" example.
168-214: Clarify that model path is unused in simulation.Line 176 passes a hardcoded model path
"models/distrust-mlabonne/Meta-Llama-3.1-8B-Instruct"toload_quantized_weights, but the function ignores this parameter (Line 104-106 inload_quantized_weights).Consider either:
- Adding a comment at Line 176 noting the path is currently unused in this simulation
- Using a placeholder like
"<path-unused-in-simulation>"to make it explicitThis would help readers understand that real usage would require an actual model path.
🔎 Suggested clarification
// 2. Load model weights (4-bit quantized) - let weights = load_quantized_weights("models/distrust-mlabonne/Meta-Llama-3.1-8B-Instruct")?; + // Note: Path is currently unused in this simulation + let weights = load_quantized_weights("models/distrust-mlabonne/Meta-Llama-3.1-8B-Instruct")?;rust/test_checkpoint.sh (1)
12-21: Consider enabling the test command or documenting the manual-verification requirement.The cargo command is commented out, leaving only a manual-verification message. If this script is intended for CI or automated testing, uncomment and adjust the command. If manual verification is required due to resource constraints (e.g., model size, GPU availability), add a comment explaining why.
rust/VERIFICATION_REPORT.md (1)
65-65: Optional: Add language identifier to fenced code block.Markdownlint flags the code block as missing a language identifier. Add
bashafter the opening triple-backticks for consistent formatting.rust/FINAL_SUMMARY.md (1)
34-44: Optional: Add language identifier to fenced code block.Markdownlint flags the code block as missing a language identifier. Add
textorbashafter the opening triple-backticks.rust/src/config/training.rs (1)
53-56: Reload comment improved but could be more explicit.The comment now clarifies that "EITHER condition" triggers a reload and that setting
reload_interval_stepsto 0 enables "only threshold-based reload." While clearer than before, consider explicitly stating: "Setting to 0 disables interval-based reloading; reloads will only occur whenreload_memory_threshold_gbis exceeded."rust/MEMORY_LEAK_FIXES_COMPLETE.md (1)
192-213: Optional: Add language identifiers to fenced code blocks.Markdownlint flags multiple code blocks (lines 192, 199, 212) as missing language identifiers. Add
textorbashafter the opening triple-backticks for consistent formatting.rust/TRAINING_OPTIMIZATION_SUMMARY.md (2)
105-109: Consider adding language identifiers to fenced code blocks.The fenced code block at lines 105-109 (showing expected output) should have a language identifier for better rendering and to satisfy markdown linting rules.
🔎 Proposed fix
-``` +```text ⚠️ No memory limit specified. Auto-detecting safe limit: 70.0 GB 🔒 Set MLX memory limit to 70.0 GB (was 6.3 GB)</details> --- `112-114`: **Consider adding language identifiers to fenced code blocks.** The fenced code block at lines 112-114 (showing progress bar output) should have a language identifier for better rendering and to satisfy markdown linting rules. <details> <summary>🔎 Proposed fix</summary> ```diff -``` +```text [00:10:00] =>---------------------------- 45/5000 ETA:8h loss: 199.2948 | lr: 5.00e-05 | 48.7 steps/s</details> </blockquote></details> <details> <summary>rust/test_memory_optimizations.sh (1)</summary><blockquote> `19-81`: **Consider more robust feature validation beyond keyword grep.** The script uses simple `grep` checks to verify memory optimization features. While this confirms the presence of keywords, it doesn't validate that features are correctly implemented or connected. Consider enhancing with: - AST-based checks using `ast-grep` for structural validation - Integration tests that exercise actual memory behavior - Verification that configuration values are actually applied However, for a basic sanity check during development, the current approach is acceptable. </blockquote></details> <details> <summary>rust/src/utils/model_utils.rs (4)</summary><blockquote> `25-72`: **Improve cross-platform compatibility and robustness.** Several issues: 1. **Portability**: Line 34 uses `HOME` environment variable which is Unix-specific. On Windows, use `USERPROFILE` or the `dirs` crate. 2. **Case-sensitive 4-bit detection**: Line 44 uses `contains("4bit")` which won't match "4Bit", "4BIT", etc. Use case-insensitive comparison. 3. **Silent failures in sorting**: Lines 49-50 silently fall back to default ordering if metadata/modified fails. This could mask filesystem issues. 4. **Return type inconsistency**: Returns `Option<String>` but the string is a path—consider `Option<PathBuf>` for type safety. <details> <summary>🔎 Proposed improvements</summary> ```diff pub fn resolve_model_path(model_name: &str, prefer_4bit: bool) -> Option<String> { // If it's already a valid path, return it if Path::new(model_name).exists() { return Some(model_name.to_string()); } // Handle HuggingFace model names (containing '/') if model_name.contains('/') { let cache_name = model_name.replace('/', "--"); - let home = std::env::var("HOME").ok()?; + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .ok()?; let cache_dir = format!("{}/.cache/huggingface/hub/models--{}", home, cache_name); if Path::new(&cache_dir).exists() { let snapshots_dir = format!("{}/snapshots", cache_dir); if let Ok(entries) = std::fs::read_dir(&snapshots_dir) { // If preferring 4-bit, look for those first if prefer_4bit { let mut fourbit_dirs: Vec<_> = entries .flatten() - .filter(|e| e.file_name().to_string_lossy().contains("4bit")) + .filter(|e| e.file_name().to_string_lossy().to_lowercase().contains("4bit")) .collect(); // Sort by modification time (newest first) fourbit_dirs.sort_by(|a, b| { let a_time = a.metadata().ok().and_then(|m| m.modified().ok()); let b_time = b.metadata().ok().and_then(|m| m.modified().ok()); b_time.cmp(&a_time) });
98-112: Fix cross-platform compatibility.Same HOME environment variable issue as
resolve_model_path. Consider extracting a shared helper function for cache directory resolution that handles both Unix (HOME) and Windows (USERPROFILE) environments.🔎 Suggested helper function
fn get_home_dir() -> Option<String> { std::env::var("HOME") .or_else(|_| std::env::var("USERPROFILE")) .ok() }Then use
get_home_dir()in both functions.
121-148: Consider using PathBuf for type safety.Returns
Vec<String>but represents filesystem paths. UsingVec<PathBuf>would provide better type safety and avoid string/path conversions at call sites.
168-182: Minor improvements for maintainability.The hardcoded
"snapshots"string at line 176 couples this function to HuggingFace's cache structure. Consider extracting it as a module constant for maintainability.+const HF_SNAPSHOTS_DIR: &str = "snapshots"; + pub fn get_base_model_dir(snapshot_path: &str) -> Option<String> { let path = Path::new(snapshot_path); if !path.exists() { return None; } // Navigate up from snapshot to the base model directory let parent = path.parent()?; - if parent.file_name()?.to_string_lossy() == "snapshots" { + if parent.file_name()?.to_string_lossy() == HF_SNAPSHOTS_DIR { let grandparent = parent.parent()?; return Some(grandparent.to_string_lossy().to_string()); }rust/src/model/loader.rs (3)
70-112: Document safety invariants for unsafe pointer cast.The unsafe block at lines 96-103 performs a raw pointer cast without documenting the safety requirements. While this is a common pattern for FFI, the code should document:
- Why
tensor.data().len() / 4is the correct element count- Alignment requirements for f32
- Lifetime assumptions (tensor_data outlives the slice)
// Convert TensorView to MLX Array // We use from_slice with the appropriate type let data = match tensor.dtype() { safetensors::Dtype::F32 => Array::from_slice( + // SAFETY: tensor.data() is a byte slice from safetensors. + // We reinterpret it as f32 slice. This is safe because: + // 1. safetensors guarantees proper alignment for F32 + // 2. len/4 gives correct f32 element count (4 bytes per f32) + // 3. tensor_data lifetime ensures the slice remains valid unsafe { std::slice::from_raw_parts( tensor.data().as_ptr() as *const f32, tensor.data().len() / 4, ) }, &shape, ),
125-158: Simplify LoRA initialization.Line 148 creates
lora_ausingfull::<f32>(..., 0.0)which is equivalent tozeros::<f32>. Usezerosfor consistency withlora_bon line 149.- let val_0 = Array::from_slice(&[0.0f32], &[]); - let lora_a = full::<f32>(&[in_features, lora_rank as i32], &val_0)?; + let lora_a = zeros::<f32>(&[in_features, lora_rank as i32])?; let lora_b = zeros::<f32>(&[lora_rank as i32, in_features])?;
179-211: Silent error swallowing at line 205.The code silently continues when encountering unsupported dtypes (line 205), which could mask data corruption or format issues. At minimum, log a warning so users know tensors were skipped.
_ => continue, // Skip unsupported for nowShould be:
- _ => continue, // Skip unsupported for now + _ => { + eprintln!("Warning: Skipping tensor '{}' with unsupported dtype: {:?}", + tensor_name, tensor.dtype()); + continue; + }rust/src/training/trainer.rs (2)
55-87: LoRA-specific optimizer initialization.The optimizer correctly filters for only LoRA parameters (lora_A, lora_B). Consider making the learning rate and other hyperparameters configurable through the constructor if not already provided by a separate config.
283-285: Document panic behavior or return Result.The
unwrap()at line 284 will panic ifzerosfails (e.g., out of memory, invalid shape). Either document this panic in a comment or change the return type toResult<Array>for safer error propagation.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (37)
rust/CHANGELOG.md(1 hunks)rust/CHECKPOINT_FIX_SUMMARY.md(1 hunks)rust/CHECKPOINT_TESTS_SUMMARY.md(1 hunks)rust/Cargo.toml(3 hunks)rust/FINAL_SUMMARY.md(1 hunks)rust/FIXES_VERIFICATION.md(1 hunks)rust/IMPROVE_MLX_BINDINGS.md(1 hunks)rust/MEMORY_LEAK_FIXES_COMPLETE.md(1 hunks)rust/MEMORY_OPTIMIZATIONS.md(1 hunks)rust/MEMORY_PRESSURE_DIAGNOSIS.md(1 hunks)rust/MEMORY_PRESSURE_FIXES.md(1 hunks)rust/TRAINING_OPTIMIZATION_SUMMARY.md(1 hunks)rust/VERIFICATION_REPORT.md(1 hunks)rust/examples/minimal_lora_training.rs(1 hunks)rust/patches/mlx-sys/build.rs(1 hunks)rust/src/benchmarks/optimizer.rs(10 hunks)rust/src/checkpoints/manager.rs(2 hunks)rust/src/checkpoints/mlx_utils.rs(1 hunks)rust/src/checkpoints/mod.rs(1 hunks)rust/src/checkpoints/state.rs(2 hunks)rust/src/cli/commands.rs(21 hunks)rust/src/cli/mod.rs(4 hunks)rust/src/config/model.rs(1 hunks)rust/src/config/training.rs(3 hunks)rust/src/lib.rs(1 hunks)rust/src/main.rs(1 hunks)rust/src/model/loader.rs(1 hunks)rust/src/model/mod.rs(1 hunks)rust/src/training/trainer.rs(1 hunks)rust/src/utils/mlx_memory.rs(2 hunks)rust/src/utils/mod.rs(1 hunks)rust/src/utils/model_utils.rs(1 hunks)rust/test_checkpoint.sh(1 hunks)rust/test_memory_optimizations.sh(1 hunks)rust/tests/checkpoint_tests.rs(1 hunks)rust/tests/integration_tests.rs(1 hunks)rust/tests/training_tests.rs(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- rust/CHANGELOG.md
- rust/FIXES_VERIFICATION.md
🚧 Files skipped from review as they are similar to previous changes (2)
- rust/Cargo.toml
- rust/src/utils/mlx_memory.rs
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/train_*.py : Use Apple's MLX framework for training on Apple Silicon, with PyTorch as a fallback for other platforms
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/train_*.py : Use Apple's MLX framework for training on Apple Silicon, with PyTorch as a fallback for other platforms
Applied to files:
rust/MEMORY_PRESSURE_FIXES.md
🧬 Code graph analysis (22)
rust/src/checkpoints/mod.rs (1)
rust/src/checkpoints/mlx_utils.rs (2)
from_flat(17-20)to_flat(10-14)
rust/test_memory_optimizations.sh (2)
rust/src/utils/memory.rs (4)
check(247-256)tests(297-321)test_memory_monitor(315-320)MemoryMonitor(236-294)rust/src/hardware/scaling.rs (1)
tests(191-214)
rust/src/lib.rs (1)
rust/src/utils/model_utils.rs (6)
get_base_model_dir(168-182)get_model_cache_dir(98-112)is_model_available(157-159)list_model_snapshots(121-148)resolve_model_path(25-72)resolve_model_path_with_error(82-89)
rust/tests/checkpoint_tests.rs (3)
rust/src/checkpoints/manager.rs (2)
default(175-181)new(23-43)rust/src/checkpoints/state.rs (1)
new(27-47)rust/src/utils/mlx_memory.rs (1)
clear_cache(83-89)
rust/src/model/mod.rs (1)
rust/src/model/loader.rs (2)
load_model(25-67)save_model_weights(214-233)
rust/src/config/training.rs (3)
rust/src/config/model.rs (1)
default(20-32)rust/src/config/distrust.rs (1)
default(16-21)rust/src/config/mod.rs (3)
default(30-41)training(5-5)Config(18-27)
rust/src/benchmarks/optimizer.rs (1)
rust/src/training/trainer.rs (1)
new(31-52)
rust/src/config/model.rs (1)
rust/src/training/lora.rs (1)
LoraConfig(11-16)
rust/src/checkpoints/mlx_utils.rs (4)
rust/src/checkpoints/manager.rs (2)
array(201-201)array(232-232)rust/src/model/loader.rs (1)
array(220-220)rust/src/utils/mlx_memory.rs (1)
array(113-113)rust/examples/minimal_lora_training.rs (1)
shape(55-63)
rust/tests/training_tests.rs (1)
rust/src/training/trainer.rs (1)
new(31-52)
rust/src/checkpoints/state.rs (2)
rust/src/checkpoints/manager.rs (1)
new(23-43)rust/src/training/trainer.rs (1)
new(31-52)
rust/src/main.rs (1)
rust/src/cli/mod.rs (1)
run(140-212)
rust/tests/integration_tests.rs (2)
rust/src/model/llama.rs (1)
config(553-555)rust/src/training/lora.rs (1)
LoraConfig(11-16)
rust/src/training/trainer.rs (8)
rust/src/model/loader.rs (4)
load_model(25-67)zeros(149-149)default(169-175)array(220-220)rust/src/checkpoints/manager.rs (4)
new(23-43)default(175-181)array(201-201)array(232-232)rust/src/checkpoints/state.rs (1)
new(27-47)rust/src/utils/memory.rs (1)
new(238-244)rust/src/checkpoints/mlx_utils.rs (3)
to_flat(10-14)from_flat(17-20)array(11-11)rust/src/utils/mlx_memory.rs (2)
clear_cache(83-89)array(113-113)rust/src/config/model.rs (1)
default(20-32)rust/src/config/training.rs (1)
default(60-87)
rust/src/utils/mod.rs (1)
rust/src/utils/model_utils.rs (6)
get_base_model_dir(168-182)get_model_cache_dir(98-112)is_model_available(157-159)list_model_snapshots(121-148)resolve_model_path(25-72)resolve_model_path_with_error(82-89)
rust/FINAL_SUMMARY.md (2)
rust/src/utils/memory.rs (4)
tests(297-321)MemoryMonitor(236-294)print_report(273-293)MemoryMonitor(230-234)rust/src/hardware/scaling.rs (1)
tests(191-214)
rust/MEMORY_PRESSURE_DIAGNOSIS.md (3)
rust/src/utils/memory.rs (3)
tests(297-321)MemoryMonitor(230-234)MemoryMonitor(236-294)rust/src/hardware/scaling.rs (1)
tests(191-214)rust/src/model/llama.rs (2)
LlamaModel(362-370)LlamaConfig(30-138)
rust/src/cli/mod.rs (1)
rust/src/cli/commands.rs (4)
optimize(544-570)train(573-741)generate(758-935)export_command(938-1026)
rust/src/checkpoints/manager.rs (2)
rust/src/checkpoints/state.rs (1)
new(27-47)rust/src/checkpoints/mlx_utils.rs (3)
from_flat(17-20)array(11-11)to_flat(10-14)
rust/MEMORY_LEAK_FIXES_COMPLETE.md (1)
rust/src/utils/memory.rs (3)
tests(297-321)MemoryMonitor(230-234)MemoryMonitor(236-294)
rust/src/model/loader.rs (6)
rust/tests/training_tests.rs (1)
mlx_rs(44-44)rust/src/model/llama.rs (4)
mlx_rs(419-419)mlx_rs(478-478)config(553-555)LlamaConfig(30-138)rust/src/cli/commands.rs (2)
serde_json(355-355)s(239-239)rust/src/config/model.rs (1)
default(20-32)rust/src/checkpoints/mlx_utils.rs (1)
array(11-11)rust/src/utils/mlx_memory.rs (1)
array(113-113)
rust/src/cli/commands.rs (9)
rust/src/model/llama.rs (15)
config(553-555)new(167-191)new(302-312)new(342-354)new(382-398)new(445-458)new(499-504)new(526-535)load_model_with_weights(794-805)LlamaForCausalLM(433-527)LlamaModel(362-370)LlamaForCausalLM(426-431)LlamaModel(372-422)LlamaConfig(12-28)LlamaConfig(30-138)rust/src/model/loader.rs (2)
load_model(25-67)array(220-220)rust/src/utils/model_utils.rs (1)
resolve_model_path(25-72)rust/src/benchmarks/optimizer.rs (2)
quick_validate(266-338)new(39-55)rust/src/training/trainer.rs (2)
train(109-127)new(31-52)rust/src/checkpoints/manager.rs (3)
new(23-43)array(201-201)array(232-232)rust/src/checkpoints/state.rs (1)
new(27-47)rust/src/utils/memory.rs (1)
new(238-244)rust/src/checkpoints/mlx_utils.rs (1)
array(11-11)
🪛 markdownlint-cli2 (0.18.1)
rust/VERIFICATION_REPORT.md
65-65: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
rust/MEMORY_PRESSURE_FIXES.md
129-129: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
137-137: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
rust/FINAL_SUMMARY.md
34-34: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
rust/MEMORY_LEAK_FIXES_COMPLETE.md
192-192: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
199-199: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
212-212: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
rust/TRAINING_OPTIMIZATION_SUMMARY.md
105-105: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
112-112: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Build and Test
- GitHub Check: Lint
🔇 Additional comments (54)
rust/IMPROVE_MLX_BINDINGS.md (1)
74-120: Clarify which MLX APIs exist vs. are proposed.This section mixes existing and non-existent MLX functions:
mlx_sys::mlx_set_memory_limitandmlx_sys::mlx_set_cache_limitare already available and used in the actual implementationmlx_sys::mlx_batched_add(line 91) does not exist in mlx-sys and should not be presented as an available APImlx_sys::mlx_set_eval_lazy(line 110) does not exist in mlx-sysmlx_sys::mlx_array_from_quantized(line 117) is proposed and would need to be implementedUpdate the code examples to show only available functions, or clearly mark proposed APIs as "future enhancements" or "would require adding."
Likely an incorrect or invalid review comment.
rust/patches/mlx-sys/build.rs (1)
106-117: LGTM—FFI surface tightened with clear documentation.The consolidation to
mlx.hand the allowlist filters reduce namespace pollution and limit the binding surface to intended MLX APIs.rust/src/config/training.rs (2)
3-25: LGTM—well-designed training mode abstraction.The
TrainingModeenum cleanly separates LoRA, full fine-tuning, and frozen modes. Thefrom_lora_rankhelper provides sensible auto-detection based on configuration.
30-31: Appropriate use of#[serde(skip)]for runtime-only field.The
training_modefield is marked with#[serde(skip)]and documented as being set during trainer initialization. This is the correct pattern for runtime-populated fields that should not be persisted in configuration files.rust/src/benchmarks/optimizer.rs (4)
84-138: LGTM!The async conversion is clean and correct. The await properly handles the asynchronous test_config call within the loop.
141-189: LGTM!The async conversion maintains all error handling and memory cleanup logic. The await on run_training_test is correct.
192-262: LGTM!The async conversion correctly uses PathBuf for path handling and properly awaits the trainer initialization and training steps. Memory monitoring logic is preserved throughout.
266-338: LGTM!The async conversion is correct with proper awaits on trainer initialization and training steps. The explicit cleanup sequence (drop, clear_cache, eval) ensures proper resource release before returning.
rust/tests/integration_tests.rs (1)
8-8: LGTM!The test expectation correctly reflects the updated default lora_rank value (16) in the ModelConfig, consistent with the changes in rust/src/config/model.rs.
rust/src/main.rs (1)
7-15: LGTM!The async conversion of the main entry point is correct. The
#[tokio::main]macro properly initializes the async runtime, andcli::run().awaitcorrectly awaits the async CLI execution.rust/src/model/mod.rs (1)
6-6: LGTM!The export change from
ModelLoaderto function-based exports (load_model,save_model_weights,ModelConfig) aligns with the loader refactoring. This is a breaking API change but appears intentional for the path-based loading approach.rust/src/utils/mod.rs (1)
3-10: LGTM!The addition of the
model_utilsmodule and its public re-exports is clean and follows the existing pattern in this file.rust/src/config/model.rs (1)
22-31: LGTM!The updated defaults significantly reduce LoRA complexity for memory efficiency:
- lora_rank: 128 → 16 (8× reduction)
- lora_alpha: 256 → 32 (maintains scale ratio of 2.0)
- lora_num_layers: 16 → 4 (4× reduction)
- lora_target_modules: 4 → 1 module
This aligns with the PR's memory-aware optimization goals and the "Chill" review mode. The scale ratio is preserved (alpha/rank = 2.0), which maintains the LoRA scaling behavior.
rust/src/lib.rs (1)
37-44: LGTM!The new public re-exports expand the crate's API surface to include checkpoint management types and model utility functions. The comments clearly document the purpose of these exports.
rust/MEMORY_PRESSURE_FIXES.md (1)
1-10: LGTM! Well-structured memory optimization documentation.The document provides comprehensive guidance on addressing memory pressure issues during LoRA fine-tuning, including immediate CLI fixes, code-level changes, and performance expectations. The executive summary and problem identification are clear and actionable.
rust/src/checkpoints/mod.rs (1)
1-7: LGTM! Clean expansion of checkpoint module API.The addition of
mlx_utilsand the expanded exports (OptimizerState,ModelState,from_flat,to_flat) provide a well-organized public API for checkpoint serialization and MLX array conversion utilities. This aligns with the PR's goal of improving checkpoint handling.rust/test_memory_optimizations.sh (1)
85-90: LGTM! Compilation check provides essential validation.The
cargo checkstep ensures the code compiles successfully before proceeding with the memory optimization summary. Good practice for a validation script.rust/tests/training_tests.rs (2)
4-36: LGTM! Test correctly adapted to async trainer initialization.The test has been properly updated to use
#[tokio::test]and await the asyncDistrustTrainer::newcall. The error handling verification correctly checks for expected failure messages when the model path doesn't exist.
38-62: LGTM! Gradient and loss computation tests are well-structured.The tests verify that gradient computation and distrust loss calculation work correctly. The use of
mlx_rs::ops::zerosandArray::from_sliceis appropriate, and assertions check expected tensor dimensions and successful loss computation.rust/MEMORY_PRESSURE_DIAGNOSIS.md (1)
1-171: LGTM! Comprehensive memory pressure diagnostic documentation.This document provides excellent diagnostic guidance for memory pressure issues during LoRA fine-tuning. The structured approach covering root causes, immediate solutions, configuration optimizations, and monitoring recommendations will be valuable for users encountering memory issues. The references to related documentation files help tie together the broader memory management narrative.
rust/src/cli/mod.rs (3)
89-91: LGTM! Reload interval parameter supports memory-aware training.The addition of the
reload_intervalparameter to the Train command aligns with the memory optimization strategy documented throughout the PR. This enables periodic model reloads to mitigate MLX-rs framework memory leaks.
122-137: LGTM! EOS token override and Export command additions.The additions support:
- Configurable EOS token handling via
eos_tokenparameter in Generate (lines 122-124)- Model export functionality via the new Export command (lines 126-137)
Both align with PR objectives for improved training resources and model export capabilities.
140-211: LGTM! Async CLI routing correctly integrated.The
run()function has been properly converted to async and all command handlers are correctly awaited where necessary. The routing logic cleanly forwards new parameters (reload_interval,eos_token) and invokes the newexport_command.rust/tests/checkpoint_tests.rs (5)
16-35: LGTM! Well-structured test helper function.The
create_test_checkpointhelper provides a reusable way to create mock checkpoints with predictable data for testing. The use of simple mock data (fixed arrays, default optimizer state) is appropriate for unit testing checkpoint persistence.
37-65: LGTM! Comprehensive round-trip test.The test verifies that checkpoints can be saved and loaded correctly, checking:
- Checkpoint file creation at expected path
- Step number preservation
- Model state tensor count
- Loss history preservation
This is essential validation for checkpoint persistence.
92-124: LGTM! Excellent cleanup validation.This test verifies the critical behavior of checkpoint cleanup with
max_checkpoints=2. It confirms:
- Only the latest 2 checkpoints are retained after saving 3
- Both in-memory listing and on-disk files reflect the cleanup
- Older checkpoints are properly deleted
This validates the checkpoint management strategy essential for long-running training.
150-177: LGTM! Reload integration test validates training workflow.This test simulates a realistic training scenario with periodic checkpoints and a reload at step 10. It validates that checkpoints persist correctly across reload events, which is critical for the memory-aware training strategy documented in this PR.
199-237: LGTM! Error handling and edge case coverage.The tests for error handling (non-existent checkpoint) and edge cases (empty checkpoint) provide essential validation that the checkpoint system handles exceptional scenarios gracefully. This prevents silent failures or crashes in production.
rust/MEMORY_OPTIMIZATIONS.md (1)
1-92: LGTM! Detailed memory optimization technical documentation.This document provides excellent technical depth on memory optimizations for low-memory training scenarios. The breakdown of improvements across model configuration, memory monitoring, cache management, reload mechanisms, and training step optimizations is clear and actionable. The expected results and usage recommendations provide concrete guidance for users.
rust/src/utils/model_utils.rs (2)
82-89: LGTM!Good error context with actionable guidance for users.
157-159: LGTM!Clean and simple availability check.
rust/src/model/loader.rs (2)
161-176: LGTM!Reasonable default LoRA configuration for training.
214-233: Add dtype validation before callingas_slice::<f32>at line 220.The function assumes all arrays are f32 without checking. If code is extended to support quantized or mixed-precision models, non-f32 arrays could silently cause incorrect data serialization or panics. Either validate the dtype explicitly or document that only f32 arrays are supported and update the load path to be consistent.
rust/src/training/trainer.rs (6)
14-52: LGTM with reasonable defaults.The async constructor properly initializes the trainer with sensible defaults. The 32GB max_memory default (line 47) is appropriate for modern Apple Silicon machines.
89-107: LGTM!Clean builder pattern for configuring the trainer.
109-155: Skeletal training implementation with placeholder logic.Lines 111-123 contain hardcoded dummy training data and a 10-step loop. This appears intentional per the inline comment, but verify that:
- This PR is meant to establish the training API surface without full implementation
- There's a follow-up task to implement actual data loading and training logic
The aggressive cache clearing (lines 132, 152) is appropriate for memory management but may impact performance once real training is implemented.
158-187: Placeholder implementations for core training operations.The forward pass, loss computation, backward pass, and parameter updates are all stubs. This aligns with the skeletal training loop but means the trainer cannot actually train a model yet.
Ensure this is tracked in a follow-up issue to implement these methods.
200-229: LGTM! Good memory management.The explicit
mem::take(line 207) and double cache clearing (lines 204, 210) ensure proper cleanup of MLX memory before reloading. This addresses the memory leak issues mentioned in the PR objectives.
232-254: LGTM!Proper base model reload with optimizer reinitialization.
rust/src/checkpoints/manager.rs (10)
12-43: LGTM!Clean CheckpointManager initialization with proper error context.
46-61: LGTM!Proper async checkpoint saving with automatic cleanup.
64-76: LGTM!Clean async loading with proper error context.
79-108: LGTM with good error handling.The filename parsing (lines 95-100) correctly handles malformed filenames by using
if let Ok(step)to skip unparseable entries. The sorted output is oldest-first which is appropriate for cleanup logic.
111-136: LGTM!Correct cleanup implementation that keeps the
max_checkpointsmost recent checkpoints.
144-182: LGTM!Well-structured state types with serialization support and sensible defaults.
185-215: Validate f32 assumption or document limitation.Line 201 uses
as_slice::<f32>()assuming all arrays are f32. This will fail or return incorrect data for quantized/mixed-precision models. Either validate dtype or document this as an f32-only function.Related to the same issue in
rust/src/model/loader.rsat line 220.
218-277: LGTM! Clever metadata embedding.The approach of embedding metadata and optimizer state as U8 tensors with underscore-prefixed names (lines 269-270) is a clean solution for single-file checkpoints. The
_tensors_datavec (line 223) properly keeps MLX arrays alive during serialization.
280-355: LGTM! Robust deserialization with fallbacks.The unwrap_or_default usage (lines 329, 337, 343, 346) provides graceful degradation for backwards compatibility if checkpoint format evolves. Good use of
Optionfor metadata fields.
358-387: Legacy compatibility function.This function provides backwards compatibility for loading old checkpoint formats. Consider adding a deprecation notice if the new format is fully adopted.
rust/src/cli/commands.rs (5)
103-542: LGTM! Robust benchmarking with safety checks.The subprocess isolation (lines 327-335) and memory safety checks (lines 285-311) provide good protection against OOM crashes during benchmarking. The --force option allows advanced users to bypass safety limits when needed.
573-741: LGTM! Smart memory auto-detection.The auto-detection logic (lines 699-722) uses an aggressive 0.8 factor for Apple Silicon's unified memory (line 707), which is appropriate given the integrated GPU/CPU memory architecture. Good user feedback with console messages.
758-835: LGTM! Proper model initialization.The addition of EOS token override (lines 803-806) provides useful flexibility. Using
load_model_with_weights(lines 831-834) ensures the model is initialized with actual weights rather than random values.
853-878: LGTM! Efficient checkpoint merging.The weight merging logic (lines 858-862) correctly overlays checkpoint weights onto base weights. The efficient reuse of
base_weights(line 854) avoids unnecessary clones.
882-932: LGTM! Consistent checkpoint loading.The checkpoint loading and merging (lines 890-904) follows the same pattern as the comparison mode, ensuring consistency across the codebase.
| pub fn new(data: &[f32], shape: &[i32]) -> Result<Self, String> { | ||
| let array = unsafe { | ||
| mlx_sys::mlx_array_new_data( | ||
| data.as_ptr() as *const std::ffi::c_void, | ||
| shape.as_ptr(), | ||
| shape.len() as i32, | ||
| mlx_sys::mlx_dtype__MLX_FLOAT32, | ||
| ) | ||
| }; | ||
| if array.ctx.is_null() { | ||
| return Err("Failed to create array".to_string()); | ||
| } | ||
| Ok(Array { inner: array }) | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add input validation for array creation.
The new method lacks validation:
- Shape validation: No check that
shapeis non-empty or that its dimensions are positive. - Data-shape consistency: No validation that
data.len()matches the product of shape dimensions.
Invalid inputs could lead to undefined behavior in the C API or create arrays with incorrect shapes.
🔎 Proposed validation logic
pub fn new(data: &[f32], shape: &[i32]) -> Result<Self, String> {
+ // Validate shape
+ if shape.is_empty() {
+ return Err("Shape cannot be empty".to_string());
+ }
+
+ // Validate dimensions are positive and compute expected size
+ let expected_size: usize = shape.iter()
+ .try_fold(1usize, |acc, &dim| {
+ if dim <= 0 {
+ return None;
+ }
+ acc.checked_mul(dim as usize)
+ })
+ .ok_or_else(|| "Invalid shape: dimensions must be positive".to_string())?;
+
+ // Validate data length matches shape
+ if data.len() != expected_size {
+ return Err(format!(
+ "Data length {} does not match shape {:?} (expected {})",
+ data.len(), shape, expected_size
+ ));
+ }
+
let array = unsafe {
mlx_sys::mlx_array_new_data(🤖 Prompt for AI Agents
In rust/examples/minimal_lora_training.rs around lines 25 to 38, the Array::new
function needs input validation: ensure shape is non-empty, ensure every
dimension > 0, compute the product of shape dimensions using checked_mul
(converting dims to usize) to detect overflow, and verify that data.len() ==
product; if any check fails return Err with a descriptive message before calling
into the unsafe C API. Implement these checks at the top of new() and only call
mlx_array_new_data when validation passes.
| /// Get shape | ||
| pub fn shape(&self) -> Vec<i32> { | ||
| let ndim = unsafe { mlx_sys::mlx_array_ndim(self.inner) }; | ||
| let shape_ptr = unsafe { mlx_sys::mlx_array_shape(self.inner) }; | ||
| let mut shape = Vec::with_capacity(ndim); | ||
| for i in 0..ndim { | ||
| shape.push(unsafe { *shape_ptr.add(i) }); | ||
| } | ||
| shape | ||
| } |
There was a problem hiding this comment.
Add safety checks for shape pointer access.
The shape method has potential safety issues:
- Line 57:
shape_ptrcould be null but is not checked before dereferencing at Line 60. - Line 60: Pointer arithmetic
shape_ptr.add(i)assumes valid memory without bounds checking.
A null or invalid shape_ptr would cause undefined behavior.
🔎 Proposed safety checks
pub fn shape(&self) -> Vec<i32> {
let ndim = unsafe { mlx_sys::mlx_array_ndim(self.inner) };
let shape_ptr = unsafe { mlx_sys::mlx_array_shape(self.inner) };
+
+ // Validate shape pointer
+ if shape_ptr.is_null() {
+ return Vec::new(); // or panic!/error depending on contract
+ }
+
let mut shape = Vec::with_capacity(ndim);
for i in 0..ndim {
shape.push(unsafe { *shape_ptr.add(i) });🤖 Prompt for AI Agents
In rust/examples/minimal_lora_training.rs around lines 54 to 63, the shape()
method dereferences a raw pointer without safety checks; ensure you first
validate ndim (cast to isize or check >= 0) and convert it to usize, then check
shape_ptr for null before using it; replace the manual pointer.add loop with a
safe conversion like creating a slice from_raw_parts(shape_ptr, ndim_usize)
inside an unsafe block so Rust performs bounds-aware indexing, and if shape_ptr
is null or ndim is invalid return an empty Vec (or propagate an error) instead
of dereferencing a null/invalid pointer.
| /// Apply batched parameter updates (key optimization!) | ||
| pub fn apply_batched_updates(params: &[Array], _updates: &[Array]) -> Result<(), String> { | ||
| println!( | ||
| "🔄 Applying batched updates to {} parameters...", | ||
| params.len() | ||
| ); | ||
| // Simplified: in real implementation, this would use MLX C++ API directly | ||
| // but here we'll just demonstrate the intent | ||
| for param in params { | ||
| param.eval()?; | ||
| } | ||
|
|
||
| println!(" ✓ Batched update completed"); | ||
| Ok(()) | ||
| } |
There was a problem hiding this comment.
Function name misleading: not actually batched.
The function is named apply_batched_updates and the documentation (Line 120) claims it's a "key optimization!", but the implementation sequentially calls eval() on each parameter individually (Lines 128-130). This is not a batched operation.
This contradicts the performance goals stated in the file header (Line 6: "Batched parameter updates") and the documentation in IMPROVE_MLX_BINDINGS.md.
Consider either:
- Renaming to
apply_updates_sequentiallyor similar to reflect actual behavior - Adding a prominent comment explaining this is a placeholder/simulation awaiting true batched API support
- Implementing actual batched operations if the mlx-sys API supports it
🔎 Suggested clarification
-/// Apply batched parameter updates (key optimization!)
-pub fn apply_batched_updates(params: &[Array], _updates: &[Array]) -> Result<(), String> {
+/// Apply parameter updates (simplified - not truly batched)
+///
+/// NOTE: This is a placeholder. True batched updates would require
+/// MLX C API batched operations (e.g., mlx_batched_add) which are
+/// not yet exposed in mlx-sys.
+pub fn apply_updates_sequentially(params: &[Array], _updates: &[Array]) -> Result<(), String> {
println!(
- "🔄 Applying batched updates to {} parameters...",
+ "🔄 Applying updates to {} parameters (sequentially)...",
params.len()
);
- // Simplified: in real implementation, this would use MLX C++ API directly
- // but here we'll just demonstrate the intent
+ // Simplified simulation - real implementation would use batched MLX operations
for param in params {
param.eval()?;
}
- println!(" ✓ Batched update completed");
+ println!(" ✓ Update completed");
Ok(())
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| /// Apply batched parameter updates (key optimization!) | |
| pub fn apply_batched_updates(params: &[Array], _updates: &[Array]) -> Result<(), String> { | |
| println!( | |
| "🔄 Applying batched updates to {} parameters...", | |
| params.len() | |
| ); | |
| // Simplified: in real implementation, this would use MLX C++ API directly | |
| // but here we'll just demonstrate the intent | |
| for param in params { | |
| param.eval()?; | |
| } | |
| println!(" ✓ Batched update completed"); | |
| Ok(()) | |
| } | |
| /// Apply parameter updates (simplified - not truly batched) | |
| /// | |
| /// NOTE: This is a placeholder. True batched updates would require | |
| /// MLX C API batched operations (e.g., mlx_batched_add) which are | |
| /// not yet exposed in mlx-sys. | |
| pub fn apply_updates_sequentially(params: &[Array], _updates: &[Array]) -> Result<(), String> { | |
| println!( | |
| "🔄 Applying updates to {} parameters (sequentially)...", | |
| params.len() | |
| ); | |
| // Simplified simulation - real implementation would use batched MLX operations | |
| for param in params { | |
| param.eval()?; | |
| } | |
| println!(" ✓ Update completed"); | |
| Ok(()) | |
| } |
🤖 Prompt for AI Agents
In rust/examples/minimal_lora_training.rs around lines 120 to 134, the function
name and docstring claim a batched update but the implementation simply loops
calling param.eval() sequentially; rename the function to something like
apply_updates_sequentially or change the docstring to mark it as a placeholder,
or (if mlx-sys supports it) replace the loop with a true batched API call that
submits all parameter updates in one batch and handles combined error reporting;
update the function name/comment consistently wherever it is referenced and add
a TODO pointing to IMPROVE_MLX_BINDINGS.md if keeping the placeholder.
| ## Implementation Priority | ||
|
|
||
| 1. ✅ Update bindings to expose full MLX C API | ||
| 2. ✅ Create safe Rust wrappers for key operations | ||
| 3. ✅ Implement batched parameter updates | ||
| 4. ✅ Enable gradient checkpointing | ||
| 5. ✅ Configure proper memory limits | ||
| 6. ✅ Add monitoring and diagnostics | ||
| 7. 📋 Optimize LoRA-specific operations | ||
| 8. 📋 Benchmark and tune performance |
There was a problem hiding this comment.
Completion status appears overstated.
The implementation priority checklist marks items 1-6 as completed (✅), but the analysis shows:
-
Expose full MLX C API (Line 193) - Many proposed APIs (e.g.,
mlx_batched_add,mlx_set_eval_lazy,mlx_get_gpu_utilization) are not actually available in the current bindings. -
Batched parameter updates (Line 195) - The actual implementation in
minimal_lora_training.rsonly simulates this by callingeval()on each parameter individually (Line 129), not using true batched operations. -
Gradient checkpointing (Line 196) - The actual implementation only prints a message (Line 138-143 in minimal_lora_training.rs) without real checkpointing logic.
Consider updating the status markers to reflect what's actually implemented vs. what's simulated/proposed. Perhaps use a different marker like 🔨 (in progress) or 📝 (documented/planned) for items that are only partially implemented or simulated in the example.
🤖 Prompt for AI Agents
In rust/IMPROVE_MLX_BINDINGS.md around lines 191-200, the checklist incorrectly
marks items 1-6 as completed; update the status markers to reflect reality:
replace the ✅ for "Expose full MLX C API" with 🔨/📝 because APIs like
mlx_batched_add, mlx_set_eval_lazy, mlx_get_gpu_utilization are not present in
the bindings, mark "Batched parameter updates" as 🔨/📝 and reference
minimal_lora_training.rs line 129 where eval() is called per-parameter (not true
batching), and mark "Gradient checkpointing" as 🔨/📝 with a note pointing to
minimal_lora_training.rs lines 138-143 where only a print exists; adjust the
checklist items and add a short footnote describing which items are simulated vs
implemented and links/line references to the example files.
| ``` | ||
| ⚠️ No memory limit specified. Auto-detecting safe limit: 70.0 GB | ||
| (Based on 96.0 GB available system memory) | ||
| 🔒 Set MLX memory limit to 70.0 GB (was 6.3 GB) | ||
| ``` | ||
|
|
There was a problem hiding this comment.
Add language specifier to fenced code block.
The fenced code block at line 129 is missing a language identifier. Add text or bash to the opening fence for consistency with markdown best practices.
🔎 Proposed fix
-```
+```text
⚠️ No memory limit specified. Auto-detecting safe limit: 70.0 GB
(Based on 96.0 GB available system memory)
🔒 Set MLX memory limit to 70.0 GB (was 6.3 GB)</details>
Based on static analysis hints.
<details>
<summary>🧰 Tools</summary>
<details>
<summary>🪛 markdownlint-cli2 (0.18.1)</summary>
129-129: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
</details>
</details>
<details>
<summary>🤖 Prompt for AI Agents</summary>
In rust/MEMORY_PRESSURE_FIXES.md around lines 129 to 134, the fenced code block
is missing a language specifier; edit the opening fence to include a language
(e.g., change totext or bash) so the block becomes text and keep
the block contents unchanged to conform to Markdown best practices.
</details>
<!-- fingerprinting:phantom:poseidon:puma -->
<!-- This is an auto-generated comment by CodeRabbit -->
| pub fn export_command( | ||
| model: &str, | ||
| checkpoint_path: &std::path::PathBuf, | ||
| output_path: &std::path::PathBuf, | ||
| ) -> Result<()> { | ||
| println!("Exporting model: {}", model); | ||
| println!("Checkpoint: {:?}", checkpoint_path); | ||
| println!("Output: {:?}", output_path); | ||
| println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"); | ||
|
|
||
| // Resolve model name | ||
| let model_name = if let Some(preset_config) = AVAILABLE_MODELS.get(model) { | ||
| preset_config | ||
| .get("name") | ||
| .and_then(|v| v.as_str()) | ||
| .unwrap_or(model) | ||
| .to_string() | ||
| } else { | ||
| model.to_string() | ||
| }; | ||
|
|
||
| // Simplified resolution for export (assume downloaded or local) | ||
| let model_path = if std::path::Path::new(&model_name).exists() { | ||
| model_name.clone() | ||
| } else { | ||
| // Try simple HF cache guess | ||
| let cache_name = model_name.replace('/', "--"); | ||
| let home = std::env::var("HOME").unwrap_or_default(); | ||
| let cache_dir = format!("{}/.cache/huggingface/hub/models--{}", home, cache_name); | ||
|
|
||
| let mut found_path = None; | ||
| if std::path::Path::new(&cache_dir).exists() { | ||
| let snapshots_dir = format!("{}/snapshots", cache_dir); | ||
| if let Ok(entries) = std::fs::read_dir(&snapshots_dir) { | ||
| for entry in entries.flatten() { | ||
| // Fix: FileType does not implement Default, use map/unwrap_or | ||
| if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) { | ||
| found_path = Some(entry.path().to_string_lossy().to_string()); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| found_path.ok_or_else(|| { | ||
| anyhow::anyhow!("Model not found: {}. Please use full path.", model_name) | ||
| })? | ||
| }; | ||
|
|
||
| println!("Base model path: {}", model_path); | ||
| // let model_dir = std::path::PathBuf::from(&model_path); | ||
|
|
||
| // 1. Load base weights | ||
| println!("1. Loading base model weights..."); | ||
| let (mut weights, _) = load_model(Path::new(&model_path))?; | ||
| println!(" Loaded {} tensors", weights.len()); | ||
|
|
||
| // 2. Load checkpoint | ||
| println!("2. Loading checkpoint..."); | ||
| let checkpoint_data = std::fs::read_to_string(checkpoint_path)?; | ||
| let checkpoint: Checkpoint = serde_json::from_str(&checkpoint_data)?; | ||
| println!(" Checkpoint step: {}", checkpoint.step); | ||
| println!( | ||
| " Merging {} tensors...", | ||
| checkpoint.model_state.weights.len() | ||
| ); | ||
|
|
||
| // 3. Merge weights | ||
| for (name, (data, shape)) in checkpoint.model_state.weights { | ||
| let array = mlx_rs::Array::from_slice(&data, &shape); | ||
| // Overwrite or insert | ||
| weights.insert(name, array); | ||
| } | ||
| println!(" Merge complete."); | ||
|
|
||
| // 4. Save to output | ||
| println!("3. Saving exported model to {:?}...", output_path); | ||
|
|
||
| // Create output directory if needed | ||
| if let Some(parent) = output_path.parent() { | ||
| std::fs::create_dir_all(parent)?; | ||
| } | ||
|
|
||
| save_model_weights(&weights, output_path)?; | ||
|
|
||
| println!("✓ Export complete!"); | ||
| println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"); | ||
|
|
||
| Ok(()) | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Deduplicate model path resolution.
Lines 960-984 manually implement HuggingFace cache resolution that duplicates logic from rust/src/utils/model_utils.rs::resolve_model_path. This creates a maintenance burden if cache structure changes.
🔎 Refactor to use existing utility
- // Simplified resolution for export (assume downloaded or local)
- let model_path = if std::path::Path::new(&model_name).exists() {
- model_name.clone()
- } else {
- // Try simple HF cache guess
- let cache_name = model_name.replace('/', "--");
- let home = std::env::var("HOME").unwrap_or_default();
- let cache_dir = format!("{}/.cache/huggingface/hub/models--{}", home, cache_name);
-
- let mut found_path = None;
- if std::path::Path::new(&cache_dir).exists() {
- let snapshots_dir = format!("{}/snapshots", cache_dir);
- if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
- for entry in entries.flatten() {
- // Fix: FileType does not implement Default, use map/unwrap_or
- if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) {
- found_path = Some(entry.path().to_string_lossy().to_string());
- break;
- }
- }
- }
- }
- found_path.ok_or_else(|| {
- anyhow::anyhow!("Model not found: {}. Please use full path.", model_name)
- })?
- };
+ let model_path = your_ai_rs::resolve_model_path(&model_name, false)
+ .ok_or_else(|| {
+ anyhow::anyhow!(
+ "Model not found: {}. Please download it first using: huggingface-cli download {}",
+ model_name, model_name
+ )
+ })?;Also applies to: 603-604, 788-789 where similar patterns are used.
| pub fn load_model(path: &Path) -> Result<(HashMap<String, Array>, ModelConfig)> { | ||
| let config_path = path.join("config.json"); | ||
|
|
||
| // Load configuration | ||
| let config_content = std::fs::read_to_string(&config_path) | ||
| .with_context(|| format!("Failed to read config from {}", config_path.display()))?; | ||
| let config: Value = serde_json::from_str(&config_content) | ||
| .with_context(|| format!("Failed to parse config from {}", config_path.display()))?; | ||
|
|
||
| let hidden_size = config["hidden_size"].as_u64().unwrap() as usize; | ||
| let num_hidden_layers = config["num_hidden_layers"].as_u64().unwrap() as usize; | ||
| let num_attention_heads = config["num_attention_heads"].as_u64().unwrap() as usize; | ||
| let num_key_value_heads = config["num_key_value_heads"] | ||
| .as_u64() | ||
| .unwrap_or(config["num_attention_heads"].as_u64().unwrap()) | ||
| as usize; | ||
| let vocab_size = config["vocab_size"].as_u64().unwrap() as usize; | ||
| let intermediate_size = config["intermediate_size"] | ||
| .as_u64() | ||
| .unwrap_or_else(|| config["hidden_size"].as_u64().unwrap() * 4) | ||
| as usize; | ||
|
|
||
| let model_config = ModelConfig { | ||
| hidden_size, | ||
| num_hidden_layers, | ||
| num_attention_heads, | ||
| num_key_value_heads, | ||
| vocab_size, | ||
| intermediate_size, | ||
| }; | ||
|
|
||
| // Load weights from safetensors files | ||
| let weights = load_safetensors_weights(path)?; | ||
|
|
||
| // Try to load from checkpoint file if it exists | ||
| let checkpoint_path = path.join("checkpoint.safetensors"); | ||
| if checkpoint_path.exists() { | ||
| let checkpoint_weights = load_checkpoint_weights(&checkpoint_path)?; | ||
| return Ok((checkpoint_weights, model_config)); | ||
| } | ||
|
|
||
| // Try to create array - if this fails, it will panic/abort | ||
| // We can't catch C++ exceptions, so we validate beforehand | ||
| Ok(Array::from_slice(data, shape)) | ||
| Ok((weights, model_config)) | ||
| } |
There was a problem hiding this comment.
Replace unwrap() calls with proper error handling.
Lines 34-41 use .unwrap() which will panic if the config.json is missing required fields or has wrong types. This breaks the Result<> contract and prevents callers from handling the error gracefully.
🔎 Proposed fix with proper error context
- let hidden_size = config["hidden_size"].as_u64().unwrap() as usize;
- let num_hidden_layers = config["num_hidden_layers"].as_u64().unwrap() as usize;
- let num_attention_heads = config["num_attention_heads"].as_u64().unwrap() as usize;
+ let hidden_size = config["hidden_size"]
+ .as_u64()
+ .ok_or_else(|| anyhow::anyhow!("Missing or invalid 'hidden_size' in config"))?
+ as usize;
+ let num_hidden_layers = config["num_hidden_layers"]
+ .as_u64()
+ .ok_or_else(|| anyhow::anyhow!("Missing or invalid 'num_hidden_layers' in config"))?
+ as usize;
+ let num_attention_heads = config["num_attention_heads"]
+ .as_u64()
+ .ok_or_else(|| anyhow::anyhow!("Missing or invalid 'num_attention_heads' in config"))?
+ as usize;
let num_key_value_heads = config["num_key_value_heads"]
.as_u64()
.unwrap_or(config["num_attention_heads"].as_u64().unwrap())
as usize;Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
rust/src/model/loader.rs lines 25-67: several calls to .unwrap() when extracting
numeric fields from the parsed JSON (hidden_size, num_hidden_layers,
num_attention_heads, num_key_value_heads, vocab_size, intermediate_size) will
panic on missing or wrong-typed fields; replace them with fallible extraction
that returns a descriptive error (e.g., use config.get("field").and_then(|v|
v.as_u64()).ok_or_else(|| anyhow::anyhow!("Missing or invalid 'field' in {}",
config_path.display()))? or call v.as_u64().with_context(|| format!("Invalid
type for 'field' in {}", config_path.display()))?), or even better deserialize
config_content into a strongly typed ModelConfig via serde and propagate serde
errors with context; ensure each fallback (like num_key_value_heads defaulting
to num_attention_heads and intermediate_size defaulting to hidden_size * 4) is
implemented without unwraps and returns a Result with contextual error messages.
rust/src/model/loader.rs
Outdated
| pub fn is_quantized_model(weights: &HashMap<String, Array>) -> bool { | ||
| // In MLX, quantized tensors are handled automatically | ||
| // We can check for specific patterns or metadata | ||
| weights.values().any(|tensor| { | ||
| // Check if tensor has quantized metadata or special properties | ||
| tensor.shape().iter().map(|&x| x as usize).sum::<usize>() > 1_000_000 // Heuristic for large tensors | ||
| }) | ||
| } |
There was a problem hiding this comment.
Quantization detection heuristic is unreliable.
The current implementation uses tensor size as a heuristic (sum of dimensions > 1M). This will:
- Misidentify large non-quantized models as quantized
- Miss small quantized models
- Not work for mixed-precision models
Consider checking for quantization-specific metadata in the safetensors file or inspecting the actual dtype/encoding. If quantization detection isn't critical, document this as a placeholder implementation.
🤖 Prompt for AI Agents
In rust/src/model/loader.rs around lines 115 to 122, the current
is_quantized_model function uses tensor size as a heuristic which is unreliable;
replace the size-based check with a real quantization detection strategy:
attempt to read quantization-specific metadata from the safetensors file (e.g.,
keys like "quantized", "quantization_scheme", per-tensor "scale" or
"zero_point") and/or inspect each tensor's dtype/encoding for int8/4bit types or
non-floating encodings, returning true if any tensor or metadata indicates
quantization; if reading such metadata/dtype is not feasible in this scope, mark
the function explicitly as a placeholder, add a clear TODO and a warning log,
and avoid relying on tensor size as the detection method.
rust/src/training/trainer.rs
Outdated
| pub async fn create_checkpoint(&self, _step: usize) -> Result<Checkpoint> { | ||
| use crate::checkpoints::mlx_utils::to_flat; | ||
|
|
||
| // #region agent log | ||
| self.log_debug( | ||
| "trainer.rs:step_end", | ||
| "Step complete", | ||
| self.global_step, | ||
| "end", | ||
| // Create checkpoint with current state | ||
| let checkpoint = Checkpoint::new( | ||
| self.step_count, | ||
| ModelState { | ||
| weights: self | ||
| .model_weights | ||
| .iter() | ||
| .map(|(k, v)| { | ||
| let (data, shape) = to_flat(v); | ||
| (k.clone(), (data, shape)) | ||
| }) | ||
| .collect(), | ||
| }, | ||
| self.optimizer_state.clone(), | ||
| self.loss_history.clone(), | ||
| Default::default(), | ||
| ); | ||
| // #endregion agent log | ||
|
|
||
| Ok(loss_val) | ||
| Ok(checkpoint) | ||
| } |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Pass actual TrainingConfig instead of Default.
Line 275 uses Default::default() for the TrainingConfig, which discards any training configuration. The trainer should maintain a reference to the actual TrainingConfig and use it here to preserve training state accurately.
+ // Add field to DistrustTrainer struct:
+ training_config: TrainingConfig,
+
// In create_checkpoint:
let checkpoint = Checkpoint::new(
self.step_count,
ModelState { ... },
self.optimizer_state.clone(),
self.loss_history.clone(),
- Default::default(),
+ self.training_config.clone(),
);Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In rust/src/training/trainer.rs around lines 257 to 279, the checkpoint is being
created with Default::default() for the TrainingConfig which drops the real
training configuration; replace that placeholder with the actual training
configuration from the trainer (e.g., self.training_config.clone() or
Arc::clone(&self.training_config)) so the checkpoint captures the real config.
If the Trainer struct does not currently hold a TrainingConfig, add a field (or
Arc/Box) for it and ensure TrainingConfig is Clone or stored behind Arc so you
can pass/clone it into Checkpoint when creating it.
There was a problem hiding this comment.
Actionable comments posted: 5
♻️ Duplicate comments (2)
rust/src/training/trainer.rs (1)
518-534:/dev/nullfallback is non-portable and duplicates manual logging.This debug logging block (and similar ones at lines 828-864, 867-924) falls back to
/dev/nullwhendebug_log_path()returnsNone. This is non-portable (Windows lacks/dev/null) and duplicates thelog_debug()helper you've already defined.As noted in a past review, replace these blocks with a guard pattern or route through
self.log_debug():Suggested pattern
- if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) - { + if let Some(log_path) = debug_log_path() { + if let Ok(mut file) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_path) + { // ... logging ... + } }rust/src/cli/commands.rs (1)
960-995: Refactor to use existing model resolution utility.As noted in previous review, lines 960-995 duplicate the model path resolution logic from
rust/src/utils/model_utils.rs::resolve_model_path. This creates a maintenance burden.The suggested refactor from the previous review should be applied here and at lines 605-606 and 799-800 where similar patterns appear.
🧹 Nitpick comments (10)
rust/src/training/trainer.rs (2)
1223-1224: Redundanteval()calls on momentum arrays.
m_new.eval()andv_new.eval()are called twice—once at lines 1223-1224 and again at lines 1237-1238. The second pair is unnecessary since the arrays are already evaluated.Proposed fix
// Save updated momentum with explicit old Array cleanup let _ = m_new.eval(); let _ = v_new.eval(); // Explicitly drop old momentum Arrays if let Some(old_m) = self.adam_m_gpu.remove(¶m_name_str) { drop(old_m); } if let Some(old_v) = self.adam_v_gpu.remove(¶m_name_str) { drop(old_v); } - // Force MLX to free dropped Arrays - // First synchronize all GPU operations to ensure completion - // Call eval() on the new momentum arrays to force synchronization - let _ = m_new.eval(); - let _ = v_new.eval(); - mlx_rs::transforms::compile::clear_cache(); let _ = crate::utils::mlx_memory::clear_cache();Also applies to: 1237-1238
1321-1330: Consider loading optimizer state from checkpoint for consistency.The comments explain the current design uses
self.adam_m(populated by the precedingsave_checkpoint), but this creates a subtle coupling: reload only works correctly if save was just called. For robustness, consider optionally loading fromcheckpoint.optimizer_statewhenself.adam_mis empty.This is acceptable for the current "save → reload" pattern, but worth noting for future maintainability.
rust/tests/checkpoint_tests.rs (1)
13-13: Unused import.
mlx_rs::ops::fullis imported but never used in the tests.- use mlx_rs::ops::full;rust/src/cli/mod.rs (1)
204-206: Inconsistent blank line in match arm.Line 205 has an unexpected blank line between
compare,andeos_token,that breaks the visual consistency of the destructuring pattern.Commands::Generate { model, prompt, checkpoint, max_tokens, temperature, - compare, eos_token, } => commands::generate(rust/src/checkpoints/manager.rs (4)
191-193: Unsafe byte slice conversion assumes correct alignment.The
from_raw_partscall assumesdata_f32is properly aligned for byte access. While this is generally safe forf32tou8, consider usingbytemuckor similar for safer transmutation.Safer alternative using bytemuck
+use bytemuck; + // MLX arrays in this project are typically F32 let data_f32 = array.as_slice::<f32>(); - let data_bytes = unsafe { - std::slice::from_raw_parts(data_f32.as_ptr() as *const u8, data_f32.len() * 4) - }; + let data_bytes: &[u8] = bytemuck::cast_slice(data_f32);This requires adding
bytemuckas a dependency but eliminates the unsafe block.
267-272: Unsafe f32 slice conversion from raw tensor data.Similar to the save path, this uses unsafe pointer casting. The slice length calculation (
len() / 4) could also produce incorrect results if the data isn't properly aligned or sized.Safer alternative
- let tensor_array = mlx_rs::Array::from_slice( - unsafe { - std::slice::from_raw_parts( - tensor.data().as_ptr() as *const f32, - tensor.data().len() / 4, - ) - }, - &shape, - ); + let data_bytes = tensor.data(); + // Ensure data length is divisible by 4 for f32 + if data_bytes.len() % 4 != 0 { + anyhow::bail!("Tensor {} has invalid byte length for f32", name); + } + let data_f32: &[f32] = bytemuck::cast_slice(data_bytes); + let tensor_array = mlx_rs::Array::from_slice(data_f32, &shape);
137-145: Redundant derive and attribute for Default.
OptimizerStatehas both#[serde(default)]and#[derive(Default)]. The#[serde(default)]on the struct level makes serde useDefault::default()for missing fields, which requiresDefaultto be implemented. Having both is correct but#[derive(Default)]alone would suffice if you only need the implementation, not serde's field-level defaults.This is technically fine as-is; just noting the redundancy.
35-51: Async methods wrap synchronous I/O.
save()andload()are markedasyncbut call synchronous helpers (save_safetensors_with_metadata,load_safetensors_with_metadata). For large checkpoints, this could block the async runtime.Consider using
tokio::task::spawn_blockingfor the I/O operations if checkpoint files become large.This is acceptable for MVP but worth noting for future scalability.
Also applies to: 53-66
rust/src/cli/commands.rs (1)
118-119: Consider simplifying the closure pattern.The closure wrapping
your_ai_rs::resolve_model_pathis unnecessary. You can call the function directly:let model_path = your_ai_rs::resolve_model_path(model_name, true) .ok_or_else(|| anyhow::anyhow!("Model not found: {}", model_name))?;This pattern also appears at lines 605-606 and 799-800.
rust/src/training/trainer.rs.old (1)
168-176: LoRA module name parsing could be more robust.Line 174 extracts the last component using
split('.').next_back(), assuming module names like"self_attn.q_proj". If the configuration format changes or uses different delimiters, this will silently fail.Consider adding validation or documentation of the expected format:
.map(|m| { let component = m.split('.').next_back().unwrap_or(m); if component == m { eprintln!("Warning: LoRA target '{}' has no '.' delimiter", m); } component.to_string() })
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
rust/src/benchmarks/optimizer.rs(10 hunks)rust/src/checkpoints/manager.rs(2 hunks)rust/src/cli/commands.rs(21 hunks)rust/src/cli/mod.rs(4 hunks)rust/src/training/trainer.rs(19 hunks)rust/src/training/trainer.rs.old(1 hunks)rust/src/training/trainer.rs.skeleton(1 hunks)rust/tests/checkpoint_tests.rs(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- rust/src/benchmarks/optimizer.rs
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/train_*.py : Use Apple's MLX framework for training on Apple Silicon, with PyTorch as a fallback for other platforms
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (2)
rust/tests/checkpoint_tests.rs (3)
rust/src/checkpoints/manager.rs (2)
default(165-171)new(20-33)rust/src/checkpoints/state.rs (1)
new(27-47)rust/src/utils/mlx_memory.rs (1)
clear_cache(83-89)
rust/src/checkpoints/manager.rs (2)
rust/src/checkpoints/state.rs (1)
new(27-47)rust/src/checkpoints/mlx_utils.rs (3)
from_flat(17-20)array(11-11)to_flat(10-14)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (16)
rust/src/training/trainer.rs (1)
1-53: Well-structured GPU/CPU split for optimizer state.The dual storage approach (
adam_m_gpu/adam_v_gpufor training,adam_m/adam_vfor checkpointing) is a reasonable design for zero-leak GPU training with periodic CPU serialization.rust/tests/checkpoint_tests.rs (2)
37-65: Good round-trip test coverage.The test properly verifies save/load functionality, file existence, step and weight count, and loss history preservation.
91-124: Thorough cleanup verification.The test correctly validates that older checkpoints are removed and only the latest
max_checkpointsare retained, including disk verification.rust/src/cli/mod.rs (2)
89-97: CLI extensions for training parameters and export are well-structured.The new CLI arguments (
reload_interval,alpha,lambda_weight,eos_token) and theExportcommand variant follow existing patterns and provide good configurability.Also applies to: 128-131, 132-143
146-221: No changes needed. Bothvalidateandgenerateare synchronous functions and correctly do not use.await.rust/src/training/trainer.rs.skeleton (3)
107-125: Skeleton train loop uses hardcoded dummy data.The
train()method useszeros::<f32>(&[1, 128])for both inputs and targets, which won't produce meaningful training. This is acceptable for a skeleton but should be clearly documented or removed before production use.The skeleton serves its purpose as API documentation and a reference implementation template.
155-175: Placeholder implementations are clearly marked.
forward,compute_loss, andbackwardare stub implementations that return dummy values. This is appropriate for a skeleton file.
54-54: No issues found. Thecrate::checkpoints::mlx_utilsmodule exists and properly exports bothto_flatandfrom_flatfunctions via the public API incheckpoints/mod.rs. The references in trainer.rs.skeleton are correct.Likely an incorrect or invalid review comment.
rust/src/checkpoints/manager.rs (1)
176-235: Clean safetensors integration with embedded metadata.The approach of storing
_metadataand_optimizeras U8 tensors containing JSON is clever and keeps everything in a single file. The implementation handles the round-trip correctly.Also applies to: 238-313
rust/src/cli/commands.rs (3)
839-876: LGTM! Checkpoint loading in comparison mode is well-implemented.The implementation correctly:
- Loads base model weights
- Creates base model with
load_model_with_weights- Clones base weights for finetuned model
- Merges checkpoint tensors into the weights HashMap
- Creates finetuned model with merged weights
This ensures both models are properly initialized for comparison.
1000-1031: Export implementation is correct.The checkpoint export logic properly:
- Loads base model weights
- Deserializes the checkpoint
- Merges checkpoint tensors (overwrites base weights)
- Saves the merged model to safetensors format
This ensures the exported model contains both the frozen backbone and the finetuned head.
705-707: The code at lines 705-707 is correct and requires no changes. TheDistrustTrainer::new()is properly defined as an async constructor that accepts&Path, and the builder pattern withwith_config()is correctly implemented. The usage matches the actual implementation.Likely an incorrect or invalid review comment.
rust/src/training/trainer.rs.old (4)
20-22: Well-designed optimizer state separation.The type aliases clearly distinguish CPU storage (for checkpointing) from GPU storage (for training), which is key to the zero-leak architecture.
383-396: Excellent memory-aware step limiting.The combination of
calculate_safe_max_stepsand enforcement logic provides strong protection against OOM crashes. The detailed warning messages (lines 490-510) help users understand why training was limited and what to do about it.Aborting when less than 50% of requested steps are safe (lines 513-521) is a good UX choice - better to fail early with a clear message than crash mysteriously later.
Also applies to: 483-523
1191-1310: GPU-only optimizer is the key to zero-leak training.This implementation is well-designed:
- Keeps all momentum on GPU (never extracts via
as_slice()during training)- Minimizes intermediate Arrays to reduce memory pressure
- Explicit cleanup with
drop()and cache clearing- CPU extraction only happens during checkpoint saves (line 1313)
The detailed comments explain the rationale and the "ZERO-LEAK GUARANTEE" comment at line 1306 clearly documents the design goal.
1335-1425: Model reload provides workaround for framework memory accumulation.The
reload_from_checkpointimplementation is thorough:
- Drops current model explicitly (lines 1351-1354)
- Forces MLX memory release (lines 1361-1362)
- Reloads base weights and merges checkpoint (lines 1369-1380)
- Restores optimizer momentum to GPU (lines 1401-1411)
- Resets baseline for leak detection (line 1416)
This workaround enables unlimited training despite framework limitations. The step-by-step comments make the process clear.
| for step in 0..=25 { | ||
| let mut checkpoint = create_test_checkpoint(step); | ||
| checkpoint.loss_history = vec![2.0f32 - (step as f32 * 0.05)]; | ||
|
|
||
| manager.save(&checkpoint).await?; | ||
|
|
||
| // Simulate reload at step 10 | ||
| if step == 10 { | ||
| let loaded = manager.load(10).await?; | ||
| assert_eq!(loaded.step, 10); | ||
| } |
There was a problem hiding this comment.
Test will fail: checkpoint at step 10 is deleted before load attempt.
With max_checkpoints: 3 and cleanup running after each save(), by the time the loop reaches step 25, checkpoint-10 will have been deleted. The manager.load(10) at line 167 will fail.
Either increase max_checkpoints to retain step 10, or load a checkpoint that's guaranteed to exist (e.g., one of the last 3).
Proposed fix
#[tokio::test]
async fn test_checkpoint_reload_integration() -> Result<()> {
let temp_dir = TempDir::new()?;
let checkpoint_dir = temp_dir.path().join("checkpoints");
- let manager = CheckpointManager::new(&checkpoint_dir, 3)?;
+ // Need enough checkpoints to retain step 10 through step 25
+ let manager = CheckpointManager::new(&checkpoint_dir, 26)?;
// Simulate training with periodic reloads
for step in 0..=25 {Or alternatively, load a recent checkpoint instead of step 10.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In rust/tests/checkpoint_tests.rs around lines 159-169 the test attempts to load
checkpoint 10 during a loop that saves 0..=25 with manager configured
max_checkpoints: 3 so checkpoint 10 will be deleted by cleanup and the load will
fail; fix by either increasing the manager's max_checkpoints so step 10 is
retained for the duration of the loop or change the load call to request a
checkpoint guaranteed to exist (e.g., load the most recent step or one of the
last 3 saved steps such as step = step.saturating_sub(0) or explicitly load step
25 or load manager.load(step) when step >= 22), and update the assertion
accordingly.
There was a problem hiding this comment.
Actionable comments posted: 6
♻️ Duplicate comments (5)
rust/src/checkpoints/state.rs (1)
6-6: Circular dependency noted.This import creates a circular dependency with
manager.rs. A previous review already flagged this—consider extracting shared types to atypes.rsmodule in a follow-up.rust/src/model/loader.rs (2)
34-45: Replaceunwrap()with proper error handling.Multiple
.unwrap()calls will panic on missing/invalid config fields, breaking theResult<>contract.
121-128: Quantization detection heuristic is unreliable.Already flagged in a previous review—the size-based heuristic doesn't reliably detect quantization.
rust/src/training/trainer.rs (1)
523-540: [DUPLICATE] Hardcoded/dev/nullfallback and manual logging duplication—usedebug_log_path()guard andlog_debug()helper.These blocks hardcode
/dev/null(non-portable) and duplicate manual JSON file writing instead of using thelog_debug()helper method defined at lines 1089-1137. This was flagged in a previous review.Replace each manual block with an optional guard (
if let Some(log_path) = debug_log_path()) that only runs when a path exists, or route throughself.log_debug()to centralize all debug logging.🔎 Refactor example for lines 523-540
- // #region agent log - after training_step - if let Ok(mut file) = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(debug_log_path().unwrap_or_else(|| PathBuf::from("/dev/null"))) - { - let json = serde_json::json!({ - "location": "trainer.rs:after_training_step", - "message": "training_step returned successfully", - "step": self.global_step, - "loss": loss, - "phase": "main_loop", - "timestamp": std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).map(|d| d.as_millis()).unwrap_or(0), - "hypothesisId": "D-training-step" - }); - let _ = writeln!(file, "{}", json); - } - // #endregion agent log + // Log after training_step (uses centralized helper) + if debug_log_path().is_some() { + self.log_debug( + "trainer.rs:after_training_step", + "training_step returned successfully", + self.global_step, + "main_loop" + ); + }Apply similar pattern to all other manual logging blocks (lines 834-870, 873-930).
Also applies to: 834-870, 873-930
rust/src/cli/commands.rs (1)
976-1000: [DUPLICATE] Replace manual HF cache resolution withresolve_model_path()utility.Lines 976-1000 manually construct HuggingFace cache paths and traverse snapshot directories, duplicating the logic in
your_ai_rs::resolve_model_path()(defined inrust/src/utils/model_utils.rslines 24-71). This creates a maintenance burden if the cache structure changes.Use the existing utility as shown in lines 118-122 for the benchmark command.
🔎 Refactor to use existing utility
- // Simplified resolution for export (assume downloaded or local) - let model_path = if std::path::Path::new(&model_name).exists() { - model_name.clone() - } else { - // Try simple HF cache guess - let cache_name = model_name.replace('/', "--"); - let home = std::env::var("HOME").unwrap_or_default(); - let cache_dir = format!("{}/.cache/huggingface/hub/models--{}", home, cache_name); - - let mut found_path = None; - if std::path::Path::new(&cache_dir).exists() { - let snapshots_dir = format!("{}/snapshots", cache_dir); - if let Ok(entries) = std::fs::read_dir(&snapshots_dir) { - for entry in entries.flatten() { - // Fix: FileType does not implement Default, use map/unwrap_or - if entry.file_type().map(|ft| ft.is_dir()).unwrap_or(false) { - found_path = Some(entry.path().to_string_lossy().to_string()); - break; - } - } - } - } - found_path.ok_or_else(|| { - anyhow::anyhow!("Model not found: {}. Please use full path.", model_name) - })? - }; + let model_path = your_ai_rs::resolve_model_path(&model_name, false) + .ok_or_else(|| { + anyhow::anyhow!( + "Model not found: {}. Please download it first using: huggingface-cli download {}", + model_name, model_name + ) + })?;
🧹 Nitpick comments (6)
rust/src/checkpoints/manager.rs (3)
38-53: Async methods perform only synchronous I/O.
save,load,load_weights_only, andcleanupare markedasyncbut only call synchronousstd::fsoperations. This blocks the async executor thread. Either usetokio::fsfor true async I/O or remove theasynckeyword if blocking is acceptable for your use case.🔎 Example: true async with tokio::fs
+use tokio::fs; + pub async fn save(&self, checkpoint: &Checkpoint) -> Result<()> { // ... - self.cleanup().await?; + self.cleanup()?; // or make cleanup truly async too Ok(()) }Or if blocking is acceptable, simply make these sync functions.
Also applies to: 56-68, 117-142
206-212: Unsafe byte slice conversion is correct but fragile.The unsafe block correctly converts
&[f32]to&[u8]. The comment notes the lifetime is bound to the checkpoint. Consider usingbytemuck::cast_slicefor a safer, more explicit conversion if adding dependencies is acceptable.
150-150: Re-export placement.The
pub usestatement in the middle of the file is unusual. Consider moving re-exports to the top of the file or tomod.rsfor better discoverability.rust/tests/checkpoint_test.rs (1)
6-6: Unused import:PathBuf.
PathBufis imported but never used in the test. Remove it to avoid the warning.🔎 Proposed fix
- use std::path::PathBuf;rust/src/cli/commands.rs (2)
603-606: Optional: Remove unnecessary closure wrapper aroundresolve_model_path.Lines 603-606 and 804-806 wrap
your_ai_rs::resolve_model_path()in a closure that just forwards arguments. This indirection is unnecessary—call the function directly at the usage sites.🔎 Simplified call pattern
- let resolve_model_path = - |model_name: &str| -> Option<String> { your_ai_rs::resolve_model_path(model_name, false) }; - - let model_path = resolve_model_path(&model_name).ok_or_else(|| { + let model_path = your_ai_rs::resolve_model_path(&model_name, false).ok_or_else(|| { anyhow::anyhow!( "Model not found: {}. Please download it first using: huggingface-cli download {}", model_name, model_name ) })?;Also applies to: 804-806
869-878: Optional: Extract checkpoint weight merging to a helper function.The pattern of loading a checkpoint and merging its weights into a
HashMap<String, Array>is repeated at lines 869-878 and 902-920. Consider extracting this to a helper function likemerge_checkpoint_weights(base_weights, checkpoint_path)to reduce duplication.🔎 Example helper function
fn merge_checkpoint_weights( mut base_weights: HashMap<String, Array>, checkpoint_path: &Path, ) -> Result<HashMap<String, Array>> { let checkpoint_data = std::fs::read_to_string(checkpoint_path)?; let checkpoint: Checkpoint = serde_json::from_str(&checkpoint_data)?; for (name, (data, shape)) in checkpoint.model_state.weights { let array = mlx_rs::Array::from_slice(&data, &shape); base_weights.insert(name, array); } Ok(base_weights) }Also applies to: 902-920
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
rust/Cargo.toml(2 hunks)rust/src/checkpoints/manager.rs(2 hunks)rust/src/checkpoints/state.rs(2 hunks)rust/src/cli/commands.rs(21 hunks)rust/src/model/loader.rs(1 hunks)rust/src/training/trainer.rs(19 hunks)rust/tests/checkpoint_test.rs(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/train_*.py : Use Apple's MLX framework for training on Apple Silicon, with PyTorch as a fallback for other platforms
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Applies to src/distrust_loss.py : Implement the core `empirical_distrust_loss` function that penalizes high-authority, low-entropy sources and rewards primary empirical sources
Applied to files:
rust/src/training/trainer.rs
🧬 Code graph analysis (6)
rust/tests/checkpoint_test.rs (2)
rust/src/checkpoints/manager.rs (2)
std(318-318)new(22-35)rust/src/checkpoints/state.rs (1)
new(28-48)
rust/src/training/trainer.rs (3)
rust/src/model/loader.rs (2)
load_model(25-67)array(226-226)rust/src/training/lora.rs (6)
mlx_rs(50-50)mlx_rs(57-57)mlx_rs(104-104)mlx_rs(107-107)new(92-115)apply_lora_to_model(26-78)rust/src/utils/mlx_memory.rs (6)
get_active_memory(44-51)get_peak_memory(54-61)get_cache_memory(64-71)clear_cache(83-89)array(113-113)stop_gradient(105-119)
rust/src/checkpoints/state.rs (1)
rust/src/checkpoints/manager.rs (1)
new(22-35)
rust/src/model/loader.rs (6)
rust/src/training/trainer.rs (2)
mlx_rs(1469-1469)mlx_rs(1475-1475)rust/tests/training_tests.rs (1)
mlx_rs(44-44)rust/src/distrust_loss.rs (1)
mlx_rs(115-115)rust/src/training/lora.rs (4)
mlx_rs(50-50)mlx_rs(57-57)mlx_rs(104-104)mlx_rs(107-107)rust/src/cli/commands.rs (2)
serde_json(355-355)s(239-239)rust/src/config/model.rs (1)
default(20-32)
rust/src/cli/commands.rs (4)
rust/src/checkpoints/manager.rs (2)
std(318-318)new(22-35)rust/src/model/loader.rs (2)
load_model(25-67)array(226-226)rust/src/utils/model_utils.rs (1)
resolve_model_path(25-72)rust/src/checkpoints/mlx_utils.rs (1)
array(11-11)
rust/src/checkpoints/manager.rs (1)
rust/src/checkpoints/state.rs (1)
new(28-48)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Build and Test
🔇 Additional comments (10)
rust/Cargo.toml (3)
3-3: Version bump aligns with breaking changes.The 0.1.0 → 0.2.0 bump correctly signals the breaking checkpoint format changes introduced in this PR (typed
ModelState/OptimizerStatereplacingHashMap-based structures).
38-38: LGTM: memmap2 addition.Supports the memory-mapped I/O usage in
manager.rsandloader.rsfor efficient large-file handling.
33-33: No action required. Tokenizers 0.21 has no listed breaking changes and the codebase uses only stable core APIs (encode, decode, encode_batch, get_ids, from_file) that remain compatible.rust/src/checkpoints/state.rs (1)
50-55: Validation logic is correct.The
validatemethod correctly checks whether the model state contains any weights.rust/src/checkpoints/manager.rs (1)
316-328: Good: Proper handling of potentially unaligned data.The code correctly checks alignment before using
from_raw_partsand falls back to a safe byte-by-byte copy for unaligned data. This prevents undefined behavior on platforms with strict alignment requirements.rust/tests/checkpoint_test.rs (1)
45-69: Test logic is sound.The test correctly exercises the checkpoint save/load round-trip and verifies step, weights count, optimizer state, and specific data values.
rust/src/model/loader.rs (1)
70-118: Memory safety: mmap lifetime is correctly scoped.The comment at lines 111-113 correctly notes that
Array::from_slicecopies the data, making it safe to drop the mmap after processing each file. The implementation is sound.rust/src/training/trainer.rs (2)
1140-1261: Excellent zero-leak GPU optimizer implementation.The GPU-only AdamW update path correctly avoids CPU extraction during training (saving
as_slice()calls for checkpointing only), uses minimal intermediate Arrays, and includes explicit cleanup withclear_cache()calls. The detailed comments explain the rationale clearly.
329-473: Strong memory management and checkpoint reload implementation.The memory-aware step calculation (lines 329-346) and enforcement logic (lines 428-473) provide clear user feedback and prevent OOM crashes. The
reload_from_checkpoint_stepmethod (lines 1282-1351) correctly restores both model weights and optimizer momentum state, addressing the prior review concern about training continuity.Also applies to: 1282-1351
rust/src/cli/commands.rs (1)
816-822: Clean EOS token override integration.The
eos_tokenparameter is properly applied toLlamaConfigbefore generation, with clear user feedback. This aligns well with the configurable EOS token support mentioned in the PR objectives.
| pub struct OptimizerState { | ||
| pub param_groups: Vec<ParamGroup>, | ||
| #[serde(skip_serializing)] | ||
| pub exp_avg: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>, | ||
| #[serde(skip_serializing)] | ||
| pub exp_avg_sq: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>, | ||
| pub step: usize, | ||
| } |
There was a problem hiding this comment.
OptimizerState fields need default for deserialization.
Similar to ModelState.weights, the exp_avg and exp_avg_sq fields have #[serde(skip_serializing)] but will fail deserialization without a default value.
🔎 Proposed fix
pub struct OptimizerState {
pub param_groups: Vec<ParamGroup>,
- #[serde(skip_serializing)]
+ #[serde(skip, default)]
pub exp_avg: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>,
- #[serde(skip_serializing)]
+ #[serde(skip, default)]
pub exp_avg_sq: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>,
pub step: usize,
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| pub struct OptimizerState { | |
| pub param_groups: Vec<ParamGroup>, | |
| #[serde(skip_serializing)] | |
| pub exp_avg: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>, | |
| #[serde(skip_serializing)] | |
| pub exp_avg_sq: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>, | |
| pub step: usize, | |
| } | |
| pub struct OptimizerState { | |
| pub param_groups: Vec<ParamGroup>, | |
| #[serde(skip, default)] | |
| pub exp_avg: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>, | |
| #[serde(skip, default)] | |
| pub exp_avg_sq: std::collections::HashMap<String, (Vec<f32>, Vec<i32>)>, | |
| pub step: usize, | |
| } |
🤖 Prompt for AI Agents
In rust/src/checkpoints/manager.rs around lines 156 to 163, the OptimizerState
struct's exp_avg and exp_avg_sq fields are marked #[serde(skip_serializing)] but
lack a serde default, causing deserialization to fail; add #[serde(default,
skip_serializing)] to both exp_avg and exp_avg_sq so serde will populate them
with Default::default (empty HashMap) during deserialization, leaving the rest
of the struct unchanged.
| #[derive(Debug, Clone, Serialize, Deserialize)] | ||
| pub struct ModelState { | ||
| #[serde(skip_serializing)] | ||
| pub weights: Vec<(String, (Vec<f32>, Vec<i32>))>, | ||
| } |
There was a problem hiding this comment.
Serde attribute inconsistency: missing skip_deserializing.
weights has #[serde(skip_serializing)] but lacks skip_deserializing. During deserialization, serde will expect a weights field in JSON but won't find it (since it was skipped on save), causing a deserialization error.
Since weights are stored as safetensors tensors and reconstructed manually in load_safetensors_with_metadata, you should use #[serde(skip)] (which applies to both) or #[serde(skip_serializing, default)] to provide a default on deserialize.
🔎 Proposed fix
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelState {
- #[serde(skip_serializing)]
+ #[serde(skip, default)]
pub weights: Vec<(String, (Vec<f32>, Vec<i32>))>,
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #[derive(Debug, Clone, Serialize, Deserialize)] | |
| pub struct ModelState { | |
| #[serde(skip_serializing)] | |
| pub weights: Vec<(String, (Vec<f32>, Vec<i32>))>, | |
| } | |
| #[derive(Debug, Clone, Serialize, Deserialize)] | |
| pub struct ModelState { | |
| #[serde(skip, default)] | |
| pub weights: Vec<(String, (Vec<f32>, Vec<i32>))>, | |
| } |
🤖 Prompt for AI Agents
In rust/src/checkpoints/state.rs around lines 21 to 25, the ModelState struct
marks weights with #[serde(skip_serializing)] only, causing serde to expect
weights on deserialize and fail; change the attribute to either #[serde(skip)]
to skip both serialize and deserialize, or to #[serde(skip_serializing,
default)] (and ensure Default for the field) so deserialization succeeds and
weights can be reconstructed by load_safetensors_with_metadata.
rust/src/model/loader.rs
Outdated
| // Create LoRA A and B matrices | ||
| let lora_rank = lora_config.lora_rank; | ||
|
|
||
| fn load_single_safetensors(&self, path: &Path) -> anyhow::Result<HashMap<String, Array>> { | ||
| let data = std::fs::read(path)?; | ||
| let tensors = SafeTensors::deserialize(&data)?; | ||
|
|
||
| let mut weights = HashMap::new(); | ||
|
|
||
| for (name, tensor) in tensors.tensors() { | ||
| // Convert safetensors tensor to MLX array with proper dtype handling | ||
| let shape: Vec<usize> = tensor.shape().to_vec(); | ||
| let shape_i32: Vec<i32> = shape.iter().map(|&s| s as i32).collect(); | ||
| let raw_data = tensor.data(); | ||
|
|
||
| // Estimate memory required for this tensor | ||
| let dtype = tensor.dtype(); | ||
| let total_elements: usize = shape.iter().product(); | ||
| let element_bytes = match dtype { | ||
| safetensors::Dtype::F32 => 4, | ||
| safetensors::Dtype::F16 | safetensors::Dtype::BF16 => 2, | ||
| safetensors::Dtype::I64 => 8, | ||
| _ => 4, | ||
| }; | ||
| let estimated_mb = (total_elements * element_bytes) / (1024 * 1024); | ||
|
|
||
| if estimated_mb > 1000 { | ||
| eprintln!( | ||
| "Warning: Large tensor '{}' ({} MB) - may cause OOM", | ||
| name, estimated_mb | ||
| ); | ||
| } | ||
| // For quantized models, we need to handle the dequantization | ||
| let val_0 = Array::from_slice(&[0.0f32], &[]); | ||
| let lora_a = full::<f32>(&[in_features, lora_rank as i32], &val_0)?; | ||
| let lora_b = zeros::<f32>(&[lora_rank as i32, in_features])?; |
There was a problem hiding this comment.
LoRA A matrix should be initialized with small random values, not zeros.
Initializing both lora_A and lora_B to zeros means the LoRA adaptation will have no effect initially (A×B = 0). Standard practice is to initialize A with small random values (e.g., Kaiming/He init) and B with zeros so the initial LoRA contribution is zero but gradients can flow through A.
🔎 Proposed fix using random initialization for lora_A
+use mlx_rs::random::uniform;
+
// Create LoRA A and B matrices
let lora_rank = lora_config.lora_rank;
-let val_0 = Array::from_slice(&[0.0f32], &[]);
-let lora_a = full::<f32>(&[in_features, lora_rank as i32], &val_0)?;
+// Initialize A with small random values (He init scale)
+let scale = (2.0 / in_features as f32).sqrt();
+let lora_a = uniform::<_, f32>(-scale, scale, &[in_features, lora_rank as i32], None)?;
let lora_b = zeros::<f32>(&[lora_rank as i32, in_features])?;🤖 Prompt for AI Agents
In rust/src/model/loader.rs around lines 149-155, lora_a is being initialized to
all zeros which prevents any initial LoRA effect; change lora_a to be
initialized with small random values (e.g., draw f32 samples from a normal
distribution with mean 0 and std = init_scale where init_scale = 1.0f32 /
(in_features as f32).sqrt() or use Kaiming-style scaling), keep lora_b as zeros,
and ensure you use the project's RNG (or rand crate) to fill an Array<f32> of
shape [in_features, lora_rank] with those samples while preserving error
handling and dtype; this provides nonzero trainable parameters for A while B
remains zero-initialized so the initial LoRA contribution is 0 but gradients
flow.
| let data = match tensor.dtype() { | ||
| safetensors::Dtype::F32 => Array::from_slice( | ||
| unsafe { | ||
| std::slice::from_raw_parts( | ||
| tensor.data().as_ptr() as *const f32, | ||
| tensor.data().len() / 4, | ||
| ) | ||
| }, | ||
| &shape, | ||
| ), | ||
| _ => continue, // Skip unsupported for now | ||
| }; |
There was a problem hiding this comment.
Silently skipping non-F32 tensors may cause data loss.
In load_checkpoint_weights, non-F32 tensors are silently skipped with continue. This could lose important weights if the checkpoint contains f16/bf16 tensors. Consider logging a warning or returning an error.
🔎 Proposed fix with warning
let data = match tensor.dtype() {
safetensors::Dtype::F32 => Array::from_slice(
// ...
),
- _ => continue, // Skip unsupported for now
+ dtype => {
+ tracing::warn!("Skipping tensor '{}' with unsupported dtype {:?}", tensor_name, dtype);
+ continue;
+ }
};📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| let data = match tensor.dtype() { | |
| safetensors::Dtype::F32 => Array::from_slice( | |
| unsafe { | |
| std::slice::from_raw_parts( | |
| tensor.data().as_ptr() as *const f32, | |
| tensor.data().len() / 4, | |
| ) | |
| }, | |
| &shape, | |
| ), | |
| _ => continue, // Skip unsupported for now | |
| }; | |
| let data = match tensor.dtype() { | |
| safetensors::Dtype::F32 => Array::from_slice( | |
| unsafe { | |
| std::slice::from_raw_parts( | |
| tensor.data().as_ptr() as *const f32, | |
| tensor.data().len() / 4, | |
| ) | |
| }, | |
| &shape, | |
| ), | |
| dtype => { | |
| tracing::warn!("Skipping tensor '{}' with unsupported dtype {:?}", tensor_name, dtype); | |
| continue; | |
| } | |
| }; |
🤖 Prompt for AI Agents
In rust/src/model/loader.rs around lines 201 to 212, the code silently continues
when encountering a non-F32 tensor which can drop important weights; change this
to detect non-F32 dtypes, log a warning including the tensor name and dtype, and
handle common numeric formats instead of skipping: if dtype is F16 or BF16
convert the raw bytes to f32 (using a conversion routine/half crate) and
construct the Array from the converted f32 slice; for any other unsupported
dtype return a descriptive Err or propagate an error so the caller can decide.
Ensure logging contains tensor identifier and dtype, and keep the existing F32
branch untouched.
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; |
There was a problem hiding this comment.
use super::* won't compile in a standalone test file.
This file is in rust/tests/ (integration tests), not a submodule. super::* refers to a non-existent parent module and will cause a compilation error. Remove this line.
🔎 Proposed fix
#[cfg(test)]
mod tests {
- use super::*;
use your_ai_rs::checkpoints::manager::{CheckpointManager, Checkpoint, ModelState, OptimizerState, TrainingConfig};
use std::path::PathBuf;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| use super::*; | |
| #[cfg(test)] | |
| mod tests { | |
| use your_ai_rs::checkpoints::manager::{CheckpointManager, Checkpoint, ModelState, OptimizerState, TrainingConfig}; | |
| use std::path::PathBuf; |
🤖 Prompt for AI Agents
In rust/tests/checkpoint_test.rs around line 4, the line "use super::*" is
invalid in an integration test (tests/ directory) because there is no parent
module; remove that line and instead import what the test needs from the crate
root (e.g., use crate_name::Item or use crate::module::Item) or fully qualify
paths in the test; ensure you add explicit use statements for any items used in
this file or prefix them with crate:: so the file compiles as a standalone
integration test.
Summary by CodeRabbit
New Features
Bug Fixes / Stability
Documentation
Chores
Tests
✏️ Tip: You can customize this high-level summary in your review settings.