Skip to content

fix/optimized-train#22

Open
arosboro wants to merge 2 commits intomainfrom
fix/optimized-train
Open

fix/optimized-train#22
arosboro wants to merge 2 commits intomainfrom
fix/optimized-train

Conversation

@arosboro
Copy link
Owner

@arosboro arosboro commented Dec 8, 2025

Summary by CodeRabbit

  • New Features

    • Early stopping enabled by default to halt training on plateaus or gradient issues
    • Automatic validation during training with best-model checkpointing
    • Auto-resume to recover interrupted runs
    • Enhanced progress monitoring (ETA, moving averages, memory/gradient health)
  • Documentation

    • Added Quick Training Guide covering workflow, troubleshooting, and defaults
    • Added deeper technical training improvements document
  • Improvements

    • Faster defaults (max_steps, warmup_steps, checkpoint interval) and more robust checkpoint recovery

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 8, 2025

Walkthrough

Adds early stopping, periodic validation with best-model checkpointing, automatic checkpoint recovery/auto-resume, more robust checkpoint saves, enhanced progress monitoring (ETA, moving averages, TensorBoard), updated training defaults, and documentation. Introduces EarlyStopping and a validate() method on DistrustTrainer.

Changes

Cohort / File(s) Summary
Documentation
QUICK_TRAINING_GUIDE.md, TRAINING_IMPROVEMENTS.md
New guides describing the faster training workflow, features (Auto-Resume, Early Stopping, validation, monitoring), defaults, troubleshooting, and tuning recommendations.
Configuration Updates
src/config.py
Added early stopping and gradient-spike configuration fields (early_stopping_enabled, early_stopping_patience, early_stopping_min_delta, grad_spike_threshold, grad_spike_patience); changed defaults: max_steps 5000→2000, warmup_steps 100→50, checkpoint_interval 500→250.
Checkpoint Management
src/checkpoints/checkpoint_manager.py
Pre-evaluates model/optimizer arrays before saving, uses logger warnings, implements two-phase save with per-parameter fallback and partial-save/rename behavior on failure, preserves checkpoint structure on optimizer errors, and adds detailed logging.
Training Framework
src/train_qlora.py
Added public EarlyStopping class; added validate(self, val_data) -> Dict[str, float] to DistrustTrainer; integrated early stopping, periodic validation, best-model checkpointing, auto-resume/interactive resume, lazy validation loading, TensorBoard logging, ETA/moving-average metrics, and enhanced progress/gradient-health checks into the training loop.

Sequence Diagram

sequenceDiagram
    participant User
    participant Main as main()
    participant Trainer as DistrustTrainer
    participant ChkptMgr as checkpoint_manager
    participant EarlyStopping
    participant Validation
    participant TensorBoard

    User->>Main: run training (maybe --auto-resume)
    Main->>Trainer: initialize (config, auto_resume)
    Trainer->>ChkptMgr: scan/load latest checkpoint (if any)
    ChkptMgr-->>Trainer: checkpoint state (model, opt, step)
    Trainer->>EarlyStopping: initialize with config

    rect rgb(230, 245, 255)
    loop Training loop (per batch/step)
        Trainer->>Trainer: forward/backward -> loss, grad_norm
        Trainer->>EarlyStopping: should_stop(loss, grad_norm, step)?
        EarlyStopping-->>Trainer: continue / trigger stop

        alt validation step
            Trainer->>Validation: validate(val_data)
            Validation-->>Trainer: val_loss, metrics
            Trainer->>ChkptMgr: save checkpoint if best
            Trainer->>TensorBoard: log validation metrics
        end

        Trainer->>TensorBoard: log training metrics, ETA, progress
    end
    end

    alt early stop or complete
        Trainer->>ChkptMgr: final checkpoint save (robust two-phase save)
        Trainer->>User: report final summary (best checkpoint)
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~70 minutes

  • Areas to focus:
    • EarlyStopping logic (patience, warmup, grad-spike detection) and state transitions.
    • validate() integration with streaming vs. non-streaming validation and best-model checkpointing.
    • checkpoint_manager two-phase save and partial-save/rename error paths.
    • auto-resume and interactive resume paths across main/trainer/checkpoint_manager.
    • TensorBoard/logging and moving-average/ETA calculations for correctness and overhead.

