Skip to content

Commit

Permalink
Migrate CLI to simple_parsing, allowing for YAML based configs (#92)
Browse files Browse the repository at this point in the history
* Initial commit

* Formatting

* PromptCollator -> PromptDataset

* elk elicit seems to work again now

* Remove vestigial argparse code

* Fix bugs

* Added stochastic rounding test

* Remove typo

* Add pytest.mark.cpu
  • Loading branch information
norabelrose authored Feb 20, 2023
1 parent 5bd59e4 commit 946372d
Show file tree
Hide file tree
Showing 17 changed files with 376 additions and 542 deletions.
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"-l 88"
],
"python.formatting.provider": "black",
"python.testing.pytestArgs": [],
"python.testing.pytestArgs": [
"tests"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
}
2 changes: 1 addition & 1 deletion elk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .extraction import PromptCollator
from .extraction import extract_hiddens, ExtractionConfig, PromptDataset
59 changes: 14 additions & 45 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Main entry point for `elk`."""

from .argparsers import add_train_args, get_extraction_parser
from .files import args_to_uuid
from .extraction import ExtractionConfig
from .list import list_runs
from argparse import ArgumentParser
from .training import RunConfig
from contextlib import nullcontext, redirect_stdout
from pathlib import Path
from simple_parsing import ArgumentParser
import logging
import warnings


def run():
Expand All @@ -16,18 +16,20 @@ def run():
subparsers.add_parser(
"extract",
help="Extract hidden states from a model.",
parents=[get_extraction_parser()],
)
).add_arguments(ExtractionConfig, dest="extraction")

elicit_parser = subparsers.add_parser(
"elicit",
help=(
"Extract and train a set of ELK reporters "
"on hidden states from `elk extract`. "
),
parents=[get_extraction_parser()],
conflict_handler="resolve",
)
add_train_args(elicit_parser)
elicit_parser.add_arguments(RunConfig, dest="run")
elicit_parser.add_argument(
"--output", "-o", type=Path, help="Path to save checkpoints to."
)

subparsers.add_parser(
"eval", help="Evaluate a set of ELK reporters generated by `elk train`."
Expand All @@ -40,30 +42,6 @@ def run():
list_runs(args)
return

from transformers import AutoConfig, PretrainedConfig

config = AutoConfig.from_pretrained(args.model)
assert isinstance(config, PretrainedConfig)

num_layers = getattr(config, "num_layers", config.num_hidden_layers)
assert isinstance(num_layers, int)

if args.layers and args.layer_stride > 1:
raise ValueError(
"Cannot use both --layers and --layer-stride. Please use only one."
)
elif args.layer_stride > 1:
args.layers = list(range(0, num_layers, args.layer_stride))

# TODO: Remove this once the extraction refactor is finished
if args.layers and args.layers != list(range(num_layers)):
warnings.warn(
"Warning: hidden states are not labeled by layer index, and reporter "
"checkpoints generated by `elk elicit` will be incorrectly named; "
"e.g. `layer_1` instead of `layer_2` for the 3rd transformer layer "
"when `--layer-stride` is 2. This will be fixed in a future release."
)

# Import here and not at the top to speed up `elk list`
from .extraction.extraction_main import run as run_extraction
from .training.train import train
Expand All @@ -77,31 +55,22 @@ def run():
local_rank = int(local_rank)

with redirect_stdout(None) if local_rank else nullcontext():
# Print CLI arguments to stdout
for key, value in vars(args).items():
print(f"{key}: {value}")

if local_rank:
logging.getLogger("transformers").setLevel(logging.CRITICAL)

if args.command == "extract":
run_extraction(args)
run_extraction(args.run.data)
elif args.command == "elicit":
# The user can specify a name for the run, but by default we use the
# MD5 hash of the arguments to ensure the name is unique
if not args.name:
args.name = args_to_uuid(args)

try:
train(args)
train(args.run, args.output)
except (EOFError, FileNotFoundError):
run_extraction(args)
run_extraction(args.run.data)

# Ensure the extraction is finished before starting training
if dist.is_initialized():
dist.barrier()

train(args)
train(args.run, args.output)

elif args.command == "eval":
# TODO: Implement evaluation script
Expand Down
213 changes: 0 additions & 213 deletions elk/argparsers.py

This file was deleted.

4 changes: 2 additions & 2 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .extraction import extract_hiddens
from .prompt_collator import PromptCollator
from .extraction import ExtractionConfig, extract_hiddens
from .prompt_dataset import PromptDataset, PromptConfig
Loading

0 comments on commit 946372d

Please sign in to comment.