Skip to content

Commit b171eb3

Browse files
corbtclaude
andcommitted
Add experiment 228 with learning rate warmup and cooldown
- Add warmup_length and cooldown_length fields to ProjectPolicyConfig - Update train.py to use adjust_lr function with batch-specific learning rates - Add experiment 228 with 20-step warmup and cooldown starting at step 20 This experiment will test whether warmup/cooldown improves training compared to our baseline constant learning rate approach. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 9466881 commit b171eb3

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

examples/art-e/all_experiments.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,9 @@
212212

213213
models["227"] = models["008"].model_copy(deep=True)
214214
models["227"].name = "email-agent-227"
215-
models["220"].base_model = "willcb/Qwen3-14B"
215+
models["227"].base_model = "willcb/Qwen3-14B"
216+
217+
models["228"] = models["008"].model_copy(deep=True)
218+
models["228"].name = "email-agent-228"
219+
models["228"].config.warmup_length = 20
220+
models["228"].config.cooldown_length = -20

examples/art-e/art_e/project_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pydantic import BaseModel
2-
from typing import Literal
2+
from typing import Literal, Union
33

44

55
class ProjectPolicyConfig(BaseModel):
@@ -13,6 +13,8 @@ class ProjectPolicyConfig(BaseModel):
1313
trajectories_per_group: int = 6
1414
groups_per_step: int = 1
1515
learning_rate: float = 1.2e-5
16+
warmup_length: Union[int, float] = 0
17+
cooldown_length: Union[int, float] = 0
1618
eval_steps: int = 30
1719
val_set_size: int = 100
1820
training_dataset_size: int = 4000

examples/art-e/art_e/train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from art_e.data.query_iterators import load_synthetic_queries
88
from art_e.data.types_enron import SyntheticQuery
99
from art_e.data.local_email_db import generate_database
10-
from art.utils import iterate_dataset
10+
from art.utils import iterate_dataset, adjust_lr
1111
from art_e.project_types import ProjectPolicyConfig
1212
from art_e.evaluate.benchmark import benchmark_model
1313
import os
@@ -139,9 +139,17 @@ async def judge_after_each(
139139
)
140140
continue # Proceed to next batch/epoch without training.
141141

142+
# Calculate learning rate for this batch
143+
current_lr = adjust_lr(
144+
batch,
145+
learning_rate=model.config.learning_rate,
146+
warmup_length=model.config.warmup_length,
147+
cooldown_length=model.config.cooldown_length,
148+
)
149+
142150
await model.train(
143151
groups,
144-
config=art.TrainConfig(learning_rate=model.config.learning_rate),
152+
config=art.TrainConfig(learning_rate=current_lr),
145153
_config=art.dev.TrainConfig(
146154
allow_training_without_logprobs=model.config.messages_only,
147155
precalculate_logprobs=model.config.precalculate_logprobs,

0 commit comments

Comments
 (0)