Possibly related PRs

  • Feat/training improvements #15 — Overlaps in modifications to src/train_qlora.py (training loop, resume/optimizer behavior); may conflict or complement resume and gradient-handling changes.

Poem

🐰
I hopped through checkpoints, soft and spry,
Caught gradients that spiked and waved goodbye,
I whispered "resume" when night fell deep,
Kept the best model safe while you sleep,
Now training wakes with hops of joy—hooray!

Pre-merge checks and finishing touches

❌ Failed checks (1 inconclusive)
Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'fix/optimized-train' is vague and does not clearly convey the main changes, which involve substantial training framework improvements including early stopping, validation, and auto-resume features. Use a more descriptive title that summarizes the primary improvement, such as 'Add early stopping, validation, and auto-resume to training framework' or 'Improve training reliability with early stopping and validation features'.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 90.91% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/optimized-train

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (5)
TRAINING_IMPROVEMENTS.md (2)

68-71: Add language to the progress format code block (MD040).

Markdownlint is complaining about the fenced block without a language. Consider tagging it as plain text for consistency:

-``` 
+```text
 Training: 45% | loss=3.2 | loss_avg=3.4 | eta_h=6.5 | ckpt=-150 | memory_mb=14051 | mem_delta=+245

---

`167-174`: **Surround the “Expected Results” table with blank lines (MD058).**

To satisfy markdownlint and improve readability, add blank lines before and after the table:



