diff --git a/elk/training/sweep.py b/elk/training/sweep.py index e4aca5a0..bbd761c9 100755 --- a/elk/training/sweep.py +++ b/elk/training/sweep.py @@ -5,7 +5,6 @@ from datasets import get_dataset_config_info from transformers import AutoConfig -from ..evaluation import Eval from ..extraction import Extract from ..files import memorably_named_dir, sweeps_dir from ..plotting.visualize import visualize_sweep @@ -134,19 +133,19 @@ def execute(self): data = replace( self.run_template.data, model=model, datasets=train_datasets ) - run = replace(self.run_template, data=data, out_dir=out_dir) + elicit = replace(self.run_template, data=data, out_dir=out_dir) if var_weight is not None and neg_cov_weight is not None: - assert isinstance(run.net, EigenFitterConfig) - run.net.var_weight = var_weight - run.net.neg_cov_weight = neg_cov_weight + assert isinstance(elicit.net, EigenFitterConfig) + elicit.net.var_weight = var_weight + elicit.net.neg_cov_weight = neg_cov_weight # Add hyperparameter values to output directory if needed - assert run.out_dir is not None - run.out_dir /= f"var_weight={var_weight:.2f}" - run.out_dir /= f"neg_cov_weight={neg_cov_weight:.2f}" + assert elicit.out_dir is not None + elicit.out_dir /= f"var_weight={var_weight:.2f}" + elicit.out_dir /= f"neg_cov_weight={neg_cov_weight:.2f}" try: - run.execute() + elicit.execute() except torch.linalg.LinAlgError as e: print(colorize(f"LinAlgError: {e}", "red")) continue @@ -161,17 +160,7 @@ def execute(self): if eval_dataset in train_datasets: continue - assert run.out_dir is not None - eval = Eval( - data=replace( - run.data, model=model, datasets=(eval_dataset,) - ), - source=run.out_dir, - out_dir=run.out_dir / "transfer" / eval_dataset, - num_gpus=run.num_gpus, - min_gpu_mem=run.min_gpu_mem, - skip_supervised=run.supervised == "none", - ) + eval = elicit.make_eval(model, eval_dataset) eval.execute(highlight_color="green") if self.visualize: diff --git a/elk/training/train.py b/elk/training/train.py index 8392f2d9..82316506 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,7 +1,7 @@ """Main training loop.""" from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, replace from pathlib import Path from typing import Literal @@ -11,6 +11,7 @@ from simple_parsing import subgroups from simple_parsing.helpers.serialization import save +from ..evaluation import Eval from ..metrics import evaluate_preds, to_one_hot from ..run import Run from ..training.supervised import train_supervised @@ -48,6 +49,26 @@ def create_models_dir(self, out_dir: Path): return reporter_dir, lr_dir + def make_eval(self, model, eval_dataset): + assert self.out_dir is not None + return Eval( + data=replace( + self.data, + model=model, + datasets=(eval_dataset,), + ), + source=self.out_dir, + out_dir=self.out_dir / "transfer" / eval_dataset, + num_gpus=self.num_gpus, + min_gpu_mem=self.min_gpu_mem, + skip_supervised=self.supervised == "none", + prompt_indices=self.prompt_indices, + concatenated_layer_offset=self.concatenated_layer_offset, + # datasets isn't needed because it's immediately overwritten + debug=self.debug, + disable_cache=self.disable_cache, + ) + def apply_to_layer( self, layer: int,