Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 55 additions & 9 deletions apps/inference/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import json
import os
from dataclasses import dataclass

import pytest
import torch
Expand All @@ -14,12 +15,53 @@
from neuronpedia_inference.server import app, initialize
from neuronpedia_inference.shared import Model

BOS_TOKEN_STR = "<|endoftext|>"

@dataclass
class ModelTestConfig:
"""Configuration for a model under test."""

model_id: str
sae_source_set: str
sae_selected_sources: list[str]
bos_token_str: str
# Feature index known to exist in the selected SAE
steer_feature_index: int
# Model embedding dimension (residual stream width)
dim_model: int


# Model configurations for testing
MODEL_CONFIGS = {
"gpt2-small": ModelTestConfig(
model_id="gpt2-small",
sae_source_set="res-jb",
sae_selected_sources=["7-res-jb"],
bos_token_str="<|endoftext|>",
steer_feature_index=5,
dim_model=768,
),
"gemma-3-270m": ModelTestConfig(
model_id="google/gemma-3-270m",
sae_source_set="gemmascope-2-res-16k",
sae_selected_sources=["5-gemmascope-2-res-16k"],
bos_token_str="<bos>",
steer_feature_index=5,
dim_model=1152,
),
}

# Select which model to test via environment variable, default to gpt2-small
_model_key = os.environ.get("TEST_MODEL", "gpt2-small")
ACTIVE_MODEL_CONFIG = MODEL_CONFIGS[_model_key]

# Export constants for backward compatibility with existing tests
BOS_TOKEN_STR = ACTIVE_MODEL_CONFIG.bos_token_str
TEST_PROMPT = "Hello, world!"
X_SECRET_KEY = "cat"
MODEL_ID = "gpt2-small"
SAE_SOURCE_SET = "res-jb"
SAE_SELECTED_SOURCES = ["7-res-jb"]
MODEL_ID = ACTIVE_MODEL_CONFIG.model_id
SAE_SOURCE_SET = ACTIVE_MODEL_CONFIG.sae_source_set
SAE_SELECTED_SOURCES = ACTIVE_MODEL_CONFIG.sae_selected_sources
DIM_MODEL = ACTIVE_MODEL_CONFIG.dim_model
ABS_TOLERANCE = 0.1
N_COMPLETION_TOKENS = 10
TEMPERATURE = 0
Expand All @@ -28,7 +70,7 @@
FREQ_PENALTY = 0.0
SEED = 42
STEER_SPECIAL_TOKENS = False
STEER_FEATURE_INDEX = 5
STEER_FEATURE_INDEX = ACTIVE_MODEL_CONFIG.steer_feature_index
INVALID_SAE_SOURCE = "fake-source"


Expand All @@ -40,18 +82,22 @@ def initialize_models():
This fixture will be run once per test session and will be available to all tests
that need an initialized model. It uses the same initialization logic as the
/initialize endpoint.

The model to test can be selected via TEST_MODEL environment variable:
TEST_MODEL=gpt2-small pytest ... (default)
TEST_MODEL=gemma-3-270m pytest ...
"""
# Set environment variables for testing
# Set environment variables for testing using the active model config
os.environ.update(
{
"MODEL_ID": "gpt2-small",
"SAE_SETS": json.dumps(["res-jb"]),
"MODEL_ID": ACTIVE_MODEL_CONFIG.model_id,
"SAE_SETS": json.dumps([ACTIVE_MODEL_CONFIG.sae_source_set]),
"MODEL_DTYPE": "float16",
"SAE_DTYPE": "float32",
"TOKEN_LIMIT": "500",
"DEVICE": "cpu",
"INCLUDE_SAE": json.dumps(
["7-res-jb"]
ACTIVE_MODEL_CONFIG.sae_selected_sources
), # Only load the specific SAE we want
"EXCLUDE_SAE": json.dumps([]),
"MAX_LOADED_SAES": "1",
Expand Down
152 changes: 65 additions & 87 deletions apps/inference/tests/integration/test_activation_all.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import math

from fastapi.testclient import TestClient
from neuronpedia_inference_client.models.activation_all_post200_response import (
ActivationAllPost200Response,
Expand All @@ -8,7 +9,6 @@
)

from tests.conftest import (
ABS_TOLERANCE,
BOS_TOKEN_STR,
MODEL_ID,
SAE_SELECTED_SOURCES,
Expand All @@ -22,15 +22,23 @@

def test_activation_all(client: TestClient):
"""
Test basic functionality of the /activation/all endpoint with a simple request.
Test the /activation/all endpoint returns valid SAE feature activations.

This test verifies:
- API returns 200 and valid response structure
- Correct number of activations returned
- Each activation has valid structure and sensible values
- Tokenization matches expected behavior
- Results are sorted by max activation value (descending)
"""
num_results = 5
request = ActivationAllPostRequest(
prompt=TEST_PROMPT,
model=MODEL_ID,
source_set=SAE_SOURCE_SET,
selected_sources=SAE_SELECTED_SOURCES,
sort_by_token_indexes=[],
num_results=5,
num_results=num_results,
ignore_bos=True,
)

Expand All @@ -42,93 +50,63 @@ def test_activation_all(client: TestClient):

assert response.status_code == 200

# Validate the structure with Pydantic model
# This will check all required fields are present with correct types
# Validate response structure with Pydantic model
data = response.json()
response_model = ActivationAllPost200Response(**data)

# Expected data based on the provided response
expected_activations_data = [
{
"source": "7-res-jb",
"index": 16653,
"values": [0.0, 46.481327056884766, 11.279630661010742, 0.0, 0.0],
"max_value": 46.481327056884766,
"max_value_index": 1,
},
{
"source": "7-res-jb",
"index": 11553,
"values": [
0.0,
0.0,
3.798774480819702,
6.36670446395874,
8.832769393920898,
],
"max_value": 8.832769393920898,
"max_value_index": 4,
},
{
"source": "7-res-jb",
"index": 9810,
"values": [
0.0,
8.095728874206543,
3.749096632003784,
4.03702449798584,
6.3894195556640625,
],
"max_value": 8.095728874206543,
"max_value_index": 1,
},
{
"source": "7-res-jb",
"index": 14806,
"values": [
0.0,
0.7275917530059814,
6.788952827453613,
5.938947677612305,
0.0,
],
"max_value": 6.788952827453613,
"max_value_index": 2,
},
{
"source": "7-res-jb",
"index": 16488,
"values": [
0.0,
3.8083033561706543,
2.710123062133789,
6.348649501800537,
2.1380198001861572,
],
"max_value": 6.348649501800537,
"max_value_index": 3,
},
]

# Verify we have the expected number of activations
assert len(response_model.activations) == len(expected_activations_data)

# Check each activation against expected data
for i, (actual, expected) in enumerate(
zip(response_model.activations, expected_activations_data)
):
assert actual.source == expected["source"], f"Activation {i}: source mismatch"
assert actual.index == expected["index"], f"Activation {i}: index mismatch"
# Verify we got the requested number of activations
assert len(response_model.activations) == num_results

# Check tokenization is correct
expected_tokens = [BOS_TOKEN_STR, "Hello", ",", " world", "!"]
assert (
response_model.tokens == expected_tokens
), f"Tokenization mismatch: expected {expected_tokens}, got {response_model.tokens}"

# Verify each activation has valid structure and values
prev_max_value = float("inf")
for i, activation in enumerate(response_model.activations):
# Source should match the requested SAE
expected_source = SAE_SELECTED_SOURCES[0]
assert (
pytest.approx(actual.values, abs=ABS_TOLERANCE) == expected["values"]
), f"Activation {i}: values mismatch"
activation.source == expected_source
), f"Activation {i}: expected source '{expected_source}', got '{activation.source}'"

# Feature index should be a valid non-negative integer
assert (
pytest.approx(actual.max_value, abs=ABS_TOLERANCE) == expected["max_value"]
), f"Activation {i}: max_value mismatch"
isinstance(activation.index, int) and activation.index >= 0
), f"Activation {i}: invalid feature index {activation.index}"

# Values should match token count and contain no NaN/inf
assert (
actual.max_value_index == expected["max_value_index"]
), f"Activation {i}: max_value_index mismatch"
len(activation.values) == len(expected_tokens)
), f"Activation {i}: expected {len(expected_tokens)} values, got {len(activation.values)}"
for j, val in enumerate(activation.values):
assert math.isfinite(
val
), f"Activation {i}, token {j}: non-finite value {val}"
assert val >= 0, f"Activation {i}, token {j}: negative activation {val}"

# Check expected tokens sequence
expected_tokens = [BOS_TOKEN_STR, "Hello", ",", " world", "!"]
assert response_model.tokens == expected_tokens
# max_value should equal the maximum of values
computed_max = max(activation.values)
assert (
abs(activation.max_value - computed_max) < 1e-5
), f"Activation {i}: max_value {activation.max_value} != max(values) {computed_max}"

# max_value_index should point to the max value
assert (
activation.values[activation.max_value_index] == computed_max
), f"Activation {i}: max_value_index {activation.max_value_index} doesn't point to max"

# Results should be sorted by max_value descending
assert activation.max_value <= prev_max_value, (
f"Activation {i}: results not sorted descending by max_value "
f"({activation.max_value} > {prev_max_value})"
)
prev_max_value = activation.max_value

# Top activation should have a reasonably high value (sanity check that SAE is working)
top_activation = response_model.activations[0]
assert (
top_activation.max_value > 1.0
), f"Top activation value {top_activation.max_value} is suspiciously low"
Loading