```diff
-### Expected Results
-| Metric | Before | After | Improvement |
+### Expected Results
+
+| Metric | Before | After | Improvement |
 ...
-| **User experience** | Poor visibility | Clear progress | Much better |
-## Testing Recommendations
+| **User experience** | Poor visibility | Clear progress | Much better |
+
+## Testing Recommendations
src/train_qlora.py (3)

603-680: Close streaming validation datasets to mirror training dataset cleanup.

When val_data is a StreamingDataset, it’s iterated in validate() but never closed, whereas train_data.close() is called after training. To avoid leaked file handles/buffers, consider tracking whether validation is streaming and closing it at the end of train().

@@
-        # Load validation data if available
-        val_data = None
+        # Load validation data if available
+        val_data = None
+        val_is_streaming = False
@@
-            print("Loading validation data...")
-            val_data = self.load_data(str(val_file_path))
-            print(f"Validation data loaded")
+            print("Loading validation data...")
+            val_data = self.load_data(str(val_file_path))
+            print("Validation data loaded")
+            from data.streaming_dataset import StreamingDataset  # if not already imported at top
+            val_is_streaming = isinstance(val_data, StreamingDataset)
@@
-        pbar.close()
-
-        # Print training summary
+        pbar.close()
+
+        # Cleanup streaming validation
+        if val_is_streaming:
+            val_data.close()
+
+        # Print training summary

(Adjust the import if StreamingDataset is already in scope.)

Also applies to: 697-705, 851-853, 962-977


740-771: Avoid duplicating final checkpoints and off‑by‑one step metadata on early stop.

When early stopping triggers, you save a final checkpoint inside the loop with step but break before incrementing self.global_step. Later, train() unconditionally calls self.save_checkpoint(self.global_step, is_final=True) again, so you both (a) write the same checkpoint-{step}-final twice and (b) keep global_step one behind the last executed update. It would be cleaner to either:

  • Set self.global_step = step (or step + 1, depending on your convention) before the break, and/or
  • Track an early_stopped flag and skip the extra final save when it’s already been done.

Also applies to: 855-885, 964-977


1111-1135: Sync CLI --max-steps default with new config/docs default.

TrainingConfig.max_steps now defaults to 2000 and the docs describe 2000 as the new default, but the CLI parser still sets --max-steps default to 5000, overriding the dataclass default for typical runs. To avoid surprises, consider making the parser default None and only overriding when the user passes the flag, or change the default here to 2000.

-    train_group.add_argument("--max-steps", type=int, default=5000, help="Max training steps")
+    train_group.add_argument(
+        "--max-steps",
+        type=int,
+        default=None,
+        help="Max training steps (default: 2000 from config unless overridden)",
+    )
@@
-    config.training.max_steps = args.max_steps
+    if args.max_steps is not None:
+        config.training.max_steps = args.max_steps

Also applies to: 248-252

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3eb6b90 and d3d1b1b.

📒 Files selected for processing (5)
  • QUICK_TRAINING_GUIDE.md (1 hunks)
  • TRAINING_IMPROVEMENTS.md (1 hunks)
  • src/checkpoints/checkpoint_manager.py (2 hunks)
  • src/config.py (2 hunks)
  • src/train_qlora.py (13 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.py

📄 CodeRabbit inference engine (GEMINI.md)

Include docstrings and comments following a clear and consistent style in all code

Files:

  • src/checkpoints/checkpoint_manager.py
  • src/config.py
  • src/train_qlora.py
src/config.py

📄 CodeRabbit inference engine (GEMINI.md)

Use dataclasses for configuration management, defined in src/config.py

Files:

  • src/config.py
src/train_*.py

📄 CodeRabbit inference engine (GEMINI.md)

Use Apple's MLX framework for training on Apple Silicon, with PyTorch as a fallback for other platforms

Files:

  • src/train_qlora.py
🧠 Learnings (1)
📚 Learning: 2025-11-30T23:47:27.584Z
Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Organize code into logical modules with single responsibilities, such as `distrust_loss.py` for the core algorithm, `train_qlora.py` for the training loop, and `prepare_data_curated.py` for data preparation

Applied to files:

  • src/train_qlora.py
🧬 Code graph analysis (1)
src/train_qlora.py (4)
src/checkpoints/checkpoint_manager.py (2)
  • validate (399-457)
  • close (541-547)
src/data/streaming_dataset.py (2)
  • get_progress (201-222)
  • close (266-272)
src/distrust_loss.py (1)
  • batch_empirical_distrust_loss (129-187)
src/config.py (1)
  • val_file (330-331)
🪛 GitHub Actions: CI
src/checkpoints/checkpoint_manager.py

[error] 168-168: F841 Local variable 'test_dict' is assigned to but never used. Remove assignment to unused variable 'test_dict'.

src/train_qlora.py

[error] 703-703: F541 [*] f-string without any placeholders. Remove extraneous f prefix.


[error] 1455-1455: F541 [*] f-string without any placeholders. Remove extraneous f prefix.


[error] 1467-1467: F541 [*] f-string without any placeholders. Remove extraneous f prefix.

🪛 markdownlint-cli2 (0.18.1)
TRAINING_IMPROVEMENTS.md

68-68: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


167-167: Tables should be surrounded by blank lines

(MD058, blanks-around-tables)

🔇 Additional comments (2)
src/checkpoints/checkpoint_manager.py (1)

190-227: Optimizer pre‑eval + empty‑file fallback looks robust.

The pattern of pre‑evaluating optimizer arrays, logging failures, and writing an empty optimizer.npz to preserve checkpoint structure is sound and should play nicely with _load_from_path’s reconstruction logic.

src/config.py (1)

248-266: Early‑stopping and checkpoint defaults are well‑integrated with existing config plumbing.

The reduced max_steps/warmup_steps, added early‑stopping and grad‑spike fields, and tighter checkpoint_interval slot cleanly into the dataclass + to_dict/from_dict flow, so checkpoints and metadata stay backward‑compatible.

Also applies to: 356-361

Comment on lines +155 to +264
class EarlyStopping:
"""
Early stopping to prevent wasted training time.

Monitors loss and gradient health, stopping when:
- Loss plateaus for patience checks
- Gradient norms become unstable
- Validation loss diverges from training loss (overfitting)
"""

def __init__(
self,
patience: int = 5,
min_delta: float = 0.01,
warmup_steps: int = 200,
grad_spike_threshold: float = 1000.0,
grad_spike_patience: int = 3,
):
"""
Initialize early stopping.

Args:
patience: Number of checks without improvement before stopping
min_delta: Minimum change to qualify as improvement
warmup_steps: Don't check early stopping until this many steps
grad_spike_threshold: Gradient norm threshold for instability
grad_spike_patience: Consecutive spikes before aborting
"""
self.patience = patience
self.min_delta = min_delta
self.warmup_steps = warmup_steps
self.grad_spike_threshold = grad_spike_threshold
self.grad_spike_patience = grad_spike_patience

self.best_loss = float('inf')
self.counter = 0
self.grad_spike_counter = 0
self.stopped_reason = None

def check_loss(self, loss: float, step: int) -> bool:
"""
Check if training should stop based on loss plateau.

Args:
loss: Current training loss
step: Current training step

Returns:
True if should stop, False otherwise
"""
if step < self.warmup_steps:
return False

if loss < self.best_loss - self.min_delta:
# Improvement
self.best_loss = loss
self.counter = 0
return False
else:
# No improvement
self.counter += 1
if self.counter >= self.patience:
self.stopped_reason = f"Loss plateau: no improvement for {self.patience} checks"
return True

return False

def check_gradient_health(self, grad_norm: float) -> bool:
"""
Check if gradients are stable.

Args:
grad_norm: Current gradient norm

Returns:
True if should abort due to instability, False otherwise
"""
if grad_norm > self.grad_spike_threshold:
self.grad_spike_counter += 1
if self.grad_spike_counter >= self.grad_spike_patience:
self.stopped_reason = f"Gradient instability: {self.grad_spike_counter} consecutive spikes > {self.grad_spike_threshold}"
return True
else:
# Reset counter on stable gradient
self.grad_spike_counter = 0

return False

def should_stop(self, loss: float, grad_norm: float, step: int) -> bool:
"""
Unified check for all stopping conditions.

Args:
loss: Current training loss
grad_norm: Current gradient norm
step: Current training step

Returns:
True if should stop, False otherwise
"""
# Check gradient health first (higher priority)
if self.check_gradient_health(grad_norm):
return True

# Check loss plateau
if self.check_loss(loss, step):
return True

return False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clarify EarlyStopping behavior vs docstring (no validation‑loss logic yet).

EarlyStopping currently stops based on training‑loss plateau and gradient spikes only; despite the docstring mentioning validation‑loss divergence, no validation metrics are passed in or checked. Either extend should_stop to accept validation loss and implement that condition, or narrow the docstring to match current behavior. You might also consider extracting EarlyStopping into a small helper module if you plan to reuse it in other training scripts. Based on learnings, this keeps responsibilities modular.

🤖 Prompt for AI Agents
In src/train_qlora.py around lines 155-264, the class docstring mentions
stopping on validation-loss divergence but the implementation only checks
training loss and gradient spikes; either implement validation-loss checking or
update the docstring to match current behavior. To implement: add an optional
val_loss argument to should_stop (and propagate through check_loss), track an
overfit counter comparing val_loss to training loss (e.g., val_loss > loss +
min_delta or val_loss / loss > 1 + some_ratio) with its own patience and
stopped_reason, reset the counter on improvement, and include this check
alongside the existing ones; update the constructor to accept overfit
patience/threshold params and update the docstring accordingly. Alternatively,
to keep minimal changes, remove the validation-loss mention from the docstring
and optionally extract EarlyStopping into a separate helper module if reuse is
planned.

Comment on lines +699 to 705
val_file_path = Path(self.config.paths.val_file)
if val_file_path.exists():
print("Loading validation data...")
val_data = self.load_data(str(val_file_path))
print(f"Validation data loaded")

# Training loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix F541 flake8 errors from f‑strings without placeholders.

Three print calls are using f‑strings with no {} placeholders, which flake8 flags as F541 and is breaking CI. They can be plain string literals.

-            print(f"Validation data loaded")
+            print("Validation data loaded")
@@
-                print(f"   Auto-resume enabled - will resume from latest checkpoint")
+                print("   Auto-resume enabled - will resume from latest checkpoint")
@@
-                print(f"   This suggests a previous training run was interrupted.")
+                print("   This suggests a previous training run was interrupted.")

Also applies to: 1453-1456, 1466-1467

🧰 Tools
🪛 GitHub Actions: CI

[error] 703-703: F541 [*] f-string without any placeholders. Remove extraneous f prefix.

🤖 Prompt for AI Agents
In src/train_qlora.py around lines 699-705 (and also at 1453-1456 and
1466-1467), there are print statements written as f-strings but containing no {}
placeholders which triggers flake8 F541; change those f-strings to plain string
literals by removing the leading "f" (e.g., replace print(f"Validation data
loaded") with print("Validation data loaded")) so the messages remain the same
but no longer use f-strings without placeholders.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/checkpoints/checkpoint_manager.py (1)

1-1: Run ruff format to fix formatting.

The CI pipeline is failing because ruff formatting would reformat this file. Run ruff format src/checkpoints/checkpoint_manager.py to fix.

src/train_qlora.py (1)

1-1: Run ruff format to fix formatting.

The CI pipeline is failing because ruff formatting would reformat this file. Run ruff format src/train_qlora.py to fix.

🧹 Nitpick comments (5)
src/checkpoints/checkpoint_manager.py (1)

217-225: Consider logging more context on optimizer save failure.

The fallback to an empty optimizer file preserves checkpoint structure, which is reasonable for recovery. However, losing optimizer state silently (beyond the error log) could cause unexpected behavior on resume (e.g., learning rate schedule reset). Consider adding a warning that mentions the impact on training resume.

                     except Exception as e:
                         logger.error(f"Failed to save optimizer state: {e}")
                         # Save empty file to maintain checkpoint structure
                         mx.savez(str(optimizer_path))
-                        logger.warning("Saved empty optimizer state due to error")
+                        logger.warning(
+                            "Saved empty optimizer state due to error. "
+                            "Resuming from this checkpoint will reset optimizer state."
+                        )
TRAINING_IMPROVEMENTS.md (1)

78-83: Add language specifier to fenced code block.

The code block showing the progress bar format should have a language specifier for consistency. Use text or plaintext for non-code output.

 **Display format**:

-```
+```text
 Training: 45% | loss=3.2 | loss_avg=3.4 | eta_h=6.5 | ckpt=-150 | memory_mb=14051 | mem_delta=+245

</blockquote></details>
<details>
<summary>src/train_qlora.py (1)</summary><blockquote>

`765-888`: **Streaming training loop is well-implemented but has code duplication.**

The streaming loop correctly integrates early stopping (lines 789-796), validation with lazy loading (lines 798-828), and enhanced progress monitoring (lines 830-874). The fallback to non-streaming on failure (lines 739-750) is good defensive coding.

However, the training loop logic is largely duplicated between streaming (lines 765-888) and non-streaming (lines 891-1006) modes. Consider extracting the common step logic into a helper method in a future refactor.

</blockquote></details>
<details>
<summary>QUICK_TRAINING_GUIDE.md (2)</summary><blockquote>

`69-97`: **Add language specifiers to terminal output code blocks.**

Several code blocks showing terminal output lack language specifiers. Use `text` for consistency and to satisfy markdown linting.


```diff
 **What you'll see**:
