Skip to content

Commit

Permalink
Fixes 256 - ValueError mutable default
Browse files Browse the repository at this point in the history
Fixes issue EleutherAI#256. Error message:

> ValueError: mutable default <class 'elk.training.train.Elicit'> for field run_template is not allowed: use
default_factory
  • Loading branch information
artkpv committed Jun 27, 2023
1 parent ec2b8a0 commit d04f17f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
8 changes: 1 addition & 7 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from transformers import AutoConfig

from ..evaluation import Eval
from ..extraction import Extract
from ..files import memorably_named_dir, sweeps_dir
from ..plotting.visualize import visualize_sweep
from ..training.eigen_reporter import EigenReporterConfig
Expand Down Expand Up @@ -52,12 +51,7 @@ class Sweep:
name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)
run_template: Elicit = field(default_factory=Elicit.Default)

def __post_init__(self, add_pooled: bool):
if not self.datasets:
Expand Down
10 changes: 10 additions & 0 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ..metrics import evaluate_preds, to_one_hot
from ..run import Run
from ..extraction import Extract
from ..training.supervised import train_supervised
from ..utils.typing import assert_type
from .ccs_reporter import CcsReporter, CcsReporterConfig
Expand All @@ -34,6 +35,15 @@ class Elicit(Run):
cross-validation. Defaults to "single", which means to train a single classifier
on the training data. "cv" means to use cross-validation."""

@staticmethod
def Default():
return Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)

def create_models_dir(self, out_dir: Path):
lr_dir = None
lr_dir = out_dir / "lr_models"
Expand Down

0 comments on commit d04f17f

Please sign in to comment.