diff --git a/elk/run.py b/elk/run.py index fb8903cc..a3ab2b95 100644 --- a/elk/run.py +++ b/elk/run.py @@ -1,5 +1,6 @@ import os import random +import subprocess from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass @@ -17,6 +18,8 @@ from torch import Tensor from tqdm import tqdm +import elk + from .debug_logging import save_debug_log from .extraction import Extract, extract from .extraction.dataset_name import DatasetDictWithName @@ -31,6 +34,19 @@ ) +def fetch_git_hash() -> str | None: + try: + return ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=Path(elk.__file__).parent.parent + ) + .decode("ascii") + .strip() + ) + except (NotADirectoryError, subprocess.CalledProcessError): + return + + @dataclass class Run(ABC, Serializable): data: Extract @@ -86,15 +102,7 @@ def execute( # properly without this flag enabled. save(self, self.out_dir / "cfg.yaml", save_dc_types=True) - path = self.out_dir / "fingerprints.yaml" - with open(path, "w") as meta_f: - yaml.dump( - { - ds_name: {split: ds[split]._fingerprint for split in ds.keys()} - for ds_name, ds in self.datasets - }, - meta_f, - ) + self.write_metadata() devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) @@ -190,3 +198,21 @@ def apply_to_layers( df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) + + def write_metadata(self): + """Write metadata about the run to a yaml file.""" + assert self.out_dir is not None + with open(self.out_dir / "metadata.yaml", "w") as meta_f: + dataset_fingerprints = { + ds_name: {split: ds[split]._fingerprint for split in ds.keys()} + for ds_name, ds in self.datasets + } + metadata = dict() + metadata["datasets"] = dataset_fingerprints + git_hash = fetch_git_hash() + if git_hash is not None: + metadata["git_hash"] = git_hash + yaml.dump( + metadata, + meta_f, + ) diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index bac0f398..df2d303d 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -27,7 +27,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): created_file_names = {file.name for file in files} expected_files = [ "cfg.yaml", - "fingerprints.yaml", + "metadata.yaml", "lr_models", "reporters", "eval.csv", @@ -58,7 +58,7 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): created_file_names = {file.name for file in files} expected_files = [ "cfg.yaml", - "fingerprints.yaml", + "metadata.yaml", "lr_models", "reporters", "eval.csv", diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index 4efd7112..5d5b74c5 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -9,7 +9,7 @@ EVAL_EXPECTED_FILES = [ "cfg.yaml", - "fingerprints.yaml", + "metadata.yaml", "eval.csv", ]