-```
+```text
 🛑 Early stopping triggered at step 847
    Reason: Loss plateau: no improvement for 5 checks
 ✓ Best model saved at step 750 (val_loss: 2.341)

...

What you'll see:
- +text
📊 Running validation at step 250...
Val Loss: 2.543 (Train: 2.891)
✓ New best model! (val_loss: 2.543)


...

**New progress bar format**:
-```
+```text
Training:  45% | loss=3.2 | loss_avg=3.4 | eta_h=6.5 | grad_norm=0.45 |
           ckpt=-150 | memory_mb=14051 | mem_delta=+245

---

`302-308`: **Add language specifier to completion output code block.**

Same as previous - add `text` language specifier for consistency.


```diff
 After training completes, you'll see:
-```
+```text
 Training complete!
 ✓ Best model saved at step 1247 (val_loss: 2.134)
 TensorBoard logs saved to: models/distrust-hermes-2-pro-mistral-7b/logs/run_2025-12-08_14-23-45

</blockquote></details>

</blockquote></details>

<details>
<summary>📜 Review details</summary>

**Configuration used**: CodeRabbit UI

**Review profile**: CHILL

**Plan**: Pro

<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between d3d1b1b8fd7528a6680a4a7dd3a2bb9efe77355d and cf601e73d4bb35b2c90331e07748d226494ab70d.

</details>

<details>
<summary>📒 Files selected for processing (4)</summary>

* `QUICK_TRAINING_GUIDE.md` (1 hunks)
* `TRAINING_IMPROVEMENTS.md` (1 hunks)
* `src/checkpoints/checkpoint_manager.py` (2 hunks)
* `src/train_qlora.py` (11 hunks)

</details>

<details>
<summary>🧰 Additional context used</summary>

<details>
<summary>📓 Path-based instructions (2)</summary>

<details>
<summary>**/*.py</summary>


**📄 CodeRabbit inference engine (GEMINI.md)**

> Include docstrings and comments following a clear and consistent style in all code

Files:
- `src/checkpoints/checkpoint_manager.py`
- `src/train_qlora.py`

</details>
<details>
<summary>src/train_*.py</summary>


**📄 CodeRabbit inference engine (GEMINI.md)**

> Use Apple's MLX framework for training on Apple Silicon, with PyTorch as a fallback for other platforms

Files:
- `src/train_qlora.py`

</details>

</details><details>
<summary>🧠 Learnings (1)</summary>

<details>
<summary>📚 Learning: 2025-11-30T23:47:27.584Z</summary>

Learnt from: CR
Repo: arosboro/your_ai PR: 0
File: GEMINI.md:0-0
Timestamp: 2025-11-30T23:47:27.584Z
Learning: Organize code into logical modules with single responsibilities, such as distrust_loss.py for the core algorithm, train_qlora.py for the training loop, and prepare_data_curated.py for data preparation


**Applied to files:**
- `src/train_qlora.py`

</details>

</details><details>
<summary>🪛 GitHub Actions: CI</summary>

<details>
<summary>src/checkpoints/checkpoint_manager.py</summary>

[error] 1-1: ruff formatting would reformat this file. Run 'ruff format src/checkpoints/checkpoint_manager.py' to fix.

</details>
<details>
<summary>src/train_qlora.py</summary>

[error] 1-1: ruff formatting would reformat this file. Run 'ruff format src/train_qlora.py' to fix.

</details>

</details>
<details>
<summary>🪛 markdownlint-cli2 (0.18.1)</summary>

<details>
<summary>QUICK_TRAINING_GUIDE.md</summary>

70-70: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

---

85-85: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

---

94-94: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

---

303-303: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

---

349-349: Emphasis used instead of a heading

(MD036, no-emphasis-as-heading)

</details>
<details>
<summary>TRAINING_IMPROVEMENTS.md</summary>

80-80: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

</details>

</details>

</details>

<details>
<summary>🔇 Additional comments (13)</summary><blockquote>

<details>
<summary>src/checkpoints/checkpoint_manager.py (1)</summary><blockquote>

`149-182`: **LGTM - Batch evaluation and partial save recovery implemented correctly.**

The batch evaluation of arrays before saving (line 152) ensures full materialization, preventing lazy evaluation issues. The fallback to per-parameter saves (lines 165-182) provides robust recovery. The previously flagged unused `test_dict` variable has been removed.

</blockquote></details>
<details>
<summary>TRAINING_IMPROVEMENTS.md (2)</summary><blockquote>

`25-43`: **Documentation now correctly reflects early stopping behavior.**

The text at line 40 clarifies that "Validation loss tracked separately for best model selection (not for early stopping)" - this addresses the previous review concern about the documentation mentioning validation-loss divergence that wasn't implemented. The current documentation accurately describes the implementation.

---

`1-6`: **Well-structured documentation for training improvements.**

The document provides comprehensive coverage of the changes including problems fixed, new features, configuration changes, migration notes, and testing recommendations. The expected results table (lines 193-199) gives clear before/after comparisons.

</blockquote></details>
<details>
<summary>src/train_qlora.py (8)</summary><blockquote>

`155-166`: **Docstring now accurately describes EarlyStopping behavior.**

The docstring (lines 156-165) correctly states that EarlyStopping monitors training loss and gradient health, and explicitly notes that validation loss is tracked separately for best model selection. This addresses the previous review concern about mismatched documentation.

---

`196-265`: **LGTM - EarlyStopping logic is well-implemented.**

The implementation correctly handles:
- Warmup period skipping (line 207-208)
- Loss improvement tracking with configurable min_delta (lines 210-214)
- Gradient spike detection with consecutive spike counting (lines 234-241)
- Priority ordering (gradient health checked first for faster abort)

---

`605-682`: **LGTM - Validation method is well-structured.**

The validate() method correctly:
- Handles both streaming and non-streaming data (lines 623-629)
- Limits validation to 50 batches to avoid long pauses (line 620)
- Computes loss without gradient updates
- Returns `inf` on failure which prevents false "best model" saves (line 676)
- Includes per-batch exception handling with continuation (lines 671-673)

---

`699-706`: **F541 issues have been addressed.**

The previous review flagged f-strings without placeholders. The current code uses regular strings where there are no placeholders (e.g., line 705: `print("  (will load on first validation run)")`).

---

`889-1006`: **Non-streaming loop mirrors streaming logic correctly.**

The non-streaming path correctly mirrors the streaming path's features including early stopping, validation, and progress monitoring. As noted above, this duplication could be refactored in the future.

---

`1161-1161`: **max_steps default updated to 2000.**

The CLI default for `--max-steps` is now 2000 (line 1161), which matches the documentation in QUICK_TRAINING_GUIDE.md. This addresses the previous review concern about mismatched defaults.

---

`1492-1519`: **Auto-resume logic is well-implemented.**

The auto-resume feature correctly:
- Activates only when `--auto-resume` is set and no explicit resume flags are provided (line 1494)
- Falls back to interactive prompting for non-auto-resume runs (lines 1505-1518)
- Filters out final checkpoints when detecting incomplete runs (line 1510)

---

`292-308`: **LGTM - EarlyStopping and validation state initialization.**

The use of `getattr` with defaults (lines 293-300) provides backward compatibility with configs that don't have the new early stopping fields. Validation tracking state is properly initialized.

</blockquote></details>
<details>
<summary>QUICK_TRAINING_GUIDE.md (2)</summary><blockquote>

`329-337`: **Defaults table now matches CLI implementation.**

The table correctly shows `max_steps` default as 2000, which now matches the CLI implementation in `train_qlora.py` (line 1161). The previous review concern has been addressed.

---

`1-19`: **Well-structured quick reference guide.**

The TL;DR section effectively summarizes the key improvements with clear before/after comparisons. The document provides practical examples, troubleshooting guidance, and hardware-specific recommendations that will be valuable for users.

</blockquote></details>

</blockquote></details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

Comment on lines +58 to +62
Training will automatically stop if:
- Loss plateaus for 5 consecutive checks (no improvement)
- Gradient norm spikes >1000 for 3 consecutive steps
- Validation loss starts increasing (overfitting)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Incorrect claim: validation loss does not trigger early stopping.

Line 61 states "Validation loss starts increasing (overfitting)" as an early stopping condition, but the EarlyStopping class only monitors training loss plateau and gradient spikes - not validation loss. Validation metrics are tracked separately for best model selection only.

Remove this incorrect item to match the actual implementation.

 Training will automatically stop if:
 - Loss plateaus for 5 consecutive checks (no improvement)
 - Gradient norm spikes >1000 for 3 consecutive steps
-- Validation loss starts increasing (overfitting)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Training will automatically stop if:
- Loss plateaus for 5 consecutive checks (no improvement)
- Gradient norm spikes >1000 for 3 consecutive steps
- Validation loss starts increasing (overfitting)
Training will automatically stop if:
- Loss plateaus for 5 consecutive checks (no improvement)
- Gradient norm spikes >1000 for 3 consecutive steps
🤖 Prompt for AI Agents
In QUICK_TRAINING_GUIDE.md around lines 58 to 62, the bullet "Validation loss
starts increasing (overfitting)" is incorrect because EarlyStopping only
monitors training loss plateau and gradient spikes; remove that bullet so the
list matches the implementation and replace or append a short note clarifying
that validation metrics are tracked separately for best-model selection (not
used for early stopping).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant