diff --git a/apps/inference/neuronpedia_inference/config.py b/apps/inference/neuronpedia_inference/config.py index 53639a3cd..5c2f4c838 100644 --- a/apps/inference/neuronpedia_inference/config.py +++ b/apps/inference/neuronpedia_inference/config.py @@ -228,6 +228,7 @@ def get_sae_lens_ids_from_neuronpedia_id( (df_exploded["model"] == model_id) & (df_exploded["neuronpedia_id"].str.endswith(f"/{neuronpedia_id}")) ] + assert ( tmp_df.shape[0] == 1 ), f"Found {tmp_df.shape[0]} entries when searching for {model_id}/{neuronpedia_id}" diff --git a/apps/inference/neuronpedia_inference/endpoints/steer/completion.py b/apps/inference/neuronpedia_inference/endpoints/steer/completion.py index 1e6f126c6..896882cd9 100644 --- a/apps/inference/neuronpedia_inference/endpoints/steer/completion.py +++ b/apps/inference/neuronpedia_inference/endpoints/steer/completion.py @@ -204,7 +204,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa editing_hooks = [ ( ( - sae_manager.get_sae_hook(feature.source) + sae_manager.get_decoder_hook(feature.source) if isinstance(feature, NPSteerFeature) else feature.hook ), @@ -248,7 +248,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa editing_hooks = [ ( ( - sae_manager.get_sae_hook(feature.source) + sae_manager.get_decoder_hook(feature.source) if isinstance(feature, NPSteerFeature) else feature.hook ), diff --git a/apps/inference/neuronpedia_inference/endpoints/steer/completion_chat.py b/apps/inference/neuronpedia_inference/endpoints/steer/completion_chat.py index 5ffbd1bb4..97c972f95 100644 --- a/apps/inference/neuronpedia_inference/endpoints/steer/completion_chat.py +++ b/apps/inference/neuronpedia_inference/endpoints/steer/completion_chat.py @@ -265,7 +265,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa editing_hooks = [ ( ( - sae_manager.get_sae_hook(feature.source) + sae_manager.get_decoder_hook(feature.source) if isinstance(feature, NPSteerFeature) else feature.hook ), @@ -308,7 +308,7 @@ def steering_hook(activations: torch.Tensor, hook: Any) -> torch.Tensor: # noqa editing_hooks = [ ( ( - sae_manager.get_sae_hook(feature.source) + sae_manager.get_decoder_hook(feature.source) if isinstance(feature, NPSteerFeature) else feature.hook ), diff --git a/apps/inference/neuronpedia_inference/sae_manager.py b/apps/inference/neuronpedia_inference/sae_manager.py index 1457fcef2..ebb3c6670 100644 --- a/apps/inference/neuronpedia_inference/sae_manager.py +++ b/apps/inference/neuronpedia_inference/sae_manager.py @@ -109,7 +109,7 @@ def load_sae(self, model_id: str, sae_id: str) -> None: df_exploded=get_saelens_neuronpedia_directory_df(), ) - loaded_sae, hook_name = SaeLensSAE.load( + loaded_sae, hook_in, hook_out = SaeLensSAE.load( release=sae_lens_release, sae_id=sae_lens_id, device=self.device, @@ -118,7 +118,8 @@ def load_sae(self, model_id: str, sae_id: str) -> None: self.sae_data[sae_id] = { "sae": loaded_sae, - "hook": hook_name, + "hook": hook_in, + "hook_out": hook_out, "neuronpedia_id": loaded_sae.cfg.neuronpedia_id, "type": SAE_TYPE.SAELENS, # TODO: this should be in SAELens @@ -129,7 +130,7 @@ def load_sae(self, model_id: str, sae_id: str) -> None: or DFA_ENABLED_NP_ID_SEGMENT_ALT in loaded_sae.cfg.neuronpedia_id ) ), - "transcoder": False, # You might want to set this based on some condition + "transcoder": hook_out is not None } self.loaded_saes[sae_id] = None # We're using OrderedDict as an OrderedSet @@ -261,6 +262,10 @@ def get_sae_type(self, sae_id: str) -> str: def get_sae_hook(self, sae_id: str) -> str: return self.sae_data.get(sae_id, {}).get("hook") + + def get_decoder_hook(self, sae_id): + data = self.sae_data.get(sae_id, {}) + return data.get("hook_out") or data.get("hook") def is_dfa_enabled(self, sae_id: str) -> bool: return self.sae_data.get(sae_id, {}).get("dfa_enabled", False) diff --git a/apps/inference/neuronpedia_inference/saes/saelens.py b/apps/inference/neuronpedia_inference/saes/saelens.py index d06bdde45..ffcd6befd 100644 --- a/apps/inference/neuronpedia_inference/saes/saelens.py +++ b/apps/inference/neuronpedia_inference/saes/saelens.py @@ -1,24 +1,79 @@ -import torch -from sae_lens.sae import SAE +"""SAE/Transcoder loader wrapper for Neuronpedia Searcher. -from neuronpedia_inference.saes.base import BaseSAE +This module previously supported only the vanilla SAE objects exposed by +`sae_lens.sae.SAE`. We now extend the functionality to transparently load +three different artifact classes coming from the sae-lens code-base: + +* SAE (classic auto-encoder) +* Transcoder +* SkipTranscoder + +The heavy lifting is delegated to `load_artifact_from_pretrained`, a new helper +published upstream that inspects the YAML metadata of a given release/sae_id +and automatically returns an instance of the correct class. + +For Neuronpedia Inference we treat each artifact uniformly – the caller only +needs the instantiated object and the hook names. For classic SAEs the single +hook `cfg.hook_name` is sufficient. Transcoders additionally come with +`cfg.hook_name_out`, the location where the decoder output should be steered. -DTYPE_MAP = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} +The `load` method therefore now returns **three** values: + + (artifact, hook_name_in, hook_name_out) + +`hook_name_out` is `None` for plain SAEs so users can branch on a simple +truthiness check to detect Transcoder-like artifacts. +""" + +from neuronpedia_inference.saes.base import BaseSAE +from sae_lens.toolkit.pretrained_sae_loaders import ( # type: ignore + load_artifact_from_pretrained, +) +from sae_lens.config import DTYPE_MAP # type: ignore class SaeLensSAE(BaseSAE): @staticmethod - def load(release: str, sae_id: str, device: str, dtype: str) -> tuple["SAE", str]: - loaded_sae, _, _ = SAE.from_pretrained( + def load(release: str, sae_id: str, device: str, dtype: str): + """Load an artifact (SAE / Transcoder / SkipTranscoder). + + Args: + release: The named release on the HF hub (e.g. "sae_lens") + sae_id: The specific SAE/Transcoder identifier inside *release*. + device: Torch device string, forwarded to the loader. + dtype: One of {"float16", "float32", "bfloat16"} – we convert + the loaded weights to this dtype after loading. + + Returns: + artifact: The initialised model instance (type depends on + YAML `type` field). + hook_name_in: Where to read encoder activations from. + hook_name_out: Where to *write* decoder deltas to when steering. + `None` for classic SAEs. + """ + + artifact, _cfg_dict, _sparsity = load_artifact_from_pretrained( release=release, sae_id=sae_id, device=device, ) - loaded_sae.to(device, dtype=DTYPE_MAP[dtype]) - loaded_sae.fold_W_dec_norm() - loaded_sae.eval() - return loaded_sae, loaded_sae.cfg.hook_name + + # Ensure correct dtype & eval mode + artifact.to(device, dtype=DTYPE_MAP[dtype]) + + # Some classes (SAE, Transcoder, SkipTranscoder) expose this helper – + # if it does not exist we silently ignore the attribute. + if hasattr(artifact, "fold_W_dec_norm"): + try: + artifact.fold_W_dec_norm() + except Exception: + # Folding is a convenience optimization, not critical – we do + # not want loading to fail if it is not implemented. + pass + + artifact.eval() + + hook_name_in = artifact.cfg.hook_name + hook_name_out = getattr(artifact.cfg, "hook_name_out", None) or None + + return artifact, hook_name_in, hook_name_out \ No newline at end of file diff --git a/apps/inference/neuronpedia_inference/server.py b/apps/inference/neuronpedia_inference/server.py index bd7bb595e..d6f9ef641 100644 --- a/apps/inference/neuronpedia_inference/server.py +++ b/apps/inference/neuronpedia_inference/server.py @@ -110,11 +110,13 @@ async def initialize( def load_model_and_sae(): # Validate inputs df = get_saelens_neuronpedia_directory_df() + models = df["model"].unique() sae_sets = df["neuronpedia_set"].unique() if args.model_id not in models: logger.error( - f"Error: Invalid model_id '{args.model_id}'. Use --list_models to see available options." + f"Error: Invalid model_id '{args.model_id}'. " + "Use --list_models to see available options." ) exit(1) # iterate through sae_sets and split them by spaces @@ -126,7 +128,8 @@ def load_model_and_sae(): invalid_sae_sets = set(args_sae_sets) - set(sae_sets) if invalid_sae_sets: logger.error( - f"Error: Invalid SAE set(s): {', '.join(invalid_sae_sets)}. Use --list_models to see available options." + f"Error: Invalid SAE set(s): {', '.join(invalid_sae_sets)}. " + "Use --list_models to see available options." ) exit(1) @@ -203,7 +206,8 @@ def load_model_and_sae(): config.set_steer_special_token_ids(special_token_ids) # type: ignore logger.info( - f"Loaded {config.CUSTOM_HF_MODEL_ID if config.CUSTOM_HF_MODEL_ID else config.OVERRIDE_MODEL_ID} on {args.device}" + f"Loaded {config.CUSTOM_HF_MODEL_ID or config.OVERRIDE_MODEL_ID} " + f"on {args.device}" ) checkCudaError()